import torch.nn as nn from utils import color model_classes = {} class BaseModelMeta(type): def __new__(cls, name, bases, attrs): new_cls = type.__new__(cls, name, bases, attrs) if name != 'BaseModel': model_classes[name] = new_cls return new_cls class BaseModel(nn.Module, metaclass=BaseModelMeta): trainer = "Train" @property def args(self): return {**self.args0, **self.args1} def __init__(self, args0: dict, args1: dict = {}): super().__init__() self.args0 = args0 self.args1 = args1 self._chns = { "color": color.chns(color.from_str(self.args['color'])) } def chns(self, name: str): return self._chns.get(name, 1)