Commit 5699ccbf authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 338ae906
import os
import torch
import torch.nn.functional as nn_f
from typing import Tuple, Union
from typing import Dict, Tuple, Union
from operator import itemgetter
from pathlib import Path
from utils import img
from utils import color
from utils import misc
from utils import sphere
from utils.mem_profiler import *
from utils.constants import *
......@@ -27,8 +29,16 @@ class PanoDataset(object):
class Chunk(object):
def __init__(self, id, dataset, *,
indices: torch.Tensor, centers: torch.Tensor):
@property
def n_views(self):
return self.indices.size(0)
@property
def n_pixels_per_view(self):
return self.dataset.n_pixels_per_view
def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
color: int, **kwargs):
"""
[summary]
......@@ -38,10 +48,9 @@ class PanoDataset(object):
"""
self.id = id
self.dataset = dataset
self.indices = indices
self.centers = centers
self.n_views = self.indices.size(0)
self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
self.indices = chunk_data['indices']
self.centers = chunk_data['centers']
self.color = color
self.colors_cpu = None
self.colors = None
self.loaded = False
......@@ -53,12 +62,12 @@ class PanoDataset(object):
def load(self):
if self.dataset.image_path is not None and self.colors_cpu is None:
images = color.cvt(
img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.dataset.c)
if self.dataset.res != list(images.shape[-2:]):
images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.color)
if self.dataset.res != tuple(images.shape[-2:]):
images = nn_f.interpolate(images, self.dataset.res)
self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
self.colors_cpu = images.permute(0, 2, 3, 1) \
[:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1)
if self.colors_cpu is not None:
self.colors = self.colors_cpu.to(self.dataset.device)
self.loaded = True
......@@ -74,15 +83,27 @@ class PanoDataset(object):
self.load()
view_idx = idx // self.n_pixels_per_view
pix_idx = idx % self.n_pixels_per_view
global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
extra_data = {}
if self.colors is not None:
extra_data['colors'] = self.colors[idx]
extra_data['color'] = self.colors[idx]
rays_o = self.centers[view_idx]
rays_d = self.dataset.pano_rays[pix_idx]
return idx, rays_o, rays_d, extra_data
rays_d = self.dataset.rays[pix_idx]
return global_idx, rays_o, rays_d, extra_data
@property
def n_views(self):
return self.centers.size(0)
@property
def n_pixels_per_view(self):
return self.pixels.size(0)
@property
def n_pixels(self):
return self.n_views * self.n_pixels_per_view
def __init__(self, desc: dict, *,
c: int = color.RGB,
def __init__(self, desc: dict, root: Path, name: str, *,
load_images: bool = True,
res: Tuple[int, int] = None,
views_to_load: Union[range, torch.Tensor] = None,
......@@ -104,7 +125,8 @@ class PanoDataset(object):
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
self.c = c
self.root = root
self.name = name
self.device = device
self._load_desc(desc, res, views_to_load, load_images)
......@@ -119,26 +141,26 @@ class PanoDataset(object):
views_to_load: Union[range, torch.Tensor],
load_images: bool):
if load_images and desc.get('view_file_pattern'):
self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern'])
file_pattern = desc['view_file_pattern']
if "/" not in file_pattern:
file_pattern = f"{self.name}/{file_pattern}"
self.image_path = str(self.root / file_pattern)
else:
self.image_path = None
self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
self.res = res if res else itemgetter("y", "x")(desc['view_res'])
self.depth_range = itemgetter("min", "max")(desc['depth_range']) \
if 'depth_range' in desc else None
self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
self.bbox = None
self.samples = desc.get('samples')
self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3)
self.indices = torch.tensor(
desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))],
device=self.device)
if views_to_load is not None:
self.centers = self.centers[views_to_load]
self.indices = self.indices[views_to_load]
self.n_views = self.centers.size(0)
self.n_pixels = self.n_views * self.res[0] * self.res[1]
self.pano_rays = self._get_pano_rays() # [H*W, 3]
self.pixels, self.rays = self._get_pano_rays()
if desc.get('gl_coord'):
print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
......@@ -148,12 +170,16 @@ class PanoDataset(object):
"""
Get unprojected rays of pixels on a panorama
:return `Tensor(H*W, 3)`: rays' directions with one unit length
:return `Tensor(N, 2)`: rays' pixel coordinates in pano image
:return `Tensor(N, 3)`: rays' directions with one unit length
"""
spher_coords = torch.cat([
torch.ones(*self.res, 1),
((misc.meshgrid(*self.res, normalize=True)) *
torch.tensor([-2.0, 1.0]) + torch.tensor([1.5, 0.0])) * PI
], dim=-1).to(device=self.device)
coords = sphere.spherical2cartesian(spher_coords)
return coords.flatten(0, 1) # [H*W, 3]
phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * PI # (H)
length = (phi.sin() * self.res[1] * 0.5).ceil() * 2
cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res) # (H, W)
mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2,
cols < (self.res[1] + length[:, None]) / 2) # (H, W)
pixs = mask.nonzero() # (N, 2)
pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * PI
pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * PI
spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1)
return pixs, sphere.spherical2cartesian(spher_coords) # (N, 3)
import os
import torch
import torch.nn.functional as nn_f
from typing import Tuple, Union
from typing import Dict, Tuple, Union
from operator import itemgetter
from pathlib import Path
from utils import img
from utils import view
from utils import color
from utils import misc
class ViewDataset(object):
......@@ -25,20 +27,21 @@ class ViewDataset(object):
class Chunk(object):
def __init__(self, id, dataset, *,
indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor):
def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
color: int, **kwargs):
"""
[summary]
:param dataset `PanoDataset`: dataset object
:param dataset `ViewDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self.id = id
self.dataset = dataset
self.indices = indices
self.centers = centers
self.rots = rots
self.indices = chunk_data['indices']
self.centers = chunk_data['centers']
self.rots = chunk_data['rots']
self.color = color
self.n_views = self.indices.size(0)
self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
self.colors = self.depths = self.bins = None
......@@ -50,10 +53,11 @@ class ViewDataset(object):
self.loaded = False
def load(self):
#print("chunk load")
try:
if self.dataset.image_path and self.colors_cpu is None:
images = color.cvt(
img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.dataset.c)
images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.color)
if self.dataset.res != list(images.shape[-2:]):
images = nn_f.interpolate(images, self.dataset.res)
self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
......@@ -79,6 +83,9 @@ class ViewDataset(object):
torch.cuda.current_stream(self.dataset.device).synchronize()
self.loaded = True
except Exception as ex:
print(ex)
exit(-1)
def __len__(self):
return self.n_views * self.n_pixels_per_view
......@@ -88,21 +95,24 @@ class ViewDataset(object):
self.load()
view_idx = idx // self.n_pixels_per_view
pix_idx = idx % self.n_pixels_per_view
global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
rays_o = self.centers[view_idx]
rays_d = self.dataset.cam_rays[pix_idx] # (N, 3)
rays_d = self.dataset.cam_rays[pix_idx][:, None] # (N, 1, 3)
r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3)
rays_d = torch.matmul(rays_d, r)
extra_data = {}
rays_d = torch.matmul(rays_d, r)[:, 0] # (N, 3)
extra_data = {
'view_idx': view_idx,
'pix_idx': pix_idx
} # TBR
if self.colors is not None:
extra_data['colors'] = self.colors[idx]
extra_data['color'] = self.colors[idx]
if self.depths is not None:
extra_data['depths'] = self.depths[idx]
extra_data['depth'] = self.depths[idx]
if self.bins is not None:
extra_data['bins'] = self.bins[idx]
return idx, rays_o, rays_d, extra_data
extra_data['bin'] = self.bins[idx]
return global_idx, rays_o, rays_d, extra_data
def __init__(self, desc: dict, *,
c: int = color.RGB,
def __init__(self, desc: dict, root: Path, name: str, *,
load_images: bool = True,
load_depths: bool = False,
load_bins: bool = False,
......@@ -127,7 +137,8 @@ class ViewDataset(object):
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
self.c = c
self.root = root
self.name = name
self.device = device
self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins)
......@@ -150,22 +161,32 @@ class ViewDataset(object):
load_depths: bool,
load_bins: bool):
if load_images and desc.get('view_file_pattern'):
self.image_path = os.path.join(self.data_dir, desc['view_file_pattern'])
file_pattern = desc['view_file_pattern']
if "/" not in file_pattern:
file_pattern = f"{self.name}/{file_pattern}"
self.image_path = str(self.root / file_pattern)
else:
self.image_path = None
if load_depths and desc.get('depth_file_pattern'):
self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern'])
file_pattern = desc['depth_file_pattern']
if "/" not in file_pattern:
file_pattern = f"{self.name}/{file_pattern}"
self.depth_path = str(self.root / file_pattern)
else:
self.depth_path = None
if load_bins and desc.get('bins_file_pattern'):
self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern'])
file_pattern = desc['bins_file_pattern']
if "/" not in file_pattern:
file_pattern = f"{self.name}/{file_pattern}"
self.bins_path = str(self.root / file_pattern)
else:
self.bins_path = None
self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
self.res = res or itemgetter("y", "x")(desc['view_res'])
self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device)
self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
self.depth_range = itemgetter("min", "max")(desc['depth_range']) \
if 'depth_range' in desc else None
self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
self.range = itemgetter("min", "max")(desc['range']) if 'range' in desc else None
self.bbox = desc.get('bbox')
self.samples = desc.get('samples')
self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3)
self.rots = torch.tensor(
......@@ -175,8 +196,7 @@ class ViewDataset(object):
]
if len(desc['view_rots'][0]) == 2 else desc['view_rots'],
device=self.device).view(-1, 3, 3) # (N, 3, 3)
self.indices = torch.tensor(
desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))],
device=self.device)
if views_to_load is not None:
......
import os
import sys
import argparse
import torch
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str,
help='The model file to load for testing')
parser.add_argument('-r', '--output-rays', type=int, default=100,
help='How many rays to output')
parser.add_argument('-p', '--prompt', action='store_true',
help='Interactive prompt mode')
parser.add_argument('dataset', type=str,
help='Dataset description file')
args = parser.parse_args()
import model as mdl
from utils import misc
from utils import color
from utils import interact
from utils import device
from data.dataset_factory import *
from data.loader import DataLoader
from modules import Samples, Voxels
from model.nsvf import NSVF
model: NSVF
samples: Samples
DATA_LOADER_CHUNK_SIZE = 1e8
data_desc_path = args.dataset if args.dataset.endswith('.json') \
else os.path.join(args.dataset, 'train.json')
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'
def get_model_files(datadir):
model_files = []
for root, _, files in os.walk(datadir):
model_files += [
os.path.join(root, file).replace(datadir, '')
for file in files if file.endswith('.tar') or file.endswith('.pth')
]
return model_files
if args.prompt: # Prompt test model, output resolution, output mode
model_files = get_model_files(data_dir)
args.model = interact.input_enum('Specify test model:', model_files,
err_msg='No such model file')
args.output_rays = interact.input_ex('Specify number of rays to output:',
interact.input_to_int(), default=10)
model_path = os.path.join(data_dir, args.model)
model_name = os.path.splitext(os.path.basename(model_path))[0]
model, iters = mdl.load(model_path, {"perturb_sample": False})
model.to(device.default()).eval()
model_class = model.__class__.__name__
model_args = model.args
print(f"model: {model_name} ({model_class})")
print("args:", json.dumps(model.args0))
dataset = DatasetFactory.load(data_desc_path)
print("Dataset loaded: " + data_desc_path)
run_dir = os.path.dirname(model_path) + '/'
output_dir = f"{run_dir}output_{int(model_name.split('_')[-1])}"
if __name__ == "__main__":
with torch.no_grad():
# 1. Initialize data loader
data_loader = DataLoader(dataset, args.output_rays, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=True, enable_preload=True,
color=color.from_str(model.args['color']))
sys.stdout.write("Export samples...\r")
for _, rays_o, rays_d, extra in data_loader:
samples, rays_mask = model.sampler(rays_o, rays_d, model.space)
invalid_rays_o = rays_o[torch.logical_not(rays_mask)]
invalid_rays_d = rays_d[torch.logical_not(rays_mask)]
rays_o = rays_o[rays_mask]
rays_d = rays_d[rays_mask]
break
print("Export samples...Done")
os.makedirs(output_dir, exist_ok=True)
export_data = {}
if model.space.bbox is not None:
export_data['bbox'] = model.space.bbox.tolist()
if isinstance(model.space, Voxels):
export_data['voxel_size'] = model.space.voxel_size.tolist()
export_data['voxels'] = model.space.voxels.tolist()
if False:
voxel_access_counts = torch.zeros(model.space.n_voxels, dtype=torch.long,
device=device.default())
iters_in_epoch = 0
data_loader.batch_size = 2 ** 20
for _, rays_o1, rays_d1, _ in data_loader:
model(rays_o1, rays_d1,
raymarching_tolerance=0.5,
raymarching_chunk_size=0,
voxel_access_counts=voxel_access_counts)
iters_in_epoch += 1
percent = iters_in_epoch / len(data_loader) * 100
sys.stdout.write(f'Export voxel access counts...{percent:.1f}% \r')
export_data['voxel_access_counts'] = voxel_access_counts.tolist()
print("Export voxel access counts...Done ")
export_data.update({
'rays_o': rays_o.tolist(),
'rays_d': rays_d.tolist(),
'invalid_rays_o': invalid_rays_o.tolist(),
'invalid_rays_d': invalid_rays_d.tolist(),
'samples': {
'depths': samples.depths.tolist(),
'dists': samples.dists.tolist(),
'voxel_indices': samples.voxel_indices.tolist()
}
})
with open(f'{output_dir}/debug_voxel_sampler_export3d.json', 'w') as fp:
json.dump(export_data, fp)
print("Write JSON file...Done")
args.output_rays
print(f"Rays: total {args.output_rays}, valid {rays_o.size(0)}")
print(f"Samples: average {samples.voxel_indices.ne(-1).sum(-1).float().mean().item()} per ray")
from math import ceil
cdf = [2.2, 3.5, 3.6, 3.7, 4.0]
bins = []
part = 1
offset = 0
for i in range(len(cdf)):
if cdf[i] >= part:
bins.append(i + 1 - offset)
offset = i + 1
part = int(cdf[i]) + 1
print(bins)
\ No newline at end of file
from torch.nn import L1Loss, MSELoss
from torch.nn.functional import l1_loss, mse_loss
from .ssim import SSIM
from .perc_loss import VGGPerceptualLoss
from .cauchy import cauchy_loss, CauchyLoss
\ No newline at end of file
import torch
def cauchy_loss(input: torch.Tensor, target: torch.Tensor = None, *, s = 1.0):
x = input - target if target is not None else input
return (s * x * x * 0.5 + 1).log().mean()
class CauchyLoss(torch.nn.Module):
def __init__(self, s = 1.0):
super().__init__()
self.s = s
def forward(self, input: torch.Tensor, target: torch.Tensor = None):
return cauchy_loss(input, target, s=self.s)
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
......
import importlib
import os
import torch
from typing import Tuple, Union
from . import base
# Automatically import any python files this directory
package_dir = os.path.dirname(__file__)
package = os.path.basename(package_dir)
for file in os.listdir(package_dir):
path = os.path.join(package_dir, file)
if file.startswith('_') or file.startswith('.'):
continue
if file.endswith('.py') or os.path.isdir(path):
model_name = file[:-3] if file.endswith('.py') else file
importlib.import_module(f'{package}.{model_name}')
def get_class(model_class_name: str) -> type:
return base.model_classes[model_class_name]
def create(model_class_name: str, args0: dict, **extra_args) -> base.BaseModel:
model_class = get_class(model_class_name)
return model_class(args0, extra_args)
def load(path: Union[str, os.PathLike], args0: dict = {}, **extra_args) -> Tuple[base.BaseModel, dict]:
states: dict = torch.load(path)
states['args'].update(args0)
model = create(states['model'], states['args'], **extra_args)
model.load_state_dict(states['states'])
return model, states
def save(path: Union[str, os.PathLike], model: base.BaseModel, **extra_states):
#print(f'Save model to {path}...')
dict = {
'model': model.__class__.__name__,
'args': model.args0,
'states': model.state_dict(),
**extra_states
}
torch.save(dict, path)
import torch.nn as nn
from utils import color
model_classes = {}
class BaseModelMeta(type):
def __new__(cls, name, bases, attrs):
new_cls = type.__new__(cls, name, bases, attrs)
if name != 'BaseModel':
model_classes[name] = new_cls
return new_cls
class BaseModel(nn.Module, metaclass=BaseModelMeta):
trainer = "Train"
@property
def args(self):
return {**self.args0, **self.args1}
def __init__(self, args0: dict, args1: dict = {}):
super().__init__()
self.args0 = args0
self.args1 = args1
self._chns = {
"color": color.chns(color.from_str(self.args['color']))
}
def chns(self, name: str):
return self._chns.get(name, 1)
\ No newline at end of file
import torch
import model
from .base import *
from modules import *
from utils.mem_profiler import MemProfiler
from utils.perf import perf
from utils.misc import masked_scatter
class NeRF(BaseModel):
trainer = "TrainWithSpace"
SamplerClass = Sampler
RendererClass = VolumnRenderer
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a NeRF model
:param args0 `dict`: basic arguments
:param args1 `dict`: extra arguments, defaults to {}
"""
if "sample_step_ratio" in args0:
args1["sample_step"] = args0["voxel_size"] * args0["sample_step_ratio"]
super().__init__(args0, args1)
# Initialize components
self._init_space()
self._init_encoders()
self._init_core()
self.sampler = self.SamplerClass(**self.args)
self.rendering = self.RendererClass(**self.args)
def _init_encoders(self):
self.pot_encoder = InputEncoder.Get(self.args['n_pot_encode'],
self.args.get('n_featdim') or 3)
if self.args.get('n_dir_encode'):
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(self.args['n_dir_encode'], self.dir_chns)
else:
self.dir_chns = 0
self.dir_encoder = None
def _init_space(self):
if 'space' not in self.args:
self.space = Space(**self.args)
elif self.args['space'] == 'octree':
self.space = Octree(**self.args)
elif self.args['space'] == 'voxels':
self.space = Voxels(**self.args)
else:
self.space = model.load(self.args['space'])[0].space
if self.args.get('n_featdim'):
self.space.create_embedding(self.args['n_featdim'])
def _new_core_unit(self):
return NerfCore(coord_chns=self.pot_encoder.out_dim,
density_chns=self.chns('density'),
color_chns=self.chns('color'),
core_nf=self.args['fc_params']['nf'],
core_layers=self.args['fc_params']['n_layers'],
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=self.args['fc_params']['nf'] // 2,
act=self.args['fc_params']['activation'],
skips=self.args['fc_params']['skips'])
def _create_core(self, n_nets=1):
return self._new_core_unit() if n_nets == 1 else nn.ModuleList([
self._new_core_unit() for _ in range(n_nets)
])
def _init_core(self):
if not self.args.get("net_bounds"):
self.core = self._create_core()
else:
self.register_buffer("net_bounds", torch.tensor(self.args["net_bounds"]), False)
self.cores = self._create_core(self.net_bounds.size(0))
def render(self, samples: Samples, *outputs: str, **kwargs) -> Dict[str, torch.Tensor]:
"""
Render colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out)
:param samples `Samples(N)`: samples
:param outputs `str...`: which types of inferred data should be returned
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
x = self.encode_x(samples)
d = self.encode_d(samples)
return self.infer(x, d, *outputs, pts=samples.pts, **kwargs)
def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, pts: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out) given their encoded positions and directions
:param x `Tensor(N, Ex)`: encoded positions
:param d `Tensor(N, Ed)`: encoded directions
:param outputs `str...`: which types of inferred data should be returned
:param pts `Tensor(N, 3)`: raw sample positions
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
if getattr(self, "core", None):
return self.core(x, d, outputs)
ret = {}
for i, core in enumerate(self.cores):
selector = (pts >= self.net_bounds[i, 0] and pts < self.net_bounds[i, 1]).all(-1)
partial_ret = core(x[selector], d[selector], outputs)
for key, value in partial_ret.items():
if value is None:
ret[key] = None
continue
if key not in ret:
ret[key] = torch.zeros(*x.shape[:-1], value.shape[-1], device=x.device)
ret[key] = masked_scatter(selector, value, ret[key])
return ret
def embed(self, samples: Samples) -> torch.Tensor:
return self.space.extract_embedding(samples.pts, samples.voxel_indices)
def encode_x(self, samples: Samples) -> torch.Tensor:
x = self.embed(samples) if self.args.get('n_featdim') else samples.pts
return self.pot_encoder(x)
def encode_d(self, samples: Samples) -> torch.Tensor:
return self.dir_encoder(samples.dirs) if self.dir_encoder is not None else None
@torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
densities = self.render(Samples(sampled_points, None, None, None, sampled_voxel_indices),
'density')
return 1 - (-densities).exp()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
return self.space.pruning(self.get_scores, threshold, train_stats)
@torch.no_grad()
def splitting(self):
ret = self.space.splitting()
if 'n_samples' in self.args0:
self.args0['n_samples'] *= 2
if 'voxel_size' in self.args0:
self.args0['voxel_size'] /= 2
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] \
* self.args0["sample_step_ratio"]
if 'sample_step' in self.args0:
self.args0['sample_step'] /= 2
self.sampler = self.SamplerClass(**self.args)
return ret
@torch.no_grad()
def double_samples(self):
pass
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
"""
Perform rendering for given rays.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param extra_outputs `list[str]`: extra items should be contained in the rendering result,
defaults to []
:return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
"""
args = {**self.args, **kwargs}
with MemProfiler(f"{self.__class__}.forward: before sampling"):
samples, rays_mask = self.sampler(rays_o, rays_d, self.space, **args)
MemProfiler.print_memory_stats(f"{self.__class__}.forward: after sampling")
with MemProfiler(f"{self.__class__}.forward: rendering"):
if samples is None:
return None
return {
**self.rendering(self, samples, extra_outputs, **args),
'samples': samples,
'rays_mask': rays_mask
}
import torch
from modules import *
from .nerf import *
class NeRFAdvance(NeRF):
RendererClass = DensityFirstVolumnRenderer
def __init__(self, args0: dict, args1: dict = {}):
super().__init__(args0, args1)
def _new_core_unit(self):
return NerfAdvCore(
x_chns=self.pot_encoder.out_dim,
d_chns=self.dir_encoder.out_dim,
density_chns=self.chns('density'),
color_chns=self.chns('color'),
density_net_params=self.args["density_net"],
color_net_params=self.args["color_net"],
specular_net_params=self.args.get("specular_net"),
appearance=self.args.get("appearance", "decomposite"),
density_color_connection=self.args.get("density_color_connection", False)
)
def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, extras={}, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out) given their encoded positions and directions
:param x `Tensor(N, Ex)`: encoded positions
:param d `Tensor(N, Ed)`: encoded directions
:param outputs `str...`: which types of inferred data should be returned
:param extras `dict`: extra data needed by cores
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
return self.core(x, d, outputs, **extras)
......@@ -27,7 +27,7 @@ class NerfDepth(nn.Module):
color_chns=self.color_chns,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
activation=fc_params['activation'],
act=fc_params['activation'],
skips=fc_params['skips'])
self.sampler = AdaptiveSampler(**sampler_params, n_bins=n_bins,
include_neighbor_bins=include_neighbor_bins)
......
from .nerf import *
from utils.geometry import *
class NSVF(NeRF):
SamplerClass = VoxelSampler
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a NSVF model
:param args0 `dict`: basic arguments
:param args1 `dict`: extra arguments, defaults to {}
"""
super().__init__(args0, args1)
......@@ -27,7 +27,7 @@ class Oracle(nn.Module):
self.net = nn.Sequential(
FcNet(in_chns=self.pos_encoder.out_dim * self.n_samples,
out_chns=0, nf=fc_params['nf'], n_layers=fc_params['n_layers'],
skips=[], activation=fc_params['activation']),
skips=[], act=fc_params['activation']),
FcLayer(fc_params['nf'], self.n_samples, out_activation)
)
......
import math
from .nerf import *
class SNeRF(NeRF):
SamplerClass = SphericalSampler
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \
if args0.get('depth_range') else [1, 0]
rot_range = [[-180, -90], [180, 90]]
args1['bbox'] = [
[sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
[sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
]
args1['sample_range'] = sample_range
super().__init__(args0, args1)
\ No newline at end of file
import math
from .nerf_advance import *
class SNeRFAdvance(NeRFAdvance):
SamplerClass = SphericalSampler
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \
if args0.get('depth_range') else [1, 0]
rot_range = [[-180, -90], [180, 90]]
args1['bbox'] = [
[sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
[sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
]
args1['sample_range'] = sample_range
if args0.get('multi_nets'):
n = args0['multi_nets']
step = (sample_range[1] - sample_range[0]) / n
args1['net_bounds'] = [[
[sample_range[0] + step * (i + 1), *args1['bbox'][0][1:]],
[sample_range[0] + step * i, *args1['bbox'][1][1:]]
] for i in range(n)]
super().__init__(args0, args1)
\ No newline at end of file
from utils.misc import print_and_log
from .snerf_advance import *
class SNeRFAdvanceX(SNeRFAdvance):
RendererClass = DensityFirstVolumnRenderer
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__(args0, args1)
def _init_core(self):
if "net_samples" not in self.args:
n_nets = self.args.get("multi_nets", 1)
k = self.args["n_samples"] // self.space.steps[0].item()
self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)]
self.cores = self._create_core(len(self.args0["net_samples"]))
def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, chunk_id: int, extras={}, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out) given their encoded positions and directions
:param x `Tensor(N, Ex)`: encoded positions
:param d `Tensor(N, Ed)`: encoded directions
:param outputs `str...`: which types of inferred data should be returned
:param chunk_id `int`: current index of sample chunk in renderer
:param extras `dict`: extra data needed by cores
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
return self.cores[chunk_id](x, d, outputs, **extras)
@torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
ret = super().splitting()
k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [val * k for val in self.space.balance_cut(0, len(self.cores))]
if len(net_samples) != len(self.cores):
print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.')
net_samples = [val * 2 for val in self.args0["net_samples"]]
self.args0['net_samples'] = net_samples
return ret
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
"""
Perform rendering for given rays.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param extra_outputs `list[str]`: extra items should be contained in the rendering result,
defaults to []
:return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
"""
return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs,
raymarching_chunk_size_or_sections=self.args["net_samples"])
......@@ -48,7 +48,7 @@ class SnerfFast(nn.Module):
core_layers=fc_params['n_layers'],
dir_chns=self.dir_chns_per_part,
dir_nf=fc_params['nf'] // 2,
activation=fc_params['activation'])
act=fc_params['activation'])
for _ in range(self.n_parts)
]
for i in range(self.n_parts):
......
from utils.misc import print_and_log
from .snerf import *
class SNeRFX(SNeRF):
trainer = "TrainWithSpace"
SamplerClass = SphericalSampler
RendererClass = VolumnRenderer
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__(args0, args1)
def _init_core(self):
if "net_samples" not in self.args:
n_nets = self.args.get("multi_nets", 1)
k = self.args["n_samples"] // self.space.steps[0].item()
self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)]
self.cores = self._create_core(len(self.args0["net_samples"]))
def render(self, samples: Samples, *outputs: str, chunk_id: int, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out)
:param samples `Samples(N)`: samples
:param outputs `str...`: which types of inferred data should be returned
:param chunk_id `int`: current index of sample chunk in renderer
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
x = self.encode_x(samples)
d = self.encode_d(samples)
return self.cores[chunk_id](x, d, outputs)
@torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
ret = super().splitting()
k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [
val * k for val in self.space.balance_cut(0, len(self.cores))
]
if len(net_samples) != len(self.cores):
print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.')
net_samples = [val * 2 for val in self.args0["net_samples"]]
self.args0['net_samples'] = net_samples
self.sampler = self.SamplerClass(**self.args)
return ret
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
"""
Perform rendering for given rays.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param extra_outputs `list[str]`: extra items should be contained in the rendering result,
defaults to []
:return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
"""
return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs,
raymarching_chunk_size_or_sections=self.args["net_samples"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment