from pathlib import Path from operator import itemgetter from modules.input_encoder import FreqEncoder from .__common__ import * from .base import BaseModel from utils import math from utils.misc import masked_scatter class NeRF(BaseModel): TrainerClass = "TrainWithSpace" SamplerClass = None RendererClass = None space: Space | Voxels | Octree @property def multi_nets(self) -> int: return self.args.get("multi_nets", 1) def __init__(self, args0: dict, args1: dict = None): """ Initialize a NeRF model :param args0 `dict`: basic arguments :param args1 `dict`: extra arguments, defaults to {} """ super().__init__(args0, args1) # Initialize components self._init_space() self._init_encoders() self._init_core() self._init_sampler() self._init_renderer() @profile def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput: inputs = inputs or self.input(samples) if len(self.cores) == 1: return self.cores[0](inputs, *outputs, samples=samples, **kwargs) return self._multi_infer(inputs, *outputs, samples=samples, **kwargs) @torch.no_grad() def split(self): ret = self.space.split() if 'n_samples' in self.args0: self.args0['n_samples'] *= 2 if 'voxel_size' in self.args0: self.args0['voxel_size'] /= 2 if "sample_step_ratio" in self.args0: self.args1["sample_step"] = self.args0["voxel_size"] \ * self.args0["sample_step_ratio"] if 'sample_step' in self.args0: self.args0['sample_step'] /= 2 return ret def export_onnx(self, path: str | Path, batch_size: int = None): self.cores[0].get_exporter().export_onnx(path / "core_0.onnx", batch_size) def _preprocess_args(self): if "sample_step_ratio" in self.args0: self.args1["sample_step"] = self.args0["voxel_size"] * self.args0["sample_step_ratio"] if self.args0.get("spherical"): sample_range = [ 1 / self.args0['depth_range'][0], 1 / self.args0['depth_range'][1] ] if 'depth_range' in self.args0 else [1, 0] rot_range = [[-180, -90], [180, 90]] self.args1['bbox'] = [ [sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])], [sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])] ] self.args1['sample_range'] = sample_range if not self.args.get("multi_nets"): if not self.args.get("net_bounds"): self.register_temp("net_bounds", None) self.args1["multi_nets"] = 1 else: self.register_temp("net_bounds", torch.tensor(self.args["net_bounds"])) self.args1["multi_nets"] = self.net_bounds.size(0) def _init_chns(self, **chns): super()._init_chns(**{ "x": self.args.get('n_featdim') or 3, "d": 3 if self.args.get('encode_d') else 0, **chns }) def _init_space(self): self.space = Space.create(self.args) if self.args.get('n_featdim'): self.space.create_embedding(self.args['n_featdim']) def _init_encoders(self): if isinstance(self.args["encode_x"], list): self.x_encoder = InputEncoder.create(self.chns("x"), *self.args["encode_x"]) else: self.x_encoder = FreqEncoder(self.chns("x"), self.args['encode_x'], cat_input=True) if self.args.get("encode_d"): if isinstance(self.args["encode_d"], list): self.d_encoder = InputEncoder.create(self.chns("d"), *self.args["encode_d"]) else: self.d_encoder = FreqEncoder(self.chns("d"), self.args['encode_d'], angular=True) else: self.d_encoder = None def _init_core(self): self.cores = self.create_multiple(self._create_core_unit, self.args.get("multi_nets", 1)) def _init_sampler(self): if self.SamplerClass is None: SamplerClass = Sampler else: SamplerClass = self.SamplerClass self.sampler = SamplerClass(**self.args) def _init_renderer(self): if self.RendererClass is None: if self.args.get("core") == "nerfadv": RendererClass = DensityFirstVolumnRenderer else: RendererClass = VolumnRenderer else: RendererClass = self.RendererClass self.renderer = RendererClass(**self.args) def _create_core_unit(self, core_params: dict = None, **args): core_params = core_params or self.args["core_params"] if self.args.get("core") == "nerfadv": return NerfAdvCore(**{ "x_chns": self.x_encoder.out_dim, "d_chns": self.d_encoder.out_dim, "density_chns": self.chns('density'), "color_chns": self.chns('color'), **core_params, **args }) else: return NerfCore(**{ "x_chns": self.x_encoder.out_dim, "density_chns": self.chns('density'), "color_chns": self.chns('color'), "d_chns": self.d_encoder.out_dim if self.d_encoder else 0, **core_params, **args }) @profile def _sample(self, data: InputData, **extra_args) -> Samples: return self.sampler(*itemgetter("rays_o", "rays_d")(data), self.space, **self.args | extra_args) @profile def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData: if len(samples.size) == 1: return self.infer(*outputs, samples=samples) return self.renderer(self, samples, *outputs, **self.args | extra_args) def _input(self, samples: Samples, what: str) -> torch.Tensor | None: if what == "x": if self.args.get('n_featdim'): return self._encode("emb", self.space.extract_embedding( samples.pts, samples.voxel_indices)) else: return self._encode("x", samples.pts) elif what == "d": if self.d_encoder and samples.dirs is not None: return self._encode("d", samples.dirs) else: return None elif what == "f": return None else: ValueError(f"Don't know how to process input \"{what}\"") def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor: if what == "x": # Normalize x according to the encoder's range requirement using space's bounding box if self.space.bbox is not None: val = (val - self.space.bbox[0]) / (self.space.bbox[1] - self.space.bbox[0]) val = val * (self.x_encoder.in_range[1] - self.x_encoder.in_range[0])\ + self.x_encoder.in_range[0] return self.x_encoder(val) elif what == "emb": return self.x_encoder(val) elif what == "d": return self.d_encoder(val) else: ValueError(f"Don't know how to encode \"{what}\"") @profile def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput: ret: NetOutput = {} for i, core in enumerate(self.cores): selector = (samples.pts >= self.net_bounds[i, 0] and samples.pts < self.net_bounds[i, 1]).all(-1) partial_ret: NetOutput = core(inputs[selector], *outputs, samples=samples[selector], **kwargs) for key, value in partial_ret.items(): if key not in ret: ret[key] = value.new_zeros(*inputs.shape, value.shape[-1]) ret[key] = masked_scatter(selector, value, ret[key]) return ret class NSVF(NeRF): SamplerClass = VoxelSampler