import os import importlib import re from utils import color from nets.snerf_fast import SnerfFast from nets.snerf import Snerf from nets.nerf import Nerf from nets.nerf_depth import NerfDepth from nets.bg_net import BgNet from nets.oracle import Oracle class SphericalViewSynConfig(object): def __init__(self, id=None): self.name = 'default' self.c = color.RGB self.net = 'nerf' self.encode_x = 10 self.encode_d = None self.depth_ref = False self.fc = { 'nf': 256, 'n_layers': 8, 'skips': [], 'act': 'relu' } self.sa = { 'sample_range': (1, 50), 'n_samples': 32, 'perturb_sample': True, 'lindisp': True } self.nerf_coarse = None if id is not None: self.from_id(id) def load(self, path): module_name = os.path.splitext(path)[0].replace('/', '.') config_module = importlib.import_module(module_name) config_module.update_config(self) self.name = module_name.split('.')[-1] def load_by_name(self, name): config_module = importlib.import_module( 'configs.' + name) config_module.update_config(self) self.name = name def split_net_type(self): match_res = re.match(r'([a-z]+)(\d*)', self.net, re.I) return match_res.group(1), int(match_res.group(2)) if match_res.group(2) != '' else None def to_id(self): id = f"{self.name}@{self.net}" if self.c != color.RGB: id += f"-{color.to_str(self.c)}" id += f"_e{self.encode_x}" if self.encode_d is not None: id += f"_ed{self.encode_d}" id += f"_fc{self.fc['nf']}x{self.fc['n_layers']}" if len(self.fc['skips']) > 0: id += "_^%s" % ','.join([f'{val}' for val in self.fc['skips']]) if self.fc['act'] != 'relu': id += f"_*{self.fc['act']}" id += "_d{0:.2f}-{1:.2f}".format(*self.sa['sample_range']) id += f"_s{self.sa['n_samples']}" if self.nerf_coarse is not None: id += f"_co{self.nerf_coarse['nf']}x{self.nerf_coarse['n_layers']}x{self.nerf_coarse['n_samples']}" neg_flags = '%s%s' % ( 'p' if not self.sa['perturb_sample'] else '', 'l' if not self.sa['lindisp'] else '' ) if neg_flags: id += f'_~{neg_flags}' pos_flags = '%s' % ( 'd' if self.depth_ref else '' ) if pos_flags: id += f'_+{pos_flags}' return id def from_id(self, id: str): id_splited = id.split('@') if len(id_splited) == 2: self.name = id_splited[0] segs = id_splited[-1].split('_') for i, seg in enumerate(segs): if not seg: continue m = re.match(r"^co(\d+)x(\d+)x(\d+)$", seg) if m is not None: # Coarse net parameters self.nerf_coarse = { 'nf': int(m.group(1)), 'n_layers': int(m.group(2)), 'n_samples': int(m.group(3)) } continue m = re.match(r"^fc(\d+)x(\d+)$", seg) if m is not None: # Full-connected network parameters self.fc['nf'] = int(m.group(1)) self.fc['n_layers'] = int(m.group(2)) continue if seg.startswith('^'): # Skip connection self.fc['skips'] = [int(str) for str in seg[4:].split(',')] continue if seg.startswith('*'): # Activation self.fc['act'] = seg[1:] continue if seg.startswith('ed'): # Encode direction self.encode_d = int(seg[2:]) if self.encode_d == 0: self.encode_d = None continue if seg.startswith('e'): # Encode self.encode_x = int(seg[1:]) continue if seg.startswith('d'): # Depth range try: self.sa['sample_range'] = tuple( float(str) for str in seg[1:].split('-')) continue except ValueError: pass if seg.startswith('s'): # Number of samples try: self.sa['n_samples'] = int(seg[1:]) continue except ValueError: pass if seg.startswith('~'): # Negative flags if seg.find('p') >= 0: self.sa['perturb_sample'] = False if seg.find('l') >= 0: self.sa['lindisp'] = False if seg.find('d') >= 0: self.depth_ref = False continue if seg.startswith('+'): # Positive flags if seg.find('p') >= 0: self.sa['perturb_sample'] = True if seg.find('l') >= 0: self.sa['lindisp'] = True if seg.find('d') >= 0: self.depth_ref = True continue if i == 0: # Net & color seg_splited = seg.split('-') if len(seg_splited) == 1: self.net = seg self.c = color.RGB else: self.net = seg_splited[0] self.c = color.from_str(seg_splited[1]) def print(self): print('==== Config %s ====' % self.name) print('Net type: ', self.net) print('Encode dim: ', self.encode_x) print('Train with depth: ', self.depth_ref) print('Support direction: ', False if self.encode_d is None else f'encode to {self.encode_d}') print('Full-connected network parameters:', self.fc) print('Sample parameters', self.sa) if self.nerf_coarse: print('NeRF fine network parameters', self.nerf_coarse) print('==========================') def create_net(self): net, multiple = self.split_net_type() if net == 'nerf': if self.nerf_coarse: coarse_fc = self.fc.copy() coarse_fc.update({ 'nf': self.nerf_coarse['nf'], 'n_layers': self.nerf_coarse['n_layers'] }) coarse_sa = self.sa.copy() coarse_sa.update({ 'n_samples': self.nerf_coarse['n_samples'] }) coarse_net = Nerf(coarse_fc, coarse_sa, c=self.c, pos_encode=self.encode_x, dir_encode=self.encode_d) else: coarse_net = None return Nerf(core_params=self.fc, sampler_params=self.sa, fine_params=self.nerf_coarse, c=self.c, pos_encode=self.encode_x, dir_encode=self.encode_d, coarse_net=coarse_net) if net == 'bgnet': return BgNet(core_params=self.fc, encode=self.encode_x, c=self.c) if net.startswith('oracle'): return Oracle(core_params=self.fc, sampler_params=self.sa, pos_encode=self.encode_x, out_activation=self.net[6:] if len(self.net) > 6 else 'sigmoid') if net == 'dnerfa': return NerfDepth(core_params=self.fc, sampler_params=self.sa, c=self.c, pos_encode=self.encode_x, n_bins=multiple or 128, include_neighbor_bins=False) if net == 'dnerf': return NerfDepth(core_params=self.fc, sampler_params=self.sa, c=self.c, pos_encode=self.encode_x, n_bins=multiple or 128) if net == 'snerf': return Snerf(core_params=self.fc, sampler_params=self.sa, n_parts=multiple or 1, c=self.c, pos_encode=self.encode_x, dir_encode=self.encode_d) if net == 'snerffast': return SnerfFast(core_params=self.fc, sampler_params=self.sa, n_parts=multiple or 1, c=self.c, pos_encode=self.encode_x, dir_encode=self.encode_d) raise ValueError(f'Invalid net type: {net} - {multiple}')