spherical_view_syn.py 13.9 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import os
import importlib
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
from utils.constants import *
from utils import color
Nianchen Deng's avatar
Nianchen Deng committed
5
6
from nets.msl_net import MslNet
from nets.msl_net_new import NewMslNet
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
10
11
12
13
14
15
16
from nets.msl_ray import MslRay
from nets.msl_fast import MslFast
from nets.snerf_fast import SnerfFast
from nets.cnerf_v3 import CNerf
from nets.nerf import CascadeNerf
from nets.nerf import CascadeNerf2
from nets.nnerf import NNerf
from nets.nerf_depth import NerfDepth
from nets.bg_net import BgNet
from nets.oracle import Oracle
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17

BobYeah's avatar
sync    
BobYeah committed
18
19
20
21
22
23

class SphericalViewSynConfig(object):

    def __init__(self):
        self.name = 'default'

Nianchen Deng's avatar
sync    
Nianchen Deng committed
24
        self.COLOR = color.RGB
BobYeah's avatar
sync    
BobYeah committed
25
26
27
28

        # Net parameters
        self.NET_TYPE = 'msl'
        self.N_ENCODE_DIM = 10
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
        self.N_DIR_ENCODE = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
        self.NORMALIZE = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
        self.DEPTH_REF = False
BobYeah's avatar
sync    
BobYeah committed
32
33
34
        self.FC_PARAMS = {
            'nf': 256,
            'n_layers': 8,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
            'skips': [],
            'activation': 'relu'
BobYeah's avatar
sync    
BobYeah committed
37
38
39
40
41
42
43
44
45
        }
        self.SAMPLE_PARAMS = {
            'spherical': True,
            'depth_range': (1, 50),
            'n_samples': 32,
            'perturb_sample': True,
            'lindisp': True,
            'inverse_r': True,
        }
Nianchen Deng's avatar
sync    
Nianchen Deng committed
46
47
48
49
50
51
        self.NERF_FINE_NET_PARAMS = {
            'enable': False,
            'nf': 256,
            'n_layers': 8,
            'additional_samples': 64
        }
BobYeah's avatar
sync    
BobYeah committed
52
53
54

    def load(self, path):
        module_name = os.path.splitext(path)[0].replace('/', '.')
Nianchen Deng's avatar
Nianchen Deng committed
55
        config_module = importlib.import_module(module_name)
BobYeah's avatar
sync    
BobYeah committed
56
57
58
59
60
        config_module.update_config(self)
        self.name = module_name.split('.')[-1]

    def load_by_name(self, name):
        config_module = importlib.import_module(
Nianchen Deng's avatar
Nianchen Deng committed
61
            'configs.' + name)
BobYeah's avatar
sync    
BobYeah committed
62
63
64
65
        config_module.update_config(self)
        self.name = name

    def to_id(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
67
68
69
70
71
72
        net_type_id = f"{self.NET_TYPE}-{color.to_str(self.COLOR)}"
        encode_id = f"_e{self.N_ENCODE_DIM}"
        dir_encode_id = f"_ed{self.N_DIR_ENCODE}" if self.N_DIR_ENCODE else ''
        fc_id = f"_fc{self.FC_PARAMS['nf']}x{self.FC_PARAMS['n_layers']}"
        skip_id = "_skip%s" % ','.join(['%d' % val for val in self.FC_PARAMS['skips']]) \
            if len(self.FC_PARAMS['skips']) > 0 else ""
        act_id = f"_*{self.FC_PARAMS['activation']}" if self.FC_PARAMS['activation'] != 'relu' else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
73
        depth_id = "_d%.2f-%.2f" % (self.SAMPLE_PARAMS['depth_range'][0],
Nianchen Deng's avatar
Nianchen Deng committed
74
                                    self.SAMPLE_PARAMS['depth_range'][1])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
75
76
77
78
        samples_id = f"_s{self.SAMPLE_PARAMS['n_samples']}"
        ffc_id = f"_ffc{self.NERF_FINE_NET_PARAMS['nf']}x{self.NERF_FINE_NET_PARAMS['n_layers']}"
        fsamples_id = f"_fs{self.NERF_FINE_NET_PARAMS['additional_samples']}"
        fine_id = f"{ffc_id}{fsamples_id}" if self.NERF_FINE_NET_PARAMS['enable'] else ''
BobYeah's avatar
sync    
BobYeah committed
79
80
81
82
83
84
        neg_flags = '%s%s%s' % (
            'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
            'l' if not self.SAMPLE_PARAMS['lindisp'] else '',
            'i' if not self.SAMPLE_PARAMS['inverse_r'] else ''
        )
        neg_flags = '_~' + neg_flags if neg_flags != '' else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
85
86
        pos_flags = '%s%s' % (
            'n' if self.NORMALIZE else '',
Nianchen Deng's avatar
sync    
Nianchen Deng committed
87
            'd' if self.DEPTH_REF else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
88
89
        )
        pos_flags = '_+' + pos_flags if pos_flags != '' else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
90
91
92
93
94
        return "%s@%s%s%s%s%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, dir_encode_id,
                                              fc_id, skip_id, act_id,
                                              depth_id, samples_id,
                                              fine_id,
                                              neg_flags, pos_flags)
BobYeah's avatar
sync    
BobYeah committed
95
96

    def from_id(self, id: str):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
97
98
99
100
        id_splited = id.split('@')
        if len(id_splited) == 2:
            self.name = id_splited[0]
        segs = id_splited[-1].split('_')
BobYeah's avatar
sync    
BobYeah committed
101
        for i, seg in enumerate(segs):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
102
103
104
105
106
107
108
109
110
111
112
113
            if seg.startswith('ffc'):  # Full-connected network parameters
                self.NERF_FINE_NET_PARAMS['nf'], self.NERF_FINE_NET_PARAMS['n_layers'] = (
                    int(str) for str in seg[3:].split('x'))
                self.NERF_FINE_NET_PARAMS['enable'] = True
                continue
            if seg.startswith('fs'):  # Number of samples
                try:
                    self.NERF_FINE_NET_PARAMS['additional_samples'] = int(seg[2:])
                    self.NERF_FINE_NET_PARAMS['enable'] = True
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
114
115
116
            if seg.startswith('fc'):  # Full-connected network parameters
                self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
                    int(str) for str in seg[2:].split('x'))
BobYeah's avatar
sync    
BobYeah committed
117
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
118
119
120
            if seg.startswith('skip'):  # Skip connection
                self.FC_PARAMS['skips'] = [int(str)
                                           for str in seg[4:].split(',')]
BobYeah's avatar
sync    
BobYeah committed
121
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
122
123
124
125
126
127
128
            if seg.startswith('*'):  # Activation
                self.FC_PARAMS['activation'] = seg[1:]
                continue
            if seg.startswith('ed'):  # Encode direction
                self.N_DIR_ENCODE = int(seg[2:])
                if self.N_DIR_ENCODE == 0:
                    self.N_DIR_ENCODE = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
129
130
131
132
                continue
            if seg.startswith('e'):  # Encode
                self.N_ENCODE_DIM = int(seg[1:])
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
133
            if seg.startswith('d'):  # Depth range
Nianchen Deng's avatar
sync    
Nianchen Deng committed
134
135
136
137
138
139
                try:
                    self.SAMPLE_PARAMS['depth_range'] = tuple(
                        float(str) for str in seg[1:].split('-'))
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
140
            if seg.startswith('s'):  # Number of samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
141
142
143
144
145
                try:
                    self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
            if seg.startswith('~'):  # Negative flags
Nianchen Deng's avatar
sync    
Nianchen Deng committed
147
148
149
150
151
152
                if seg.find('p') >= 0:
                    self.SAMPLE_PARAMS['perturb_sample'] = False
                if seg.find('l') >= 0:
                    self.SAMPLE_PARAMS['lindisp'] = False
                if seg.find('i') >= 0:
                    self.SAMPLE_PARAMS['inverse_r'] = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
153
154
155
156
                if seg.find('n') >= 0:
                    self.NORMALIZE = False
                if seg.find('d') >= 0:
                    self.DEPTH_REF = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
157
158
                continue
            if seg.startswith('+'):  # Positive flags
Nianchen Deng's avatar
sync    
Nianchen Deng committed
159
160
161
162
163
164
                if seg.find('p') >= 0:
                    self.SAMPLE_PARAMS['perturb_sample'] = True
                if seg.find('l') >= 0:
                    self.SAMPLE_PARAMS['lindisp'] = True
                if seg.find('i') >= 0:
                    self.SAMPLE_PARAMS['inverse_r'] = True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
165
166
167
                if seg.find('n') >= 0:
                    self.NORMALIZE = True
                if seg.find('d') >= 0:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
168
                    self.DEPTH_REF = True
BobYeah's avatar
sync    
BobYeah committed
169
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
170
171
            if i == 0:  # NetType
                self.NET_TYPE, color_str = seg.split('-')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
172
                self.COLOR = color.from_str(color_str)
BobYeah's avatar
sync    
BobYeah committed
173
174
175
176
177

    def print(self):
        print('==== Config %s ====' % self.name)
        print('Net type: ', self.NET_TYPE)
        print('Encode dim: ', self.N_ENCODE_DIM)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
178
        print('Normalize: ', self.NORMALIZE)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
179
180
181
        print('Train with depth: ', self.DEPTH_REF)
        print('Support direction: ', False if self.N_DIR_ENCODE is None
              else f'encode to {self.N_DIR_ENCODE}')
BobYeah's avatar
sync    
BobYeah committed
182
183
        print('Full-connected network parameters:', self.FC_PARAMS)
        print('Sample parameters', self.SAMPLE_PARAMS)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
184
185
        if self.NERF_FINE_NET_PARAMS['enable']:
            print('NeRF fine network parameters', self.NERF_FINE_NET_PARAMS)
BobYeah's avatar
sync    
BobYeah committed
186
        print('==========================')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
187
188

    def create_net(self):
Nianchen Deng's avatar
Nianchen Deng committed
189
190
191
192
        if self.NET_TYPE == 'msl':
            return MslNet(fc_params=self.FC_PARAMS,
                          sampler_params=self.SAMPLE_PARAMS,
                          normalize_coord=self.NORMALIZE,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
193
194
195
196
197
198
199
                          c=self.COLOR,
                          encode_to_dim=self.N_ENCODE_DIM)
        if self.NET_TYPE == 'mslray':
            return MslRay(fc_params=self.FC_PARAMS,
                          sampler_params=self.SAMPLE_PARAMS,
                          normalize_coord=self.NORMALIZE,
                          c=self.COLOR,
Nianchen Deng's avatar
Nianchen Deng committed
200
                          encode_to_dim=self.N_ENCODE_DIM)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        if self.NET_TYPE == 'mslfast':
            return MslFast(fc_params=self.FC_PARAMS,
                           sampler_params=self.SAMPLE_PARAMS,
                           normalize_coord=self.NORMALIZE,
                           c=self.COLOR,
                           encode_to_dim=self.N_ENCODE_DIM)
        if self.NET_TYPE == 'msl2fast':
            return MslFast(fc_params=self.FC_PARAMS,
                           sampler_params=self.SAMPLE_PARAMS,
                           normalize_coord=self.NORMALIZE,
                           c=self.COLOR,
                           encode_to_dim=self.N_ENCODE_DIM,
                           include_r=True)
        if self.NET_TYPE == 'nerf':
            return CascadeNerf(fc_params=self.FC_PARAMS,
                               sampler_params=self.SAMPLE_PARAMS,
                               fine_params=self.NERF_FINE_NET_PARAMS,
                               normalize_coord=self.NORMALIZE,
                               c=self.COLOR,
                               coord_encode=self.N_ENCODE_DIM,
                               dir_encode=self.N_DIR_ENCODE)
        if self.NET_TYPE == 'nerf2':
            return CascadeNerf2(fc_params=self.FC_PARAMS,
                                sampler_params=self.SAMPLE_PARAMS,
                                normalize_coord=self.NORMALIZE,
                                c=self.COLOR,
                                coord_encode=self.N_ENCODE_DIM,
                                dir_encode=self.N_DIR_ENCODE)
        if self.NET_TYPE == 'nerfbg':
            return CascadeNerf(fc_params=self.FC_PARAMS,
                               sampler_params=self.SAMPLE_PARAMS,
                               fine_params=self.NERF_FINE_NET_PARAMS,
                               normalize_coord=self.NORMALIZE,
                               c=self.COLOR,
                               coord_encode=self.N_ENCODE_DIM,
                               dir_encode=self.N_DIR_ENCODE,
                               bg_layer=True)
        if self.NET_TYPE == 'bgnet':
            return BgNet(fc_params=self.FC_PARAMS,
                         encode=self.N_ENCODE_DIM,
                         c=self.COLOR)
        if self.NET_TYPE.startswith('oracle'):
            return Oracle(fc_params=self.FC_PARAMS,
                          sampler_params=self.SAMPLE_PARAMS,
                          normalize_coord=self.NORMALIZE,
                          coord_encode=self.N_ENCODE_DIM,
                          out_activation=self.NET_TYPE[6:] if len(self.NET_TYPE) > 6 else 'sigmoid')
        if self.NET_TYPE.startswith('cnerf'):
            return CNerf(fc_params=self.FC_PARAMS,
                         sampler_params=self.SAMPLE_PARAMS,
                         c=self.COLOR,
                         coord_encode=self.N_ENCODE_DIM,
                         n_bins=int(self.NET_TYPE[5:] if len(self.NET_TYPE) > 5 else 128))
        if self.NET_TYPE.startswith('dnerfa'):
            return NerfDepth(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             n_bins=int(self.NET_TYPE[7:] if len(self.NET_TYPE) > 7 else 128),
                             include_neighbor_bins=False)
        if self.NET_TYPE.startswith('dnerf'):
            return NerfDepth(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             n_bins=int(self.NET_TYPE[6:] if len(self.NET_TYPE) > 6 else 128))
        if self.NET_TYPE.startswith('nnerf'):
            return NNerf(fc_params=self.FC_PARAMS,
                         sampler_params=self.SAMPLE_PARAMS,
                         n_nets=int(self.NET_TYPE[5:] if len(self.NET_TYPE) > 5 else 1),
                         normalize_coord=self.NORMALIZE,
                         c=self.COLOR,
                         coord_encode=self.N_ENCODE_DIM,
                         dir_encode=self.N_DIR_ENCODE)
        if self.NET_TYPE.startswith('snerffastx'):
            return SnerfFast(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             n_parts=int(self.NET_TYPE[10:] if len(self.NET_TYPE) > 10 else 1),
                             normalize_coord=self.NORMALIZE,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             dir_encode=self.N_DIR_ENCODE,
                             multiple_net=False)
        if self.NET_TYPE.startswith('snerffast'):
            return SnerfFast(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             n_parts=int(self.NET_TYPE[9:] if len(self.NET_TYPE) > 9 else 1),
                             normalize_coord=self.NORMALIZE,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             dir_encode=self.N_DIR_ENCODE)
Nianchen Deng's avatar
Nianchen Deng committed
292
293
294
295
296
297
298
299
        if self.NET_TYPE.startswith('nmsl'):
            n_nets = int(self.NET_TYPE[4:]) if len(self.NET_TYPE) > 4 else 2
            if self.SAMPLE_PARAMS['n_samples'] % n_nets != 0:
                raise ValueError('n_samples should be divisible by n_nets')
            return NewMslNet(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             normalize_coord=self.NORMALIZE,
                             n_nets=n_nets,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
300
                             c=self.COLOR,
Nianchen Deng's avatar
Nianchen Deng committed
301
                             encode_to_dim=self.N_ENCODE_DIM)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
302
        raise ValueError('Invalid net type')