Commit 5699ccbf authored by Nianchen Deng's avatar Nianchen Deng
Browse files


parent 338ae906
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn.modules.linear import Identity
from utils.constants import *
from .generic import *
from .sampler import *
from .input_encoder import *
from .renderer import *
class NerfCore(nn.Module):
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
dir_chns=0, dir_nf=0, activation='relu', skips=[]):
self.core = FcNet(in_chns=coord_chns, out_chns=0, nf=core_nf, n_layers=core_layers,
skips=skips, activation=activation)
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
if color_chns == 0:
self.feature_out = None
self.color_out = None
elif dir_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf)
self.color_out = nn.Sequential(
FcLayer(core_nf + dir_chns, dir_nf, activation),
FcLayer(dir_nf, color_chns)
self.feature_out = Identity()
self.color_out = FcLayer(core_nf, color_chns)
def forward(self, coord: torch.Tensor, dir: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
core_output = self.core(coord)
density = self.density_out(core_output) if self.density_out is not None else None
if self.color_out is None:
color = None
feature = self.feature_out(core_output)
if dir is not None:
feature =[feature, dir], dim=-1)
color = torch.sigmoid(self.color_out(feature))
return color, density
\ No newline at end of file
from .space import *
from .core import *
\ No newline at end of file
from .generic import *
from typing import Dict
class NerfCore(nn.Module):
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
dir_chns=0, dir_nf=0, act='relu', skips=[]):
self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers,
skips=skips, act=act)
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
if color_chns == 0:
self.feature_out = None
self.color_out = None
elif dir_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf)
self.color_out = nn.Sequential(
FcLayer(core_nf + dir_chns, dir_nf, act),
FcLayer(dir_nf, color_chns)
self.feature_out = torch.nn.Identity()
self.color_out = FcLayer(core_nf, color_chns)
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]:
ret = {}
core_output = self.core(x)
if 'density' in outputs:
ret['density'] = torch.relu(self.density_out(core_output)) \
if self.density_out is not None else None
if 'color' in outputs:
if self.color_out is None:
ret['color'] = None
feature = self.feature_out(core_output)
if dir is not None:
feature =[feature, d], dim=-1)
ret['color'] = self.color_out(feature).sigmoid()
for key in outputs:
if key == 'density' or key == 'color':
ret[key] = None
return ret
class NerfAdvCore(nn.Module):
def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
density_net_params: dict, color_net_params: dict,
specular_net_params: dict = None,
Create a NeRF-Adv Core Net.
Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
Other parameters will be properly set automatically.
:param x_chns `int`: the channels of input "position"
:param d_chns `int`: the channels of input "direction"
:param density_chns `int`: the channels of output "density"
:param color_chns `int`: the channels of output "color"
:param density_net_params `dict`: parameters for the density net
:param color_net_params `dict`: parameters for the color net
:param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
:param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
:param density_color_connection `bool`: (optional) whether to add connections between
density net and color net, defaults to False
self.density_chns = density_chns
self.color_chns = color_chns
self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0
self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0
self.appearance = appearance
self.density_color_connection = density_color_connection
self.density_net = FcNet(**density_net_params,
out_chns=self.density_chns + self.color_feature_chns,
if self.appearance == "newtype":
self.specular_feature_chns = d_chns * 3
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
self.specular_net = "Placeholder"
if self.appearance == "decomposite":
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
if specular_net_params:
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + d_chns + self.color_feature_chns,
self.specular_net = FcNet(**specular_net_params,
in_chns=d_chns + self.specular_feature_chns,
out_chns=self.color_chns) if specular_net_params else None
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *,
color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]:
input_shape = x.shape[:-1]
if len(input_shape) > 1:
x = x.flatten(0, -2)
d = d.flatten(0, -2)
n = x.shape[0]
c = self.color_chns
ret: Dict[str, torch.Tensor] = {}
if 'density' in outputs:
density_net_out: torch.Tensor = self.density_net(x)
ret['density'] = density_net_out[:, :self.density_chns]
color_feats = density_net_out[:, self.density_chns:]
if 'color_feat' in outputs:
ret['color_feat'] = color_feats
if 'color' in outputs or 'specluar' in outputs:
if 'density' in ret:
valid_mask = ret['density'][:, 0].detach() >= 1e-4
indices = valid_mask.nonzero()[:, 0]
x, d, color_feats = x[indices], d[indices], color_feats[indices]
indices = None
speculars = None
color_net_in = [x]
if not self.specular_net:
if self.density_color_connection:
color_net_in =, -1)
color_net_out: torch.Tensor = self.color_net(color_net_in)
diffuses = color_net_out[:, :c]
specular_features = color_net_out[:, -self.specular_feature_chns:]
if self.appearance == "newtype":
speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]),
d[..., None])[..., 0]
# TODO relu or not?
diffuses = diffuses.relu()
speculars = speculars.relu()
colors = diffuses + speculars
if not self.specular_net:
colors = diffuses
diffuses = None
specular_net_in =[d, specular_features], -1)
specular_net_out = self.specular_net(specular_net_in)
if self.appearance == "decomposite":
speculars = specular_net_out
colors = diffuses + speculars
diffuses = None
colors = specular_net_out
colors = torch.sigmoid(colors) # TODO indent or not?
if 'color' in outputs:
ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \
if indices else colors
if 'diffuse' in outputs:
ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \
if indices is not None and diffuses is not None else diffuses
if 'specular' in outputs:
ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \
if indices is not None and speculars is not None else speculars
if len(input_shape) > 1:
ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()}
return ret
......@@ -34,7 +34,7 @@ class Sine(nn.Module):
class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, activation: str = 'linear', skip_chns: int = 0):
def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0):
nls_and_inits = {
'sine': (Sine(), sine_init),
......@@ -48,7 +48,7 @@ class FcLayer(nn.Module):
'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init),
'linear': (None, None)
nl, nl_weight_init = nls_and_inits[activation]
nl, nl_weight_init = nls_and_inits[act] = nn.Sequential(
nn.Linear(in_chns + skip_chns, out_chns),
......@@ -59,7 +59,7 @@ class FcLayer(nn.Module):
if nl_weight_init is not None:
nl_weight_init( if isinstance(, nn.Linear) else[0])
def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
return[x0, x], dim=-1) if self.skip else x)
......@@ -68,9 +68,9 @@ class FcLayer(nn.Module):
linear_net = if isinstance(, nn.Linear) else[0]
return linear_net.weight, linear_net.bias
def init_params(self, activation):
def init_params(self, act):
weight, bias = self.get_params()
nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(activation))
nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act))
def copy_to(self, layer):
......@@ -83,7 +83,7 @@ class FcLayer(nn.Module):
class FcNet(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
skips: List[int] = [], activation: str = 'relu'):
skips: List[int] = [], act: str = 'relu', out_act = 'linear'):
Initialize a full-connection net
......@@ -95,12 +95,12 @@ class FcNet(nn.Module):
self.layers = [FcLayer(in_chns, nf, activation)] + [
FcLayer(nf, nf, activation, skip_chns=in_chns if i in skips else 0)
self.layers = [FcLayer(in_chns, nf, act)] + [
FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0)
for i in range(n_layers - 1)
if out_chns > 0:
self.layers.append(FcLayer(nf, out_chns))
if out_chns:
self.layers.append(FcLayer(nf, out_chns, out_act))
for i, layer in enumerate(self.layers):
self.add_module(f"layer{i}", layer)
from itertools import cycle
from math import ceil
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as nn_f
from utils.constants import *
from utils.perf import perf
from .generic import *
from .sampler import Samples
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
Calculate energies from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: energies which block light rays
if raw_noise_std > 0:
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
densities = densities + torch.normal(0.0, raw_noise_std, densities.size())
return densities * dists[..., None]
def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
Calculate alphas from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: alphas
energies = density2energy(densities, dists, raw_noise_std)
return 1.0 - torch.exp(-energies)
class AlphaComposition(nn.Module):
......@@ -11,18 +47,26 @@ class AlphaComposition(nn.Module):
def forward(self, colors, alphas, bg=None):
:param colors `Tensor(N, P, C)`: [description]
:param alphas `Tensor(N, P, 1)`: [description]
:param bg `Tensor([N, ]C)`: [description], defaults to None
:return `Tensor(N, C)`: [description]
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1] + TINY_FLOAT, dim=-1)
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2)
one_minus_alpha =[
torch.ones_like(one_minus_alpha[..., 0:1]),
torch.ones_like(one_minus_alpha[..., :1, :]),
], dim=-1)
weights = alphas * one_minus_alpha # (N_rays, N)
], dim=-2)
weights = alphas * one_minus_alpha # (N, P, 1)
# (N_rays, 1|3), computed weighted color of each sample along each ray.
final_color = torch.sum(weights[..., None] * colors, dim=-2)
# (N, C), computed weighted color of each sample along each ray.
final_color = torch.sum(weights * colors, dim=-2)
# To composite onto a white background, use the accumulated alpha map.
if bg is not None:
......@@ -38,58 +82,290 @@ class AlphaComposition(nn.Module):
class VolumnRenderer(nn.Module):
def __init__(self, *, raw_noise_std=0.0, sigma_as_density=True):
Initialize a Rendering module
class States:
kernel: nn.Module
samples: Samples
hit_mask: torch.Tensor
early_stop_tolerance: float
N: int
P: int
colors: torch.Tensor
diffuses: torch.Tensor
speculars: torch.Tensor
energies: torch.Tensor
weights: torch.Tensor
cum_energies: torch.Tensor
exp_energies: torch.Tensor
tot_evaluations: Dict[str, int]
chunk: Tuple[slice, slice]
cum_chunk: Tuple[slice, slice]
cum_last: Tuple[slice, slice]
chunk_id: int
def start(self) -> int:
return self.chunk[1].start
def end(self) -> int:
return self.chunk[1].stop
def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None:
self.kernel = kernel
self.samples = samples
self.early_stop_tolerance = early_stop_tolerance
N, P = samples.size
self.hit_mask = samples.voxel_indices != -1 # (N, P)
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
self.tot_evaluations = {}
self.N, self.P = N, P
self.chunk_id = -1
def n_hits(self, start: int = None, end: int = None) -> int:
if start is None:
return self.hit_mask.count_nonzero().item()
if end is None:
return self.hit_mask[:, start].count_nonzero().item()
return self.hit_mask[:, start:end].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int):
if key not in self.tot_evaluations:
self.tot_evaluations[key] = 0
self.tot_evaluations[key] += n
def next_chunk(self, *, length=None, end=None):
start = 0 if not hasattr(self, "chunk") else self.end
length = length or self.P
end = min(end or start + length, self.P)
self.chunk = slice(None), slice(start, end)
self.cum_chunk = slice(None), slice(start + 1, end + 1)
self.cum_last = slice(None), slice(start, start + 1)
self.chunk_id += 1
return self
def __init__(self, **kwargs):
self.alpha_composition = AlphaComposition()
self.sigma_as_density = sigma_as_density
self.raw_noise_std = raw_noise_std
def forward(self, colors, sigmas, z_vals, bg_color=None, ret_depth=False, debug=False):
"""Transforms model's predictions to semantically meaningful values.
color: [num_rays, num_samples along ray, 1|3]. Predicted color from model.
density: [num_rays, num_samples along ray]. Predicted density from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rgb_map: [num_rays, 1|3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *,
raymarching_early_stop_tolerance: float = 0,
raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
Perform volumn rendering.
:param kernel: render kernel
:param samples `Samples(N, P)`: samples
:param extra_outputs `list[str]`: extra items should be contained in the result dict.
Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
Should between 0 and 1 (0 means no early stop). Defaults to 0
:param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
0 and `None` means not splitting the raymarching process. Defaults to `None`
:return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
alphas = self.density2alpha(sigmas, z_vals) if self.sigma_as_density \
else nn_f.sigmoid(sigmas)
ret = self.alpha_composition(colors, alphas, bg_color)
if ret_depth:
ret['depth'] = torch.sum(ret['weights'] * z_vals, dim=-1)
if debug:
ret['layers'] =[colors, alphas[..., None]], dim=-1)
if samples.size[1] == 0:
print("VolumnRenderer.forward(): # of samples is zero")
return None
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance)
if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)]
if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections
for chunk_samples in cycle(chunk_sections):
if s.end >= s.P:
chunk_size = -raymarching_chunk_size_or_sections
chunk_hits = s.n_hits(0)
for i in range(1, s.P):
n_hits = s.n_hits(i)
if chunk_hits + n_hits > chunk_size:
n_hits = s.n_hits(i)
chunk_hits = 0
chunk_hits += n_hits
ret = {
'color': torch.sum(s.colors * s.weights, 1),
'tot_evaluations': s.tot_evaluations
for key in extra_outputs:
if key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse':
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular':
ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers':
ret['layers'] =[s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states':
ret['states'] = s
ret[key] = getattr(s, key)
return ret
def density2alpha(self, densities: torch.Tensor, z_vals: torch.Tensor):
# if raymarching_chunk_size == 0:
# raymarching_chunk_samples = 1
# if raymarching_chunk_samples != 0:
# if isinstance(raymarching_chunk_samples, int):
# raymarching_chunk_samples = repeat(raymarching_chunk_samples,
# ceil(s.P / raymarching_chunk_samples))
# chunk_offset = 0
# for chunk_samples in raymarching_chunk_samples:
# start, end = chunk_offset, chunk_offset + chunk_samples
# n_hits = self._forward_chunk(s, start, end)
# if n_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0
# chunk_offset += chunk_samples
# elif raymarching_chunk_size > 0:
# chunk_offset, chunk_hits = 0, s.n_hits(0)
# for i in range(1, s.P):
# n_hits = s.n_hits(i)
# if chunk_hits + n_hits > raymarching_chunk_size:
# self._forward_chunk(s, chunk_offset, i, chunk_hits)
# if chunk_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0
# n_hits = s.n_hits(i)
# chunk_hits, chunk_offset = 0, i
# chunk_hits += n_hits
# self._forward_chunk(s, chunk_offset, s.P, chunk_hits)
# else:
# self._forward_chunk(s, 0, s.P)
# return self._composite(s, extra_outputs)
# original_depth = samples.get('original_point_depth', None)
# if original_depth is not None:
# results['z'] = (original_depth * probs).sum(-1)
# if getattr(input_fn, "track_max_probs", False) and (not
# input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs'])
def _calc_weights(self, s: States):
Raw value inferred from model to color and alpha
Calculate weights of samples in composited outputs
:param densities `Tensor(N.rays, N.samples)`: model's output density
:param z_vals `Tensor(N.rays, N.samples)`: integration time
:return `Tensor(N.rays, N.samples)`: alpha
:param s `States`: states
:param start `int`: chunk's start
:param end `int`: chunk's end
s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N)
dists = z_vals[..., 1:] - z_vals[..., :-1]
last_dist = torch.zeros_like(z_vals[..., 0:1]) + TINY_FLOAT
dists =[dists, last_dist], -1)
def _apply_early_stop(self, s: States):
Stop rays whose accumulated opacity are larger than a threshold
if self.raw_noise_std > 0.:
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
noise = torch.normal(0.0, self.raw_noise_std, densities.size())
densities = densities + noise
return -torch.exp(-torch.relu(densities) * dists) + 1.0
:param s `States`: s
:param end `int`: chunk's end
if s.end < s.P and s.early_stop_tolerance > 0:
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0
def _forward_chunk(self, s: States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# Infer densities and colors
fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse',
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
class DensityFirstVolumnRenderer(VolumnRenderer):
def __init__(self, **kwargs):
def _forward_chunk(self, s: VolumnRenderer.States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X
fi_encoded_x = s.kernel.encode_x(fi_samples) # (N', Ex)
# Infer densities (shape)
fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat',
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("density", fi_idxs[0].size(0))
# Remove samples whose weights are less than a threshold
s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0
# Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_encoded_x = fi_encoded_x[fi_mask] # (N", Ex)
fi_color_feats = fi_outputs['color_feat'][fi_mask]
# For all valid samples: encode D
fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs]) # (N", Ed)
# Infer colors (appearance)
fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse',
extras={"color_feats": fi_color_feats})
# if s.chunk_id == 0:
# fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1:
# fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
# elif s.chunk_id == 2:
# fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
# else:
# fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
from typing import Tuple
from .space import Space, Voxels
import torch
import torch.nn as nn
from typing import Tuple
from utils import device
from utils import sphere
from utils.constants import *
from utils.perf import perf, checkpoint
from .generic import *
from clib import *
class Bins(object):
def up(self):
return self.bounds[1:]
def lo(self):
return self.bounds[:-1]
def __init__(self, vals: torch.Tensor):
self.vals = vals
self.bounds =[
......@@ -16,8 +28,6 @@ class Bins(object):
0.5 * (self.vals[1:] + self.vals[:-1]),
self.up = self.bounds[1:]
self.lo = self.bounds[:-1]
def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
......@@ -26,14 +36,60 @@ class Bins(object):
def to(self, device: torch.device):
self.vals =
self.bounds =
self.up = self.bounds[1:]
self.lo = self.bounds[:-1]
class Samples:
pts: torch.Tensor
"""`Tensor(N[, P], 3)`"""
dirs: torch.Tensor
"""`Tensor(N[, P], 3)`"""
depths: torch.Tensor
"""`Tensor(N[, P])`"""
dists: torch.Tensor
"""`Tensor(N[, P])`"""
voxel_indices: torch.Tensor
"""`Tensor(N[, P])`"""
def size(self):
return self.pts.size()[:-1]
def device(self):
return self.pts.device
def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor,
dists: torch.Tensor, voxel_indices: torch.Tensor) -> None:
self.pts = pts
self.dirs = dirs
self.depths = depths
self.dists = dists
self.voxel_indices = voxel_indices
def __getitem__(self, index):
return Samples(
def reshape(self, *shape: int):
return Samples(
pts=self.pts.reshape(*shape, 3),
dirs=self.dirs.reshape(*shape, 3),
class Sampler(nn.Module):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, spherical: bool, lindisp: bool):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs):
Initialize a Sampler module
......@@ -44,37 +100,81 @@ class Sampler(nn.Module):
self.lindisp = lindisp
self.spherical = spherical
self.perturb_sample = perturb_sample
s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
if s_range[1] > s_range[0]:
s_range[0] += 1e-4
s_range[1] -= 1e-4
s_range[0] -= 1e-4
s_range[1] += 1e-4
self.bins = Bins.linspace(s_range, n_samples, device=device.default())
def forward(self, rays_o, rays_d):
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical`
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:return `Tensor(B, N, 3)`: sampled points
:return `Tensor(B, N)`: corresponding depths along rays
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:return `Samples(N, P)`: samples
s = self.bins.vals.expand(rays_o.size(0), -1)
if self.perturb_sample:
if perturb_sample:
s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
pts, depths = self._get_sample_points(rays_o, rays_d, s)
voxel_indices = space_module.get_voxel_indices(pts)
valid_rays_mask =
return Samples(
dirs=rays_d[:, None].expand(-1, depths.size(1), -1),
)[valid_rays_mask], valid_rays_mask
def _get_sample_points(self, rays_o, rays_d, s):
z = torch.reciprocal(s) if self.lindisp else s
if self.spherical:
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
return sphers, depths, s, pts
return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None
pts = rays_o[:, None] + rays_d[:, None] * z[..., None]
depths = z
return pts, depths
def _calc_dists(self, vals):
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N)
dists = vals[..., 1:] - vals[..., :-1]
last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT
return[dists, last_dist], -1)
class SphericalSampler(Sampler):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, **kwargs):
Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
super().__init__(sample_range=sample_range, n_samples=n_samples,
perturb_sample=perturb_sample, lindisp=False)
def _get_sample_points(self, rays_o, rays_d, s):
r = torch.reciprocal(s)
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r)
pts = sphere.cartesian2spherical(pts, inverse_r=True)
return pts, depths
class PdfSampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
spherical: bool, lindisp: bool):
spherical: bool, lindisp: bool, **kwargs):
Initialize a Sampler module
......@@ -90,7 +190,7 @@ class PdfSampler(nn.Module):
self.n_samples = n_samples
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical`
......@@ -166,22 +266,116 @@ class PdfSampler(nn.Module):
class VoxelSampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
lindisp: bool, space):
def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs):
Initialize a Sampler module
Initialize a VoxelSampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
:param step_size: step size
self.lindisp = lindisp
self.perturb_sample = perturb_sample
self.n_samples = n_samples = space
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
self.sample_step = sample_step
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
:param rays_o `Tensor(N, 3)`: rays' origin positions
:param rays_d `Tensor(N, 3)`: rays' directions
:param step_size `float`: gap between samples along a ray
:return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
:return `Tensor(N)`: valid rays mask
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
n_rays = rays_o.size(0)
ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N')
hits = intersections.hits
min_depths = intersections.min_depths
max_depths = intersections.max_depths
voxel_indices = intersections.voxel_indices
rays_near_depth = min_depths[:, :1] # (N', 1)
rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1)
rays_length = rays_far_depth - rays_near_depth
rays_steps = (rays_length / self.sample_step).ceil().long()
rays_step_size = rays_length / rays_steps
max_steps = rays_steps.max().item()
rays_step = torch.arange(max_steps, device=rays_o.device,
dtype=torch.float)[None].repeat(n_rays, 1) # (N', P)
invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[
ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth)
] # (N', P)
samples_depth[invalid_samples_mask] = HUGE_FLOAT
samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * samples_depth[..., None],
dirs=rays_d.expand(-1, max_steps, -1),
), valid_rays_mask
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
:param rays_o `Tensor(N, 3)`: [description]
:param rays_d `Tensor(N, 3)`: [description]
:param step_size `float`: [description]
:return `Samples(N, P)`: [description]
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
checkpoint("Ray intersect")
if intersections.size == 0:
return None, valid_rays_mask
min_depth = intersections.min_depths
max_depth = intersections.max_depths
pts_idx = intersections.voxel_indices
dists = max_depth - min_depth
tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1)
probs = dists / tot_dists
steps = tot_dists[:, 0] / self.sample_step
# sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample)
sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT)
checkpoint("Inverse CDF sampling")
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * sampled_depths[..., None],
dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
), valid_rays_mask
from math import ceil
import torch
import numpy as np
from typing import List, NoReturn, Tuple, Union
from torch import nn
from plyfile import PlyData, PlyElement
from utils.geometry import *
from utils.constants import *
from utils.voxels import *
from utils.perf import perf
from clib import *
class Intersections:
min_depths: torch.Tensor
"""`Tensor(N, P)` Min ray depths of intersected voxels"""
max_depths: torch.Tensor
"""`Tensor(N, P)` Max ray depths of intersected voxels"""
voxel_indices: torch.Tensor
"""`Tensor(N, P)` Indices of intersected voxels"""
hits: torch.Tensor
"""`Tensor(N)` Number of hits"""
def size(self):
return self.hits.size(0)
def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
self.min_depths = min_depths
self.max_depths = max_depths
self.voxel_indices = voxel_indices
self.hits = hits
def __getitem__(self, index):
return Intersections(
class Space(nn.Module):
bbox: Union[torch.Tensor, None]
"""`Tensor(2, 3)` Bounding box"""
def __init__(self, *, bbox: List[float] = None, **kwargs):
if bbox is None:
self.bbox = None
self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False)
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
raise NotImplementedError
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
raise NotImplementedError
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
raise NotImplementedError
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
if self.bbox is not None:
out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1) # (N...)
voxel_indices[out_bbox] = -1
return voxel_indices
def pruning(self, score_fn, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
def splitting(self):
raise NotImplementedError()
class Voxels(Space):
steps: torch.Tensor
"""`Tensor(3)` Steps along each dimension"""
corners: torch.Tensor
"""`Tensor(C, 3)` Corner positions"""
voxels: torch.Tensor
"""`Tensor(M, 3)` Voxel centers"""
corner_indices: torch.Tensor
"""`Tensor(M, 8)` Voxel corner indices"""
voxel_indices_in_grid: torch.Tensor
"""`Tensor(G)` Indices in voxel list or -1 for pruned space"""
def dims(self) -> int:
"""`int` Number of dimensions"""
return self.steps.size(0)
def n_voxels(self) -> int:
"""`int` Number of voxels"""
return self.voxels.size(0)
def n_corner(self) -> int:
"""`int` Number of corners"""
return self.corners.size(0)
def voxel_size(self) -> torch.Tensor:
"""`Tensor(3)` Voxel size"""
return (self.bbox[1] - self.bbox[0]) / self.steps
def device(self) -> torch.device:
return self.voxels.device
def __init__(self, *, voxel_size: float = None,
steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None:
if self.bbox is None:
raise ValueError("Missing argument 'bbox'")
if voxel_size is not None:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels))
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
Create a embedding on voxel corners.
:param name `str`: embedding name
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
name = f'emb_{name}'
self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims))
return self.__getattr__(name)
def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
return getattr(self, f'emb_{name}')
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
Extract embedding values at given points using trilinear interpolation.
:param pts `Tensor(N, 3)`: points to extract values
:param voxel_indices `Tensor(N)`: corresponding voxel indices
:param name `str`: embedding name, default to 'default'
:return `Tensor(N, X)`: extracted values
emb = self.get_embedding(name)
if emb is None:
raise KeyError(f"Embedding '{name}' doesn't exist")
voxels = self.voxels[voxel_indices] # (N, 3)
corner_indices = self.corner_indices[voxel_indices] # (N, 8)
p = (pts - voxels) / self.voxel_size + 0.5 # (N, 3) normed-coords in voxel
features = emb(corner_indices).reshape(pts.size(0), 8, -1) # (N, 8, X)
return trilinear_interp(p, features)
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
Calculate intersections of rays and voxels.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param n_max_hits `int`: maximum number of hits (for allocating enough space)
:return `Intersection`: intersections of rays and voxels
# Prepend a dim to meet the requirement of external call
rays_o = rays_o[None].contiguous()
rays_d = rays_d[None].contiguous()
voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits)
invalid_voxel_mask = voxel_indices.eq(-1)
hits = n_max_hits - invalid_voxel_mask.sum(-1)
# Sort intersections according to their depths
min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
min_depths, sorted_idx = min_depths.sort(dim=-1)
max_depths = max_depths.gather(-1, sorted_idx)
voxel_indices = voxel_indices.gather(-1, sorted_idx)
return Intersections(
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
Get voxel indices of points.
If a point is not in any valid voxels, its corresponding voxel index is -1.
:param pts `Tensor(N..., 3)`: points
:return `Tensor(N...)`: corresponding voxel indices
grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps)
grid_indices[out_mask] = 0
voxel_indices = self.voxel_indices_in_grid[grid_indices]
voxel_indices[out_mask] = -1
return voxel_indices
def splitting(self) -> None:
Split voxels into smaller voxels with half size.
n_voxels_before = self.n_voxels
self.steps *= 2
self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
.reshape(-1, 3)
return n_voxels_before, self.n_voxels
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
self.voxels = self.voxels[keeps]
self.corner_indices = self.corner_indices[keeps]
return keeps.size(0), keeps.sum().item()
def pruning(self, score_fn, threshold: float = 0.5) -> None:
scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0]) # (M)
return self.prune(scores > threshold)
def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)
def balance_cut(self, dim: int, n_parts: int) -> List[int]:
n_voxels_list = self.n_voxels_along_dim(dim)
cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
bins = []
part = 1
offset = 0
for i in range(len(cdf)):
if cdf[i] >= part:
bins.append(i + 1 - offset)
offset = i + 1
part = int(cdf[i]) + 1
return bins
def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits)
sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand(
sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten()
def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor:
def get_scores_once(pts, idxs):
scores = score_fn(pts, idxs).reshape(-1, bits ** 3) # (B, P)
if reduce_fn is not None:
scores = reduce_fn(scores) # (B[, ...])
return scores
sampled_xyz, sampled_idx = self.sample(bits)
chunk_size = 64
get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
for i in range(0, self.voxels.size(0), chunk_size)
], 0) # (M[, ...])
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_corners(self):
Update voxel corners.
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
def _update_voxel_indices_in_grid(self):
Update voxel indices in grid.
grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps)
self.voxel_indices_in_grid = grid_indices.new_full([], -1)
self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device)
def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs):
# Handle buffers
for name, buffer in self.named_buffers(recurse=False):
if name in self._non_persistent_buffers_set:
buffer.resize_as_(state_dict[prefix + name])
# Handle embeddings
for name, module in self.named_modules():
if name.startswith('emb_'):
setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim))
class Octree(Voxels):
def __init__(self, **kwargs) -> None:
self.nodes_cached = None
self.tree_cached = None
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
if self.nodes_cached is None:
self.nodes_cached, self.tree_cached = build_easy_octree(
self.voxels, 0.5 * self.voxel_size)
return self.nodes_cached, self.tree_cached
def clear(self):
self.nodes_cached = None
self.tree_cached = None
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int):
nodes, tree = self.get()
return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)
def splitting(self):
ret = super().splitting()
return ret
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
ret = super().prune(keeps)
return ret
nerf++ @ a30f1a5a
Subproject commit a30f1a5ad116e43aad90c426a966b2a3fcedaf7e
import torch
import torch.nn as nn
from modules import *
from utils import color
class Nerf(nn.Module):
def __init__(self, fc_params, sampler_params, *,
c: int = color.RGB,
n_pos_encode: int = 0,
n_dir_encode: int = None,
coarse_net=None, **kwargs):
Initialize a NeRF unit
:param fc_params `dict`: parameters for full-connection network
:param sampler_params `dict`: parameters for sampler
:param c `int`: color mode
:param n_pos_encode `int`: encode position to number of dimensions
:param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
:param coarse_net `NerfUnit`: optional coarse net
self.coarse_net = coarse_net
self.color = c
self.coord_chns = 3
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
if n_dir_encode is not None:
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
self.dir_chns = 0
self.dir_encoder = None
self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=fc_params['nf'] // 2,
sampler_params['spherical'] = False
self.sampler = PdfSampler(**sampler_params) if self.coarse_net is not None \
else Sampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
ret_depth=False, debug=False) -> torch.Tensor:
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param prev_ret `Mapping`:
:param ret_depth `bool`:
:return: `Tensor(B, C)``, inferred images/pixels
if self.coarse_net is not None:
coarse_ret = self.coarse_net(rays_o, rays_d, ret_depth=ret_depth, debug=debug)
coords, depths, s_vals, _ = self.sampler(rays_o, rays_d, coarse_ret['sample'],
coords, depths, s_vals, _ = self.sampler(rays_o, rays_d)
coords_encoded = self.pos_encoder(coords)
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, s_vals.size(-1), -1) \
if self.dir_encoder is not None else None
colors, densities = self.core(coords_encoded, dirs_encoded)
ret = self.rendering(colors, densities[..., 0], depths, ret_depth=ret_depth, debug=debug)
ret['sample'] = s_vals
if self.coarse_net is not None:
ret['coarse'] = coarse_ret
return ret
import torch
import torch.nn as nn
from modules import *
from utils import color
class NSVF(nn.Module):
def __init__(self, fc_params, sampler_params, *,
c: int = color.RGB,
n_featdim: int = 32,
n_pos_encode: int = 0,
n_dir_encode: int = None,
Initialize a NSVF model
:param fc_params `dict`: parameters for full-connection network
:param sampler_params `dict`: parameters for sampler
:param c `int`: color mode
:param n_pos_encode `int`: encode position to number of dimensions
:param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
:param coarse_net `NerfUnit`: optional coarse net
self.color = c
self.coord_chns = n_featdim
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
if n_dir_encode is not None:
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
self.dir_chns = 0
self.dir_encoder = None
self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=fc_params['nf'] // 2,
skips=fc_params['skips']) = OctTreeSpace()
sampler_params['space'] =
self.sampler = VoxelSampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
ret_depth=False, debug=False) -> torch.Tensor:
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param prev_ret `Mapping`:
:param ret_depth `bool`:
:return: `Tensor(B, C)``, inferred images/pixels
feats, dirs, z_s, dz_s = self.sampler(rays_o, rays_d)
feats_encoded = self.pos_encoder(feats)
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, z_s.size(-1), -1) \
if self.dir_encoder is not None else None
colors, densities = self.core(feats_encoded, dirs_encoded)
ret = self.rendering(colors, densities[..., 0], z_s, dz_s, ret_depth=ret_depth, debug=debug)
return ret
import torch
import torch.nn as nn
from modules import *
from utils import sphere
from utils import color
class Snerf(nn.Module):
def __init__(self, fc_params, sampler_params, *,
n_parts: int = 1,
c: int = color.RGB,
pos_encode: int = 10,
dir_encode: int = None,
spherical_dir: bool = False, **kwargs):
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
self.color = c
self.spherical_dir = spherical_dir
self.n_samples = sampler_params['n_samples']
self.n_parts = n_parts
self.samples_per_part = self.n_samples // self.n_parts
self.coord_chns = 3
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns)
if dir_encode is not None:
self.dir_encoder = InputEncoder.Get(dir_encode, 2 if self.spherical_dir else 3)
self.dir_chns_encoded = self.dir_encoder.out_dim
self.dir_encoder = None
self.dir_chns_encoded = 0
self.nets = nn.ModuleList(
dir_nf=fc_params['nf'] // 2,
for _ in range(self.n_parts))
sampler_params['spherical'] = True
self.sampler = Sampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth=False, debug=False) -> torch.Tensor:
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:return: `Tensor(B, C)``, inferred images/pixels
n_rays = rays_o.size(0)
coords, depths, _, pts = self.sampler(rays_o, rays_d)
coords_encoded = self.pos_encoder(coords)
if self.dir_encoder is not None:
if self.spherical_dir:
dirs_encoded = self.dir_encoder(sphere.calc_local_dir(rays_d, coords, pts))
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, self.n_samples, -1)
dirs_encoded = None
densities = torch.empty(n_rays, self.n_samples, device=device.default())
colors = torch.empty(n_rays, self.n_samples, self.color_chns, device=device.default())
for i, net in enumerate(self.nets):
s = slice(i * self.samples_per_part, (i + 1) * self.samples_per_part)
c, d = net(coords_encoded[:, s],
dirs_encoded[:, s] if dirs_encoded is not None else None)
colors[:, s] = c
densities[:, s] = d
ret = self.rendering(colors.view(-1, self.n_samples, self.color_chns),
densities, depths, ret_depth=ret_depth, debug=debug)
if debug:
ret['sample_densities'] = densities
ret['sample_depths'] = depths
return ret
class SnerfExport(nn.Module):
def __init__(self, net: Snerf):
super().__init__() = net
def forward(self, coords_encoded, z_vals):
colors = []
densities = []
for i in range(
s = slice(i *, (i + 1) *
mlp =[i] if is not None else
c, d = mlp(coords_encoded[:, s].flatten(1, 2))
colors =, 1)
densities =, 1)
alphas =, z_vals)
return[colors, alphas[..., None]], -1)
......@@ -3,13 +3,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as nn_f\n",
"import matplotlib.pyplot as plt\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
......@@ -18,21 +16,16 @@
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"from data.spherical_view_syn import *\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n",
"from utils import img\n",
"from utils import device\n",
"from utils.view import *\n",
"from components.fnr import FoveatedNeuralRenderer\n",
"datadir = f\"{rootdir}/data/__new/__demo/for_crop\"\n",
"figs = ['our', 'gt', 'nerf', 'fgt']\n",
"crops = {\n",
" 'classroom_0': [[720, 800, 128], [1097, 982, 256]],\n",
" 'lobby_1': [[570, 1000, 100], [1049, 1049, 256]],\n",
" 'stones_2': [[720, 800, 100], [680, 1317, 256]],\n",
" 'barbershop_3': [[745, 810, 100], [1135, 627, 256]]\n",
" 'classroom_0': [[720, 790, 100], [370, 1160, 200]],\n",
" 'lobby_1': [[570, 1000, 100], [1300, 1000, 200]],\n",
" 'stones_2': [[720, 800, 100], [680, 1317, 200]],\n",
" 'barbershop_3': [[745, 810, 100], [950, 900, 200]]\n",
"colors = torch.tensor([[0, 1, 0, 1], [1, 1, 0, 1]], dtype=torch.float)\n",
"border = 10\n",
......@@ -78,16 +71,18 @@
"[fovea_patches, periph_patches], dim=-1),\n",
" [f\"{datadir}/patch/{scene}_{fig}.png\" for fig in figs])\n",
", f\"{datadir}/overlay/{scene}.png\")\n"
"outputs": [],
"metadata": {}
"metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
"hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('base': conda)",
"name": "python3"
"name": "python3",
"display_name": "Python 3.8.5 64-bit ('base': conda)"
"language_info": {
"codemirror_mode": {
......@@ -170,7 +170,7 @@
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n",
" outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
"['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
"['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
"['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
......@@ -203,7 +203,7 @@
" center = (0, 0)\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n",
" outputdir = 'panorama'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
"['blended'], f'{outputdir}/{view_idx:04d}.png')"
"outputs": [
......@@ -216,7 +216,7 @@
" ret_raw=False)\n",
" if True:\n",
" outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" left_images['blended'],\n",
" right_images['blended']\n",
......@@ -228,7 +228,7 @@
" right_images['blended'][:, 1:3]\n",
" ], dim=1)\n",
", '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n",
" #misc.create_dir(outputdir + '/mid')\n",
" #os.makedirs(outputdir + '/mid', exist_ok=True)\n",
"['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n",
"['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n",
" print(\"%s %d Saved\" % (scene, i))\n",
......@@ -110,7 +110,7 @@
" #plot_figures(images, center)\n",
" outputdir = '../__1_eval/output_mono_periph/ref_as_right_eye/%s/' % scene\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" #for key in images:\n",
" key = 'blended'\n",
"[key], outputdir + 'view%04d_%s.png' % (view_idx, key))\n"
......@@ -131,7 +131,7 @@
" images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n",
" misc.create_dir('output/eval_gaze')\n",
" os.makedirs('output/eval_gaze', exist_ok=True)\n",
" out_path = 'output/eval_gaze/gaze%03d_%d,%d.png' % (gaze_idx, x, y)\n",
"['blended'], out_path)\n",
" print('Output ' + out_path)\n",
......@@ -130,7 +130,7 @@
" images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n",
" misc.create_dir('output/teasers')\n",
" os.makedirs('output/teasers', exist_ok=True)\n",
" for key in images:\n",
" images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
......@@ -150,7 +150,7 @@
"print(\"Encoded:\", encoded)\n",
"#plot_figures(images, center)\n",
"#os.makedirs('output/teasers', exist_ok=True)\n",
"#for key in images:\n",
"# images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
......@@ -188,7 +188,7 @@
"#plot_figures(left_images, right_images, centers[set_id][0], centers[set_id][1])\n",
"os.makedirs('output', exist_ok=True)\n",
"for key in left_images:\n",
" left_images[key], 'output/set%d_%s_l.png' % (set_id, key))\n",
......@@ -117,7 +117,7 @@
" left_images = gen.gen(left_center, left_view, mono_trans=mono_trans)\n",
" right_images = gen.gen(right_center, right_view, mono_trans=mono_trans)\n",
" \n",
" misc.create_dir('output/video_frames/hmd2')\n",
" os.makedirs('output/video_frames/hmd2', exist_ok=True)\n",
"[left_images['blended'], right_images['blended']], -1),\n",
" 'output/video_frames/hmd2/view%04d.png' % view_idx)\n",
" print('Frame %d saved' % view_idx)\n"
......@@ -155,7 +155,7 @@
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n",
" outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
"['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
"['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
"['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
......@@ -196,7 +196,7 @@
" center = (0, 0)\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n",
" outputdir = 'nerf_our'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
"['blended'], f'{outputdir}/{view_idx:04d}.png')"
......@@ -101,7 +101,7 @@
"gaze = [37.55656052, 20.7297554]\n",
"images = renderer(view, gaze, using_mask=False, ret_raw=True)\n",
"outputdir = '../__demo/mono_f60&m110/'\n",
"os.makedirs(outputdir, exist_ok=True)\n",
"['layers_img'][0], f'{outputdir}{scene}_fovea.png')\n",
"['blended'], f'{outputdir}{scene}.png')\n",
"['blended_raw'], f'{outputdir}{scene}_noCE.png')"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment