from utils.misc import dump_tensors_to_csv from .__common__ import * from .base import BaseModel from typing import Callable from .nerf import NeRF from utils.voxels import trilinear_interp class CNeRF(BaseModel): TrainerClass = "TrainMultiScale" class InterpSpace(object): def __init__(self, space: Voxels, vidxs: Tensor, feats_fn: Callable[[Any], Tensor]) -> None: super().__init__() self.space = space self.corner_indices, self.corners = space.get_corners(vidxs) self.feats_on_corners = feats_fn(self.corners) @perf def interp(self, samples: Samples) -> Tensor: with perf("Prepare for coarse interpolation"): voxels = self.space.voxels[samples.interp_vidxs] cidxs = self.corner_indices[samples.interp_vidxs] # (N, 8) feats_on_corners = self.feats_on_corners[cidxs] # (N, 8, X) # (N, 3) normed-coords in voxel p = (samples.pts - voxels) / self.space.voxel_size + .5 with perf("Interpolate features"): return trilinear_interp(p, feats_on_corners) @property def stage(self): return self.args.get("stage", 0) def __init__(self, args0: dict, args1: dict = None): super().__init__(args0, args1) self.sub_models = [] args0_for_submodel = { key: value for key, value in args0.items() if key != "sub_models" and key != "interp_on_coarse" } for i in range(len(self.args["sub_models"])): self.args["sub_models"][i] = { **args0_for_submodel, **self.args["sub_models"][i] } self.sub_models.append(NeRF(self.args["sub_models"][i], args1)) self.sub_models = torch.nn.ModuleList(self.sub_models) for i in range(self.stage): print(f"__init__: freeze model {i}") self.model(i).freeze() def model(self, level: int) -> NeRF: return self.sub_models[level] def trigger_stage(self, stage: int): print(f"trigger_stage: freeze model {stage - 1}") self.model(stage - 1).freeze() self.model(stage).space = self.model(stage - 1).space.clone() self.args0["stage"] = stage @perf def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput: inputs = inputs or self.input(samples) return self.model(samples.level).infer(*outputs, samples=samples, inputs=inputs, **kwargs) def print_config(self): for i, model in enumerate(self.sub_models): print(f"Model {i} =====>") model.print_config() @torch.no_grad() def split(self): return self.model(self.stage).split() def _input(self, samples: Samples, what: str) -> Optional[Tensor]: if what == "f": if samples.level == 0: return None if samples.interp_space is None: return self._infer_features(pts=samples.pts, level=samples.level - 1) return samples.interp_space.interp(samples) else: return self.model(samples.level)._input(samples, what) @perf def _sample(self, data: InputData, **extra_args) -> Samples: samples: Samples = self.model(data["level"])._sample(data, **extra_args) samples.level = data["level"] # TODO remove below #dump_tensors_to_csv(f"/home/dengnc/dvs/data/classroom/_nets/ms_train_t0.8/_cnerf_ioc/{'train' if self.training else 'test'}.csv", # samples.voxel_indices, data["rays_d"]) return samples @perf def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData: self._prepare_interp(samples, on_coarse=self.args.get("interp_on_coarse")) return self.model(samples.level).renderer(self, samples, *outputs, **{ **self.model(samples.level).args, **extra_args}) def _infer_features(self, samples: Samples = None, **sample_data) -> NetOutput: samples = samples or Samples(**sample_data) if self.args.get("interp_on_coarse"): self._prepare_interp(samples, on_coarse=True) inputs = self.input(samples, "x", "f") return self.infer("features", samples=samples, inputs=inputs)["features"] def _prepare_interp(self, samples: Samples, on_coarse: bool): if samples.level == 0: return if on_coarse: interp_space = self.model(samples.level - 1).space samples.interp_vidxs = interp_space.get_voxel_indices(samples.pts) else: interp_space = self.model(samples.level).space samples.interp_vidxs = samples.voxel_indices samples.interp_space = CNeRF.InterpSpace(interp_space, samples.interp_vidxs, lambda corners: self._infer_features( pts=corners, level=samples.level - 1)) def _after_load_state_dict(self) -> None: a: torch.Tensor = None return print(list(self.model(0).named_parameters())[2]) exit()