spherical_view_syn.py 8.78 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import os
import importlib
Nianchen Deng's avatar
Nianchen Deng committed
3
import re
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
from utils import color
from nets.snerf_fast import SnerfFast
Nianchen Deng's avatar
Nianchen Deng committed
6
7
from nets.snerf import Snerf
from nets.nerf import Nerf
Nianchen Deng's avatar
sync    
Nianchen Deng committed
8
9
10
from nets.nerf_depth import NerfDepth
from nets.bg_net import BgNet
from nets.oracle import Oracle
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11

BobYeah's avatar
sync    
BobYeah committed
12
13
14

class SphericalViewSynConfig(object):

Nianchen Deng's avatar
Nianchen Deng committed
15
    def __init__(self, id=None):
BobYeah's avatar
sync    
BobYeah committed
16
        self.name = 'default'
Nianchen Deng's avatar
Nianchen Deng committed
17
18
        self.c = color.RGB
        self.net = 'nerf'
Nianchen Deng's avatar
sync    
Nianchen Deng committed
19
20
        self.encode_x = 10
        self.encode_d = None
Nianchen Deng's avatar
Nianchen Deng committed
21
22
        self.depth_ref = False
        self.fc = {
BobYeah's avatar
sync    
BobYeah committed
23
24
            'nf': 256,
            'n_layers': 8,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
            'skips': [],
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
            'act': 'relu'
BobYeah's avatar
sync    
BobYeah committed
27
        }
Nianchen Deng's avatar
Nianchen Deng committed
28
29
        self.sa = {
            'sample_range': (1, 50),
BobYeah's avatar
sync    
BobYeah committed
30
31
            'n_samples': 32,
            'perturb_sample': True,
Nianchen Deng's avatar
Nianchen Deng committed
32
            'lindisp': True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
33
        }
Nianchen Deng's avatar
Nianchen Deng committed
34
35
36
37
        self.nerf_coarse = None

        if id is not None:
            self.from_id(id)
BobYeah's avatar
sync    
BobYeah committed
38
39
40

    def load(self, path):
        module_name = os.path.splitext(path)[0].replace('/', '.')
Nianchen Deng's avatar
Nianchen Deng committed
41
        config_module = importlib.import_module(module_name)
BobYeah's avatar
sync    
BobYeah committed
42
43
44
45
46
        config_module.update_config(self)
        self.name = module_name.split('.')[-1]

    def load_by_name(self, name):
        config_module = importlib.import_module(
Nianchen Deng's avatar
Nianchen Deng committed
47
            'configs.' + name)
BobYeah's avatar
sync    
BobYeah committed
48
49
50
        config_module.update_config(self)
        self.name = name

Nianchen Deng's avatar
Nianchen Deng committed
51
52
53
54
    def split_net_type(self):
        match_res = re.match(r'([a-z]+)(\d*)', self.net, re.I)
        return match_res.group(1), int(match_res.group(2)) if match_res.group(2) != '' else None

BobYeah's avatar
sync    
BobYeah committed
55
    def to_id(self):
Nianchen Deng's avatar
Nianchen Deng committed
56
57
58
        id = f"{self.name}@{self.net}"
        if self.c != color.RGB:
            id += f"-{color.to_str(self.c)}"
Nianchen Deng's avatar
sync    
Nianchen Deng committed
59
60
61
        id += f"_e{self.encode_x}"
        if self.encode_d is not None:
            id += f"_ed{self.encode_d}"
Nianchen Deng's avatar
Nianchen Deng committed
62
63
64
        id += f"_fc{self.fc['nf']}x{self.fc['n_layers']}"
        if len(self.fc['skips']) > 0:
            id += "_^%s" % ','.join([f'{val}' for val in self.fc['skips']])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
65
66
        if self.fc['act'] != 'relu':
            id += f"_*{self.fc['act']}"
Nianchen Deng's avatar
Nianchen Deng committed
67
68
69
70
71
72
73
        id += "_d{0:.2f}-{1:.2f}".format(*self.sa['sample_range'])
        id += f"_s{self.sa['n_samples']}"
        if self.nerf_coarse is not None:
            id += f"_co{self.nerf_coarse['nf']}x{self.nerf_coarse['n_layers']}x{self.nerf_coarse['n_samples']}"
        neg_flags = '%s%s' % (
            'p' if not self.sa['perturb_sample'] else '',
            'l' if not self.sa['lindisp'] else ''
BobYeah's avatar
sync    
BobYeah committed
74
        )
Nianchen Deng's avatar
Nianchen Deng committed
75
76
77
78
        if neg_flags:
            id += f'_~{neg_flags}'
        pos_flags = '%s' % (
            'd' if self.depth_ref else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
79
        )
Nianchen Deng's avatar
Nianchen Deng committed
80
81
82
        if pos_flags:
            id += f'_+{pos_flags}'
        return id
BobYeah's avatar
sync    
BobYeah committed
83
84

    def from_id(self, id: str):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
85
86
87
88
        id_splited = id.split('@')
        if len(id_splited) == 2:
            self.name = id_splited[0]
        segs = id_splited[-1].split('_')
BobYeah's avatar
sync    
BobYeah committed
89
        for i, seg in enumerate(segs):
Nianchen Deng's avatar
Nianchen Deng committed
90
            if not seg:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
                continue
Nianchen Deng's avatar
Nianchen Deng committed
92
93
94
95
96
97
98
99
100
101
102
103
            m = re.match(r"^co(\d+)x(\d+)x(\d+)$", seg)
            if m is not None:  # Coarse net parameters
                self.nerf_coarse = {
                    'nf': int(m.group(1)),
                    'n_layers': int(m.group(2)),
                    'n_samples': int(m.group(3))
                }
                continue
            m = re.match(r"^fc(\d+)x(\d+)$", seg)
            if m is not None:  # Full-connected network parameters
                self.fc['nf'] = int(m.group(1))
                self.fc['n_layers'] = int(m.group(2))
BobYeah's avatar
sync    
BobYeah committed
104
                continue
Nianchen Deng's avatar
Nianchen Deng committed
105
106
            if seg.startswith('^'):  # Skip connection
                self.fc['skips'] = [int(str) for str in seg[4:].split(',')]
BobYeah's avatar
sync    
BobYeah committed
107
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
108
            if seg.startswith('*'):  # Activation
Nianchen Deng's avatar
sync    
Nianchen Deng committed
109
                self.fc['act'] = seg[1:]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
110
111
                continue
            if seg.startswith('ed'):  # Encode direction
Nianchen Deng's avatar
sync    
Nianchen Deng committed
112
113
114
                self.encode_d = int(seg[2:])
                if self.encode_d == 0:
                    self.encode_d = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
115
116
                continue
            if seg.startswith('e'):  # Encode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
117
                self.encode_x = int(seg[1:])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
118
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
119
            if seg.startswith('d'):  # Depth range
Nianchen Deng's avatar
sync    
Nianchen Deng committed
120
                try:
Nianchen Deng's avatar
Nianchen Deng committed
121
                    self.sa['sample_range'] = tuple(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
122
123
124
125
                        float(str) for str in seg[1:].split('-'))
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
126
            if seg.startswith('s'):  # Number of samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
127
                try:
Nianchen Deng's avatar
Nianchen Deng committed
128
                    self.sa['n_samples'] = int(seg[1:])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
129
130
131
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
132
            if seg.startswith('~'):  # Negative flags
Nianchen Deng's avatar
sync    
Nianchen Deng committed
133
                if seg.find('p') >= 0:
Nianchen Deng's avatar
Nianchen Deng committed
134
                    self.sa['perturb_sample'] = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
                if seg.find('l') >= 0:
Nianchen Deng's avatar
Nianchen Deng committed
136
                    self.sa['lindisp'] = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
137
                if seg.find('d') >= 0:
Nianchen Deng's avatar
Nianchen Deng committed
138
                    self.depth_ref = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
139
140
                continue
            if seg.startswith('+'):  # Positive flags
Nianchen Deng's avatar
sync    
Nianchen Deng committed
141
                if seg.find('p') >= 0:
Nianchen Deng's avatar
Nianchen Deng committed
142
                    self.sa['perturb_sample'] = True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
143
                if seg.find('l') >= 0:
Nianchen Deng's avatar
Nianchen Deng committed
144
                    self.sa['lindisp'] = True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
145
                if seg.find('d') >= 0:
Nianchen Deng's avatar
Nianchen Deng committed
146
                    self.depth_ref = True
BobYeah's avatar
sync    
BobYeah committed
147
                continue
Nianchen Deng's avatar
Nianchen Deng committed
148
149
150
151
152
153
154
155
            if i == 0:  # Net & color
                seg_splited = seg.split('-')
                if len(seg_splited) == 1:
                    self.net = seg
                    self.c = color.RGB
                else:
                    self.net = seg_splited[0]
                    self.c = color.from_str(seg_splited[1])
BobYeah's avatar
sync    
BobYeah committed
156
157
158

    def print(self):
        print('==== Config %s ====' % self.name)
Nianchen Deng's avatar
Nianchen Deng committed
159
        print('Net type: ', self.net)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
160
        print('Encode dim: ', self.encode_x)
Nianchen Deng's avatar
Nianchen Deng committed
161
        print('Train with depth: ', self.depth_ref)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
162
163
        print('Support direction: ', False if self.encode_d is None
              else f'encode to {self.encode_d}')
Nianchen Deng's avatar
Nianchen Deng committed
164
165
166
167
        print('Full-connected network parameters:', self.fc)
        print('Sample parameters', self.sa)
        if self.nerf_coarse:
            print('NeRF fine network parameters', self.nerf_coarse)
BobYeah's avatar
sync    
BobYeah committed
168
        print('==========================')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
169
170

    def create_net(self):
Nianchen Deng's avatar
Nianchen Deng committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        net, multiple = self.split_net_type()
        if net == 'nerf':
            if self.nerf_coarse:
                coarse_fc = self.fc.copy()
                coarse_fc.update({
                    'nf': self.nerf_coarse['nf'],
                    'n_layers': self.nerf_coarse['n_layers']
                })
                coarse_sa = self.sa.copy()
                coarse_sa.update({
                    'n_samples': self.nerf_coarse['n_samples']
                })
                coarse_net = Nerf(coarse_fc, coarse_sa,
                                  c=self.c,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
185
186
                                  pos_encode=self.encode_x,
                                  dir_encode=self.encode_d)
Nianchen Deng's avatar
Nianchen Deng committed
187
188
            else:
                coarse_net = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
189
            return Nerf(core_params=self.fc,
Nianchen Deng's avatar
Nianchen Deng committed
190
191
192
                        sampler_params=self.sa,
                        fine_params=self.nerf_coarse,
                        c=self.c,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
193
194
                        pos_encode=self.encode_x,
                        dir_encode=self.encode_d,
Nianchen Deng's avatar
Nianchen Deng committed
195
196
                        coarse_net=coarse_net)
        if net == 'bgnet':
Nianchen Deng's avatar
sync    
Nianchen Deng committed
197
198
            return BgNet(core_params=self.fc,
                         encode=self.encode_x,
Nianchen Deng's avatar
Nianchen Deng committed
199
200
                         c=self.c)
        if net.startswith('oracle'):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
201
            return Oracle(core_params=self.fc,
Nianchen Deng's avatar
Nianchen Deng committed
202
                          sampler_params=self.sa,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
203
                          pos_encode=self.encode_x,
Nianchen Deng's avatar
Nianchen Deng committed
204
205
                          out_activation=self.net[6:] if len(self.net) > 6 else 'sigmoid')
        if net == 'dnerfa':
Nianchen Deng's avatar
sync    
Nianchen Deng committed
206
            return NerfDepth(core_params=self.fc,
Nianchen Deng's avatar
Nianchen Deng committed
207
208
                             sampler_params=self.sa,
                             c=self.c,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
209
                             pos_encode=self.encode_x,
Nianchen Deng's avatar
Nianchen Deng committed
210
                             n_bins=multiple or 128,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
211
                             include_neighbor_bins=False)
Nianchen Deng's avatar
Nianchen Deng committed
212
        if net == 'dnerf':
Nianchen Deng's avatar
sync    
Nianchen Deng committed
213
            return NerfDepth(core_params=self.fc,
Nianchen Deng's avatar
Nianchen Deng committed
214
215
                             sampler_params=self.sa,
                             c=self.c,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
216
                             pos_encode=self.encode_x,
Nianchen Deng's avatar
Nianchen Deng committed
217
218
                             n_bins=multiple or 128)
        if net == 'snerf':
Nianchen Deng's avatar
sync    
Nianchen Deng committed
219
            return Snerf(core_params=self.fc,
Nianchen Deng's avatar
Nianchen Deng committed
220
221
222
                         sampler_params=self.sa,
                         n_parts=multiple or 1,
                         c=self.c,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
223
224
                         pos_encode=self.encode_x,
                         dir_encode=self.encode_d)
Nianchen Deng's avatar
Nianchen Deng committed
225
        if net == 'snerffast':
Nianchen Deng's avatar
sync    
Nianchen Deng committed
226
            return SnerfFast(core_params=self.fc,
Nianchen Deng's avatar
Nianchen Deng committed
227
228
229
                             sampler_params=self.sa,
                             n_parts=multiple or 1,
                             c=self.c,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
230
231
                             pos_encode=self.encode_x,
                             dir_encode=self.encode_d)
Nianchen Deng's avatar
Nianchen Deng committed
232
        raise ValueError(f'Invalid net type: {net} - {multiple}')