base.py 3.04 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
import json
from torch import Tensor

Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
from utils import color
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
7
from utils.nn import Module
from utils.types import *
from utils.profile import profile
Nianchen Deng's avatar
sync    
Nianchen Deng committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21


model_classes = {}


class BaseModelMeta(type):

    def __new__(cls, name, bases, attrs):
        new_cls = type.__new__(cls, name, bases, attrs)
        if name != 'BaseModel':
            model_classes[name] = new_cls
        return new_cls


Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
class BaseModel(Module, metaclass=BaseModelMeta):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
23
24
25
26
27

    @property
    def args(self):
        return {**self.args0, **self.args1}

Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
29
30
31
32
    @property
    def color(self) -> int:
        return self.args.get("color", color.RGB)
    
    def __init__(self, args0: dict, args1: dict = None):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
33
34
        super().__init__()
        self.args0 = args0
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
37
38
        self.args1 = args1 or {}
        self._preprocess_args()
        self._init_chns()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
    def chns(self, name: str, value: int = None) -> int:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
41
        if value is not None:
            self._chns[name] = value
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
        return self._chns.get(name, 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    def input(self, samples: Samples, *whats: str) -> NetInput:
        all = ["x", "d", "f"]
        whats = whats or all
        return NetInput(**{
            key: self._input(samples, key)
            for key in all if key in whats
        })

    def infer(self, *outputs, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
        """
        Infer colors, energies or other values (specified by `outputs`) of samples 
        (invalid items are filtered out) given their encoded positions and directions

        :param outputs `str...`: which types of inferred data should be returned
        :param samples `Samples(N)`: samples
        :param inputs `NetInput(N)`: (optional) inputs to net
        :return `NetOutput`: data inferred by core net
        """
        raise NotImplementedError()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
    @profile
Nianchen Deng's avatar
sync    
Nianchen Deng committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def forward(self, data: InputData, *outputs: str, **extra_args) -> ReturnData:
        """
        Perform rendering for given rays.

        :param data `InputData`: input data
        :param outputs `str...`: items should be contained in the rendering result
        :param extra_args `{str:*}`: extra arguments for this forward process
        :return `ReturnData`: the rendering result, see corresponding Renderer implementation
        """
        ret = {}
        samples = self._sample(data, **extra_args)  # (N, P)
        ret["rays_filter"] = samples.filter_rays()
        ret.update(self._render(samples, *outputs, **extra_args))
        return ret

    def print_config(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
81
        return json.dumps(self.args)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
82
83
84
85
86
87
88
89
90
91
    
    def _preprocess_args(self):
        pass

    def _init_chns(self, **chns):
        self._chns = {}
        if "color" in self.args:
            self._chns["color"] = color.chns(self.color)
        self._chns.update(chns)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
    def _input(self, samples: Samples, what: str) -> Tensor | None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
93
94
95
96
        raise NotImplementedError()

    def _sample(self, data: InputData, **extra_args) -> Samples:
        raise NotImplementedError()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
97

Nianchen Deng's avatar
sync    
Nianchen Deng committed
98
99
    def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
        raise NotImplementedError()