Commit 6294701e authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 2824f796
from .sampler import * from .sampler import Sampler, PdfSampler, VoxelSampler
from .input_encoder import * from .input_encoder import InputEncoder, IntegratedPosEncoder
from .renderer import * from .renderer import VolumnRenderer, DensityFirstVolumnRenderer
from .space import * from .space import Space, Voxels, Octree
from .core import * from .core import NerfCore, NerfAdvCore, MultiNerf
\ No newline at end of file \ No newline at end of file
import re
import torch
from typing import Iterable, Tuple
from .generic import * from .generic import *
from typing import Dict from utils.misc import union, split
from utils.type import NetInput, NetOutput
from utils.module import Module
from utils.samples import Samples
class NerfCore(nn.Module): class NerfCore(Module):
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers, def __init__(self, *, x_chns, density_chns, color_chns, nf, n_layers,
dir_chns=0, dir_nf=0, act='relu', skips=[]): d_chns=0, d_nf=0, act='relu', skips=[], with_layer_norm=False,
density_out_act='relu', color_out_act='sigmoid', f_chns=0):
super().__init__() super().__init__()
self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers, self.input_f = f_chns > 0
skips=skips, act=act) self.core_field = FcBlock(in_chns=x_chns + f_chns, out_chns=None, nf=nf, n_layers=n_layers,
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None skips=skips, act=act, with_ln=with_layer_norm)
self.density_out = FcLayer(nf, density_chns, density_out_act, with_ln=False) \
if density_chns > 0 else None
if color_chns == 0: if color_chns == 0:
self.feature_out = None
self.color_out = None self.color_out = None
elif dir_chns > 0: elif d_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf) self.color_out = FcBlock(in_chns=nf + d_chns, out_chns=color_chns,
self.color_out = nn.Sequential( nf=d_nf or nf // 2, n_layers=1,
FcLayer(core_nf + dir_chns, dir_nf, act), act=act, out_act=color_out_act, with_ln=with_layer_norm)
FcLayer(dir_nf, color_chns) self.with_dir = True
)
else: else:
self.feature_out = torch.nn.Identity() self.color_out = FcLayer(nf, color_chns, color_out_act, with_ln=False)
self.color_out = FcLayer(core_nf, color_chns) self.with_dir = False
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]: def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput:
ret = {} ret = {}
core_output = self.core(x) if features is None:
if 'density' in outputs: features = self.core_field(union(inputs.x, inputs.f) if self.input_f else inputs.x)
ret['density'] = torch.relu(self.density_out(core_output)) \ if 'features' in outputs:
if self.density_out is not None else None ret['features'] = features
if 'color' in outputs: if 'densities' in outputs and self.density_out:
if self.color_out is None: ret['densities'] = self.density_out(features)
ret['color'] = None if 'colors' in outputs and self.color_out:
else: if self.with_dir:
feature = self.feature_out(core_output) features = union(features, inputs.d)
if dir is not None: ret['colors'] = self.color_out(features)
feature = torch.cat([feature, d], dim=-1)
ret['color'] = self.color_out(feature).sigmoid()
for key in outputs:
if key == 'density' or key == 'color':
continue
ret[key] = None
return ret return ret
class NerfAdvCore(nn.Module): class NerfAdvCore(Module):
def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int, def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
density_net_params: dict, color_net_params: dict, density_net: dict, color_net: dict, specular_net: dict = None,
specular_net_params: dict = None, appearance="decomposite", with_layer_norm=False, f_chns=0):
appearance="decomposite",
density_color_connection=False):
""" """
Create a NeRF-Adv Core Net. Create a NeRF-Adv Core Net.
Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act". Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
...@@ -64,84 +64,87 @@ class NerfAdvCore(nn.Module): ...@@ -64,84 +64,87 @@ class NerfAdvCore(nn.Module):
:param color_net_params `dict`: parameters for the color net :param color_net_params `dict`: parameters for the color net
:param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None :param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
:param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite" :param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
:param density_color_connection `bool`: (optional) whether to add connections between
density net and color net, defaults to False
""" """
super().__init__() super().__init__()
self.input_f = f_chns > 0
self.density_chns = density_chns self.density_chns = density_chns
self.color_chns = color_chns self.color_chns = color_chns
self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0 self.specular_feature_chns = color_net["nf"] if specular_net else 0
self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0 self.color_feature_chns = density_net["nf"]
self.appearance = appearance self.appearance = appearance
self.density_color_connection = density_color_connection self.density_net = FcBlock(**density_net,
self.density_net = FcNet(**density_net_params, in_chns=x_chns + f_chns,
in_chns=x_chns, out_chns=self.density_chns + self.color_feature_chns,
out_chns=self.density_chns + self.color_feature_chns, out_act='relu',
out_act='relu') with_ln=with_layer_norm)
if self.appearance == "newtype": if self.appearance == "newtype":
self.specular_feature_chns = d_chns * 3 self.specular_feature_chns = d_chns * 3
self.color_net = FcNet(**color_net_params, self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns, in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns) out_chns=self.color_chns + self.specular_feature_chns,
with_ln=with_layer_norm)
self.specular_net = "Placeholder" self.specular_net = "Placeholder"
else: else:
if self.appearance == "decomposite": match = re.match("mlp_basis\((\d+)\)", self.appearance)
self.color_net = FcNet(**color_net_params, if match is not None:
in_chns=x_chns + self.color_feature_chns, basis_dim = int(match.group(1))
out_chns=self.color_chns + self.specular_feature_chns) self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns * basis_dim)
elif self.appearance == "decomposite":
self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns,
with_ln=with_layer_norm)
else: else:
if specular_net_params: if specular_net:
self.color_net = FcNet(**color_net_params, self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns, in_chns=x_chns + self.color_feature_chns,
out_chns=self.specular_feature_chns) out_chns=self.specular_feature_chns,
with_ln=with_layer_norm)
else: else:
self.color_net = FcNet(**color_net_params, self.color_net = FcBlock(**color_net,
in_chns=x_chns + d_chns + self.color_feature_chns, in_chns=x_chns + d_chns + self.color_feature_chns,
out_chns=self.color_chns) out_chns=self.color_chns,
self.specular_net = FcNet(**specular_net_params, with_ln=with_layer_norm)
in_chns=d_chns + self.specular_feature_chns, self.specular_net = FcBlock(**specular_net,
out_chns=self.color_chns) if specular_net_params else None in_chns=d_chns + self.specular_feature_chns,
out_chns=self.color_chns,
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *, with_ln=with_layer_norm) if specular_net else None
color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]:
input_shape = x.shape[:-1] def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput:
if len(input_shape) > 1: output_shape = inputs.shape
x = x.flatten(0, -2)
d = d.flatten(0, -2) ret: NetOutput = {}
n = x.shape[0]
c = self.color_chns if 'densities' in outputs or 'features' in outputs:
density_net_in = union(inputs.x, inputs.f) if self.input_f else inputs.x
ret: Dict[str, torch.Tensor] = {} density_net_out: torch.Tensor = self.density_net(density_net_in)
densities, features = split(density_net_out, self.density_chns, -1)
if 'density' in outputs: if 'features' in outputs:
density_net_out: torch.Tensor = self.density_net(x) ret['features'] = features
ret['density'] = density_net_out[:, :self.density_chns] if 'densities' in outputs:
color_feats = density_net_out[:, self.density_chns:] ret['densities'] = densities
if 'color_feat' in outputs:
ret['color_feat'] = color_feats if 'colors' in outputs or 'specluars' in outputs or 'diffuses' in outputs:
if 'densities' in ret:
if 'color' in outputs or 'specluar' in outputs: valid_mask = ret['densities'][..., 0].detach() >= 1e-4
if 'density' in ret: indices: Tuple[torch.Tensor, ...] = valid_mask.nonzero(as_tuple=True)
valid_mask = ret['density'][:, 0].detach() >= 1e-4 inputs, features = inputs[indices], features[indices]
indices = valid_mask.nonzero()[:, 0]
x, d, color_feats = x[indices], d[indices], color_feats[indices]
else: else:
indices = None indices = None
speculars = None color_net_in = [inputs.x, features]
color_net_in = [x]
if not self.specular_net: if not self.specular_net:
color_net_in.append(d) color_net_in.append(inputs.d)
if self.density_color_connection: color_net_out: torch.Tensor = self.color_net(union(*color_net_in))
color_net_in.append(color_feats) diffuses = color_net_out[..., :self.color_chns]
color_net_in = torch.cat(color_net_in, -1) specular_features = color_net_out[..., -self.specular_feature_chns:]
color_net_out: torch.Tensor = self.color_net(color_net_in)
diffuses = color_net_out[:, :c]
specular_features = color_net_out[:, -self.specular_feature_chns:]
if self.appearance == "newtype": if self.appearance == "newtype":
speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]), speculars = torch.matmul(
d[..., None])[..., 0] specular_features.reshape(*inputs.shape, -1, inputs.d.shape[-1]),
inputs.d[..., None])[..., 0]
# TODO relu or not? # TODO relu or not?
diffuses = diffuses.relu() diffuses = diffuses.relu()
speculars = speculars.relu() speculars = speculars.relu()
...@@ -150,26 +153,55 @@ class NerfAdvCore(nn.Module): ...@@ -150,26 +153,55 @@ class NerfAdvCore(nn.Module):
if not self.specular_net: if not self.specular_net:
colors = diffuses colors = diffuses
diffuses = None diffuses = None
speculars = None
else: else:
specular_net_in = torch.cat([d, specular_features], -1) specular_net_out = self.specular_net(union(inputs.d, specular_features))
specular_net_out = self.specular_net(specular_net_in)
if self.appearance == "decomposite": if self.appearance == "decomposite":
speculars = specular_net_out speculars = specular_net_out
colors = diffuses + speculars colors = diffuses + speculars
else: else:
diffuses = None diffuses = None
speculars = None
colors = specular_net_out colors = specular_net_out
colors = torch.sigmoid(colors) # TODO indent or not? colors = torch.sigmoid(colors) # TODO indent or not?
if 'color' in outputs:
ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \ def postprocess(data: torch.Tensor):
if indices else colors return data.new_zeros(*output_shape, data.shape[-1]).index_put(indices, data)\
if 'diffuse' in outputs: if indices is not None else data
ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \
if indices is not None and diffuses is not None else diffuses if 'colors' in outputs:
if 'specular' in outputs: ret['colors'] = postprocess(colors)
ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \ if 'diffuses' in outputs and diffuses is not None:
if indices is not None and speculars is not None else speculars ret['diffuses'] = postprocess(diffuses)
if 'speculars' in outputs and speculars is not None:
if len(input_shape) > 1: ret['speculars'] = postprocess(speculars)
ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()}
return ret return ret
class MultiNerf(Module):
@property
def n_levels(self):
return len(self.nets)
def __init__(self, nets: Iterable[Module]):
super().__init__()
self.nets = nets
for i in len(nets):
self.add_module(f"Level {i}", nets[i])
def set_frozen(self, level: int, on: bool):
for net in self.nets:
net[level].train(not on)
def forward(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
L = samples.level
if L == 0:
return self.nets[L](inputs, *outputs)
if samples.features is not None:
return self.nets[L](NetInput(inputs.x, inputs.d, samples.features), *outputs)
features = self.nets[0](inputs, 'features')['features']
for i in range(1, L):
features = self.nets[i](NetInput(inputs.x, inputs.d, features), 'features')['features']
samples.features = features
return self(inputs, *outputs, samples=samples)
from typing import List
import math
import torch
import torch.nn as nn
from utils.constants import *
class BatchLinear(nn.Linear):
'''
A linear meta-layer that can deal with batched weight matrices and biases,
as for instance output by a hypernetwork.
'''
__doc__ = nn.Linear.__doc__
def forward(self, input, params=None):
# if params is None:
# params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
weight = params['weight']
output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
output += bias.unsqueeze(-2)
return output
class Sine(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return torch.sin(30 * input)
class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0):
super().__init__()
nls_and_inits = {
'sine': (Sine(), sine_init),
'relu': (nn.ReLU(), None),
'sigmoid': (nn.Sigmoid(), None),
'tanh': (nn.Tanh(), None),
'selu': (nn.SELU(), init_weights_selu),
'softplus': (nn.Softplus(), init_weights_normal),
'elu': (nn.ELU(), init_weights_elu),
'softmax': (nn.Softmax(dim=-1), softmax_init),
'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init),
'linear': (None, None)
}
nl, nl_weight_init = nls_and_inits[act]
self.net = nn.Sequential(
nn.Linear(in_chns + skip_chns, out_chns),
nl
) if nl else nn.Linear(in_chns + skip_chns, out_chns)
self.skip = skip_chns != 0
if nl_weight_init is not None:
nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0])
else:
self.init_params(act)
def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
return self.net(torch.cat([x0, x], dim=-1) if self.skip else x)
def get_params(self):
linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0]
return linear_net.weight, linear_net.bias
def init_params(self, act):
weight, bias = self.get_params()
nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act))
nn.init.zeros_(bias)
def copy_to(self, layer):
weight, bias = self.get_params()
dst_weight, dst_bias = layer.get_params()
dst_weight.copy_(weight)
dst_bias.copy_(bias)
class FcNet(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
skips: List[int] = [], act: str = 'relu', out_act = 'linear'):
"""
Initialize a full-connection net
:kwarg in_chns: channels of input
:kwarg out_chns: channels of output
:kwarg nf: number of features in each hidden layer
:kwarg n_layers: number of layers
:kwarg skips: create skip connections from input to layers in this list
"""
super().__init__()
self.layers = [FcLayer(in_chns, nf, act)] + [
FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0)
for i in range(n_layers - 1)
]
if out_chns:
self.layers.append(FcLayer(nf, out_chns, out_act))
for i, layer in enumerate(self.layers):
self.add_module(f"layer{i}", layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = x
for layer in self.layers:
x = layer(x, x0)
return x
########################
# Initialization methods
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# For PINNet, Raissi et al. 2019
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# grab from upstream pytorch branch and paste here for now
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def init_weights_trunc_normal(m):
# For PINNet, Raissi et al. 2019
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
fan_in = m.weight.size(1)
fan_out = m.weight.size(0)
std = math.sqrt(2.0 / float(fan_in + fan_out))
mean = 0.
# initialize with the same behavior as tf.truncated_normal
# "The generated values follow a normal distribution with specified mean and
# standard deviation, except that values whose magnitude is more than 2
# standard deviations from the mean are dropped and re-picked."
_no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)
def init_weights_normal(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
def init_weights_selu(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
def init_weights_elu(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))
def init_weights_xavier(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
nn.init.xavier_normal_(m.weight)
def sine_init(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-math.sqrt(6 / num_input) / 30, math.sqrt(6 / num_input) / 30)
def first_layer_sine_init(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-1 / num_input, 1 / num_input)
def softmax_init(m):
with torch.no_grad():
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.constant_(m.bias, val=0)
from .linear import *
import torch
import torch.nn.functional as F
class Sine(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
return (30 * x).sin()
class Mise(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
return x * torch.tanh(F.softplus(x))
from typing import List
from .weight_init import *
from .fn import *
class BatchLinear(nn.Linear):
'''
A linear meta-layer that can deal with batched weight matrices and biases,
as for instance output by a hypernetwork.
'''
__doc__ = nn.Linear.__doc__
def forward(self, input, params=None):
# if params is None:
# params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
weight = params['weight']
output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
output += bias.unsqueeze(-2)
return output
class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0,
with_ln: bool = True):
super().__init__()
nls_and_inits = {
'sine': (Sine, init_weights_sine),
'relu': (nn.ReLU, init_weights_relu),
'leakyrelu': (nn.LeakyReLU, init_weights_leakyrelu),
'sigmoid': (nn.Sigmoid, init_weights_xavier),
'tanh': (nn.Tanh, init_weights_xavier),
'selu': (nn.SELU, init_weights_selu),
'softplus': (nn.Softplus, init_weights_trunc_normal),
'elu': (nn.ELU, init_weights_elu),
'softmax': (nn.Softmax, init_weights_softmax),
'logsoftmax': (nn.LogSoftmax, init_weights_softmax),
'mise': (Mise, init_weights_xavier),
'linear': (nn.Identity, init_weights_xavier)
}
nl_cls, weight_init_fn = nls_and_inits[act]
self.net = [nn.Linear(in_chns + skip_chns, out_chns)]
if with_ln:
self.net += [nn.LayerNorm([out_chns])]
self.net += [nl_cls()]
self.net = nn.Sequential(*self.net)
self.skip = skip_chns != 0
self.with_ln = with_ln
self.net.apply(weight_init_fn)
def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
return self.net(torch.cat([x0, x], -1) if self.skip else x)
def __repr__(self) -> str:
s = f"{self.net[0].in_features} -> {self.net[0].out_features}, "\
+ ", ".join(module.__class__.__name__ for module in self.net[1:])
return f"{self._get_name()}({s})"
class FcBlock(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
skips: List[int] = [], act: str = 'relu', out_act='linear', with_ln=True):
"""
Initialize a full-connection net
:kwarg in_chns: channels of input
:kwarg out_chns: channels of output
:kwarg nf: number of features in each hidden layer
:kwarg n_layers: number of layers
:kwarg skips: create skip connections from input to layers in this list
"""
super().__init__()
self.layers = nn.ModuleList([
FcLayer(in_chns, nf, act, with_ln=with_ln)] + [
FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0, with_ln=with_ln)
for i in range(1, n_layers)
])
if out_chns:
self.layers.append(FcLayer(nf, out_chns, out_act, with_ln=False))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = x
for layer in self.layers:
x = layer(x, x0)
return x
def __repr__(self):
lines = []
for key, module in self.layers._modules.items():
mod_str = repr(module)
mod_str = nn.modules.module._addindent(mod_str, 2)
lines.append('(' + key + '): ' + mod_str)
main_str = self._get_name() + '('
if lines:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
import torch
import torch.nn as nn
from utils import math
def init_weights_trunc_normal(m):
# For PINNet, Raissi et al. 2019
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
if isinstance(m, nn.Linear):
fan_in = m.weight.size(1)
fan_out = m.weight.size(0)
std = math.sqrt(2.0 / float(fan_in + fan_out))
mean = 0.
# initialize with the same behavior as tf.truncated_normal
# "The generated values follow a normal distribution with specified mean and
# standard deviation, except that values whose magnitude is more than 2
# standard deviations from the mean are dropped and re-picked."
_no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)
def init_weights_relu(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu')
if m.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(m.bias, -bound, bound)
def init_weights_leakyrelu(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
if m.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(m.bias, -bound, bound)
def init_weights_selu(m):
if isinstance(m, nn.Linear):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
def init_weights_elu(m):
if isinstance(m, nn.Linear):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))
def init_weights_xavier(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def init_weights_sine(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-math.sqrt(6 / num_input) / 30, math.sqrt(6 / num_input) / 30)
def init_weights_sine_first_layer(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-1 / num_input, 1 / num_input)
def init_weights_softmax(m):
with torch.no_grad():
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.constant_(m.bias, val=0)
from typing import Tuple
import torch import torch
import torch.nn as nn
from utils import device
from utils.constants import *
from .generic import *
from .generic import *
from utils import math
from utils.module import Module
class InputEncoder(nn.Module):
def Get(multires, input_dims): class InputEncoder(Module):
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires - 1,
'num_freqs': multires,
}
return InputEncoder(**embed_kwargs)
def __init__(self, **kwargs): def __init__(self, chns, L, cat_input=False):
super().__init__() super().__init__()
self.in_dim = kwargs['input_dims'] emb = torch.exp(torch.arange(L, dtype=torch.float) * math.log(2.))
self.num_freqs = kwargs['num_freqs']
self.out_dim = self.in_dim * self.num_freqs * 2 self.emb = nn.Parameter(emb, requires_grad=False)
self.include_input = kwargs['include_input'] or self.num_freqs == 0 self.in_dim = chns
if self.include_input: self.out_dim = chns * (L * 2 + cat_input)
self.out_dim += self.in_dim self.cat_input = cat_input
if self.num_freqs > 0:
self.freq_bands = 2. ** torch.linspace(0, kwargs['max_freq_log2'], self.num_freqs, def forward(self, x: torch.Tensor, angular=False):
device=device.default()) sizes = x.size()
x0 = x
def forward(self, input: torch.Tensor) -> torch.Tensor:
if angular:
x = torch.acos(x.clamp(-1, 1))
x = x[..., None] @ self.emb[None]
x = torch.cat([torch.sin(x), torch.cos(x)], -1)
x = x.flatten(-2)
if self.cat_input:
x = torch.cat([x0, x], -1)
return x
def extra_repr(self) -> str:
return f'in={self.in_dim}, out={self.out_dim}, cat_input={self.cat_input}'
class IntegratedPosEncoder(Module):
def __init__(self, chns, L, shape: str, cat_input=False):
super.__init__()
self.shape = shape
def _lift_gaussian(self, d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
r_var: torch.Tensor, diag: bool):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_sq = d**2
d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)
if diag:
d_outer_diag = d_sq
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def _conical_frustum_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, base_radius: float,
diag: bool, stable: bool = True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
""" """
Encode the given input to R^D space if stable:
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
(3 * mu**2 + hw**2)**2)
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def _cylinder_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
:param input `Tensor(..., C)`: input Args:
:return `Tensor(..., D): encoded d: torch.float32 3-vector, the axis of the cylinder
:rtype: torch.Tensor t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
""" """
if self.num_freqs > 0: t_mean = (t0 + t1) / 2
input_ = input.unsqueeze(-2) # to (..., 1, C) r_var = radius**2 / 4
input_ = input_ * self.freq_bands[:, None] # (..., Ne, C) t_var = (t1 - t0)**2 / 12
output = torch.stack([input_.sin(), input_.cos()], dim=-2).flatten(-3) return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
if self.include_input:
output = torch.cat([input, output], dim=-1) def cast_rays(self, t_vals: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor,
rays_r: torch.Tensor, diag: bool = True):
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it.
Args:
t_vals: float array, the "fencepost" distances along the ray.
rays_o: float array, the ray origin coordinates.
rays_d: float array, the ray direction vectors.
radii: float array, the radii (base radii for cones) of the rays.
ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
diag: boolean, whether or not the covariance matrices should be diagonal.
Returns:
a tuple of arrays of means and covariances.
"""
t0 = t_vals[..., :-1]
t1 = t_vals[..., 1:]
if self.shape == 'cone':
gaussian_fn = self._conical_frustum_to_gaussian
elif self.shape == 'cylinder':
gaussian_fn = self._cylinder_to_gaussian
else: else:
output = input assert False
return output means, covs = gaussian_fn(rays_d, t0, t1, rays_r, diag)
\ No newline at end of file means = means + rays_o[..., None, :]
return means, covs
def integrated_pos_enc(x_coord: Tuple[torch.Tensor, torch.Tensor], min_deg: int, max_deg: int,
diag: bool = True):
"""Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].
Args:
x_coord: a tuple containing: x, torch.ndarray, variables to be encoded. Should
be in [-pi, pi]. x_cov, torch.ndarray, covariance matrices for `x`.
min_deg: int, the min degree of the encoding.
max_deg: int, the max degree of the encoding.
diag: bool, if true, expects input covariances to be diagonal (full
otherwise).
Returns:
encoded: torch.ndarray, encoded variables.
"""
if diag:
x, x_cov_diag = x_coord
scales = torch.tensor([2**i for i in range(min_deg, max_deg)], device=x.device)[:, None]
shape = list(x.shape[:-1]) + [-1]
y = torch.reshape(x[..., None, :] * scales, shape)
y_var = torch.reshape(x_cov_diag[..., None, :] * scales**2, shape)
else:
x, x_cov = x_coord
num_dims = x.shape[-1]
basis = torch.cat([
2**i * torch.eye(num_dims, device=x.device)
for i in range(min_deg, max_deg)
], 1)
y = torch.matmul(x, basis)
# Get the diagonal of a covariance matrix (ie, variance). This is equivalent
# to jax.vmap(torch.diag)((basis.T @ covs) @ basis).
y_var = (torch.matmul(x_cov, basis) * basis).sum(-2)
return math.expected_sin(
torch.cat([y, y + 0.5 * math.pi], -1),
torch.cat([y_var] * 2, -1))[0]
from itertools import cycle
from math import ceil
from typing import Dict, Tuple, Union
import torch import torch
import torch.nn as nn from itertools import cycle
from typing import Dict, Set, Tuple, Union
from utils.type import NetInput, ReturnData
from utils.constants import *
from utils.perf import perf
from .generic import * from .generic import *
from .sampler import Samples from model.base import BaseModel
from utils import math
from utils.module import Module
from utils.perf import checkpoint, perf
from utils.samples import Samples
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
...@@ -41,7 +43,7 @@ def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: f ...@@ -41,7 +43,7 @@ def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: f
return 1.0 - torch.exp(-energies) return 1.0 - torch.exp(-energies)
class AlphaComposition(nn.Module): class AlphaComposition(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -58,7 +60,7 @@ class AlphaComposition(nn.Module): ...@@ -58,7 +60,7 @@ class AlphaComposition(nn.Module):
# Compute weight for RGB of each sample along each ray. A cumprod() is # Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this # used to express the idea of the ray not having reflected up to this
# sample yet. # sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2) one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + math.tiny, dim=-2)
one_minus_alpha = torch.cat([ one_minus_alpha = torch.cat([
torch.ones_like(one_minus_alpha[..., :1, :]), torch.ones_like(one_minus_alpha[..., :1, :]),
one_minus_alpha one_minus_alpha
...@@ -80,23 +82,25 @@ class AlphaComposition(nn.Module): ...@@ -80,23 +82,25 @@ class AlphaComposition(nn.Module):
} }
class VolumnRenderer(nn.Module): class VolumnRenderer(Module):
class States: class States:
kernel: nn.Module kernel: BaseModel
samples: Samples samples: Samples
hit_mask: torch.Tensor
early_stop_tolerance: float early_stop_tolerance: float
outputs: Set[str]
hit_mask: torch.Tensor
N: int N: int
P: int P: int
device: torch.device
colors: torch.Tensor colors: torch.Tensor
diffuses: torch.Tensor densities: torch.Tensor
speculars: torch.Tensor
energies: torch.Tensor energies: torch.Tensor
weights: torch.Tensor weights: torch.Tensor
cum_energies: torch.Tensor cum_energies: torch.Tensor
exp_energies: torch.Tensor exp_energies: torch.Tensor
tot_evaluations: Dict[str, int] tot_evaluations: Dict[str, int]
chunk: Tuple[slice, slice] chunk: Tuple[slice, slice]
...@@ -112,16 +116,18 @@ class VolumnRenderer(nn.Module): ...@@ -112,16 +116,18 @@ class VolumnRenderer(nn.Module):
def end(self) -> int: def end(self) -> int:
return self.chunk[1].stop return self.chunk[1].stop
def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None: def __init__(self, kernel: BaseModel, samples: Samples, early_stop_tolerance: float,
outputs: Set[str]) -> None:
self.kernel = kernel self.kernel = kernel
self.samples = samples self.samples = samples
self.early_stop_tolerance = early_stop_tolerance self.early_stop_tolerance = early_stop_tolerance
self.outputs = outputs
N, P = samples.size N, P = samples.size
self.hit_mask = samples.voxel_indices != -1 # (N, P) self.device = self.samples.device
self.hit_mask = samples.voxel_indices != -1 # (N, P) | bool
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device) self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device) self.densities = torch.zeros(N, P, 1, device=samples.device)
self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device) self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device) self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device) self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
...@@ -130,12 +136,14 @@ class VolumnRenderer(nn.Module): ...@@ -130,12 +136,14 @@ class VolumnRenderer(nn.Module):
self.N, self.P = N, P self.N, self.P = N, P
self.chunk_id = -1 self.chunk_id = -1
def n_hits(self, start: int = None, end: int = None) -> int: def n_hits(self, index: Union[int, slice] = None) -> int:
if start is None: if not isinstance(self.hit_mask, torch.Tensor):
if index is not None:
return self.N * self.colors[:, index].shape[1]
return self.N * self.P
if index is None:
return self.hit_mask.count_nonzero().item() return self.hit_mask.count_nonzero().item()
if end is None: return self.hit_mask[:, index].count_nonzero().item()
return self.hit_mask[:, start].count_nonzero().item()
return self.hit_mask[:, start:end].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int): def accumulate_tot_evaluations(self, key: str, n: int):
if key not in self.tot_evaluations: if key not in self.tot_evaluations:
...@@ -152,21 +160,31 @@ class VolumnRenderer(nn.Module): ...@@ -152,21 +160,31 @@ class VolumnRenderer(nn.Module):
self.chunk_id += 1 self.chunk_id += 1
return self return self
def put(self, key: str, values: torch.Tensor, indices: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[slice, slice]]):
if not hasattr(self, key):
new_tensor = torch.zeros(self.N, self.P, values.shape[-1], device=self.device)
setattr(self, key, new_tensor)
tensor: torch.Tensor = getattr(self, key)
# if isinstance(indices[0], torch.Tensor):
# tensor.index_put_(indices, values)
# else:
tensor[indices] = values
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
@perf @perf
def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *, def forward(self, kernel: BaseModel, samples: Samples, *outputs: str,
raymarching_early_stop_tolerance: float = 0, raymarching_early_stop_tolerance: float = 0,
raymarching_chunk_size_or_sections: Union[int, List[int]] = None, raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
**kwargs): **kwargs) -> ReturnData:
""" """
Perform volumn rendering. Perform volumn rendering.
:param kernel: render kernel :param kernel `BaseModel`: render kernel
:param samples `Samples(N, P)`: samples :param samples `Samples(N, P)`: samples
:param extra_outputs `list[str]`: extra items should be contained in the result dict. :param outputs `str...`: items should be contained in the result dict.
Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to [] Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop. :param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
Should between 0 and 1 (0 means no early stop). Defaults to 0 Should between 0 and 1 (0 means no early stop). Defaults to 0
:param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process. :param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
...@@ -179,13 +197,32 @@ class VolumnRenderer(nn.Module): ...@@ -179,13 +197,32 @@ class VolumnRenderer(nn.Module):
print("VolumnRenderer.forward(): # of samples is zero") print("VolumnRenderer.forward(): # of samples is zero")
return None return None
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance) infer_outputs = set()
for key in outputs:
if key == "color":
infer_outputs.add("colors")
infer_outputs.add("densities")
elif key == "specular":
infer_outputs.add("speculars")
infer_outputs.add("densities")
elif key == "diffuse":
infer_outputs.add("diffuses")
infer_outputs.add("densities")
elif key == "depth":
infer_outputs.add("densities")
else:
infer_outputs.add(key)
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance, infer_outputs)
checkpoint("Prepare states object")
if not raymarching_chunk_size_or_sections: if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P] raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \ elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0: raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)] raymarching_chunk_size_or_sections = [
math.ceil(s.P / raymarching_chunk_size_or_sections)
]
if isinstance(raymarching_chunk_size_or_sections, list): if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections chunk_sections = raymarching_chunk_size_or_sections
...@@ -205,60 +242,31 @@ class VolumnRenderer(nn.Module): ...@@ -205,60 +242,31 @@ class VolumnRenderer(nn.Module):
chunk_hits += n_hits chunk_hits += n_hits
self._forward_chunk(s.next_chunk()) self._forward_chunk(s.next_chunk())
ret = { checkpoint("Run forward chunks")
'color': torch.sum(s.colors * s.weights, 1),
'tot_evaluations': s.tot_evaluations ret = {}
} for key in outputs:
for key in extra_outputs: if key == 'color':
if key == 'depth': ret['color'] = torch.sum(s.colors * s.weights, 1)
elif key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1) ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse': elif key == 'diffuse' and hasattr(s, "diffuses"):
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1) ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular': elif key == 'specular' and hasattr(s, "speculars"):
ret['specular'] = torch.sum(s.speculars * s.weights, 1) ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers': elif key == 'layers':
ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1) ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states': elif key == 'states':
ret['states'] = s ret['states'] = s
else: else:
ret[key] = getattr(s, key) if hasattr(s, key):
return ret ret[key] = getattr(s, key)
# if raymarching_chunk_size == 0: checkpoint("Set return data")
# raymarching_chunk_samples = 1
# if raymarching_chunk_samples != 0:
# if isinstance(raymarching_chunk_samples, int):
# raymarching_chunk_samples = repeat(raymarching_chunk_samples,
# ceil(s.P / raymarching_chunk_samples))
# chunk_offset = 0
# for chunk_samples in raymarching_chunk_samples:
# start, end = chunk_offset, chunk_offset + chunk_samples
# n_hits = self._forward_chunk(s, start, end)
# if n_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0
# chunk_offset += chunk_samples
# elif raymarching_chunk_size > 0:
# chunk_offset, chunk_hits = 0, s.n_hits(0)
# for i in range(1, s.P):
# n_hits = s.n_hits(i)
# if chunk_hits + n_hits > raymarching_chunk_size:
# self._forward_chunk(s, chunk_offset, i, chunk_hits)
# if chunk_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0
# n_hits = s.n_hits(i)
# chunk_hits, chunk_offset = 0, i
# chunk_hits += n_hits
# self._forward_chunk(s, chunk_offset, s.P, chunk_hits)
# else:
# self._forward_chunk(s, 0, s.P)
# return self._composite(s, extra_outputs) return ret
# original_depth = samples.get('original_point_depth', None)
# if original_depth is not None:
# results['z'] = (original_depth * probs).sum(-1)
# if getattr(input_fn, "track_max_probs", False) and (not self.training):
# input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs'])
@perf
def _calc_weights(self, s: States): def _calc_weights(self, s: States):
""" """
Calculate weights of samples in composited outputs Calculate weights of samples in composited outputs
...@@ -267,11 +275,13 @@ class VolumnRenderer(nn.Module): ...@@ -267,11 +275,13 @@ class VolumnRenderer(nn.Module):
:param start `int`: chunk's start :param start `int`: chunk's start
:param end `int`: chunk's end :param end `int`: chunk's end
""" """
s.energies[s.chunk] = density2energy(s.densities[s.chunk], s.samples.dists[s.chunk])
s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \ s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last] + s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp() s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk] s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
@perf
def _apply_early_stop(self, s: States): def _apply_early_stop(self, s: States):
""" """
Stop rays whose accumulated opacity are larger than a threshold Stop rays whose accumulated opacity are larger than a threshold
...@@ -279,32 +289,26 @@ class VolumnRenderer(nn.Module): ...@@ -279,32 +289,26 @@ class VolumnRenderer(nn.Module):
:param s `States`: s :param s `States`: s
:param end `int`: chunk's end :param end `int`: chunk's end
""" """
if s.end < s.P and s.early_stop_tolerance > 0: if s.end < s.P and s.early_stop_tolerance > 0 and isinstance(s.hit_mask, torch.Tensor):
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0 s.hit_mask[rays_to_stop, s.end:] = 0
@perf
def _forward_chunk(self, s: States) -> int: def _forward_chunk(self, s: States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N') if isinstance(s.hit_mask, torch.Tensor):
fi_idxs[1].add_(s.start) fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)
if fi_idxs[0].size(0) == 0:
if fi_idxs[0].size(0) == 0: s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] return
return 0 fi_idxs[1].add_(s.start)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
# fi_* means "filtered" by hit mask else:
fi_samples = s.samples[fi_idxs] # N -> N' fi_idxs = s.chunk
# Infer densities and colors fi_outputs = s.kernel.infer(*s.outputs, samples=s.samples[fi_idxs], chunk_id=s.chunk_id)
fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse', for key, value in fi_outputs.items():
chunk_id=s.chunk_id) s.put(key, value, fi_idxs)
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
self._calc_weights(s) self._calc_weights(s)
self._apply_early_stop(s) self._apply_early_stop(s)
...@@ -322,19 +326,19 @@ class DensityFirstVolumnRenderer(VolumnRenderer): ...@@ -322,19 +326,19 @@ class DensityFirstVolumnRenderer(VolumnRenderer):
if fi_idxs[0].size(0) == 0: if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0 return
# fi_* means "filtered" by hit mask # fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N' fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X # For all valid samples: encode X
fi_encoded_x = s.kernel.encode_x(fi_samples) # (N', Ex) density_inputs = s.kernel.input(fi_samples, "x", "f") # (N', Ex)
# Infer densities (shape) # Infer densities (shape)
fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat', density_outputs = s.kernel.infer('densities', 'features', samples=fi_samples,
chunk_id=s.chunk_id) inputs=density_inputs, chunk_id=s.chunk_id)
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists)) s.put('densities', density_outputs['densities'], fi_idxs)
s.accumulate_tot_evaluations("density", fi_idxs[0].size(0)) s.accumulate_tot_evaluations("densities", fi_idxs[0].size(0))
self._calc_weights(s) self._calc_weights(s)
self._apply_early_stop(s) self._apply_early_stop(s)
...@@ -345,16 +349,17 @@ class DensityFirstVolumnRenderer(VolumnRenderer): ...@@ -345,16 +349,17 @@ class DensityFirstVolumnRenderer(VolumnRenderer):
# Update "filtered" tensors # Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs] fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N" fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_encoded_x = fi_encoded_x[fi_mask] # (N", Ex) fi_samples = s.samples[fi_idxs] # N -> N"
fi_color_feats = fi_outputs['color_feat'][fi_mask] fi_features = density_outputs['features'][fi_mask]
color_inputs = s.kernel.input(fi_samples, "d") # (N")
# For all valid samples: encode D color_inputs.x = density_inputs.x[fi_mask]
fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs]) # (N", Ed)
# Infer colors (appearance) # Infer colors (appearance)
fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse', outputs = s.outputs.copy()
chunk_id=s.chunk_id, if 'densities' in outputs:
extras={"color_feats": fi_color_feats}) outputs.remove('densities')
color_outputs = s.kernel.infer(*outputs, samples=fi_samples, inputs=color_inputs,
chunk_id=s.chunk_id, features=fi_features)
# if s.chunk_id == 0: # if s.chunk_id == 0:
# fi_colors[:] *= fi_colors.new_tensor([1, 0, 0]) # fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1: # elif s.chunk_id == 1:
...@@ -363,9 +368,6 @@ class DensityFirstVolumnRenderer(VolumnRenderer): ...@@ -363,9 +368,6 @@ class DensityFirstVolumnRenderer(VolumnRenderer):
# fi_colors[:] *= fi_colors.new_tensor([0, 0, 1]) # fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
# else: # else:
# fi_colors[:] *= fi_colors.new_tensor([1, 1, 0]) # fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
s.colors.index_put_(fi_idxs, fi_outputs['color']) for key, value in color_outputs.items():
if fi_outputs['specular'] is not None: s.put(key, value, fi_idxs)
s.speculars.index_put_(fi_idxs, fi_outputs['specular']) s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
from .space import Space, Voxels
import torch import torch
import torch.nn as nn
from typing import Tuple from typing import Tuple
from .generic import *
from .space import Space
from clib import *
from utils import device from utils import device
from utils import sphere from utils import sphere
from utils.constants import * from utils import misc
from utils import math
from utils.module import Module
from utils.samples import Samples
from utils.perf import perf, checkpoint from utils.perf import perf, checkpoint
from .generic import *
from clib import *
class Bins(object): class Bins(object):
...@@ -38,140 +40,104 @@ class Bins(object): ...@@ -38,140 +40,104 @@ class Bins(object):
self.bounds = self.bounds.to(device) self.bounds = self.bounds.to(device)
class Samples: class Sampler(Module):
pts: torch.Tensor
"""`Tensor(N[, P], 3)`"""
dirs: torch.Tensor
"""`Tensor(N[, P], 3)`"""
depths: torch.Tensor
"""`Tensor(N[, P])`"""
dists: torch.Tensor
"""`Tensor(N[, P])`"""
voxel_indices: torch.Tensor
"""`Tensor(N[, P])`"""
@property
def size(self):
return self.pts.size()[:-1]
@property
def device(self):
return self.pts.device
def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor,
dists: torch.Tensor, voxel_indices: torch.Tensor) -> None:
self.pts = pts
self.dirs = dirs
self.depths = depths
self.dists = dists
self.voxel_indices = voxel_indices
def __getitem__(self, index):
return Samples(
pts=self.pts[index],
dirs=self.dirs[index],
depths=self.depths[index],
dists=self.dists[index],
voxel_indices=self.voxel_indices[index])
def reshape(self, *shape: int): def __init__(self, **kwargs):
return Samples(
pts=self.pts.reshape(*shape, 3),
dirs=self.dirs.reshape(*shape, 3),
depths=self.depths.reshape(*shape),
dists=self.dists.reshape(*shape),
voxel_indices=self.voxel_indices.reshape(*shape))
class Sampler(nn.Module):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs):
""" """
Initialize a Sampler module Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
""" """
super().__init__() super().__init__()
self.lindisp = lindisp self._samples_indices_cached = None
s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
if s_range[1] > s_range[0]:
s_range[0] += 1e-4
s_range[1] -= 1e-4
else:
s_range[0] -= 1e-4
s_range[1] += 1e-4
self.bins = Bins.linspace(s_range, n_samples, device=device.default())
@perf def _sample(self, range: Tuple[float, float], n_rays: int, n_samples: int, perturb: bool,
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, device: torch.device) -> torch.Tensor:
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
""" """
Sample points along rays. return Spherical or Cartesian coordinates, [summary]
specified by `self.shperical`
:param rays_o `Tensor(N, 3)`: rays' origin :param t_range `float, float`: sampling range
:param rays_d `Tensor(N, 3)`: rays' direction :param n_rays `int`: number of rays (B)
:return `Samples(N, P)`: samples :param n_samples `int`: number of samples per ray (P)
:param perturb `bool`: whether perturb sampling
:param device `torch.device`: the device used to create tensors
:return `Tensor(B, P+1)`: sampling bounds of t
""" """
s = self.bins.vals.expand(rays_o.size(0), -1) bounds = torch.linspace(*range, n_samples + 1, device=device) # (P+1)
if perturb_sample: if perturb:
s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s) rand_bounds = torch.cat([
pts, depths = self._get_sample_points(rays_o, rays_d, s) bounds[:1],
voxel_indices = space_module.get_voxel_indices(pts) 0.5 * (bounds[1:] + bounds[:-1]),
valid_rays_mask = voxel_indices.ne(-1).any(dim=-1) bounds[-1:]
return Samples( ])
pts=pts, rand_vals = torch.rand(n_rays, n_samples + 1, device=device)
dirs=rays_d[:, None].expand(-1, depths.size(1), -1), bounds = rand_bounds[:-1] * (1 - rand_vals) + rand_bounds[1:] * rand_vals
depths=depths, else:
dists=self._calc_dists(depths), bounds = bounds[None].expand(n_rays, -1)
voxel_indices=voxel_indices return bounds
)[valid_rays_mask], valid_rays_mask
def _get_sample_points(self, rays_o, rays_d, s):
z = torch.reciprocal(s) if self.lindisp else s
pts = rays_o[:, None] + rays_d[:, None] * z[..., None]
depths = z
return pts, depths
def _calc_dists(self, vals):
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N)
dists = vals[..., 1:] - vals[..., :-1]
last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT
return torch.cat([dists, last_dist], -1)
class SphericalSampler(Sampler): def _get_samples_indices(self, pts: torch.Tensor):
if self._samples_indices_cached is None\
or self._samples_indices_cached.shape[0] < pts.shape[0]\
or self._samples_indices_cached.shape[1] < pts.shape[1]:
self._samples_indices_cached = misc.meshgrid(
*pts.shape[:2], swap_dim=True, device=pts.device)
return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]]
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, @perf
perturb_sample: bool, **kwargs): def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_: Space, *,
sample_range: Tuple[float, float], n_samples: int, lindisp: bool = False,
perturb_sample: bool = True, spherical: bool = False,
**kwargs) -> Tuple[Samples, torch.Tensor]:
""" """
Initialize a Sampler module Sample points along rays.
:param depth_range: depth range for sampler :param rays_o `Tensor(B, 3)`: rays' origin
:param n_samples: count to sample along ray :param rays_d `Tensor(B, 3)`: rays' direction
:param perturb_sample: perturb the sample depths :param sample_range `float, float`: sampling range
:param lindisp: If True, sample linearly in inverse depth rather than in depth :param n_samples `int`: number of samples per ray
:param lindisp `bool`: whether sample linearly in disparity space (1/depth)
:param perturb_sample `bool`: whether perturb sampling
:return `Samples(B, P)`: samples
""" """
super().__init__(sample_range=sample_range, n_samples=n_samples, if spherical:
perturb_sample=perturb_sample, lindisp=False) t_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
rays_o.device)
def _get_sample_points(self, rays_o, rays_d, s): t0, t1 = t_bounds[:, :-1], t_bounds[:, 1:] # (B, P)
r = torch.reciprocal(s) t = (t0 + t1) * .5
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r)
pts = sphere.cartesian2spherical(pts, inverse_r=True) p, z = sphere.ray_sphere_intersect(rays_o, rays_d, t.reciprocal())
return pts, depths p = sphere.cartesian2spherical(p, inverse_r=True)
vidxs = space_.get_voxel_indices(p)
return Samples(
pts=p,
dirs=rays_d[:, None].expand(-1, n_samples, -1),
depths=z,
dists=(t1 + math.tiny).reciprocal() - t0.reciprocal(),
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=t
)
else:
sample_range = (1 / sample_range[0], 1 / sample_range[1]) if lindisp else sample_range
z_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
rays_o.device)
if lindisp:
z_bounds = z_bounds.reciprocal()
z0, z1 = z_bounds[:, :-1], z_bounds[:, 1:] # (B, P)
z = (z0 + z1) * .5
p = rays_o[:, None] + rays_d[:, None] * z[..., None]
vidxs = space_.get_voxel_indices(p)
return Samples(
pts=p,
dirs=rays_d[:, None].expand(-1, n_samples, -1),
depths=z,
dists=z1 - z0,
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=z
)
class PdfSampler(nn.Module): class PdfSampler(Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
spherical: bool, lindisp: bool, **kwargs): spherical: bool, lindisp: bool, **kwargs):
...@@ -226,7 +192,7 @@ class PdfSampler(nn.Module): ...@@ -226,7 +192,7 @@ class PdfSampler(nn.Module):
:return `Tensor(..., N)`: samples :return `Tensor(..., N)`: samples
''' '''
# Get pdf # Get pdf
weights = weights + TINY_FLOAT # prevent nans weights = weights + math.tiny # prevent nans
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
cdf = torch.cat([ cdf = torch.cat([
torch.zeros_like(pdf[..., :1]), torch.zeros_like(pdf[..., :1]),
...@@ -256,17 +222,17 @@ class PdfSampler(nn.Module): ...@@ -256,17 +222,17 @@ class PdfSampler(nn.Module):
# fix numeric issue # fix numeric issue
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N] denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N]
denom = torch.where(denom < TINY_FLOAT, torch.ones_like(denom), denom) denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_FLOAT) samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + math.tiny)
return samples return samples
class VoxelSampler(nn.Module): class VoxelSampler(Module):
def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs): def __init__(self, *, sample_step: float, **kwargs):
""" """
Initialize a VoxelSampler module Initialize a VoxelSampler module
...@@ -274,11 +240,10 @@ class VoxelSampler(nn.Module): ...@@ -274,11 +240,10 @@ class VoxelSampler(nn.Module):
:param step_size: step size :param step_size: step size
""" """
super().__init__() super().__init__()
self.perturb_sample = perturb_sample
self.sample_step = sample_step self.sample_step = sample_step
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *,
**kwargs) -> Tuple[Samples, torch.Tensor]: perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
""" """
[summary] [summary]
...@@ -312,13 +277,13 @@ class VoxelSampler(nn.Module): ...@@ -312,13 +277,13 @@ class VoxelSampler(nn.Module):
invalid_samples_mask = rays_step >= rays_steps invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \ samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5) # (N', P) * (torch.rand_like(samples_min_depth) if perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P) samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[ samples_voxel_index = voxel_indices[
ray_index_list[:, None], ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth) torch.searchsorted(max_depths, samples_depth)
] # (N', P) ] # (N', P)
samples_depth[invalid_samples_mask] = HUGE_FLOAT samples_depth[invalid_samples_mask] = math.huge
samples_dist[invalid_samples_mask] = 0 samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1 samples_voxel_index[invalid_samples_mask] = -1
...@@ -332,8 +297,8 @@ class VoxelSampler(nn.Module): ...@@ -332,8 +297,8 @@ class VoxelSampler(nn.Module):
), valid_rays_mask ), valid_rays_mask
@perf @perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
**kwargs) -> Tuple[Samples, torch.Tensor]: space: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
""" """
[summary] [summary]
...@@ -342,7 +307,7 @@ class VoxelSampler(nn.Module): ...@@ -342,7 +307,7 @@ class VoxelSampler(nn.Module):
:param step_size `float`: [description] :param step_size `float`: [description]
:return `Samples(N, P)`: [description] :return `Samples(N, P)`: [description]
""" """
intersections = space_module.ray_intersect(rays_o, rays_d, 100) intersections = space.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0 valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask] rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask] rays_d = rays_d[valid_rays_mask]
...@@ -363,11 +328,11 @@ class VoxelSampler(nn.Module): ...@@ -363,11 +328,11 @@ class VoxelSampler(nn.Module):
# sample points and use middle point approximation # sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling( sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample) pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
sampled_indices = sampled_indices.long() sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1) invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0) sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT) sampled_depths.masked_fill_(invalid_idx_mask, math.huge)
checkpoint("Inverse CDF sampling") checkpoint("Inverse CDF sampling")
......
import torch import torch
from typing import List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from torch import nn
from clib import *
from model.utils import load
from utils.module import Module
from utils.geometry import * from utils.geometry import *
from utils.constants import *
from utils.voxels import * from utils.voxels import *
from utils.perf import perf from utils.perf import perf
from clib import * from utils.env import get_env
class Intersections: class Intersections:
...@@ -41,32 +42,42 @@ class Intersections: ...@@ -41,32 +42,42 @@ class Intersections:
hits=self.hits[index]) hits=self.hits[index])
class Space(nn.Module): class Space(Module):
bbox: Union[torch.Tensor, None] bbox: Optional[torch.Tensor]
"""`Tensor(2, 3)` Bounding box""" """`Tensor(2, 3)` Bounding box"""
def __init__(self, *, bbox: List[float] = None, **kwargs): @property
def dims(self) -> int:
"""`int` Number of dimensions"""
return self.bbox.shape[1] if self.bbox is not None else 3
@staticmethod
def create(args: dict) -> 'Space':
if 'space' not in args:
return Space(**args)
if args['space'] == 'octree':
return Octree(**args)
if args['space'] == 'voxels':
return Voxels(**args)
return load(args['space']).space
def __init__(self, clone_src: "Space" = None, *, bbox: List[float] = None, **kwargs):
super().__init__() super().__init__()
if bbox is None: if clone_src:
self.bbox = None self.device = clone_src.device
self.register_temp('bbox', clone_src.bbox)
else: else:
self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False) self.register_temp('bbox', None if not bbox else torch.tensor(bbox).reshape(2, -1))
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
raise NotImplementedError
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
raise NotImplementedError
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
raise NotImplementedError raise NotImplementedError
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
if self.bbox is None:
return 0
voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long) voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
if self.bbox is not None: out_bbox = get_out_of_bound_mask(pts, self.bbox) # (N...)
out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1) # (N...) voxel_indices[out_bbox] = -1
voxel_indices[out_bbox] = -1
return voxel_indices return voxel_indices
@torch.no_grad() @torch.no_grad()
...@@ -74,9 +85,13 @@ class Space(nn.Module): ...@@ -74,9 +85,13 @@ class Space(nn.Module):
raise NotImplementedError() raise NotImplementedError()
@torch.no_grad() @torch.no_grad()
def split(self): def split(self) -> Tuple[int, int]:
raise NotImplementedError() raise NotImplementedError()
@torch.no_grad()
def clone(self):
return Space(self)
class Voxels(Space): class Voxels(Space):
steps: torch.Tensor steps: torch.Tensor
...@@ -92,12 +107,11 @@ class Voxels(Space): ...@@ -92,12 +107,11 @@ class Voxels(Space):
"""`Tensor(M, 8)` Voxel corner indices""" """`Tensor(M, 8)` Voxel corner indices"""
voxel_indices_in_grid: torch.Tensor voxel_indices_in_grid: torch.Tensor
"""`Tensor(G)` Indices in voxel list or -1 for pruned space""" """`Tensor(G)` Indices in voxel list or -1 for pruned space
@property Note that the first element is perserved for 'invalid voxel'(-1), so the grid
def dims(self) -> int: index should be offset by 1 before querying for corresponding voxel index.
"""`int` Number of dimensions""" """
return self.steps.size(0)
@property @property
def n_voxels(self) -> int: def n_voxels(self) -> int:
...@@ -109,30 +123,52 @@ class Voxels(Space): ...@@ -109,30 +123,52 @@ class Voxels(Space):
"""`int` Number of corners""" """`int` Number of corners"""
return self.corners.size(0) return self.corners.size(0)
@property
def n_grids(self) -> int:
"""`int` Number of grids, i.e. steps[0] * steps[1] * ... * steps[D]"""
return self.steps.prod().item()
@property @property
def voxel_size(self) -> torch.Tensor: def voxel_size(self) -> torch.Tensor:
"""`Tensor(3)` Voxel size""" """`Tensor(3)` Voxel size"""
return (self.bbox[1] - self.bbox[0]) / self.steps return (self.bbox[1] - self.bbox[0]) / self.steps
@property @property
def device(self) -> torch.device: def corner_embeddings(self) -> Dict[str, torch.nn.Embedding]:
return self.voxels.device return {name[4:]: emb for name, emb in self.named_modules() if name.startswith("emb_")}
def __init__(self, *, voxel_size: float = None, @property
steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None: def voxel_embeddings(self) -> Dict[str, torch.nn.Embedding]:
super().__init__(**kwargs) return {name[5:]: emb for name, emb in self.named_modules() if name.startswith("vemb_")}
if self.bbox is None:
raise ValueError("Missing argument 'bbox'") def __init__(self, clone_src: "Voxels" = None, *, bbox: List[float] = None,
if voxel_size is not None: voxel_size: float = None, steps: Union[torch.Tensor, Tuple[int, ...]] = None,
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size)) **kwargs) -> None:
super().__init__(clone_src, bbox=bbox, **kwargs)
if clone_src:
self.register_buffer('steps', clone_src.steps)
self.register_buffer('voxels', clone_src.voxels)
self.register_buffer("corners", clone_src.corners)
self.register_buffer("corner_indices", clone_src.corner_indices)
self.register_buffer('voxel_indices_in_grid', clone_src.voxel_indices_in_grid)
else: else:
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long)) if self.bbox is None:
self.register_buffer('voxels', init_voxels(self.bbox, self.steps)) raise ValueError("Missing argument 'bbox'")
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps) if voxel_size is not None:
self.register_buffer("corners", corners) self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
self.register_buffer("corner_indices", corner_indices) else:
self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels)) self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
self._register_load_state_dict_pre_hook(self._before_load_state_dict) self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(-1, self.n_voxels))
def clone(self):
return Voxels(self)
def to_vi(self, gi: torch.Tensor) -> torch.Tensor:
return self.voxel_indices_in_grid[gi + 1]
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
""" """
...@@ -144,7 +180,7 @@ class Voxels(Space): ...@@ -144,7 +180,7 @@ class Voxels(Space):
""" """
if self.get_embedding(name) is not None: if self.get_embedding(name) is not None:
raise KeyError(f"Embedding '{name}' already existed") raise KeyError(f"Embedding '{name}' already existed")
emb = torch.nn.Embedding(self.n_corners, n_dims, device=self.device) emb = torch.nn.Embedding(self.n_corners, n_dims).to(self.device)
setattr(self, f'emb_{name}', emb) setattr(self, f'emb_{name}', emb)
return emb return emb
...@@ -152,8 +188,9 @@ class Voxels(Space): ...@@ -152,8 +188,9 @@ class Voxels(Space):
return getattr(self, f'emb_{name}', None) return getattr(self, f'emb_{name}', None)
def set_embedding(self, weight: torch.Tensor, name: str = 'default'): def set_embedding(self, weight: torch.Tensor, name: str = 'default'):
emb = torch.nn.Embedding(*weight.shape, _weight=weight, device=self.device) emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device)
setattr(self, f'emb_{name}', emb) setattr(self, f'emb_{name}', emb)
return emb
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor, def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor: name: str = 'default') -> torch.Tensor:
...@@ -173,6 +210,41 @@ class Voxels(Space): ...@@ -173,6 +210,41 @@ class Voxels(Space):
p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel
return trilinear_interp(p, emb(corner_indices)) return trilinear_interp(p, emb(corner_indices))
def create_voxel_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
"""
Create a embedding on voxels.
:param name `str`: embedding name
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxels
"""
if self.get_voxel_embedding(name) is not None:
raise KeyError(f"Embedding '{name}' already existed")
emb = torch.nn.Embedding(self.n_voxels, n_dims).to(self.device)
setattr(self, f'vemb_{name}', emb)
return emb
def get_voxel_embedding(self, name: str = 'default') -> torch.nn.Embedding:
return getattr(self, f'vemb_{name}', None)
def set_voxel_embedding(self, weight: torch.Tensor, name: str = 'default'):
emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device)
setattr(self, f'vemb_{name}', emb)
return emb
def extract_voxel_embedding(self, voxel_indices: torch.Tensor, name: str = 'default') -> torch.Tensor:
"""
Extract embedding values at given voxels.
:param voxel_indices `Tensor(N)`: voxel indices
:param name `str`: embedding name, default to 'default'
:return `Tensor(N, X)`: extracted values
"""
emb = self.get_voxel_embedding(name)
if emb is None:
raise KeyError(f"Embedding '{name}' doesn't exist")
return emb(voxel_indices)
@perf @perf
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
""" """
...@@ -192,8 +264,8 @@ class Voxels(Space): ...@@ -192,8 +264,8 @@ class Voxels(Space):
hits = n_max_hits - invalid_voxel_mask.sum(-1) hits = n_max_hits - invalid_voxel_mask.sum(-1)
# Sort intersections according to their depths # Sort intersections according to their depths
min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT) min_depths.masked_fill_(invalid_voxel_mask, math.huge)
max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT) max_depths.masked_fill_(invalid_voxel_mask, math.huge)
min_depths, sorted_idx = min_depths.sort(dim=-1) min_depths, sorted_idx = min_depths.sort(dim=-1)
max_depths = max_depths.gather(-1, sorted_idx) max_depths = max_depths.gather(-1, sorted_idx)
voxel_indices = voxel_indices.gather(-1, sorted_idx) voxel_indices = voxel_indices.gather(-1, sorted_idx)
...@@ -215,52 +287,102 @@ class Voxels(Space): ...@@ -215,52 +287,102 @@ class Voxels(Space):
:param pts `Tensor(N..., 3)`: points :param pts `Tensor(N..., 3)`: points
:return `Tensor(N...)`: corresponding voxel indices :return `Tensor(N...)`: corresponding voxel indices
""" """
grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps) gi = to_grid_indices(pts, self.bbox, self.steps)
grid_indices[out_mask] = 0 return self.to_vi(gi)
voxel_indices = self.voxel_indices_in_grid[grid_indices]
voxel_indices[out_mask] = -1 @perf
return voxel_indices def get_corners(self, vidxs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
vidxs = vidxs.unique()
if vidxs[0] == -1:
vidxs = vidxs[1:]
cidxs = self.corner_indices[vidxs].unique()
fi_cidxs = torch.full([self.n_corners], -1, dtype=torch.long, device=self.device)
fi_cidxs[cidxs] = torch.arange(cidxs.shape[0], device=self.device)
fi_corner_indices = fi_cidxs[self.corner_indices]
fi_corners = self.corners[cidxs]
return fi_corner_indices, fi_corners
@torch.no_grad() @torch.no_grad()
def split(self) -> None: def split(self) -> Tuple[int, int]:
""" """
Split voxels into smaller voxels with half size. Split voxels into smaller voxels with half size.
""" """
# Calculate new voxels and corners
new_steps = self.steps * 2 new_steps = self.steps * 2
new_voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\ new_voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
.reshape(-1, 3) .reshape(-1, 3)
new_corners, new_corner_indices = get_corners(new_voxels, self.bbox, new_steps) new_corners, new_corner_indices = get_corners(new_voxels, self.bbox, new_steps)
# Calculate new embeddings through trilinear interpolation # Split corner embeddings through interpolation
grid_indices_of_new_corners = to_flat_indices( corner_embs = self.corner_embeddings
to_grid_coords(new_corners, self.bbox, steps=self.steps).min(self.steps - 1), if len(corner_embs) > 0:
self.steps) gi_of_new_corners = to_grid_indices(new_corners, self.bbox, self.steps)
voxel_indices_of_new_corners = self.voxel_indices_in_grid[grid_indices_of_new_corners] vi_of_new_corners = self.to_vi(gi_of_new_corners)
for name, _ in self.named_modules(): for name, emb in corner_embs.items():
if not name.startswith("emb_"): new_emb_weight = self.extract_embedding(new_corners, vi_of_new_corners, name=name)
continue self.set_embedding(new_emb_weight, name=name)
new_emb_weight = self.extract_embedding(new_corners, voxel_indices_of_new_corners, # Remove old embedding weight and related state from optimizer
name=name[4:]) self._update_optimizer(emb.weight)
self.set_embedding(new_emb_weight, name=name[4:])
# Split voxel embeddings
self._update_voxel_embeddings(lambda val: torch.repeat_interleave(val, 8, dim=0))
# Apply new tensors # Apply new tensors
self.steps = new_steps self.steps = new_steps
self.voxels = new_voxels self.voxels = new_voxels
self.corners = new_corners self.corners = new_corners
self.corner_indices = new_corner_indices self.corner_indices = new_corner_indices
self._update_voxel_indices_in_grid() self._update_gi2vi()
return self.n_voxels // 8, self.n_voxels return self.n_voxels // 8, self.n_voxels
@torch.no_grad() @torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
self.voxels = self.voxels[keeps] self.voxels = self.voxels[keeps]
self.corner_indices = self.corner_indices[keeps] self.corner_indices = self.corner_indices[keeps]
self._update_voxel_indices_in_grid() self._update_gi2vi()
# Prune voxel embeddings
self._update_voxel_embeddings(lambda val: val[keeps])
return keeps.size(0), keeps.sum().item() return keeps.size(0), keeps.sum().item()
def _update_voxel_embeddings(self, update_fn):
for name, emb in self.voxel_embeddings.items():
new_emb = self.set_voxel_embedding(update_fn(emb.weight), name)
self._update_optimizer(emb.weight, new_emb.weight, update_fn)
def _update_optimizer(self, old_param: nn.Parameter, new_param: nn.Parameter, update_fn):
optimizer = get_env()["trainer"].optimizer
if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
# Update related states in optimizer
if old_param in optimizer.state:
if new_param is not None:
# Transfer state from old parameter to new parameter
state = optimizer.state[old_param]
state.update({
key: update_fn(state[key])
for key in ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'] if key in state
})
optimizer.state[new_param] = state
# Remove state of old parameter
optimizer.state.pop(old_param)
# Update parameter list in optimizer
for group in optimizer.param_groups:
try:
if new_param is not None:
# Replace old parameter with new one
idx = group['params'].index(old_param)
group['params'][idx] = new_param
else:
# Or just remove old parameter if new parameter is not specified
group['params'].remove(old_param)
except Exception:
pass
def n_voxels_along_dim(self, dim: int) -> torch.Tensor: def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
sum_dims = [val for val in range(self.dims) if val != dim] sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims) return self.voxel_indices_in_grid[1:].reshape(*self.steps).ne(-1).sum(sum_dims)
def balance_cut(self, dim: int, n_parts: int) -> List[int]: def balance_cut(self, dim: int, n_parts: int) -> List[int]:
n_voxels_list = self.n_voxels_along_dim(dim) n_voxels_list = self.n_voxels_along_dim(dim)
...@@ -269,10 +391,11 @@ class Voxels(Space): ...@@ -269,10 +391,11 @@ class Voxels(Space):
part = 1 part = 1
offset = 0 offset = 0
for i in range(len(cdf)): for i in range(len(cdf)):
if cdf[i] >= part: if cdf[i] > part:
bins.append(i + 1 - offset) bins.append(i - offset)
offset = i + 1 offset = i
part = int(cdf[i]) + 1 part = int(cdf[i]) + 1
bins.append(len(cdf) - offset)
return bins return bins
def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -299,17 +422,16 @@ class Voxels(Space): ...@@ -299,17 +422,16 @@ class Voxels(Space):
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d) return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_voxel_indices_in_grid(self): def _update_gi2vi(self):
""" """
Update voxel indices in grid. Update voxel indices in grid.
""" """
grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps) gi = to_grid_indices(self.voxels, self.bbox, self.steps)
self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1) # Perserve the first element in voxel_indices_in_grid for 'invalid voxel'(-1)
self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device) self.voxel_indices_in_grid = gi.new_full([self.n_grids + 1], -1)
self.voxel_indices_in_grid[gi + 1] = torch.arange(self.n_voxels, device=self.device)
@torch.no_grad() def _before_load_state_dict(self, state_dict, prefix, *args):
def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs):
# Handle buffers # Handle buffers
for name, buffer in self.named_buffers(recurse=False): for name, buffer in self.named_buffers(recurse=False):
if name in self._non_persistent_buffers_set: if name in self._non_persistent_buffers_set:
...@@ -320,12 +442,17 @@ class Voxels(Space): ...@@ -320,12 +442,17 @@ class Voxels(Space):
for name, module in self.named_modules(): for name, module in self.named_modules():
if name.startswith('emb_'): if name.startswith('emb_'):
setattr(self, name, torch.nn.Embedding(self.n_corners, module.embedding_dim)) setattr(self, name, torch.nn.Embedding(self.n_corners, module.embedding_dim))
if name.startswith('vemb_'):
setattr(self, name, torch.nn.Embedding(self.n_voxels, module.embedding_dim))
def _after_load_state_dict(self):
self._update_gi2vi()
class Octree(Voxels): class Octree(Voxels):
def __init__(self, **kwargs) -> None: def __init__(self, clone_src: "Octree" = None, **kwargs) -> None:
super().__init__(**kwargs) super().__init__(clone_src, **kwargs)
self.nodes_cached = None self.nodes_cached = None
self.tree_cached = None self.tree_cached = None
......
MIT License
Copyright (c) 2020 bmild
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# NeRF: Neural Radiance Fields
### [Project Page](http://tancik.com/nerf) | [Video](https://youtu.be/JuH79E8rdKc) | [Paper](https://arxiv.org/abs/2003.08934) | [Data](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
[![Open Tiny-NeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb)<br>
Tensorflow implementation of optimizing a neural representation for a single scene and rendering new views.<br><br>
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
[Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*<sup>1</sup>,
[Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*<sup>1</sup>,
[Matthew Tancik](http://tancik.com/)\*<sup>1</sup>,
[Jonathan T. Barron](http://jonbarron.info/)<sup>2</sup>,
[Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)<sup>3</sup>,
[Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)<sup>1</sup> <br>
<sup>1</sup>UC Berkeley, <sup>2</sup>Google Research, <sup>3</sup>UC San Diego
\*denotes equal contribution
in ECCV 2020 (Oral Presentation, Best Paper Honorable Mention)
<img src='imgs/pipeline.jpg'/>
## TL;DR quickstart
To setup a conda environment, download example training data, begin the training process, and launch Tensorboard:
```
conda env create -f environment.yml
conda activate nerf
bash download_example_data.sh
python run_nerf.py --config config_fern.txt
tensorboard --logdir=logs/summaries --port=6006
```
If everything works without errors, you can now go to `localhost:6006` in your browser and watch the "Fern" scene train.
## Setup
Python 3 dependencies:
* Tensorflow 1.15
* matplotlib
* numpy
* imageio
* configargparse
The LLFF data loader requires ImageMagick.
We provide a conda environment setup file including all of the above dependencies. Create the conda environment `nerf` by running:
```
conda env create -f environment.yml
```
You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
## What is a NeRF?
A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views.
Optimizing a NeRF takes between a few hours and a day or two (depending on resolution) and only requires a single GPU. Rendering an image from an optimized NeRF takes somewhere between less than a second and ~30 seconds, again depending on resolution.
## Running code
Here we show how to run our code on two example scenes. You can download the rest of the synthetic and real data used in the paper [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
### Optimizing a NeRF
Run
```
bash download_example_data.sh
```
to get the our synthetic Lego dataset and the LLFF Fern dataset.
To optimize a low-res Fern NeRF:
```
python run_nerf.py --config config_fern.txt
```
After 200k iterations (about 15 hours), you should get a video like this at `logs/fern_test/fern_test_spiral_200000_rgb.mp4`:
![ferngif](https://people.eecs.berkeley.edu/~bmild/nerf/fern_200k_256w.gif)
To optimize a low-res Lego NeRF:
```
python run_nerf.py --config config_lego.txt
```
After 200k iterations, you should get a video like this:
![legogif](https://people.eecs.berkeley.edu/~bmild/nerf/lego_200k_256w.gif)
### Rendering a NeRF
Run
```
bash download_example_weights.sh
```
to get a pretrained high-res NeRF for the Fern dataset. Now you can use [`render_demo.ipynb`](https://github.com/bmild/nerf/blob/master/render_demo.ipynb) to render new views.
### Replicating the paper results
The example config files run at lower resolutions than the quantitative/qualitative results in the paper and video. To replicate the results from the paper, start with the config files in [`paper_configs/`](https://github.com/bmild/nerf/tree/master/paper_configs). Our synthetic Blender data and LLFF scenes are hosted [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and the DeepVoxels data is hosted by Vincent Sitzmann [here](https://drive.google.com/open?id=1lUvJWB6oFtT8EQ_NzBrXnmi25BufxRfl).
### Extracting geometry from a NeRF
Check out [`extract_mesh.ipynb`](https://github.com/bmild/nerf/blob/master/extract_mesh.ipynb) for an example of running marching cubes to extract a triangle mesh from a trained NeRF network. You'll need the install the [PyMCubes](https://github.com/pmneila/PyMCubes) package for marching cubes plus the [trimesh](https://github.com/mikedh/trimesh) and [pyrender](https://github.com/mmatl/pyrender) packages if you want to render the mesh inside the notebook:
```
pip install trimesh pyrender PyMCubes
```
## Generating poses for your own scenes
### Don't have poses?
We recommend using the `imgs2poses.py` script from the [LLFF code](https://github.com/fyusion/llff). Then you can pass the base scene directory into our code using `--datadir <myscene>` along with `-dataset_type llff`. You can take a look at the `config_fern.txt` config file for example settings to use for a forward facing scene. For a spherically captured 360 scene, we recomment adding the `--no_ndc --spherify --lindisp` flags.
### Already have poses!
In `run_nerf.py` and all other code, we use the same pose coordinate system as in OpenGL: the local camera coordinate system of an image is defined in a way that the X axis points to the right, the Y axis upwards, and the Z axis backwards as seen from the image.
Poses are stored as 3x4 numpy arrays that represent camera-to-world transformation matrices. The other data you will need is simple pinhole camera intrinsics (`hwf = [height, width, focal length]`) and near/far scene bounds. Take a look at [our data loading code](https://github.com/bmild/nerf/blob/master/run_nerf.py#L406) to see more.
## Citation
```
@inproceedings{mildenhall2020nerf,
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
year={2020},
booktitle={ECCV},
}
```
conda create --name nerf python=3.7
conda activate nerf
conda install cudatoolkit=10.0
conda install tensorflow-gpu=1.15
conda install numpy
conda install matplotlib
conda install imageio
conda install imageio-ffmpeg
conda install configargparse
\ No newline at end of file
expname = bar0514
basedir = ./logs
datadir = ./data/barbershop_2021.05.04
dataset_type = llff
factor = 1
llffhold = 20
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0
expname = bedroom_nerf_2021.01.18
basedir = ./logs
datadir = ./data/bedroom_nerf_2021.01.18
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
expname = gallery_nerf_2021.01.20
basedir = ./logs
datadir = ./data/gallery_nerf_2021.01.20
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
expname = gas_nerf_2021.01.17
basedir = ./logs
datadir = ./data/gas_nerf_2021.01.17
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
expname = lego_test
basedir = ./logs
datadir = ./data/nerf_synthetic/lego
dataset_type = blender
half_res = True
no_batching = True
N_samples = 64
N_importance = 64
use_viewdirs = True
white_bkgd = True
N_rand = 1024
\ No newline at end of file
expname = lobby_nerf_2021.01.20
basedir = ./logs
datadir = ./data/lobby_nerf_2021.01.20
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment