import torch import torch.nn as nn from modules import * from utils import color class NSVF(nn.Module): def __init__(self, fc_params, sampler_params, *, c: int = color.RGB, n_featdim: int = 32, n_pos_encode: int = 0, n_dir_encode: int = None, **kwargs): """ Initialize a NSVF model :param fc_params `dict`: parameters for full-connection network :param sampler_params `dict`: parameters for sampler :param c `int`: color mode :param n_pos_encode `int`: encode position to number of dimensions :param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored :param coarse_net `NerfUnit`: optional coarse net """ super().__init__() self.color = c self.coord_chns = n_featdim self.color_chns = color.chns(self.color) self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns) if n_dir_encode is not None: self.dir_chns = 3 self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns) else: self.dir_chns = 0 self.dir_encoder = None self.core = NerfCore(coord_chns=self.pos_encoder.out_dim, density_chns=1, color_chns=self.color_chns, core_nf=fc_params['nf'], core_layers=fc_params['n_layers'], dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0, dir_nf=fc_params['nf'] // 2, activation=fc_params['activation'], skips=fc_params['skips']) self.space = OctTreeSpace() sampler_params['space'] = self.space self.sampler = VoxelSampler(**sampler_params) self.rendering = VolumnRenderer() def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, ret_depth=False, debug=False) -> torch.Tensor: """ rays -> colors :param rays_o `Tensor(B, 3)`: rays' origin :param rays_d `Tensor(B, 3)`: rays' direction :param prev_ret `Mapping`: :param ret_depth `bool`: :return: `Tensor(B, C)``, inferred images/pixels """ feats, dirs, z_s, dz_s = self.sampler(rays_o, rays_d) feats_encoded = self.pos_encoder(feats) dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, z_s.size(-1), -1) \ if self.dir_encoder is not None else None colors, densities = self.core(feats_encoded, dirs_encoded) ret = self.rendering(colors, densities[..., 0], z_s, dz_s, ret_depth=ret_depth, debug=debug) return ret