from .__common__ import * from .nerf import NeRF class SnerfFast(NeRF): def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, chunk_id: int, **kwargs) -> NetOutput: inputs = inputs or self.input(samples) return { key: value.reshape(*samples.size, -1) for key, value in self.cores[chunk_id](inputs, *outputs).items() } def _preprocess_args(self): self.args0["spherical"] = True super()._preprocess_args() self.samples_per_part = self.args['n_samples'] // self.multi_nets def _init_chns(self): super()._init_chns(x=2) def _create_core_unit(self): return super()._create_core_unit( x_chns=self.x_encoder.out_dim * self.samples_per_part, density_chns=self.chns('density') * self.samples_per_part, color_chns=self.chns('color') * self.samples_per_part) def _input(self, samples: Samples, what: str) -> torch.Tensor | None: if what == "x": return self._encode("x", samples.pts[..., -self.chns("x"):]).flatten(1, 2) elif what == "d": return self._encode("d", samples.dirs[:, 0])\ if self.d_encoder and samples.dirs is not None else None else: return super()._input(samples, what) def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor: if what == "x": # Normalize x according to the encoder's range requirement using space's bounding box bbox = self.space.bbox[:, -self.chns("x"):] val = (val - bbox[0]) / (bbox[1] - bbox[0]) val = val * (self.x_encoder.in_range[1] - self.x_encoder.in_range[0])\ + self.x_encoder.in_range[0] return self.x_encoder(val) return super()._encode(what, val) def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData: return super()._render(samples, *outputs, **extra_args | {"raymarching_chunk_size_or_sections", [self.samples_per_part]}) def _sample(self, data: InputData, **extra_args) -> Samples: samples = super()._sample(data, **extra_args) samples.voxel_indices = 0 return samples class SnerfFastExport(torch.nn.Module): def __init__(self, net: SnerfFast): super().__init__() self.net = net def forward(self, coords_encoded, z_vals): colors = [] densities = [] for i in range(self.net.multi_nets): s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part) mlp = self.net.nets[i] if self.net.nets is not None else self.net.net c, d = mlp(coords_encoded[:, s].flatten(1, 2)) colors.append(c.view(-1, self.net.samples_per_part, self.net.color_chns)) densities.append(d) colors = torch.cat(colors, 1) densities = torch.cat(densities, 1) alphas = self.net.rendering.density2alpha(densities, z_vals) return torch.cat([colors, alphas[..., None]], -1)