import torch import model from .base import * from modules import * from utils.mem_profiler import MemProfiler from utils.perf import perf from utils.misc import masked_scatter class NeRF(BaseModel): trainer = "TrainWithSpace" SamplerClass = Sampler RendererClass = VolumnRenderer def __init__(self, args0: dict, args1: dict = {}): """ Initialize a NeRF model :param args0 `dict`: basic arguments :param args1 `dict`: extra arguments, defaults to {} """ if "sample_step_ratio" in args0: args1["sample_step"] = args0["voxel_size"] * args0["sample_step_ratio"] super().__init__(args0, args1) # Initialize components self._init_space() self._init_encoders() self._init_core() self.sampler = self.SamplerClass(**self.args) self.rendering = self.RendererClass(**self.args) def _init_encoders(self): self.pot_encoder = InputEncoder.Get(self.args['n_pot_encode'], self.args.get('n_featdim') or 3) if self.args.get('n_dir_encode'): self.dir_chns = 3 self.dir_encoder = InputEncoder.Get(self.args['n_dir_encode'], self.dir_chns) else: self.dir_chns = 0 self.dir_encoder = None def _init_space(self): if 'space' not in self.args: self.space = Space(**self.args) elif self.args['space'] == 'octree': self.space = Octree(**self.args) elif self.args['space'] == 'voxels': self.space = Voxels(**self.args) else: self.space = model.load(self.args['space'])[0].space if self.args.get('n_featdim'): self.space.create_embedding(self.args['n_featdim']) def _new_core_unit(self): return NerfCore(coord_chns=self.pot_encoder.out_dim, density_chns=self.chns('density'), color_chns=self.chns('color'), core_nf=self.args['fc_params']['nf'], core_layers=self.args['fc_params']['n_layers'], dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0, dir_nf=self.args['fc_params']['nf'] // 2, act=self.args['fc_params']['activation'], skips=self.args['fc_params']['skips']) def _create_core(self, n_nets=1): return self._new_core_unit() if n_nets == 1 else nn.ModuleList([ self._new_core_unit() for _ in range(n_nets) ]) def _init_core(self): if not self.args.get("net_bounds"): self.core = self._create_core() else: self.register_buffer("net_bounds", torch.tensor(self.args["net_bounds"]), False) self.cores = self._create_core(self.net_bounds.size(0)) def render(self, samples: Samples, *outputs: str, **kwargs) -> Dict[str, torch.Tensor]: """ Render colors, energies and other values (specified by `outputs`) of samples (invalid items are filtered out) :param samples `Samples(N)`: samples :param outputs `str...`: which types of inferred data should be returned :return `Dict[str, Tensor(N, *)]`: outputs of cores """ x = self.encode_x(samples) d = self.encode_d(samples) return self.infer(x, d, *outputs, pts=samples.pts, **kwargs) def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, pts: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: """ Infer colors, energies and other values (specified by `outputs`) of samples (invalid items are filtered out) given their encoded positions and directions :param x `Tensor(N, Ex)`: encoded positions :param d `Tensor(N, Ed)`: encoded directions :param outputs `str...`: which types of inferred data should be returned :param pts `Tensor(N, 3)`: raw sample positions :return `Dict[str, Tensor(N, *)]`: outputs of cores """ if getattr(self, "core", None): return self.core(x, d, outputs) ret = {} for i, core in enumerate(self.cores): selector = (pts >= self.net_bounds[i, 0] and pts < self.net_bounds[i, 1]).all(-1) partial_ret = core(x[selector], d[selector], outputs) for key, value in partial_ret.items(): if value is None: ret[key] = None continue if key not in ret: ret[key] = torch.zeros(*x.shape[:-1], value.shape[-1], device=x.device) ret[key] = masked_scatter(selector, value, ret[key]) return ret def embed(self, samples: Samples) -> torch.Tensor: return self.space.extract_embedding(samples.pts, samples.voxel_indices) def encode_x(self, samples: Samples) -> torch.Tensor: x = self.embed(samples) if self.args.get('n_featdim') else samples.pts return self.pot_encoder(x) def encode_d(self, samples: Samples) -> torch.Tensor: return self.dir_encoder(samples.dirs) if self.dir_encoder is not None else None @torch.no_grad() def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: densities = self.render(Samples(sampled_points, None, None, None, sampled_voxel_indices), 'density') return 1 - (-densities).exp() @torch.no_grad() def pruning(self, threshold: float = 0.5, train_stats=False): return self.space.pruning(self.get_scores, threshold, train_stats) @torch.no_grad() def splitting(self): ret = self.space.splitting() if 'n_samples' in self.args0: self.args0['n_samples'] *= 2 if 'voxel_size' in self.args0: self.args0['voxel_size'] /= 2 if "sample_step_ratio" in self.args0: self.args1["sample_step"] = self.args0["voxel_size"] \ * self.args0["sample_step_ratio"] if 'sample_step' in self.args0: self.args0['sample_step'] /= 2 self.sampler = self.SamplerClass(**self.args) return ret @torch.no_grad() def double_samples(self): pass @perf def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, extra_outputs: List[str] = [], **kwargs) -> torch.Tensor: """ Perform rendering for given rays. :param rays_o `Tensor(N, 3)`: rays' origin :param rays_d `Tensor(N, 3)`: rays' direction :param extra_outputs `list[str]`: extra items should be contained in the rendering result, defaults to [] :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation """ args = {**self.args, **kwargs} with MemProfiler(f"{self.__class__}.forward: before sampling"): samples, rays_mask = self.sampler(rays_o, rays_d, self.space, **args) MemProfiler.print_memory_stats(f"{self.__class__}.forward: after sampling") with MemProfiler(f"{self.__class__}.forward: rendering"): if samples is None: return None return { **self.rendering(self, samples, extra_outputs, **args), 'samples': samples, 'rays_mask': rays_mask }