from .__common__ import * from .base import BaseModel from operator import itemgetter from utils import math from utils.misc import masked_scatter, merge class NeRF(BaseModel): TrainerClass = "TrainWithSpace" SamplerClass = None RendererClass = None space: Union[Space, Voxels, Octree] @property def multi_nets(self) -> int: return self.args.get("multi_nets", 1) def __init__(self, args0: dict, args1: dict = None): """ Initialize a NeRF model :param args0 `dict`: basic arguments :param args1 `dict`: extra arguments, defaults to {} """ super().__init__(args0, args1) # Initialize components self._init_space() self._init_encoders() self._init_core() self._init_sampler() self._init_renderer() @perf def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput: inputs = inputs or self.input(samples) if len(self.cores) == 1: return self.cores[0](inputs, *outputs, samples=samples, **kwargs) return self._multi_infer(inputs, *outputs, samples=samples, **kwargs) @torch.no_grad() def split(self): ret = self.space.split() if 'n_samples' in self.args0: self.args0['n_samples'] *= 2 if 'voxel_size' in self.args0: self.args0['voxel_size'] /= 2 if "sample_step_ratio" in self.args0: self.args1["sample_step"] = self.args0["voxel_size"] \ * self.args0["sample_step_ratio"] if 'sample_step' in self.args0: self.args0['sample_step'] /= 2 return ret def _preprocess_args(self): if "sample_step_ratio" in self.args0: self.args1["sample_step"] = self.args0["voxel_size"] * self.args0["sample_step_ratio"] if self.args0.get("spherical"): sample_range = [1 / self.args0['depth_range'][0], 1 / self.args0['depth_range'][1]] \ if self.args0.get('depth_range') else [1, 0] rot_range = [[-180, -90], [180, 90]] self.args1['bbox'] = [ [sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])], [sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])] ] self.args1['sample_range'] = sample_range if not self.args.get("net_bounds"): self.register_temp("net_bounds", None) self.args1["multi_nets"] = 1 else: self.register_temp("net_bounds", torch.tensor(self.args["net_bounds"])) self.args1["multi_nets"] = self.net_bounds.size(0) def _init_chns(self, **chns): super()._init_chns(**{ "x": self.args.get('n_featdim') or 3, "d": 3 if self.args.get('encode_d') else 0, **chns }) def _init_space(self): self.space = Space.create(self.args) if self.args.get('n_featdim'): self.space.create_embedding(self.args['n_featdim']) def _init_encoders(self): self.x_encoder = InputEncoder(self.chns("x"), self.args['encode_x'], cat_input=True) self.d_encoder = InputEncoder(self.chns("d"), self.args['encode_d'])\ if self.chns("d") > 0 else None def _init_core(self): self.cores = self.create_multiple(self._create_core_unit, self.args.get("multi_nets", 1)) def _init_sampler(self): if self.SamplerClass is None: SamplerClass = Sampler else: SamplerClass = self.SamplerClass self.sampler = SamplerClass(**self.args) def _init_renderer(self): if self.RendererClass is None: if self.args.get("core") == "nerfadv": RendererClass = DensityFirstVolumnRenderer else: RendererClass = VolumnRenderer else: RendererClass = self.RendererClass self.renderer = RendererClass(**self.args) def _create_core_unit(self, core_params: dict = None, **args): core_params = core_params or self.args["core_params"] if self.args.get("core") == "nerfadv": return NerfAdvCore(**{ "x_chns": self.x_encoder.out_dim, "d_chns": self.d_encoder.out_dim, "density_chns": self.chns('density'), "color_chns": self.chns('color'), **core_params, **args }) else: return NerfCore(**{ "x_chns": self.x_encoder.out_dim, "density_chns": self.chns('density'), "color_chns": self.chns('color'), "d_chns": self.d_encoder.out_dim if self.d_encoder else 0, **core_params, **args }) @perf def _sample(self, data: InputData, **extra_args) -> Samples: return self.sampler(*itemgetter("rays_o", "rays_d")(data), self.space, **merge(self.args, extra_args)) @perf def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData: if len(samples.size) == 1: return self.infer(*outputs, samples=samples) return self.renderer(self, samples, *outputs, **merge(self.args, extra_args)) def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]: if what == "x": if self.args.get('n_featdim'): return self._encode("emb", self.space.extract_embedding( samples.pts, samples.voxel_indices)) else: return self._encode("x", samples.pts) elif what == "d": if self.d_encoder and samples.dirs is not None: return self._encode("d", samples.dirs) else: return None elif what == "f": return None else: ValueError(f"Don't know how to process input \"{what}\"") def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor: if what == "x": if self.args.get("spherical"): sr = self.args['sample_range'] # scale val.r: [sr[0], sr[1]] -> [-PI/2, PI/2] val = val.clone() val[..., 0] = ((val[..., 0] - sr[0]) / (sr[1] - sr[0]) - .5) * math.pi return self.x_encoder(val) else: return self.x_encoder(val * math.pi) elif what == "emb": return self.x_encoder(val * math.pi) elif what == "d": return self.d_encoder(val, angular=True) else: ValueError(f"Don't know how to encode \"{what}\"") @perf def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput: ret: NetOutput = {} for i, core in enumerate(self.cores): selector = (samples.pts >= self.net_bounds[i, 0] and samples.pts < self.net_bounds[i, 1]).all(-1) partial_ret: NetOutput = core(inputs[selector], *outputs, samples=samples[selector], **kwargs) for key, value in partial_ret.items(): if key not in ret: ret[key] = value.new_zeros(*inputs.shape, value.shape[-1]) ret[key] = masked_scatter(selector, value, ret[key]) return ret class NSVF(NeRF): SamplerClass = VoxelSampler