import json from torch import Tensor from utils import color from utils.nn import Module from utils.types import * from utils.profile import profile model_classes = {} class BaseModelMeta(type): def __new__(cls, name, bases, attrs): new_cls = type.__new__(cls, name, bases, attrs) if name != 'BaseModel': model_classes[name] = new_cls return new_cls class BaseModel(Module, metaclass=BaseModelMeta): @property def args(self): return {**self.args0, **self.args1} @property def color(self) -> int: return self.args.get("color", color.RGB) def __init__(self, args0: dict, args1: dict = None): super().__init__() self.args0 = args0 self.args1 = args1 or {} self._preprocess_args() self._init_chns() def chns(self, name: str, value: int = None) -> int: if value is not None: self._chns[name] = value return self._chns.get(name, 1) def input(self, samples: Samples, *whats: str) -> NetInput: all = ["x", "d", "f"] whats = whats or all return NetInput(**{ key: self._input(samples, key) for key in all if key in whats }) def infer(self, *outputs, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput: """ Infer colors, energies or other values (specified by `outputs`) of samples (invalid items are filtered out) given their encoded positions and directions :param outputs `str...`: which types of inferred data should be returned :param samples `Samples(N)`: samples :param inputs `NetInput(N)`: (optional) inputs to net :return `NetOutput`: data inferred by core net """ raise NotImplementedError() @profile def forward(self, data: InputData, *outputs: str, **extra_args) -> ReturnData: """ Perform rendering for given rays. :param data `InputData`: input data :param outputs `str...`: items should be contained in the rendering result :param extra_args `{str:*}`: extra arguments for this forward process :return `ReturnData`: the rendering result, see corresponding Renderer implementation """ ret = {} samples = self._sample(data, **extra_args) # (N, P) ret["rays_filter"] = samples.filter_rays() ret.update(self._render(samples, *outputs, **extra_args)) return ret def print_config(self): return json.dumps(self.args) def _preprocess_args(self): pass def _init_chns(self, **chns): self._chns = {} if "color" in self.args: self._chns["color"] = color.chns(self.color) self._chns.update(chns) def _input(self, samples: Samples, what: str) -> Tensor | None: raise NotImplementedError() def _sample(self, data: InputData, **extra_args) -> Samples: raise NotImplementedError() def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData: raise NotImplementedError()