from .__common__ import * from .nerf import NeRF from .utils import load class SNeRFX(NeRF): def _preprocess_args(self): self.args0["spherical"] = True super()._preprocess_args() if "net_samples" not in self.args: n_nets = self.args.get("multi_nets", 1) cut_by_space = load(self.args['cut_by']).space if "cut_by" in self.args else self.space k = self.args["n_samples"] // cut_by_space.steps[0].item() self.args0["net_samples"] = [val * k for val in cut_by_space.balance_cut(0, n_nets)] self.args1["multi_nets"] = len(self.args["net_samples"]) @torch.no_grad() def split(self): ret = super().split() self.args0['net_samples'] = [val * 2 for val in self.args0['net_samples']] return ret @profile def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData: return super()._render(samples, *outputs, **extra_args | {"raymarching_chunk_size_or_sections": self.args["net_samples"]}) @profile def _multi_infer(self, inputs: NetInput, *outputs: str, chunk_id: int, **kwargs) -> NetOutput: return self.cores[chunk_id](inputs, *outputs, **kwargs)