import torch.nn as nn
from utils import color


model_classes = {}


class BaseModelMeta(type):

    def __new__(cls, name, bases, attrs):
        new_cls = type.__new__(cls, name, bases, attrs)
        if name != 'BaseModel':
            model_classes[name] = new_cls
        return new_cls


class BaseModel(nn.Module, metaclass=BaseModelMeta):

    trainer = "Train"

    @property
    def args(self):
        return {**self.args0, **self.args1}

    def __init__(self, args0: dict, args1: dict = {}):
        super().__init__()
        self.args0 = args0
        self.args1 = args1
        self._chns = {
            "color": color.chns(color.from_str(self.args['color']))
        }

    def chns(self, name: str):
        return self._chns.get(name, 1)