spherical_view_syn.py 3.96 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
import os
import importlib
from os.path import join
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
from ..my import color_mode
BobYeah's avatar
sync    
BobYeah committed
5
6
7
8
9
10
11


class SphericalViewSynConfig(object):

    def __init__(self):
        self.name = 'default'

Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
        self.COLOR = color_mode.RGB
BobYeah's avatar
sync    
BobYeah committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

        # Net parameters
        self.NET_TYPE = 'msl'
        self.N_ENCODE_DIM = 10
        self.FC_PARAMS = {
            'nf': 256,
            'n_layers': 8,
            'skips': []
        }
        self.SAMPLE_PARAMS = {
            'spherical': True,
            'depth_range': (1, 50),
            'n_samples': 32,
            'perturb_sample': True,
            'lindisp': True,
            'inverse_r': True,
        }

    def load(self, path):
        module_name = os.path.splitext(path)[0].replace('/', '.')
        config_module = importlib.import_module(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
34
            'deep_view_syn.' + module_name)
BobYeah's avatar
sync    
BobYeah committed
35
36
37
38
39
        config_module.update_config(self)
        self.name = module_name.split('.')[-1]

    def load_by_name(self, name):
        config_module = importlib.import_module(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
            'deep_view_syn.configs.' + name)
BobYeah's avatar
sync    
BobYeah committed
41
42
43
44
        config_module.update_config(self)
        self.name = name

    def to_id(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
        net_type_id = "%s-%s" % (self.NET_TYPE, color_mode.to_str(self.COLOR))
BobYeah's avatar
sync    
BobYeah committed
46
47
48
49
50
51
52
        encode_id = "_e%d" % self.N_ENCODE_DIM
        fc_id = "_fc%dx%d" % (self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'])
        skip_id = "_skip%s" % ','.join([
            '%d' % val
            for val in self.FC_PARAMS['skips']
        ]) if len(self.FC_PARAMS['skips']) > 0 else ""
        depth_id = "_d%d-%d" % (self.SAMPLE_PARAMS['depth_range'][0],
Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
                                self.SAMPLE_PARAMS['depth_range'][1])
BobYeah's avatar
sync    
BobYeah committed
54
55
56
57
58
59
60
61
62
63
        samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples']
        neg_flags = '%s%s%s' % (
            'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
            'l' if not self.SAMPLE_PARAMS['lindisp'] else '',
            'i' if not self.SAMPLE_PARAMS['inverse_r'] else ''
        )
        neg_flags = '_~' + neg_flags if neg_flags != '' else ''
        return "%s@%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, neg_flags)

    def from_id(self, id: str):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
65
66
67
        id_splited = id.split('@')
        if len(id_splited) == 2:
            self.name = id_splited[0]
        segs = id_splited[-1].split('_')
BobYeah's avatar
sync    
BobYeah committed
68
        for i, seg in enumerate(segs):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
69
            if seg.startswith('e'):  # Encode
BobYeah's avatar
sync    
BobYeah committed
70
71
                self.N_ENCODE_DIM = int(seg[1:])
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
72
73
74
            if seg.startswith('fc'):  # Full-connected network parameters
                self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
                    int(str) for str in seg[2:].split('x'))
BobYeah's avatar
sync    
BobYeah committed
75
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
78
            if seg.startswith('skip'):  # Skip connection
                self.FC_PARAMS['skips'] = [int(str)
                                           for str in seg[4:].split(',')]
BobYeah's avatar
sync    
BobYeah committed
79
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
80
81
82
            if seg.startswith('d'):  # Depth range
                self.SAMPLE_PARAMS['depth_range'] = tuple(
                    float(str) for str in seg[1:].split('-'))
BobYeah's avatar
sync    
BobYeah committed
83
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
84
            if seg.startswith('s'):  # Number of samples
BobYeah's avatar
sync    
BobYeah committed
85
86
                self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
87
            if seg.startswith('~'):  # Negative flags
BobYeah's avatar
sync    
BobYeah committed
88
89
90
91
                self.SAMPLE_PARAMS['perturb_sample'] = (seg.find('p') < 0)
                self.SAMPLE_PARAMS['lindisp'] = (seg.find('l') < 0)
                self.SAMPLE_PARAMS['inverse_r'] = (seg.find('i') < 0)
                continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
93
94
95
            if i == 0:  # NetType
                self.NET_TYPE, color_str = seg.split('-')
                self.COLOR = color_mode.from_str(color_str)
                continue
BobYeah's avatar
sync    
BobYeah committed
96
97
98
99
100
101
102
103

    def print(self):
        print('==== Config %s ====' % self.name)
        print('Net type: ', self.NET_TYPE)
        print('Encode dim: ', self.N_ENCODE_DIM)
        print('Full-connected network parameters:', self.FC_PARAMS)
        print('Sample parameters', self.SAMPLE_PARAMS)
        print('==========================')