spherical_view_syn.py 14.4 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import os
import importlib
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
from utils.constants import *
from utils import color
Nianchen Deng's avatar
Nianchen Deng committed
5
6
from nets.msl_net import MslNet
from nets.msl_net_new import NewMslNet
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
from nets.msl_ray import MslRay
from nets.msl_fast import MslFast
from nets.snerf_fast import SnerfFast
10
from nets.snerf_fast_new import SnerfFastNew
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
12
13
14
15
16
17
from nets.cnerf_v3 import CNerf
from nets.nerf import CascadeNerf
from nets.nerf import CascadeNerf2
from nets.nnerf import NNerf
from nets.nerf_depth import NerfDepth
from nets.bg_net import BgNet
from nets.oracle import Oracle
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18

BobYeah's avatar
sync    
BobYeah committed
19
20
21
22
23
24

class SphericalViewSynConfig(object):

    def __init__(self):
        self.name = 'default'

Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
        self.COLOR = color.RGB
BobYeah's avatar
sync    
BobYeah committed
26
27
28
29

        # Net parameters
        self.NET_TYPE = 'msl'
        self.N_ENCODE_DIM = 10
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
        self.N_DIR_ENCODE = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
        self.NORMALIZE = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
        self.DEPTH_REF = False
BobYeah's avatar
sync    
BobYeah committed
33
34
35
        self.FC_PARAMS = {
            'nf': 256,
            'n_layers': 8,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
36
37
            'skips': [],
            'activation': 'relu'
BobYeah's avatar
sync    
BobYeah committed
38
39
40
41
42
43
44
45
46
        }
        self.SAMPLE_PARAMS = {
            'spherical': True,
            'depth_range': (1, 50),
            'n_samples': 32,
            'perturb_sample': True,
            'lindisp': True,
            'inverse_r': True,
        }
Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
48
49
50
51
52
        self.NERF_FINE_NET_PARAMS = {
            'enable': False,
            'nf': 256,
            'n_layers': 8,
            'additional_samples': 64
        }
BobYeah's avatar
sync    
BobYeah committed
53
54
55

    def load(self, path):
        module_name = os.path.splitext(path)[0].replace('/', '.')
Nianchen Deng's avatar
Nianchen Deng committed
56
        config_module = importlib.import_module(module_name)
BobYeah's avatar
sync    
BobYeah committed
57
58
59
60
61
        config_module.update_config(self)
        self.name = module_name.split('.')[-1]

    def load_by_name(self, name):
        config_module = importlib.import_module(
Nianchen Deng's avatar
Nianchen Deng committed
62
            'configs.' + name)
BobYeah's avatar
sync    
BobYeah committed
63
64
65
66
        config_module.update_config(self)
        self.name = name

    def to_id(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
67
68
69
70
71
72
73
        net_type_id = f"{self.NET_TYPE}-{color.to_str(self.COLOR)}"
        encode_id = f"_e{self.N_ENCODE_DIM}"
        dir_encode_id = f"_ed{self.N_DIR_ENCODE}" if self.N_DIR_ENCODE else ''
        fc_id = f"_fc{self.FC_PARAMS['nf']}x{self.FC_PARAMS['n_layers']}"
        skip_id = "_skip%s" % ','.join(['%d' % val for val in self.FC_PARAMS['skips']]) \
            if len(self.FC_PARAMS['skips']) > 0 else ""
        act_id = f"_*{self.FC_PARAMS['activation']}" if self.FC_PARAMS['activation'] != 'relu' else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
74
        depth_id = "_d%.2f-%.2f" % (self.SAMPLE_PARAMS['depth_range'][0],
Nianchen Deng's avatar
Nianchen Deng committed
75
                                    self.SAMPLE_PARAMS['depth_range'][1])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
78
79
        samples_id = f"_s{self.SAMPLE_PARAMS['n_samples']}"
        ffc_id = f"_ffc{self.NERF_FINE_NET_PARAMS['nf']}x{self.NERF_FINE_NET_PARAMS['n_layers']}"
        fsamples_id = f"_fs{self.NERF_FINE_NET_PARAMS['additional_samples']}"
        fine_id = f"{ffc_id}{fsamples_id}" if self.NERF_FINE_NET_PARAMS['enable'] else ''
BobYeah's avatar
sync    
BobYeah committed
80
81
82
83
84
85
        neg_flags = '%s%s%s' % (
            'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
            'l' if not self.SAMPLE_PARAMS['lindisp'] else '',
            'i' if not self.SAMPLE_PARAMS['inverse_r'] else ''
        )
        neg_flags = '_~' + neg_flags if neg_flags != '' else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
86
87
        pos_flags = '%s%s' % (
            'n' if self.NORMALIZE else '',
Nianchen Deng's avatar
sync    
Nianchen Deng committed
88
            'd' if self.DEPTH_REF else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
89
90
        )
        pos_flags = '_+' + pos_flags if pos_flags != '' else ''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
92
93
94
95
        return "%s@%s%s%s%s%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, dir_encode_id,
                                              fc_id, skip_id, act_id,
                                              depth_id, samples_id,
                                              fine_id,
                                              neg_flags, pos_flags)
BobYeah's avatar
sync    
BobYeah committed
96
97

    def from_id(self, id: str):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
98
99
100
101
        id_splited = id.split('@')
        if len(id_splited) == 2:
            self.name = id_splited[0]
        segs = id_splited[-1].split('_')
BobYeah's avatar
sync    
BobYeah committed
102
        for i, seg in enumerate(segs):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
103
104
105
106
107
108
109
110
111
112
113
114
            if seg.startswith('ffc'):  # Full-connected network parameters
                self.NERF_FINE_NET_PARAMS['nf'], self.NERF_FINE_NET_PARAMS['n_layers'] = (
                    int(str) for str in seg[3:].split('x'))
                self.NERF_FINE_NET_PARAMS['enable'] = True
                continue
            if seg.startswith('fs'):  # Number of samples
                try:
                    self.NERF_FINE_NET_PARAMS['additional_samples'] = int(seg[2:])
                    self.NERF_FINE_NET_PARAMS['enable'] = True
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
115
116
117
            if seg.startswith('fc'):  # Full-connected network parameters
                self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
                    int(str) for str in seg[2:].split('x'))
BobYeah's avatar
sync    
BobYeah committed
118
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
119
120
121
            if seg.startswith('skip'):  # Skip connection
                self.FC_PARAMS['skips'] = [int(str)
                                           for str in seg[4:].split(',')]
BobYeah's avatar
sync    
BobYeah committed
122
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
123
124
125
126
127
128
129
            if seg.startswith('*'):  # Activation
                self.FC_PARAMS['activation'] = seg[1:]
                continue
            if seg.startswith('ed'):  # Encode direction
                self.N_DIR_ENCODE = int(seg[2:])
                if self.N_DIR_ENCODE == 0:
                    self.N_DIR_ENCODE = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
130
131
132
133
                continue
            if seg.startswith('e'):  # Encode
                self.N_ENCODE_DIM = int(seg[1:])
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
134
            if seg.startswith('d'):  # Depth range
Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
136
137
138
139
140
                try:
                    self.SAMPLE_PARAMS['depth_range'] = tuple(
                        float(str) for str in seg[1:].split('-'))
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
141
            if seg.startswith('s'):  # Number of samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
142
143
144
145
146
                try:
                    self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
                    continue
                except ValueError:
                    pass
Nianchen Deng's avatar
sync    
Nianchen Deng committed
147
            if seg.startswith('~'):  # Negative flags
Nianchen Deng's avatar
sync    
Nianchen Deng committed
148
149
150
151
152
153
                if seg.find('p') >= 0:
                    self.SAMPLE_PARAMS['perturb_sample'] = False
                if seg.find('l') >= 0:
                    self.SAMPLE_PARAMS['lindisp'] = False
                if seg.find('i') >= 0:
                    self.SAMPLE_PARAMS['inverse_r'] = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
154
155
156
157
                if seg.find('n') >= 0:
                    self.NORMALIZE = False
                if seg.find('d') >= 0:
                    self.DEPTH_REF = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
158
159
                continue
            if seg.startswith('+'):  # Positive flags
Nianchen Deng's avatar
sync    
Nianchen Deng committed
160
161
162
163
164
165
                if seg.find('p') >= 0:
                    self.SAMPLE_PARAMS['perturb_sample'] = True
                if seg.find('l') >= 0:
                    self.SAMPLE_PARAMS['lindisp'] = True
                if seg.find('i') >= 0:
                    self.SAMPLE_PARAMS['inverse_r'] = True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
166
167
168
                if seg.find('n') >= 0:
                    self.NORMALIZE = True
                if seg.find('d') >= 0:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
169
                    self.DEPTH_REF = True
BobYeah's avatar
sync    
BobYeah committed
170
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
171
172
            if i == 0:  # NetType
                self.NET_TYPE, color_str = seg.split('-')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
173
                self.COLOR = color.from_str(color_str)
BobYeah's avatar
sync    
BobYeah committed
174
175
176
177
178

    def print(self):
        print('==== Config %s ====' % self.name)
        print('Net type: ', self.NET_TYPE)
        print('Encode dim: ', self.N_ENCODE_DIM)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
179
        print('Normalize: ', self.NORMALIZE)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
180
181
182
        print('Train with depth: ', self.DEPTH_REF)
        print('Support direction: ', False if self.N_DIR_ENCODE is None
              else f'encode to {self.N_DIR_ENCODE}')
BobYeah's avatar
sync    
BobYeah committed
183
184
        print('Full-connected network parameters:', self.FC_PARAMS)
        print('Sample parameters', self.SAMPLE_PARAMS)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
185
186
        if self.NERF_FINE_NET_PARAMS['enable']:
            print('NeRF fine network parameters', self.NERF_FINE_NET_PARAMS)
BobYeah's avatar
sync    
BobYeah committed
187
        print('==========================')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
188
189

    def create_net(self):
Nianchen Deng's avatar
Nianchen Deng committed
190
191
192
193
        if self.NET_TYPE == 'msl':
            return MslNet(fc_params=self.FC_PARAMS,
                          sampler_params=self.SAMPLE_PARAMS,
                          normalize_coord=self.NORMALIZE,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
194
195
196
197
198
199
200
                          c=self.COLOR,
                          encode_to_dim=self.N_ENCODE_DIM)
        if self.NET_TYPE == 'mslray':
            return MslRay(fc_params=self.FC_PARAMS,
                          sampler_params=self.SAMPLE_PARAMS,
                          normalize_coord=self.NORMALIZE,
                          c=self.COLOR,
Nianchen Deng's avatar
Nianchen Deng committed
201
                          encode_to_dim=self.N_ENCODE_DIM)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        if self.NET_TYPE == 'mslfast':
            return MslFast(fc_params=self.FC_PARAMS,
                           sampler_params=self.SAMPLE_PARAMS,
                           normalize_coord=self.NORMALIZE,
                           c=self.COLOR,
                           encode_to_dim=self.N_ENCODE_DIM)
        if self.NET_TYPE == 'msl2fast':
            return MslFast(fc_params=self.FC_PARAMS,
                           sampler_params=self.SAMPLE_PARAMS,
                           normalize_coord=self.NORMALIZE,
                           c=self.COLOR,
                           encode_to_dim=self.N_ENCODE_DIM,
                           include_r=True)
        if self.NET_TYPE == 'nerf':
            return CascadeNerf(fc_params=self.FC_PARAMS,
                               sampler_params=self.SAMPLE_PARAMS,
                               fine_params=self.NERF_FINE_NET_PARAMS,
                               normalize_coord=self.NORMALIZE,
                               c=self.COLOR,
                               coord_encode=self.N_ENCODE_DIM,
                               dir_encode=self.N_DIR_ENCODE)
        if self.NET_TYPE == 'nerf2':
            return CascadeNerf2(fc_params=self.FC_PARAMS,
                                sampler_params=self.SAMPLE_PARAMS,
                                normalize_coord=self.NORMALIZE,
                                c=self.COLOR,
                                coord_encode=self.N_ENCODE_DIM,
                                dir_encode=self.N_DIR_ENCODE)
        if self.NET_TYPE == 'nerfbg':
            return CascadeNerf(fc_params=self.FC_PARAMS,
                               sampler_params=self.SAMPLE_PARAMS,
                               fine_params=self.NERF_FINE_NET_PARAMS,
                               normalize_coord=self.NORMALIZE,
                               c=self.COLOR,
                               coord_encode=self.N_ENCODE_DIM,
                               dir_encode=self.N_DIR_ENCODE,
                               bg_layer=True)
        if self.NET_TYPE == 'bgnet':
            return BgNet(fc_params=self.FC_PARAMS,
                         encode=self.N_ENCODE_DIM,
                         c=self.COLOR)
        if self.NET_TYPE.startswith('oracle'):
            return Oracle(fc_params=self.FC_PARAMS,
                          sampler_params=self.SAMPLE_PARAMS,
                          normalize_coord=self.NORMALIZE,
                          coord_encode=self.N_ENCODE_DIM,
                          out_activation=self.NET_TYPE[6:] if len(self.NET_TYPE) > 6 else 'sigmoid')
        if self.NET_TYPE.startswith('cnerf'):
            return CNerf(fc_params=self.FC_PARAMS,
                         sampler_params=self.SAMPLE_PARAMS,
                         c=self.COLOR,
                         coord_encode=self.N_ENCODE_DIM,
                         n_bins=int(self.NET_TYPE[5:] if len(self.NET_TYPE) > 5 else 128))
        if self.NET_TYPE.startswith('dnerfa'):
            return NerfDepth(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             n_bins=int(self.NET_TYPE[7:] if len(self.NET_TYPE) > 7 else 128),
                             include_neighbor_bins=False)
        if self.NET_TYPE.startswith('dnerf'):
            return NerfDepth(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             n_bins=int(self.NET_TYPE[6:] if len(self.NET_TYPE) > 6 else 128))
        if self.NET_TYPE.startswith('nnerf'):
            return NNerf(fc_params=self.FC_PARAMS,
                         sampler_params=self.SAMPLE_PARAMS,
                         n_nets=int(self.NET_TYPE[5:] if len(self.NET_TYPE) > 5 else 1),
                         normalize_coord=self.NORMALIZE,
                         c=self.COLOR,
                         coord_encode=self.N_ENCODE_DIM,
                         dir_encode=self.N_DIR_ENCODE)
        if self.NET_TYPE.startswith('snerffastx'):
            return SnerfFast(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             n_parts=int(self.NET_TYPE[10:] if len(self.NET_TYPE) > 10 else 1),
                             normalize_coord=self.NORMALIZE,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             dir_encode=self.N_DIR_ENCODE,
                             multiple_net=False)
285
286
287
288
289
290
291
        if self.NET_TYPE.startswith('snerffastnew'):
            return SnerfFastNew(fc_params=self.FC_PARAMS,
                                sampler_params=self.SAMPLE_PARAMS,
                                n_parts=int(self.NET_TYPE[12:] if len(self.NET_TYPE) > 12 else 1),
                                normalize_coord=self.NORMALIZE,
                                c=self.COLOR,
                                coord_encode=self.N_ENCODE_DIM)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
292
293
294
295
296
297
298
299
        if self.NET_TYPE.startswith('snerffast'):
            return SnerfFast(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             n_parts=int(self.NET_TYPE[9:] if len(self.NET_TYPE) > 9 else 1),
                             normalize_coord=self.NORMALIZE,
                             c=self.COLOR,
                             coord_encode=self.N_ENCODE_DIM,
                             dir_encode=self.N_DIR_ENCODE)
Nianchen Deng's avatar
Nianchen Deng committed
300
301
302
303
304
305
306
307
        if self.NET_TYPE.startswith('nmsl'):
            n_nets = int(self.NET_TYPE[4:]) if len(self.NET_TYPE) > 4 else 2
            if self.SAMPLE_PARAMS['n_samples'] % n_nets != 0:
                raise ValueError('n_samples should be divisible by n_nets')
            return NewMslNet(fc_params=self.FC_PARAMS,
                             sampler_params=self.SAMPLE_PARAMS,
                             normalize_coord=self.NORMALIZE,
                             n_nets=n_nets,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
308
                             c=self.COLOR,
Nianchen Deng's avatar
Nianchen Deng committed
309
                             encode_to_dim=self.N_ENCODE_DIM)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
310
        raise ValueError('Invalid net type')