from operator import itemgetter from .__common__ import * from utils import netio from utils.args import BaseArgs model_classes: dict[str, "Model"] = {} class Model(nn.Module): class Args(BaseArgs): color: str = "rgb" coord: str = "gl" args: Args color: Color def __init__(self, args: Args): super().__init__() self.args = args self.color = Color[self.args.color] # stub method def __call__(self, rays: Rays, *outputs: str, **args) -> ReturnData: ... def forward(self, rays: Rays, *outputs: str, **args) -> ReturnData: raise NotImplementedError() @staticmethod def get_class(typename: str) -> Type["Model"] | None: return model_classes.get(typename) @staticmethod def create(typename: str, args: dict|Args) -> "Model": ModelCls = Model.get_class(typename) if ModelCls is None: raise ValueError(f"Model {typename} is not found") if isinstance(args, dict): args = ModelCls.Args(**args) return ModelCls(args) @staticmethod def load(path: PathLike) -> "Model": ckpt = netio.load_checkpoint(Path(path))[0] model_type, model_args = itemgetter("model", "model_args")(ckpt["args"]) model = Model.create(model_type, model_args) model.load_state_dict(ckpt["states"]["model"]) return model