Commit 6294701e authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 2824f796
......@@ -80,11 +80,11 @@ if __name__ == "__main__":
color=color.from_str(model.args['color']))
sys.stdout.write("Export samples...\r")
for _, rays_o, rays_d, extra in data_loader:
samples, rays_mask = model.sampler(rays_o, rays_d, model.space)
invalid_rays_o = rays_o[torch.logical_not(rays_mask)]
invalid_rays_d = rays_d[torch.logical_not(rays_mask)]
rays_o = rays_o[rays_mask]
rays_d = rays_d[rays_mask]
samples, rays_filter = model.sampler(rays_o, rays_d, model.space)
invalid_rays_o = rays_o[torch.logical_not(rays_filter)]
invalid_rays_d = rays_d[torch.logical_not(rays_filter)]
rays_o = rays_o[rays_filter]
rays_d = rays_d[rays_filter]
break
print("Export samples...Done")
......
import torch
from torch import Tensor
from typing import List, Dict, Union, Optional, Any
from modules import *
from utils.type import InputData, NetInput, NetOutput, ReturnData
from utils.perf import perf
from utils.samples import Samples
import importlib
import os
import torch
from typing import Tuple, Union
from . import base
from .utils import *
# Automatically import any python files this directory
# Automatically import model files this directory
package_dir = os.path.dirname(__file__)
package = os.path.basename(package_dir)
for file in os.listdir(package_dir):
path = os.path.join(package_dir, file)
if file.startswith('_') or file.startswith('.'):
if file.startswith('_') or file.startswith('.') or file == "utils.py":
continue
if file.endswith('.py') or os.path.isdir(path):
model_name = file[:-3] if file.endswith('.py') else file
importlib.import_module(f'{package}.{model_name}')
def get_class(model_class_name: str) -> type:
return base.model_classes[model_class_name]
def create(model_class_name: str, args0: dict, **extra_args) -> base.BaseModel:
model_class = get_class(model_class_name)
return model_class(args0, extra_args)
def load(path: Union[str, os.PathLike], args0: dict = {}, **extra_args) -> Tuple[base.BaseModel, dict]:
states: dict = torch.load(path)
states['args'].update(args0)
model = create(states['model'], states['args'], **extra_args)
model.load_state_dict(states['states'])
return model, states
def save(path: Union[str, os.PathLike], model: base.BaseModel, **extra_states):
#print(f'Save model to {path}...')
dict = {
'model': model.__class__.__name__,
'args': model.args0,
'states': model.state_dict(),
**extra_states
}
torch.save(dict, path)
import torch.nn as nn
import json
from typing import Optional
from torch import Tensor
from utils import color
from utils.misc import print_and_log
from utils.samples import Samples
from utils.module import Module
from utils.type import NetInput, NetOutput, InputData, ReturnData
from utils.perf import perf
model_classes = {}
......@@ -14,21 +22,82 @@ class BaseModelMeta(type):
return new_cls
class BaseModel(nn.Module, metaclass=BaseModelMeta):
TrainerClass = "Train"
class BaseModel(Module, metaclass=BaseModelMeta):
@property
def args(self):
return {**self.args0, **self.args1}
def __init__(self, args0: dict, args1: dict = {}):
@property
def color(self) -> int:
return self.args.get("color", color.RGB)
def __init__(self, args0: dict, args1: dict = None):
super().__init__()
self.args0 = args0
self.args1 = args1
self._chns = {
"color": color.chns(color.from_str(self.args['color']))
}
self.args1 = args1 or {}
self._preprocess_args()
self._init_chns()
def chns(self, name: str):
def chns(self, name: str, value: int = None) -> Optional[int]:
if value is not None:
self._chns[name] = value
else:
return self._chns.get(name, 1)
def input(self, samples: Samples, *whats: str) -> NetInput:
all = ["x", "d", "f"]
whats = whats or all
return NetInput(**{
key: self._input(samples, key)
for key in all if key in whats
})
def infer(self, *outputs, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
"""
Infer colors, energies or other values (specified by `outputs`) of samples
(invalid items are filtered out) given their encoded positions and directions
:param outputs `str...`: which types of inferred data should be returned
:param samples `Samples(N)`: samples
:param inputs `NetInput(N)`: (optional) inputs to net
:return `NetOutput`: data inferred by core net
"""
raise NotImplementedError()
@perf
def forward(self, data: InputData, *outputs: str, **extra_args) -> ReturnData:
"""
Perform rendering for given rays.
:param data `InputData`: input data
:param outputs `str...`: items should be contained in the rendering result
:param extra_args `{str:*}`: extra arguments for this forward process
:return `ReturnData`: the rendering result, see corresponding Renderer implementation
"""
ret = {}
samples = self._sample(data, **extra_args) # (N, P)
ret["rays_filter"] = samples.filter_rays()
ret.update(self._render(samples, *outputs, **extra_args))
return ret
def print_config(self):
print_and_log(json.dumps(self.args))
def _preprocess_args(self):
pass
def _init_chns(self, **chns):
self._chns = {}
if "color" in self.args:
self._chns["color"] = color.chns(self.color)
self._chns.update(chns)
def _input(self, samples: Samples, what: str) -> Optional[Tensor]:
raise NotImplementedError()
def _sample(self, data: InputData, **extra_args) -> Samples:
raise NotImplementedError()
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
raise NotImplementedError()
from utils.misc import dump_tensors_to_csv
from .__common__ import *
from .base import BaseModel
from typing import Callable
from .nerf import NeRF
from utils.voxels import trilinear_interp
class CNeRF(BaseModel):
TrainerClass = "TrainMultiScale"
class InterpSpace(object):
def __init__(self, space: Voxels, vidxs: Tensor, feats_fn: Callable[[Any], Tensor]) -> None:
super().__init__()
self.space = space
self.corner_indices, self.corners = space.get_corners(vidxs)
self.feats_on_corners = feats_fn(self.corners)
@perf
def interp(self, samples: Samples) -> Tensor:
with perf("Prepare for coarse interpolation"):
voxels = self.space.voxels[samples.interp_vidxs]
cidxs = self.corner_indices[samples.interp_vidxs] # (N, 8)
feats_on_corners = self.feats_on_corners[cidxs] # (N, 8, X)
# (N, 3) normed-coords in voxel
p = (samples.pts - voxels) / self.space.voxel_size + .5
with perf("Interpolate features"):
return trilinear_interp(p, feats_on_corners)
@property
def stage(self):
return self.args.get("stage", 0)
def __init__(self, args0: dict, args1: dict = None):
super().__init__(args0, args1)
self.sub_models = []
args0_for_submodel = {
key: value for key, value in args0.items()
if key != "sub_models" and key != "interp_on_coarse"
}
for i in range(len(self.args["sub_models"])):
self.args["sub_models"][i] = {
**args0_for_submodel,
**self.args["sub_models"][i]
}
self.sub_models.append(NeRF(self.args["sub_models"][i], args1))
self.sub_models = torch.nn.ModuleList(self.sub_models)
for i in range(self.stage):
print(f"__init__: freeze model {i}")
self.model(i).freeze()
def model(self, level: int) -> NeRF:
return self.sub_models[level]
def trigger_stage(self, stage: int):
print(f"trigger_stage: freeze model {stage - 1}")
self.model(stage - 1).freeze()
self.model(stage).space = self.model(stage - 1).space.clone()
self.args0["stage"] = stage
@perf
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
return self.model(samples.level).infer(*outputs, samples=samples, inputs=inputs, **kwargs)
def print_config(self):
for i, model in enumerate(self.sub_models):
print(f"Model {i} =====>")
model.print_config()
@torch.no_grad()
def split(self):
return self.model(self.stage).split()
def _input(self, samples: Samples, what: str) -> Optional[Tensor]:
if what == "f":
if samples.level == 0:
return None
if samples.interp_space is None:
return self._infer_features(pts=samples.pts, level=samples.level - 1)
return samples.interp_space.interp(samples)
else:
return self.model(samples.level)._input(samples, what)
@perf
def _sample(self, data: InputData, **extra_args) -> Samples:
samples: Samples = self.model(data["level"])._sample(data, **extra_args)
samples.level = data["level"]
# TODO remove below
#dump_tensors_to_csv(f"/home/dengnc/dvs/data/classroom/_nets/ms_train_t0.8/_cnerf_ioc/{'train' if self.training else 'test'}.csv",
# samples.voxel_indices, data["rays_d"])
return samples
@perf
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
self._prepare_interp(samples, on_coarse=self.args.get("interp_on_coarse"))
return self.model(samples.level).renderer(self, samples, *outputs, **{
**self.model(samples.level).args, **extra_args})
def _infer_features(self, samples: Samples = None, **sample_data) -> NetOutput:
samples = samples or Samples(**sample_data)
if self.args.get("interp_on_coarse"):
self._prepare_interp(samples, on_coarse=True)
inputs = self.input(samples, "x", "f")
return self.infer("features", samples=samples, inputs=inputs)["features"]
def _prepare_interp(self, samples: Samples, on_coarse: bool):
if samples.level == 0:
return
if on_coarse:
interp_space = self.model(samples.level - 1).space
samples.interp_vidxs = interp_space.get_voxel_indices(samples.pts)
else:
interp_space = self.model(samples.level).space
samples.interp_vidxs = samples.voxel_indices
samples.interp_space = CNeRF.InterpSpace(interp_space, samples.interp_vidxs,
lambda corners: self._infer_features(
pts=corners, level=samples.level - 1))
def _after_load_state_dict(self) -> None:
a: torch.Tensor = None
return
print(list(self.model(0).named_parameters())[2])
exit()
import torch
from .__common__ import *
from .nerf import NeRF
class MNeRF(NeRF):
"""
Multi-scale NeRF
"""
TrainerClass = "TrainMultiScale"
def freeze(self, level: int):
for core in self.cores:
core.set_frozen(level, True)
def _create_core_unit(self):
nets = []
in_chns = self.x_encoder.out_dim
for core_params in self.args['core_params']:
nets.append(super()._create_core_unit(core_params, x_chns=in_chns))
in_chns = self.x_encoder.out_dim + core_params['nf']
return MultiNerf(nets)
@perf
def _sample(self, data: InputData, **extra_args) -> Samples:
samples = super()._sample(data, **extra_args)
samples.level = data["level"]
return samples
from .__common__ import *
from .mnerf import MNeRF
from utils.misc import merge
class MNeRFAdvance(MNeRF):
"""
Advanced Multi-scale NeRF
"""
TrainerClass = "TrainMultiScale"
def n_samples(self, level: int = -1) -> int:
return self.args["n_samples_list"][level]
def split(self):
self.args0["n_samples_list"] = [val * 2 for val in self.args["n_samples_list"]]
return super().split()
def _sample(self, data: InputData, **extra_args) -> Samples:
return super()._sample(data, **merge(extra_args, n_samples=self.n_samples(data["level"]) + 1))
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
L = samples.level
steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][0]
curr_samples = samples[:, ::steps]
curr_samples.level = 0
curr_samples.features = None
for i in range(L):
render_out = super()._render(curr_samples, 'features', **extra_args)
next_steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][i + 1]
next_samples = samples[:, ::next_steps]
features = curr_samples.interpolate(next_samples, render_out['features'])
curr_samples = next_samples
curr_samples.level = i + 1
curr_samples.features = features
return super()._render(curr_samples, *outputs, **extra_args)
import torch
from .__common__ import *
from .base import BaseModel
from operator import itemgetter
import model
from .base import *
from modules import *
from utils.mem_profiler import MemProfiler
from utils.perf import perf
from utils.misc import masked_scatter
from utils import math
from utils.misc import masked_scatter, merge
class NeRF(BaseModel):
TrainerClass = "TrainWithSpace"
SamplerClass = Sampler
RendererClass = VolumnRenderer
SamplerClass = None
RendererClass = None
space: Union[Space, Voxels, Octree]
def __init__(self, args0: dict, args1: dict = {}):
@property
def multi_nets(self) -> int:
return self.args.get("multi_nets", 1)
def __init__(self, args0: dict, args1: dict = None):
"""
Initialize a NeRF model
:param args0 `dict`: basic arguments
:param args1 `dict`: extra arguments, defaults to {}
"""
if "sample_step_ratio" in args0:
args1["sample_step"] = args0["voxel_size"] * args0["sample_step_ratio"]
super().__init__(args0, args1)
# Initialize components
self._init_space()
self._init_encoders()
self._init_core()
self.sampler = self.SamplerClass(**self.args)
self.rendering = self.RendererClass(**self.args)
self._init_sampler()
self._init_renderer()
def _init_encoders(self):
self.pot_encoder = InputEncoder.Get(self.args['n_pot_encode'],
self.args.get('n_featdim') or 3)
if self.args.get('n_dir_encode'):
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(self.args['n_dir_encode'], self.dir_chns)
@perf
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
if len(self.cores) == 1:
return self.cores[0](inputs, *outputs, samples=samples, **kwargs)
return self._multi_infer(inputs, *outputs, samples=samples, **kwargs)
@torch.no_grad()
def split(self):
ret = self.space.split()
if 'n_samples' in self.args0:
self.args0['n_samples'] *= 2
if 'voxel_size' in self.args0:
self.args0['voxel_size'] /= 2
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] \
* self.args0["sample_step_ratio"]
if 'sample_step' in self.args0:
self.args0['sample_step'] /= 2
return ret
def _preprocess_args(self):
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] * self.args0["sample_step_ratio"]
if self.args0.get("spherical"):
sample_range = [1 / self.args0['depth_range'][0], 1 / self.args0['depth_range'][1]] \
if self.args0.get('depth_range') else [1, 0]
rot_range = [[-180, -90], [180, 90]]
self.args1['bbox'] = [
[sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
[sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
]
self.args1['sample_range'] = sample_range
if not self.args.get("net_bounds"):
self.register_temp("net_bounds", None)
self.args1["multi_nets"] = 1
else:
self.dir_chns = 0
self.dir_encoder = None
self.register_temp("net_bounds", torch.tensor(self.args["net_bounds"]))
self.args1["multi_nets"] = self.net_bounds.size(0)
def _init_chns(self, **chns):
super()._init_chns(**{
"x": self.args.get('n_featdim') or 3,
"d": 3 if self.args.get('encode_d') else 0,
**chns
})
def _init_space(self):
if 'space' not in self.args:
self.space = Space(**self.args)
elif self.args['space'] == 'octree':
self.space = Octree(**self.args)
elif self.args['space'] == 'voxels':
self.space = Voxels(**self.args)
else:
self.space = model.load(self.args['space'])[0].space
self.space = Space.create(self.args)
if self.args.get('n_featdim'):
self.space.create_embedding(self.args['n_featdim'])
def _new_core_unit(self):
return NerfCore(coord_chns=self.pot_encoder.out_dim,
density_chns=self.chns('density'),
color_chns=self.chns('color'),
core_nf=self.args['fc_params']['nf'],
core_layers=self.args['fc_params']['n_layers'],
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=self.args['fc_params']['nf'] // 2,
act=self.args['fc_params']['activation'],
skips=self.args['fc_params']['skips'])
def _create_core(self, n_nets=1):
return self._new_core_unit() if n_nets == 1 else nn.ModuleList([
self._new_core_unit() for _ in range(n_nets)
])
def _init_encoders(self):
self.x_encoder = InputEncoder(self.chns("x"), self.args['encode_x'], cat_input=True)
self.d_encoder = InputEncoder(self.chns("d"), self.args['encode_d'])\
if self.chns("d") > 0 else None
def _init_core(self):
if not self.args.get("net_bounds"):
self.core = self._create_core()
self.cores = self.create_multiple(self._create_core_unit, self.args.get("multi_nets", 1))
def _init_sampler(self):
if self.SamplerClass is None:
SamplerClass = Sampler
else:
self.register_buffer("net_bounds", torch.tensor(self.args["net_bounds"]), False)
self.cores = self._create_core(self.net_bounds.size(0))
SamplerClass = self.SamplerClass
self.sampler = SamplerClass(**self.args)
def render(self, samples: Samples, *outputs: str, **kwargs) -> Dict[str, torch.Tensor]:
"""
Render colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out)
def _init_renderer(self):
if self.RendererClass is None:
if self.args.get("core") == "nerfadv":
RendererClass = DensityFirstVolumnRenderer
else:
RendererClass = VolumnRenderer
else:
RendererClass = self.RendererClass
self.renderer = RendererClass(**self.args)
def _create_core_unit(self, core_params: dict = None, **args):
core_params = core_params or self.args["core_params"]
if self.args.get("core") == "nerfadv":
return NerfAdvCore(**{
"x_chns": self.x_encoder.out_dim,
"d_chns": self.d_encoder.out_dim,
"density_chns": self.chns('density'),
"color_chns": self.chns('color'),
**core_params,
**args
})
else:
return NerfCore(**{
"x_chns": self.x_encoder.out_dim,
"density_chns": self.chns('density'),
"color_chns": self.chns('color'),
"d_chns": self.d_encoder.out_dim if self.d_encoder else 0,
**core_params,
**args
})
:param samples `Samples(N)`: samples
:param outputs `str...`: which types of inferred data should be returned
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
x = self.encode_x(samples)
d = self.encode_d(samples)
return self.infer(x, d, *outputs, pts=samples.pts, **kwargs)
@perf
def _sample(self, data: InputData, **extra_args) -> Samples:
return self.sampler(*itemgetter("rays_o", "rays_d")(data), self.space,
**merge(self.args, extra_args))
def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, pts: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out) given their encoded positions and directions
:param x `Tensor(N, Ex)`: encoded positions
:param d `Tensor(N, Ed)`: encoded directions
:param outputs `str...`: which types of inferred data should be returned
:param pts `Tensor(N, 3)`: raw sample positions
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
if getattr(self, "core", None):
return self.core(x, d, outputs)
ret = {}
@perf
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
if len(samples.size) == 1:
return self.infer(*outputs, samples=samples)
return self.renderer(self, samples, *outputs, **merge(self.args, extra_args))
def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]:
if what == "x":
if self.args.get('n_featdim'):
return self._encode("emb", self.space.extract_embedding(
samples.pts, samples.voxel_indices))
else:
return self._encode("x", samples.pts)
elif what == "d":
if self.d_encoder and samples.dirs is not None:
return self._encode("d", samples.dirs)
else:
return None
elif what == "f":
return None
else:
ValueError(f"Don't know how to process input \"{what}\"")
def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor:
if what == "x":
if self.args.get("spherical"):
sr = self.args['sample_range']
# scale val.r: [sr[0], sr[1]] -> [-PI/2, PI/2]
val = val.clone()
val[..., 0] = ((val[..., 0] - sr[0]) / (sr[1] - sr[0]) - .5) * math.pi
return self.x_encoder(val)
else:
return self.x_encoder(val * math.pi)
elif what == "emb":
return self.x_encoder(val * math.pi)
elif what == "d":
return self.d_encoder(val, angular=True)
else:
ValueError(f"Don't know how to encode \"{what}\"")
@perf
def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
ret: NetOutput = {}
for i, core in enumerate(self.cores):
selector = (pts >= self.net_bounds[i, 0] and pts < self.net_bounds[i, 1]).all(-1)
partial_ret = core(x[selector], d[selector], outputs)
selector = (samples.pts >= self.net_bounds[i, 0]
and samples.pts < self.net_bounds[i, 1]).all(-1)
partial_ret: NetOutput = core(inputs[selector], *outputs, samples=samples[selector],
**kwargs)
for key, value in partial_ret.items():
if value is None:
ret[key] = None
continue
if key not in ret:
ret[key] = torch.zeros(*x.shape[:-1], value.shape[-1], device=x.device)
ret[key] = value.new_zeros(*inputs.shape, value.shape[-1])
ret[key] = masked_scatter(selector, value, ret[key])
return ret
def embed(self, samples: Samples) -> torch.Tensor:
return self.space.extract_embedding(samples.pts, samples.voxel_indices)
def encode_x(self, samples: Samples) -> torch.Tensor:
x = self.embed(samples) if self.args.get('n_featdim') else samples.pts
return self.pot_encoder(x)
def encode_d(self, samples: Samples) -> torch.Tensor:
return self.dir_encoder(samples.dirs) if self.dir_encoder else None
class NSVF(NeRF):
@torch.no_grad()
def split(self):
ret = self.space.split()
if 'n_samples' in self.args0:
self.args0['n_samples'] *= 2
if 'voxel_size' in self.args0:
self.args0['voxel_size'] /= 2
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] \
* self.args0["sample_step_ratio"]
if 'sample_step' in self.args0:
self.args0['sample_step'] /= 2
self.sampler = self.SamplerClass(**self.args)
if self.args.get('n_featdim') and hasattr(self, "trainer"):
self.trainer.reset_optimizer()
return ret
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
"""
Perform rendering for given rays.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param extra_outputs `list[str]`: extra items should be contained in the rendering result,
defaults to []
:return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
"""
args = {**self.args, **kwargs}
with MemProfiler(f"{self.__class__}.forward: before sampling"):
samples, rays_mask = self.sampler(rays_o, rays_d, self.space, **args)
MemProfiler.print_memory_stats(f"{self.__class__}.forward: after sampling")
with MemProfiler(f"{self.__class__}.forward: rendering"):
if samples is None:
return None
return {
**self.rendering(self, samples, extra_outputs, **args),
'samples': samples,
'rays_mask': rays_mask
}
SamplerClass = VoxelSampler
import torch
from modules import *
from .nerf import *
class NeRFAdvance(NeRF):
RendererClass = DensityFirstVolumnRenderer
def __init__(self, args0: dict, args1: dict = {}):
super().__init__(args0, args1)
def _new_core_unit(self):
return NerfAdvCore(
x_chns=self.pot_encoder.out_dim,
d_chns=self.dir_encoder.out_dim,
density_chns=self.chns('density'),
color_chns=self.chns('color'),
density_net_params=self.args["density_net"],
color_net_params=self.args["color_net"],
specular_net_params=self.args.get("specular_net"),
appearance=self.args.get("appearance", "decomposite"),
density_color_connection=self.args.get("density_color_connection", False)
)
def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, extras={}, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out) given their encoded positions and directions
:param x `Tensor(N, Ex)`: encoded positions
:param d `Tensor(N, Ed)`: encoded directions
:param outputs `str...`: which types of inferred data should be returned
:param extras `dict`: extra data needed by cores
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
return self.core(x, d, outputs, **extras)
from .nerf import *
from utils.geometry import *
class NSVF(NeRF):
SamplerClass = VoxelSampler
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a NSVF model
:param args0 `dict`: basic arguments
:param args1 `dict`: extra arguments, defaults to {}
"""
super().__init__(args0, args1)
import math
from .nerf import *
class SNeRF(NeRF):
SamplerClass = SphericalSampler
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \
if args0.get('depth_range') else [1, 0]
rot_range = [[-180, -90], [180, 90]]
args1['bbox'] = [
[sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
[sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
]
args1['sample_range'] = sample_range
super().__init__(args0, args1)
\ No newline at end of file
import math
from .nerf_advance import *
class SNeRFAdvance(NeRFAdvance):
SamplerClass = SphericalSampler
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \
if args0.get('depth_range') else [1, 0]
rot_range = [[-180, -90], [180, 90]]
args1['bbox'] = [
[sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
[sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
]
args1['sample_range'] = sample_range
if args0.get('multi_nets'):
n = args0['multi_nets']
step = (sample_range[1] - sample_range[0]) / n
args1['net_bounds'] = [[
[sample_range[0] + step * (i + 1), *args1['bbox'][0][1:]],
[sample_range[0] + step * i, *args1['bbox'][1][1:]]
] for i in range(n)]
super().__init__(args0, args1)
\ No newline at end of file
from utils.misc import print_and_log
from .snerf_advance import *
class SNeRFAdvanceX(SNeRFAdvance):
RendererClass = DensityFirstVolumnRenderer
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__(args0, args1)
def _init_core(self):
if "net_samples" not in self.args:
n_nets = self.args.get("multi_nets", 1)
k = self.args["n_samples"] // self.space.steps[0].item()
self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)]
self.cores = self._create_core(len(self.args0["net_samples"]))
def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, chunk_id: int, extras={}, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out) given their encoded positions and directions
:param x `Tensor(N, Ex)`: encoded positions
:param d `Tensor(N, Ed)`: encoded directions
:param outputs `str...`: which types of inferred data should be returned
:param chunk_id `int`: current index of sample chunk in renderer
:param extras `dict`: extra data needed by cores
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
return self.cores[chunk_id](x, d, outputs, **extras)
@torch.no_grad()
def split(self):
ret = super().split()
k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [val * k for val in self.space.balance_cut(0, len(self.cores))]
if len(net_samples) != len(self.cores):
print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.')
net_samples = [val * 2 for val in self.args0["net_samples"]]
self.args0['net_samples'] = net_samples
return ret
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
"""
Perform rendering for given rays.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param extra_outputs `list[str]`: extra items should be contained in the rendering result,
defaults to []
:return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
"""
return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs,
raymarching_chunk_size_or_sections=self.args["net_samples"])
import torch
import torch.nn as nn
from modules import *
from utils import sphere
from utils import color
from .__common__ import *
from .nerf import NeRF
from utils.misc import merge
class SnerfFast(nn.Module):
def __init__(self, fc_params, sampler_params, *,
n_parts: int = 1,
c: int = color.RGB,
pos_encode: int = 0,
dir_encode: int = None,
spherical_dir: bool = False, **kwargs):
"""
Initialize a multi-sphere-layer net
class SnerfFast(NeRF):
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.color = c
self.spherical_dir = spherical_dir
self.n_samples = sampler_params['n_samples']
self.n_parts = n_parts
self.samples_per_part = self.n_samples // self.n_parts
self.coord_chns = 2
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns)
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, chunk_id: int, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
ret = self.cores[chunk_id](inputs, *outputs)
return {
key: value.reshape(*inputs.shape, -1)
for key, value in ret.items()
}
if dir_encode is not None:
self.dir_encoder = InputEncoder.Get(dir_encode, 2 if self.spherical_dir else 3)
self.dir_chns_per_part = self.dir_encoder.out_dim * \
(self.samples_per_part if self.spherical_dir else 1)
else:
self.dir_encoder = None
self.dir_chns_per_part = 0
def _preprocess_args(self):
self.args0["spherical"] = True
super()._preprocess_args()
self.samples_per_part = self.args['n_samples'] // self.multi_nets
self.nets = [
NerfCore(coord_chns=self.pos_encoder.out_dim * self.samples_per_part,
density_chns=self.samples_per_part,
color_chns=self.color_chns * self.samples_per_part,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
dir_chns=self.dir_chns_per_part,
dir_nf=fc_params['nf'] // 2,
act=fc_params['activation'])
for _ in range(self.n_parts)
]
for i in range(self.n_parts):
self.add_module(f"mlp_{i:d}", self.nets[i])
sampler_params['spherical'] = True
self.sampler = Sampler(**sampler_params)
self.rendering = VolumnRenderer()
def _init_chns(self):
super()._init_chns(x=2)
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth=False, debug=False) -> torch.Tensor:
"""
rays -> colors
def _create_core_unit(self):
return super()._create_core_unit(
x_chns=self.x_encoder.out_dim * self.samples_per_part,
density_chns=self.chns('density') * self.samples_per_part,
color_chns=self.chns('color') * self.samples_per_part)
def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]:
if what == "x":
return self._encode("x", samples.pts[..., -self.chns("x"):]).flatten(1, 2)
elif what == "d":
return self._encode("d", samples.dirs[:, 0])\
if self.d_encoder and samples.dirs is not None else None
else:
return super()._input(samples, what)
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:return: `Tensor(B, C)``, inferred images/pixels
"""
coords, depths, _, pts = self.sampler(rays_o, rays_d)
#print('NaN count: ', coords.isnan().sum().item(), depths.isnan().sum().item(), pts.isnan().sum().item())
coords_encoded = self.pos_encoder(coords[..., -self.coord_chns:])
dirs_encoded = self.dir_encoder(
sphere.calc_local_dir(rays_d, coords, pts) if self.spherical_dir else rays_d) \
if self.dir_encoder is not None else None
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
return self._render(samples, *outputs,
**merge(extra_args,
raymarching_chunk_size_or_sections=[self.samples_per_part]))
densities = torch.empty(rays_o.size(0), self.n_samples, device=device.default())
colors = torch.empty(rays_o.size(0), self.n_samples, self.color_chns,
device=device.default())
for i, net in enumerate(self.nets):
s = slice(i * self.samples_per_part, (i + 1) * self.samples_per_part)
c, d = net(coords_encoded[:, s].flatten(1, 2),
dirs_encoded[:, s].flatten(1, 2) if self.spherical_dir else dirs_encoded)
colors[:, s] = c.view(-1, self.samples_per_part, self.color_chns)
densities[:, s] = d
ret = self.rendering(colors.view(-1, self.n_samples, self.color_chns),
densities, depths, ret_depth=ret_depth, debug=debug)
if debug:
ret['sample_densities'] = densities
ret['sample_depths'] = depths
return ret
def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, chunk_id: int, **kwargs) -> NetOutput:
ret = self.cores[chunk_id](inputs, *outputs)
return {
key: value.reshape(*samples.size, -1)
for key, value in ret.items()
}
class SnerfFastExport(nn.Module):
class SnerfFastExport(torch.nn.Module):
def __init__(self, net: SnerfFast):
super().__init__()
......@@ -99,7 +59,7 @@ class SnerfFastExport(nn.Module):
def forward(self, coords_encoded, z_vals):
colors = []
densities = []
for i in range(self.net.n_parts):
for i in range(self.net.multi_nets):
s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part)
mlp = self.net.nets[i] if self.net.nets is not None else self.net.net
c, d = mlp(coords_encoded[:, s].flatten(1, 2))
......
from utils.misc import print_and_log
from .snerf import *
from .__common__ import *
from .nerf import NeRF
from .utils import load
from utils.misc import merge
class SNeRFX(SNeRF):
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
class SNeRFX(NeRF):
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__(args0, args1)
def _init_core(self):
def _preprocess_args(self):
self.args0["spherical"] = True
super()._preprocess_args()
if "net_samples" not in self.args:
n_nets = self.args.get("multi_nets", 1)
k = self.args["n_samples"] // self.space.steps[0].item()
self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)]
self.cores = self._create_core(len(self.args0["net_samples"]))
def render(self, samples: Samples, *outputs: str, chunk_id: int, **kwargs) -> Dict[str, torch.Tensor]:
"""
Infer colors, energies and other values (specified by `outputs`) of samples
(invalid items are filtered out)
:param samples `Samples(N)`: samples
:param outputs `str...`: which types of inferred data should be returned
:param chunk_id `int`: current index of sample chunk in renderer
:return `Dict[str, Tensor(N, *)]`: outputs of cores
"""
x = self.encode_x(samples)
d = self.encode_d(samples)
return self.cores[chunk_id](x, d, outputs)
cut_by_space = load(self.args['cut_by']).space if "cut_by" in self.args else self.space
k = self.args["n_samples"] // cut_by_space.steps[0].item()
self.args0["net_samples"] = [val * k for val in cut_by_space.balance_cut(0, n_nets)]
self.args1["multi_nets"] = len(self.args["net_samples"])
@torch.no_grad()
def split(self):
ret = super().split()
k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [
val * k for val in self.space.balance_cut(0, len(self.cores))
]
if len(net_samples) != len(self.cores):
print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.')
net_samples = [val * 2 for val in self.args0["net_samples"]]
self.args0['net_samples'] = net_samples
self.sampler = self.SamplerClass(**self.args)
self.args0['net_samples'] = [val * 2 for val in self.args0['net_samples']]
return ret
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
"""
Perform rendering for given rays.
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
return super()._render(samples, *outputs,
**merge(extra_args,
raymarching_chunk_size_or_sections=self.args["net_samples"]))
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param extra_outputs `list[str]`: extra items should be contained in the rendering result,
defaults to []
:return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
"""
return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs,
raymarching_chunk_size_or_sections=self.args["net_samples"])
@perf
def _multi_infer(self, inputs: NetInput, *outputs: str, chunk_id: int, **kwargs) -> NetOutput:
return self.cores[chunk_id](inputs, *outputs, **kwargs)
from pathlib import Path
from typing import Optional, Union
from .base import model_classes, BaseModel
from utils import netio
def get_class(model_class_name: str) -> Optional[type]:
return model_classes.get(model_class_name)
def deserialize(states: dict, **extra_args) -> BaseModel:
model_name: str = states["model"]
model_args: dict = states["args"]
Cls = get_class(model_name)
if Cls is None:
# For compatible with old S-* class
if model_name.startswith("S"):
Cls = get_class(model_name[1:])
model_args["spherical"] = True
model: BaseModel = Cls(states['args'], extra_args)
if "states" in states:
model.load_state_dict(states['states'])
return model
def serialize(model: BaseModel) -> dict:
return {
'model': model.cls,
'args': model.args0,
'states': model.state_dict()
}
def load(path: Union[str, Path]) -> BaseModel:
return deserialize(netio.load_checkpoint(path)[0])
from .__common__ import *
from .nerf import NeRF
class VNeRF(NeRF):
def _init_chns(self):
super()._init_chns(x=3)
def _init_space(self):
self.space = Space.create(self.args)
self.space.create_voxel_embedding(self.args['n_featdim'])
def _create_core_unit(self):
return super()._create_core_unit(x_chns=self.x_encoder.out_dim + self.args['n_featdim'])
def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]:
if what == "x":
return torch.cat([
self.space.extract_voxel_embedding(samples.voxel_indices),
self._encode("x", samples.pts)
], dim=-1)
else:
return super()._input(samples, what)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment