Commit 6294701e authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 2824f796
from .sampler import *
from .input_encoder import *
from .renderer import *
from .space import *
from .core import *
\ No newline at end of file
from .sampler import Sampler, PdfSampler, VoxelSampler
from .input_encoder import InputEncoder, IntegratedPosEncoder
from .renderer import VolumnRenderer, DensityFirstVolumnRenderer
from .space import Space, Voxels, Octree
from .core import NerfCore, NerfAdvCore, MultiNerf
\ No newline at end of file
import re
import torch
from typing import Iterable, Tuple
from .generic import *
from typing import Dict
from utils.misc import union, split
from utils.type import NetInput, NetOutput
from utils.module import Module
from utils.samples import Samples
class NerfCore(nn.Module):
class NerfCore(Module):
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
dir_chns=0, dir_nf=0, act='relu', skips=[]):
def __init__(self, *, x_chns, density_chns, color_chns, nf, n_layers,
d_chns=0, d_nf=0, act='relu', skips=[], with_layer_norm=False,
density_out_act='relu', color_out_act='sigmoid', f_chns=0):
super().__init__()
self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers,
skips=skips, act=act)
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
self.input_f = f_chns > 0
self.core_field = FcBlock(in_chns=x_chns + f_chns, out_chns=None, nf=nf, n_layers=n_layers,
skips=skips, act=act, with_ln=with_layer_norm)
self.density_out = FcLayer(nf, density_chns, density_out_act, with_ln=False) \
if density_chns > 0 else None
if color_chns == 0:
self.feature_out = None
self.color_out = None
elif dir_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf)
self.color_out = nn.Sequential(
FcLayer(core_nf + dir_chns, dir_nf, act),
FcLayer(dir_nf, color_chns)
)
elif d_chns > 0:
self.color_out = FcBlock(in_chns=nf + d_chns, out_chns=color_chns,
nf=d_nf or nf // 2, n_layers=1,
act=act, out_act=color_out_act, with_ln=with_layer_norm)
self.with_dir = True
else:
self.feature_out = torch.nn.Identity()
self.color_out = FcLayer(core_nf, color_chns)
self.color_out = FcLayer(nf, color_chns, color_out_act, with_ln=False)
self.with_dir = False
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]:
def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput:
ret = {}
core_output = self.core(x)
if 'density' in outputs:
ret['density'] = torch.relu(self.density_out(core_output)) \
if self.density_out is not None else None
if 'color' in outputs:
if self.color_out is None:
ret['color'] = None
else:
feature = self.feature_out(core_output)
if dir is not None:
feature = torch.cat([feature, d], dim=-1)
ret['color'] = self.color_out(feature).sigmoid()
for key in outputs:
if key == 'density' or key == 'color':
continue
ret[key] = None
if features is None:
features = self.core_field(union(inputs.x, inputs.f) if self.input_f else inputs.x)
if 'features' in outputs:
ret['features'] = features
if 'densities' in outputs and self.density_out:
ret['densities'] = self.density_out(features)
if 'colors' in outputs and self.color_out:
if self.with_dir:
features = union(features, inputs.d)
ret['colors'] = self.color_out(features)
return ret
class NerfAdvCore(nn.Module):
class NerfAdvCore(Module):
def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
density_net_params: dict, color_net_params: dict,
specular_net_params: dict = None,
appearance="decomposite",
density_color_connection=False):
density_net: dict, color_net: dict, specular_net: dict = None,
appearance="decomposite", with_layer_norm=False, f_chns=0):
"""
Create a NeRF-Adv Core Net.
Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
......@@ -64,84 +64,87 @@ class NerfAdvCore(nn.Module):
:param color_net_params `dict`: parameters for the color net
:param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
:param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
:param density_color_connection `bool`: (optional) whether to add connections between
density net and color net, defaults to False
"""
super().__init__()
self.input_f = f_chns > 0
self.density_chns = density_chns
self.color_chns = color_chns
self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0
self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0
self.specular_feature_chns = color_net["nf"] if specular_net else 0
self.color_feature_chns = density_net["nf"]
self.appearance = appearance
self.density_color_connection = density_color_connection
self.density_net = FcNet(**density_net_params,
in_chns=x_chns,
out_chns=self.density_chns + self.color_feature_chns,
out_act='relu')
self.density_net = FcBlock(**density_net,
in_chns=x_chns + f_chns,
out_chns=self.density_chns + self.color_feature_chns,
out_act='relu',
with_ln=with_layer_norm)
if self.appearance == "newtype":
self.specular_feature_chns = d_chns * 3
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns,
with_ln=with_layer_norm)
self.specular_net = "Placeholder"
else:
if self.appearance == "decomposite":
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
match = re.match("mlp_basis\((\d+)\)", self.appearance)
if match is not None:
basis_dim = int(match.group(1))
self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns * basis_dim)
elif self.appearance == "decomposite":
self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns,
with_ln=with_layer_norm)
else:
if specular_net_params:
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.specular_feature_chns)
if specular_net:
self.color_net = FcBlock(**color_net,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.specular_feature_chns,
with_ln=with_layer_norm)
else:
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + d_chns + self.color_feature_chns,
out_chns=self.color_chns)
self.specular_net = FcNet(**specular_net_params,
in_chns=d_chns + self.specular_feature_chns,
out_chns=self.color_chns) if specular_net_params else None
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *,
color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]:
input_shape = x.shape[:-1]
if len(input_shape) > 1:
x = x.flatten(0, -2)
d = d.flatten(0, -2)
n = x.shape[0]
c = self.color_chns
ret: Dict[str, torch.Tensor] = {}
if 'density' in outputs:
density_net_out: torch.Tensor = self.density_net(x)
ret['density'] = density_net_out[:, :self.density_chns]
color_feats = density_net_out[:, self.density_chns:]
if 'color_feat' in outputs:
ret['color_feat'] = color_feats
if 'color' in outputs or 'specluar' in outputs:
if 'density' in ret:
valid_mask = ret['density'][:, 0].detach() >= 1e-4
indices = valid_mask.nonzero()[:, 0]
x, d, color_feats = x[indices], d[indices], color_feats[indices]
self.color_net = FcBlock(**color_net,
in_chns=x_chns + d_chns + self.color_feature_chns,
out_chns=self.color_chns,
with_ln=with_layer_norm)
self.specular_net = FcBlock(**specular_net,
in_chns=d_chns + self.specular_feature_chns,
out_chns=self.color_chns,
with_ln=with_layer_norm) if specular_net else None
def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput:
output_shape = inputs.shape
ret: NetOutput = {}
if 'densities' in outputs or 'features' in outputs:
density_net_in = union(inputs.x, inputs.f) if self.input_f else inputs.x
density_net_out: torch.Tensor = self.density_net(density_net_in)
densities, features = split(density_net_out, self.density_chns, -1)
if 'features' in outputs:
ret['features'] = features
if 'densities' in outputs:
ret['densities'] = densities
if 'colors' in outputs or 'specluars' in outputs or 'diffuses' in outputs:
if 'densities' in ret:
valid_mask = ret['densities'][..., 0].detach() >= 1e-4
indices: Tuple[torch.Tensor, ...] = valid_mask.nonzero(as_tuple=True)
inputs, features = inputs[indices], features[indices]
else:
indices = None
speculars = None
color_net_in = [x]
color_net_in = [inputs.x, features]
if not self.specular_net:
color_net_in.append(d)
if self.density_color_connection:
color_net_in.append(color_feats)
color_net_in = torch.cat(color_net_in, -1)
color_net_out: torch.Tensor = self.color_net(color_net_in)
diffuses = color_net_out[:, :c]
specular_features = color_net_out[:, -self.specular_feature_chns:]
color_net_in.append(inputs.d)
color_net_out: torch.Tensor = self.color_net(union(*color_net_in))
diffuses = color_net_out[..., :self.color_chns]
specular_features = color_net_out[..., -self.specular_feature_chns:]
if self.appearance == "newtype":
speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]),
d[..., None])[..., 0]
speculars = torch.matmul(
specular_features.reshape(*inputs.shape, -1, inputs.d.shape[-1]),
inputs.d[..., None])[..., 0]
# TODO relu or not?
diffuses = diffuses.relu()
speculars = speculars.relu()
......@@ -150,26 +153,55 @@ class NerfAdvCore(nn.Module):
if not self.specular_net:
colors = diffuses
diffuses = None
speculars = None
else:
specular_net_in = torch.cat([d, specular_features], -1)
specular_net_out = self.specular_net(specular_net_in)
specular_net_out = self.specular_net(union(inputs.d, specular_features))
if self.appearance == "decomposite":
speculars = specular_net_out
colors = diffuses + speculars
else:
diffuses = None
speculars = None
colors = specular_net_out
colors = torch.sigmoid(colors) # TODO indent or not?
if 'color' in outputs:
ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \
if indices else colors
if 'diffuse' in outputs:
ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \
if indices is not None and diffuses is not None else diffuses
if 'specular' in outputs:
ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \
if indices is not None and speculars is not None else speculars
if len(input_shape) > 1:
ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()}
colors = torch.sigmoid(colors) # TODO indent or not?
def postprocess(data: torch.Tensor):
return data.new_zeros(*output_shape, data.shape[-1]).index_put(indices, data)\
if indices is not None else data
if 'colors' in outputs:
ret['colors'] = postprocess(colors)
if 'diffuses' in outputs and diffuses is not None:
ret['diffuses'] = postprocess(diffuses)
if 'speculars' in outputs and speculars is not None:
ret['speculars'] = postprocess(speculars)
return ret
class MultiNerf(Module):
@property
def n_levels(self):
return len(self.nets)
def __init__(self, nets: Iterable[Module]):
super().__init__()
self.nets = nets
for i in len(nets):
self.add_module(f"Level {i}", nets[i])
def set_frozen(self, level: int, on: bool):
for net in self.nets:
net[level].train(not on)
def forward(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
L = samples.level
if L == 0:
return self.nets[L](inputs, *outputs)
if samples.features is not None:
return self.nets[L](NetInput(inputs.x, inputs.d, samples.features), *outputs)
features = self.nets[0](inputs, 'features')['features']
for i in range(1, L):
features = self.nets[i](NetInput(inputs.x, inputs.d, features), 'features')['features']
samples.features = features
return self(inputs, *outputs, samples=samples)
from typing import List
import math
import torch
import torch.nn as nn
from utils.constants import *
class BatchLinear(nn.Linear):
'''
A linear meta-layer that can deal with batched weight matrices and biases,
as for instance output by a hypernetwork.
'''
__doc__ = nn.Linear.__doc__
def forward(self, input, params=None):
# if params is None:
# params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
weight = params['weight']
output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
output += bias.unsqueeze(-2)
return output
class Sine(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return torch.sin(30 * input)
class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0):
super().__init__()
nls_and_inits = {
'sine': (Sine(), sine_init),
'relu': (nn.ReLU(), None),
'sigmoid': (nn.Sigmoid(), None),
'tanh': (nn.Tanh(), None),
'selu': (nn.SELU(), init_weights_selu),
'softplus': (nn.Softplus(), init_weights_normal),
'elu': (nn.ELU(), init_weights_elu),
'softmax': (nn.Softmax(dim=-1), softmax_init),
'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init),
'linear': (None, None)
}
nl, nl_weight_init = nls_and_inits[act]
self.net = nn.Sequential(
nn.Linear(in_chns + skip_chns, out_chns),
nl
) if nl else nn.Linear(in_chns + skip_chns, out_chns)
self.skip = skip_chns != 0
if nl_weight_init is not None:
nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0])
else:
self.init_params(act)
def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
return self.net(torch.cat([x0, x], dim=-1) if self.skip else x)
def get_params(self):
linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0]
return linear_net.weight, linear_net.bias
def init_params(self, act):
weight, bias = self.get_params()
nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act))
nn.init.zeros_(bias)
def copy_to(self, layer):
weight, bias = self.get_params()
dst_weight, dst_bias = layer.get_params()
dst_weight.copy_(weight)
dst_bias.copy_(bias)
class FcNet(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
skips: List[int] = [], act: str = 'relu', out_act = 'linear'):
"""
Initialize a full-connection net
:kwarg in_chns: channels of input
:kwarg out_chns: channels of output
:kwarg nf: number of features in each hidden layer
:kwarg n_layers: number of layers
:kwarg skips: create skip connections from input to layers in this list
"""
super().__init__()
self.layers = [FcLayer(in_chns, nf, act)] + [
FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0)
for i in range(n_layers - 1)
]
if out_chns:
self.layers.append(FcLayer(nf, out_chns, out_act))
for i, layer in enumerate(self.layers):
self.add_module(f"layer{i}", layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = x
for layer in self.layers:
x = layer(x, x0)
return x
########################
# Initialization methods
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# For PINNet, Raissi et al. 2019
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# grab from upstream pytorch branch and paste here for now
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def init_weights_trunc_normal(m):
# For PINNet, Raissi et al. 2019
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
fan_in = m.weight.size(1)
fan_out = m.weight.size(0)
std = math.sqrt(2.0 / float(fan_in + fan_out))
mean = 0.
# initialize with the same behavior as tf.truncated_normal
# "The generated values follow a normal distribution with specified mean and
# standard deviation, except that values whose magnitude is more than 2
# standard deviations from the mean are dropped and re-picked."
_no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)
def init_weights_normal(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
def init_weights_selu(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
def init_weights_elu(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))
def init_weights_xavier(m):
if type(m) == BatchLinear or type(m) == nn.Linear:
if hasattr(m, 'weight'):
nn.init.xavier_normal_(m.weight)
def sine_init(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-math.sqrt(6 / num_input) / 30, math.sqrt(6 / num_input) / 30)
def first_layer_sine_init(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-1 / num_input, 1 / num_input)
def softmax_init(m):
with torch.no_grad():
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.constant_(m.bias, val=0)
from .linear import *
import torch
import torch.nn.functional as F
class Sine(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
return (30 * x).sin()
class Mise(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
return x * torch.tanh(F.softplus(x))
from typing import List
from .weight_init import *
from .fn import *
class BatchLinear(nn.Linear):
'''
A linear meta-layer that can deal with batched weight matrices and biases,
as for instance output by a hypernetwork.
'''
__doc__ = nn.Linear.__doc__
def forward(self, input, params=None):
# if params is None:
# params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
weight = params['weight']
output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
output += bias.unsqueeze(-2)
return output
class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0,
with_ln: bool = True):
super().__init__()
nls_and_inits = {
'sine': (Sine, init_weights_sine),
'relu': (nn.ReLU, init_weights_relu),
'leakyrelu': (nn.LeakyReLU, init_weights_leakyrelu),
'sigmoid': (nn.Sigmoid, init_weights_xavier),
'tanh': (nn.Tanh, init_weights_xavier),
'selu': (nn.SELU, init_weights_selu),
'softplus': (nn.Softplus, init_weights_trunc_normal),
'elu': (nn.ELU, init_weights_elu),
'softmax': (nn.Softmax, init_weights_softmax),
'logsoftmax': (nn.LogSoftmax, init_weights_softmax),
'mise': (Mise, init_weights_xavier),
'linear': (nn.Identity, init_weights_xavier)
}
nl_cls, weight_init_fn = nls_and_inits[act]
self.net = [nn.Linear(in_chns + skip_chns, out_chns)]
if with_ln:
self.net += [nn.LayerNorm([out_chns])]
self.net += [nl_cls()]
self.net = nn.Sequential(*self.net)
self.skip = skip_chns != 0
self.with_ln = with_ln
self.net.apply(weight_init_fn)
def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
return self.net(torch.cat([x0, x], -1) if self.skip else x)
def __repr__(self) -> str:
s = f"{self.net[0].in_features} -> {self.net[0].out_features}, "\
+ ", ".join(module.__class__.__name__ for module in self.net[1:])
return f"{self._get_name()}({s})"
class FcBlock(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
skips: List[int] = [], act: str = 'relu', out_act='linear', with_ln=True):
"""
Initialize a full-connection net
:kwarg in_chns: channels of input
:kwarg out_chns: channels of output
:kwarg nf: number of features in each hidden layer
:kwarg n_layers: number of layers
:kwarg skips: create skip connections from input to layers in this list
"""
super().__init__()
self.layers = nn.ModuleList([
FcLayer(in_chns, nf, act, with_ln=with_ln)] + [
FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0, with_ln=with_ln)
for i in range(1, n_layers)
])
if out_chns:
self.layers.append(FcLayer(nf, out_chns, out_act, with_ln=False))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = x
for layer in self.layers:
x = layer(x, x0)
return x
def __repr__(self):
lines = []
for key, module in self.layers._modules.items():
mod_str = repr(module)
mod_str = nn.modules.module._addindent(mod_str, 2)
lines.append('(' + key + '): ' + mod_str)
main_str = self._get_name() + '('
if lines:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
import torch
import torch.nn as nn
from utils import math
def init_weights_trunc_normal(m):
# For PINNet, Raissi et al. 2019
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
if isinstance(m, nn.Linear):
fan_in = m.weight.size(1)
fan_out = m.weight.size(0)
std = math.sqrt(2.0 / float(fan_in + fan_out))
mean = 0.
# initialize with the same behavior as tf.truncated_normal
# "The generated values follow a normal distribution with specified mean and
# standard deviation, except that values whose magnitude is more than 2
# standard deviations from the mean are dropped and re-picked."
_no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)
def init_weights_relu(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu')
if m.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(m.bias, -bound, bound)
def init_weights_leakyrelu(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
if m.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(m.bias, -bound, bound)
def init_weights_selu(m):
if isinstance(m, nn.Linear):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))
def init_weights_elu(m):
if isinstance(m, nn.Linear):
num_input = m.weight.size(-1)
nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))
def init_weights_xavier(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def init_weights_sine(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-math.sqrt(6 / num_input) / 30, math.sqrt(6 / num_input) / 30)
def init_weights_sine_first_layer(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
m.weight.uniform_(-1 / num_input, 1 / num_input)
def init_weights_softmax(m):
with torch.no_grad():
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.constant_(m.bias, val=0)
from typing import Tuple
import torch
import torch.nn as nn
from utils import device
from utils.constants import *
from .generic import *
from .generic import *
from utils import math
from utils.module import Module
class InputEncoder(nn.Module):
def Get(multires, input_dims):
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires - 1,
'num_freqs': multires,
}
return InputEncoder(**embed_kwargs)
class InputEncoder(Module):
def __init__(self, **kwargs):
def __init__(self, chns, L, cat_input=False):
super().__init__()
self.in_dim = kwargs['input_dims']
self.num_freqs = kwargs['num_freqs']
self.out_dim = self.in_dim * self.num_freqs * 2
self.include_input = kwargs['include_input'] or self.num_freqs == 0
if self.include_input:
self.out_dim += self.in_dim
if self.num_freqs > 0:
self.freq_bands = 2. ** torch.linspace(0, kwargs['max_freq_log2'], self.num_freqs,
device=device.default())
def forward(self, input: torch.Tensor) -> torch.Tensor:
emb = torch.exp(torch.arange(L, dtype=torch.float) * math.log(2.))
self.emb = nn.Parameter(emb, requires_grad=False)
self.in_dim = chns
self.out_dim = chns * (L * 2 + cat_input)
self.cat_input = cat_input
def forward(self, x: torch.Tensor, angular=False):
sizes = x.size()
x0 = x
if angular:
x = torch.acos(x.clamp(-1, 1))
x = x[..., None] @ self.emb[None]
x = torch.cat([torch.sin(x), torch.cos(x)], -1)
x = x.flatten(-2)
if self.cat_input:
x = torch.cat([x0, x], -1)
return x
def extra_repr(self) -> str:
return f'in={self.in_dim}, out={self.out_dim}, cat_input={self.cat_input}'
class IntegratedPosEncoder(Module):
def __init__(self, chns, L, shape: str, cat_input=False):
super.__init__()
self.shape = shape
def _lift_gaussian(self, d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
r_var: torch.Tensor, diag: bool):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_sq = d**2
d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)
if diag:
d_outer_diag = d_sq
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def _conical_frustum_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, base_radius: float,
diag: bool, stable: bool = True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
"""
Encode the given input to R^D space
if stable:
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
(3 * mu**2 + hw**2)**2)
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def _cylinder_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
:param input `Tensor(..., C)`: input
:return `Tensor(..., D): encoded
:rtype: torch.Tensor
Args:
d: torch.float32 3-vector, the axis of the cylinder
t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
"""
if self.num_freqs > 0:
input_ = input.unsqueeze(-2) # to (..., 1, C)
input_ = input_ * self.freq_bands[:, None] # (..., Ne, C)
output = torch.stack([input_.sin(), input_.cos()], dim=-2).flatten(-3)
if self.include_input:
output = torch.cat([input, output], dim=-1)
t_mean = (t0 + t1) / 2
r_var = radius**2 / 4
t_var = (t1 - t0)**2 / 12
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def cast_rays(self, t_vals: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor,
rays_r: torch.Tensor, diag: bool = True):
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it.
Args:
t_vals: float array, the "fencepost" distances along the ray.
rays_o: float array, the ray origin coordinates.
rays_d: float array, the ray direction vectors.
radii: float array, the radii (base radii for cones) of the rays.
ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
diag: boolean, whether or not the covariance matrices should be diagonal.
Returns:
a tuple of arrays of means and covariances.
"""
t0 = t_vals[..., :-1]
t1 = t_vals[..., 1:]
if self.shape == 'cone':
gaussian_fn = self._conical_frustum_to_gaussian
elif self.shape == 'cylinder':
gaussian_fn = self._cylinder_to_gaussian
else:
output = input
return output
\ No newline at end of file
assert False
means, covs = gaussian_fn(rays_d, t0, t1, rays_r, diag)
means = means + rays_o[..., None, :]
return means, covs
def integrated_pos_enc(x_coord: Tuple[torch.Tensor, torch.Tensor], min_deg: int, max_deg: int,
diag: bool = True):
"""Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].
Args:
x_coord: a tuple containing: x, torch.ndarray, variables to be encoded. Should
be in [-pi, pi]. x_cov, torch.ndarray, covariance matrices for `x`.
min_deg: int, the min degree of the encoding.
max_deg: int, the max degree of the encoding.
diag: bool, if true, expects input covariances to be diagonal (full
otherwise).
Returns:
encoded: torch.ndarray, encoded variables.
"""
if diag:
x, x_cov_diag = x_coord
scales = torch.tensor([2**i for i in range(min_deg, max_deg)], device=x.device)[:, None]
shape = list(x.shape[:-1]) + [-1]
y = torch.reshape(x[..., None, :] * scales, shape)
y_var = torch.reshape(x_cov_diag[..., None, :] * scales**2, shape)
else:
x, x_cov = x_coord
num_dims = x.shape[-1]
basis = torch.cat([
2**i * torch.eye(num_dims, device=x.device)
for i in range(min_deg, max_deg)
], 1)
y = torch.matmul(x, basis)
# Get the diagonal of a covariance matrix (ie, variance). This is equivalent
# to jax.vmap(torch.diag)((basis.T @ covs) @ basis).
y_var = (torch.matmul(x_cov, basis) * basis).sum(-2)
return math.expected_sin(
torch.cat([y, y + 0.5 * math.pi], -1),
torch.cat([y_var] * 2, -1))[0]
from itertools import cycle
from math import ceil
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
from itertools import cycle
from typing import Dict, Set, Tuple, Union
from utils.type import NetInput, ReturnData
from utils.constants import *
from utils.perf import perf
from .generic import *
from .sampler import Samples
from model.base import BaseModel
from utils import math
from utils.module import Module
from utils.perf import checkpoint, perf
from utils.samples import Samples
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
......@@ -41,7 +43,7 @@ def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: f
return 1.0 - torch.exp(-energies)
class AlphaComposition(nn.Module):
class AlphaComposition(Module):
def __init__(self):
super().__init__()
......@@ -58,7 +60,7 @@ class AlphaComposition(nn.Module):
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2)
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + math.tiny, dim=-2)
one_minus_alpha = torch.cat([
torch.ones_like(one_minus_alpha[..., :1, :]),
one_minus_alpha
......@@ -80,23 +82,25 @@ class AlphaComposition(nn.Module):
}
class VolumnRenderer(nn.Module):
class VolumnRenderer(Module):
class States:
kernel: nn.Module
kernel: BaseModel
samples: Samples
hit_mask: torch.Tensor
early_stop_tolerance: float
outputs: Set[str]
hit_mask: torch.Tensor
N: int
P: int
device: torch.device
colors: torch.Tensor
diffuses: torch.Tensor
speculars: torch.Tensor
densities: torch.Tensor
energies: torch.Tensor
weights: torch.Tensor
cum_energies: torch.Tensor
exp_energies: torch.Tensor
tot_evaluations: Dict[str, int]
chunk: Tuple[slice, slice]
......@@ -112,16 +116,18 @@ class VolumnRenderer(nn.Module):
def end(self) -> int:
return self.chunk[1].stop
def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None:
def __init__(self, kernel: BaseModel, samples: Samples, early_stop_tolerance: float,
outputs: Set[str]) -> None:
self.kernel = kernel
self.samples = samples
self.early_stop_tolerance = early_stop_tolerance
self.outputs = outputs
N, P = samples.size
self.hit_mask = samples.voxel_indices != -1 # (N, P)
self.device = self.samples.device
self.hit_mask = samples.voxel_indices != -1 # (N, P) | bool
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.densities = torch.zeros(N, P, 1, device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
......@@ -130,12 +136,14 @@ class VolumnRenderer(nn.Module):
self.N, self.P = N, P
self.chunk_id = -1
def n_hits(self, start: int = None, end: int = None) -> int:
if start is None:
def n_hits(self, index: Union[int, slice] = None) -> int:
if not isinstance(self.hit_mask, torch.Tensor):
if index is not None:
return self.N * self.colors[:, index].shape[1]
return self.N * self.P
if index is None:
return self.hit_mask.count_nonzero().item()
if end is None:
return self.hit_mask[:, start].count_nonzero().item()
return self.hit_mask[:, start:end].count_nonzero().item()
return self.hit_mask[:, index].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int):
if key not in self.tot_evaluations:
......@@ -152,21 +160,31 @@ class VolumnRenderer(nn.Module):
self.chunk_id += 1
return self
def put(self, key: str, values: torch.Tensor, indices: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[slice, slice]]):
if not hasattr(self, key):
new_tensor = torch.zeros(self.N, self.P, values.shape[-1], device=self.device)
setattr(self, key, new_tensor)
tensor: torch.Tensor = getattr(self, key)
# if isinstance(indices[0], torch.Tensor):
# tensor.index_put_(indices, values)
# else:
tensor[indices] = values
def __init__(self, **kwargs):
super().__init__()
@perf
def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *,
def forward(self, kernel: BaseModel, samples: Samples, *outputs: str,
raymarching_early_stop_tolerance: float = 0,
raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
**kwargs):
**kwargs) -> ReturnData:
"""
Perform volumn rendering.
:param kernel: render kernel
:param kernel `BaseModel`: render kernel
:param samples `Samples(N, P)`: samples
:param extra_outputs `list[str]`: extra items should be contained in the result dict.
Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param outputs `str...`: items should be contained in the result dict.
Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
Should between 0 and 1 (0 means no early stop). Defaults to 0
:param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
......@@ -179,13 +197,32 @@ class VolumnRenderer(nn.Module):
print("VolumnRenderer.forward(): # of samples is zero")
return None
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance)
infer_outputs = set()
for key in outputs:
if key == "color":
infer_outputs.add("colors")
infer_outputs.add("densities")
elif key == "specular":
infer_outputs.add("speculars")
infer_outputs.add("densities")
elif key == "diffuse":
infer_outputs.add("diffuses")
infer_outputs.add("densities")
elif key == "depth":
infer_outputs.add("densities")
else:
infer_outputs.add(key)
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance, infer_outputs)
checkpoint("Prepare states object")
if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)]
raymarching_chunk_size_or_sections = [
math.ceil(s.P / raymarching_chunk_size_or_sections)
]
if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections
......@@ -205,60 +242,31 @@ class VolumnRenderer(nn.Module):
chunk_hits += n_hits
self._forward_chunk(s.next_chunk())
ret = {
'color': torch.sum(s.colors * s.weights, 1),
'tot_evaluations': s.tot_evaluations
}
for key in extra_outputs:
if key == 'depth':
checkpoint("Run forward chunks")
ret = {}
for key in outputs:
if key == 'color':
ret['color'] = torch.sum(s.colors * s.weights, 1)
elif key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse':
elif key == 'diffuse' and hasattr(s, "diffuses"):
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular':
elif key == 'specular' and hasattr(s, "speculars"):
ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers':
ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states':
ret['states'] = s
else:
ret[key] = getattr(s, key)
return ret
if hasattr(s, key):
ret[key] = getattr(s, key)
# if raymarching_chunk_size == 0:
# raymarching_chunk_samples = 1
# if raymarching_chunk_samples != 0:
# if isinstance(raymarching_chunk_samples, int):
# raymarching_chunk_samples = repeat(raymarching_chunk_samples,
# ceil(s.P / raymarching_chunk_samples))
# chunk_offset = 0
# for chunk_samples in raymarching_chunk_samples:
# start, end = chunk_offset, chunk_offset + chunk_samples
# n_hits = self._forward_chunk(s, start, end)
# if n_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0
# chunk_offset += chunk_samples
# elif raymarching_chunk_size > 0:
# chunk_offset, chunk_hits = 0, s.n_hits(0)
# for i in range(1, s.P):
# n_hits = s.n_hits(i)
# if chunk_hits + n_hits > raymarching_chunk_size:
# self._forward_chunk(s, chunk_offset, i, chunk_hits)
# if chunk_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0
# n_hits = s.n_hits(i)
# chunk_hits, chunk_offset = 0, i
# chunk_hits += n_hits
# self._forward_chunk(s, chunk_offset, s.P, chunk_hits)
# else:
# self._forward_chunk(s, 0, s.P)
checkpoint("Set return data")
# return self._composite(s, extra_outputs)
# original_depth = samples.get('original_point_depth', None)
# if original_depth is not None:
# results['z'] = (original_depth * probs).sum(-1)
# if getattr(input_fn, "track_max_probs", False) and (not self.training):
# input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs'])
return ret
@perf
def _calc_weights(self, s: States):
"""
Calculate weights of samples in composited outputs
......@@ -267,11 +275,13 @@ class VolumnRenderer(nn.Module):
:param start `int`: chunk's start
:param end `int`: chunk's end
"""
s.energies[s.chunk] = density2energy(s.densities[s.chunk], s.samples.dists[s.chunk])
s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
@perf
def _apply_early_stop(self, s: States):
"""
Stop rays whose accumulated opacity are larger than a threshold
......@@ -279,32 +289,26 @@ class VolumnRenderer(nn.Module):
:param s `States`: s
:param end `int`: chunk's end
"""
if s.end < s.P and s.early_stop_tolerance > 0:
if s.end < s.P and s.early_stop_tolerance > 0 and isinstance(s.hit_mask, torch.Tensor):
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0
@perf
def _forward_chunk(self, s: States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
if isinstance(s.hit_mask, torch.Tensor):
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return
fi_idxs[1].add_(s.start)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
else:
fi_idxs = s.chunk
# Infer densities and colors
fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse',
chunk_id=s.chunk_id)
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
fi_outputs = s.kernel.infer(*s.outputs, samples=s.samples[fi_idxs], chunk_id=s.chunk_id)
for key, value in fi_outputs.items():
s.put(key, value, fi_idxs)
self._calc_weights(s)
self._apply_early_stop(s)
......@@ -322,19 +326,19 @@ class DensityFirstVolumnRenderer(VolumnRenderer):
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
return
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X
fi_encoded_x = s.kernel.encode_x(fi_samples) # (N', Ex)
density_inputs = s.kernel.input(fi_samples, "x", "f") # (N', Ex)
# Infer densities (shape)
fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat',
chunk_id=s.chunk_id)
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("density", fi_idxs[0].size(0))
density_outputs = s.kernel.infer('densities', 'features', samples=fi_samples,
inputs=density_inputs, chunk_id=s.chunk_id)
s.put('densities', density_outputs['densities'], fi_idxs)
s.accumulate_tot_evaluations("densities", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
......@@ -345,16 +349,17 @@ class DensityFirstVolumnRenderer(VolumnRenderer):
# Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_encoded_x = fi_encoded_x[fi_mask] # (N", Ex)
fi_color_feats = fi_outputs['color_feat'][fi_mask]
# For all valid samples: encode D
fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs]) # (N", Ed)
fi_samples = s.samples[fi_idxs] # N -> N"
fi_features = density_outputs['features'][fi_mask]
color_inputs = s.kernel.input(fi_samples, "d") # (N")
color_inputs.x = density_inputs.x[fi_mask]
# Infer colors (appearance)
fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse',
chunk_id=s.chunk_id,
extras={"color_feats": fi_color_feats})
outputs = s.outputs.copy()
if 'densities' in outputs:
outputs.remove('densities')
color_outputs = s.kernel.infer(*outputs, samples=fi_samples, inputs=color_inputs,
chunk_id=s.chunk_id, features=fi_features)
# if s.chunk_id == 0:
# fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1:
......@@ -363,9 +368,6 @@ class DensityFirstVolumnRenderer(VolumnRenderer):
# fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
# else:
# fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
for key, value in color_outputs.items():
s.put(key, value, fi_idxs)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
from .space import Space, Voxels
import torch
import torch.nn as nn
from typing import Tuple
from .generic import *
from .space import Space
from clib import *
from utils import device
from utils import sphere
from utils.constants import *
from utils import misc
from utils import math
from utils.module import Module
from utils.samples import Samples
from utils.perf import perf, checkpoint
from .generic import *
from clib import *
class Bins(object):
......@@ -38,140 +40,104 @@ class Bins(object):
self.bounds = self.bounds.to(device)
class Samples:
pts: torch.Tensor
"""`Tensor(N[, P], 3)`"""
dirs: torch.Tensor
"""`Tensor(N[, P], 3)`"""
depths: torch.Tensor
"""`Tensor(N[, P])`"""
dists: torch.Tensor
"""`Tensor(N[, P])`"""
voxel_indices: torch.Tensor
"""`Tensor(N[, P])`"""
@property
def size(self):
return self.pts.size()[:-1]
@property
def device(self):
return self.pts.device
def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor,
dists: torch.Tensor, voxel_indices: torch.Tensor) -> None:
self.pts = pts
self.dirs = dirs
self.depths = depths
self.dists = dists
self.voxel_indices = voxel_indices
def __getitem__(self, index):
return Samples(
pts=self.pts[index],
dirs=self.dirs[index],
depths=self.depths[index],
dists=self.dists[index],
voxel_indices=self.voxel_indices[index])
class Sampler(Module):
def reshape(self, *shape: int):
return Samples(
pts=self.pts.reshape(*shape, 3),
dirs=self.dirs.reshape(*shape, 3),
depths=self.depths.reshape(*shape),
dists=self.dists.reshape(*shape),
voxel_indices=self.voxel_indices.reshape(*shape))
class Sampler(nn.Module):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs):
def __init__(self, **kwargs):
"""
Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super().__init__()
self.lindisp = lindisp
s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
if s_range[1] > s_range[0]:
s_range[0] += 1e-4
s_range[1] -= 1e-4
else:
s_range[0] -= 1e-4
s_range[1] += 1e-4
self.bins = Bins.linspace(s_range, n_samples, device=device.default())
self._samples_indices_cached = None
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
def _sample(self, range: Tuple[float, float], n_rays: int, n_samples: int, perturb: bool,
device: torch.device) -> torch.Tensor:
"""
Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical`
[summary]
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:return `Samples(N, P)`: samples
:param t_range `float, float`: sampling range
:param n_rays `int`: number of rays (B)
:param n_samples `int`: number of samples per ray (P)
:param perturb `bool`: whether perturb sampling
:param device `torch.device`: the device used to create tensors
:return `Tensor(B, P+1)`: sampling bounds of t
"""
s = self.bins.vals.expand(rays_o.size(0), -1)
if perturb_sample:
s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
pts, depths = self._get_sample_points(rays_o, rays_d, s)
voxel_indices = space_module.get_voxel_indices(pts)
valid_rays_mask = voxel_indices.ne(-1).any(dim=-1)
return Samples(
pts=pts,
dirs=rays_d[:, None].expand(-1, depths.size(1), -1),
depths=depths,
dists=self._calc_dists(depths),
voxel_indices=voxel_indices
)[valid_rays_mask], valid_rays_mask
def _get_sample_points(self, rays_o, rays_d, s):
z = torch.reciprocal(s) if self.lindisp else s
pts = rays_o[:, None] + rays_d[:, None] * z[..., None]
depths = z
return pts, depths
def _calc_dists(self, vals):
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N)
dists = vals[..., 1:] - vals[..., :-1]
last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT
return torch.cat([dists, last_dist], -1)
bounds = torch.linspace(*range, n_samples + 1, device=device) # (P+1)
if perturb:
rand_bounds = torch.cat([
bounds[:1],
0.5 * (bounds[1:] + bounds[:-1]),
bounds[-1:]
])
rand_vals = torch.rand(n_rays, n_samples + 1, device=device)
bounds = rand_bounds[:-1] * (1 - rand_vals) + rand_bounds[1:] * rand_vals
else:
bounds = bounds[None].expand(n_rays, -1)
return bounds
class SphericalSampler(Sampler):
def _get_samples_indices(self, pts: torch.Tensor):
if self._samples_indices_cached is None\
or self._samples_indices_cached.shape[0] < pts.shape[0]\
or self._samples_indices_cached.shape[1] < pts.shape[1]:
self._samples_indices_cached = misc.meshgrid(
*pts.shape[:2], swap_dim=True, device=pts.device)
return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]]
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, **kwargs):
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_: Space, *,
sample_range: Tuple[float, float], n_samples: int, lindisp: bool = False,
perturb_sample: bool = True, spherical: bool = False,
**kwargs) -> Tuple[Samples, torch.Tensor]:
"""
Initialize a Sampler module
Sample points along rays.
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param sample_range `float, float`: sampling range
:param n_samples `int`: number of samples per ray
:param lindisp `bool`: whether sample linearly in disparity space (1/depth)
:param perturb_sample `bool`: whether perturb sampling
:return `Samples(B, P)`: samples
"""
super().__init__(sample_range=sample_range, n_samples=n_samples,
perturb_sample=perturb_sample, lindisp=False)
def _get_sample_points(self, rays_o, rays_d, s):
r = torch.reciprocal(s)
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r)
pts = sphere.cartesian2spherical(pts, inverse_r=True)
return pts, depths
if spherical:
t_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
rays_o.device)
t0, t1 = t_bounds[:, :-1], t_bounds[:, 1:] # (B, P)
t = (t0 + t1) * .5
p, z = sphere.ray_sphere_intersect(rays_o, rays_d, t.reciprocal())
p = sphere.cartesian2spherical(p, inverse_r=True)
vidxs = space_.get_voxel_indices(p)
return Samples(
pts=p,
dirs=rays_d[:, None].expand(-1, n_samples, -1),
depths=z,
dists=(t1 + math.tiny).reciprocal() - t0.reciprocal(),
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=t
)
else:
sample_range = (1 / sample_range[0], 1 / sample_range[1]) if lindisp else sample_range
z_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
rays_o.device)
if lindisp:
z_bounds = z_bounds.reciprocal()
z0, z1 = z_bounds[:, :-1], z_bounds[:, 1:] # (B, P)
z = (z0 + z1) * .5
p = rays_o[:, None] + rays_d[:, None] * z[..., None]
vidxs = space_.get_voxel_indices(p)
return Samples(
pts=p,
dirs=rays_d[:, None].expand(-1, n_samples, -1),
depths=z,
dists=z1 - z0,
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=z
)
class PdfSampler(nn.Module):
class PdfSampler(Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
spherical: bool, lindisp: bool, **kwargs):
......@@ -226,7 +192,7 @@ class PdfSampler(nn.Module):
:return `Tensor(..., N)`: samples
'''
# Get pdf
weights = weights + TINY_FLOAT # prevent nans
weights = weights + math.tiny # prevent nans
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
cdf = torch.cat([
torch.zeros_like(pdf[..., :1]),
......@@ -256,17 +222,17 @@ class PdfSampler(nn.Module):
# fix numeric issue
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N]
denom = torch.where(denom < TINY_FLOAT, torch.ones_like(denom), denom)
denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_FLOAT)
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + math.tiny)
return samples
class VoxelSampler(nn.Module):
class VoxelSampler(Module):
def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs):
def __init__(self, *, sample_step: float, **kwargs):
"""
Initialize a VoxelSampler module
......@@ -274,11 +240,10 @@ class VoxelSampler(nn.Module):
:param step_size: step size
"""
super().__init__()
self.perturb_sample = perturb_sample
self.sample_step = sample_step
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *,
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
"""
[summary]
......@@ -312,13 +277,13 @@ class VoxelSampler(nn.Module):
invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5) # (N', P)
* (torch.rand_like(samples_min_depth) if perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[
ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth)
] # (N', P)
samples_depth[invalid_samples_mask] = HUGE_FLOAT
samples_depth[invalid_samples_mask] = math.huge
samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1
......@@ -332,8 +297,8 @@ class VoxelSampler(nn.Module):
), valid_rays_mask
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
space: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
"""
[summary]
......@@ -342,7 +307,7 @@ class VoxelSampler(nn.Module):
:param step_size `float`: [description]
:return `Samples(N, P)`: [description]
"""
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
intersections = space.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
......@@ -363,11 +328,11 @@ class VoxelSampler(nn.Module):
# sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample)
pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT)
sampled_depths.masked_fill_(invalid_idx_mask, math.huge)
checkpoint("Inverse CDF sampling")
......
import torch
from typing import List, Tuple, Union
from torch import nn
from typing import Dict, List, Optional, Tuple, Union
from clib import *
from model.utils import load
from utils.module import Module
from utils.geometry import *
from utils.constants import *
from utils.voxels import *
from utils.perf import perf
from clib import *
from utils.env import get_env
class Intersections:
......@@ -41,32 +42,42 @@ class Intersections:
hits=self.hits[index])
class Space(nn.Module):
bbox: Union[torch.Tensor, None]
class Space(Module):
bbox: Optional[torch.Tensor]
"""`Tensor(2, 3)` Bounding box"""
def __init__(self, *, bbox: List[float] = None, **kwargs):
@property
def dims(self) -> int:
"""`int` Number of dimensions"""
return self.bbox.shape[1] if self.bbox is not None else 3
@staticmethod
def create(args: dict) -> 'Space':
if 'space' not in args:
return Space(**args)
if args['space'] == 'octree':
return Octree(**args)
if args['space'] == 'voxels':
return Voxels(**args)
return load(args['space']).space
def __init__(self, clone_src: "Space" = None, *, bbox: List[float] = None, **kwargs):
super().__init__()
if bbox is None:
self.bbox = None
if clone_src:
self.device = clone_src.device
self.register_temp('bbox', clone_src.bbox)
else:
self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False)
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
raise NotImplementedError
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
raise NotImplementedError
self.register_temp('bbox', None if not bbox else torch.tensor(bbox).reshape(2, -1))
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
raise NotImplementedError
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
if self.bbox is None:
return 0
voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
if self.bbox is not None:
out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1) # (N...)
voxel_indices[out_bbox] = -1
out_bbox = get_out_of_bound_mask(pts, self.bbox) # (N...)
voxel_indices[out_bbox] = -1
return voxel_indices
@torch.no_grad()
......@@ -74,9 +85,13 @@ class Space(nn.Module):
raise NotImplementedError()
@torch.no_grad()
def split(self):
def split(self) -> Tuple[int, int]:
raise NotImplementedError()
@torch.no_grad()
def clone(self):
return Space(self)
class Voxels(Space):
steps: torch.Tensor
......@@ -92,12 +107,11 @@ class Voxels(Space):
"""`Tensor(M, 8)` Voxel corner indices"""
voxel_indices_in_grid: torch.Tensor
"""`Tensor(G)` Indices in voxel list or -1 for pruned space"""
"""`Tensor(G)` Indices in voxel list or -1 for pruned space
@property
def dims(self) -> int:
"""`int` Number of dimensions"""
return self.steps.size(0)
Note that the first element is perserved for 'invalid voxel'(-1), so the grid
index should be offset by 1 before querying for corresponding voxel index.
"""
@property
def n_voxels(self) -> int:
......@@ -109,30 +123,52 @@ class Voxels(Space):
"""`int` Number of corners"""
return self.corners.size(0)
@property
def n_grids(self) -> int:
"""`int` Number of grids, i.e. steps[0] * steps[1] * ... * steps[D]"""
return self.steps.prod().item()
@property
def voxel_size(self) -> torch.Tensor:
"""`Tensor(3)` Voxel size"""
return (self.bbox[1] - self.bbox[0]) / self.steps
@property
def device(self) -> torch.device:
return self.voxels.device
def corner_embeddings(self) -> Dict[str, torch.nn.Embedding]:
return {name[4:]: emb for name, emb in self.named_modules() if name.startswith("emb_")}
def __init__(self, *, voxel_size: float = None,
steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None:
super().__init__(**kwargs)
if self.bbox is None:
raise ValueError("Missing argument 'bbox'")
if voxel_size is not None:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
@property
def voxel_embeddings(self) -> Dict[str, torch.nn.Embedding]:
return {name[5:]: emb for name, emb in self.named_modules() if name.startswith("vemb_")}
def __init__(self, clone_src: "Voxels" = None, *, bbox: List[float] = None,
voxel_size: float = None, steps: Union[torch.Tensor, Tuple[int, ...]] = None,
**kwargs) -> None:
super().__init__(clone_src, bbox=bbox, **kwargs)
if clone_src:
self.register_buffer('steps', clone_src.steps)
self.register_buffer('voxels', clone_src.voxels)
self.register_buffer("corners", clone_src.corners)
self.register_buffer("corner_indices", clone_src.corner_indices)
self.register_buffer('voxel_indices_in_grid', clone_src.voxel_indices_in_grid)
else:
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels))
self._register_load_state_dict_pre_hook(self._before_load_state_dict)
if self.bbox is None:
raise ValueError("Missing argument 'bbox'")
if voxel_size is not None:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
else:
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(-1, self.n_voxels))
def clone(self):
return Voxels(self)
def to_vi(self, gi: torch.Tensor) -> torch.Tensor:
return self.voxel_indices_in_grid[gi + 1]
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
"""
......@@ -144,7 +180,7 @@ class Voxels(Space):
"""
if self.get_embedding(name) is not None:
raise KeyError(f"Embedding '{name}' already existed")
emb = torch.nn.Embedding(self.n_corners, n_dims, device=self.device)
emb = torch.nn.Embedding(self.n_corners, n_dims).to(self.device)
setattr(self, f'emb_{name}', emb)
return emb
......@@ -152,8 +188,9 @@ class Voxels(Space):
return getattr(self, f'emb_{name}', None)
def set_embedding(self, weight: torch.Tensor, name: str = 'default'):
emb = torch.nn.Embedding(*weight.shape, _weight=weight, device=self.device)
emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device)
setattr(self, f'emb_{name}', emb)
return emb
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
......@@ -173,6 +210,41 @@ class Voxels(Space):
p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel
return trilinear_interp(p, emb(corner_indices))
def create_voxel_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
"""
Create a embedding on voxels.
:param name `str`: embedding name
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxels
"""
if self.get_voxel_embedding(name) is not None:
raise KeyError(f"Embedding '{name}' already existed")
emb = torch.nn.Embedding(self.n_voxels, n_dims).to(self.device)
setattr(self, f'vemb_{name}', emb)
return emb
def get_voxel_embedding(self, name: str = 'default') -> torch.nn.Embedding:
return getattr(self, f'vemb_{name}', None)
def set_voxel_embedding(self, weight: torch.Tensor, name: str = 'default'):
emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device)
setattr(self, f'vemb_{name}', emb)
return emb
def extract_voxel_embedding(self, voxel_indices: torch.Tensor, name: str = 'default') -> torch.Tensor:
"""
Extract embedding values at given voxels.
:param voxel_indices `Tensor(N)`: voxel indices
:param name `str`: embedding name, default to 'default'
:return `Tensor(N, X)`: extracted values
"""
emb = self.get_voxel_embedding(name)
if emb is None:
raise KeyError(f"Embedding '{name}' doesn't exist")
return emb(voxel_indices)
@perf
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
"""
......@@ -192,8 +264,8 @@ class Voxels(Space):
hits = n_max_hits - invalid_voxel_mask.sum(-1)
# Sort intersections according to their depths
min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
min_depths.masked_fill_(invalid_voxel_mask, math.huge)
max_depths.masked_fill_(invalid_voxel_mask, math.huge)
min_depths, sorted_idx = min_depths.sort(dim=-1)
max_depths = max_depths.gather(-1, sorted_idx)
voxel_indices = voxel_indices.gather(-1, sorted_idx)
......@@ -215,52 +287,102 @@ class Voxels(Space):
:param pts `Tensor(N..., 3)`: points
:return `Tensor(N...)`: corresponding voxel indices
"""
grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps)
grid_indices[out_mask] = 0
voxel_indices = self.voxel_indices_in_grid[grid_indices]
voxel_indices[out_mask] = -1
return voxel_indices
gi = to_grid_indices(pts, self.bbox, self.steps)
return self.to_vi(gi)
@perf
def get_corners(self, vidxs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
vidxs = vidxs.unique()
if vidxs[0] == -1:
vidxs = vidxs[1:]
cidxs = self.corner_indices[vidxs].unique()
fi_cidxs = torch.full([self.n_corners], -1, dtype=torch.long, device=self.device)
fi_cidxs[cidxs] = torch.arange(cidxs.shape[0], device=self.device)
fi_corner_indices = fi_cidxs[self.corner_indices]
fi_corners = self.corners[cidxs]
return fi_corner_indices, fi_corners
@torch.no_grad()
def split(self) -> None:
def split(self) -> Tuple[int, int]:
"""
Split voxels into smaller voxels with half size.
"""
# Calculate new voxels and corners
new_steps = self.steps * 2
new_voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
.reshape(-1, 3)
new_corners, new_corner_indices = get_corners(new_voxels, self.bbox, new_steps)
# Calculate new embeddings through trilinear interpolation
grid_indices_of_new_corners = to_flat_indices(
to_grid_coords(new_corners, self.bbox, steps=self.steps).min(self.steps - 1),
self.steps)
voxel_indices_of_new_corners = self.voxel_indices_in_grid[grid_indices_of_new_corners]
for name, _ in self.named_modules():
if not name.startswith("emb_"):
continue
new_emb_weight = self.extract_embedding(new_corners, voxel_indices_of_new_corners,
name=name[4:])
self.set_embedding(new_emb_weight, name=name[4:])
# Split corner embeddings through interpolation
corner_embs = self.corner_embeddings
if len(corner_embs) > 0:
gi_of_new_corners = to_grid_indices(new_corners, self.bbox, self.steps)
vi_of_new_corners = self.to_vi(gi_of_new_corners)
for name, emb in corner_embs.items():
new_emb_weight = self.extract_embedding(new_corners, vi_of_new_corners, name=name)
self.set_embedding(new_emb_weight, name=name)
# Remove old embedding weight and related state from optimizer
self._update_optimizer(emb.weight)
# Split voxel embeddings
self._update_voxel_embeddings(lambda val: torch.repeat_interleave(val, 8, dim=0))
# Apply new tensors
self.steps = new_steps
self.voxels = new_voxels
self.corners = new_corners
self.corner_indices = new_corner_indices
self._update_voxel_indices_in_grid()
self._update_gi2vi()
return self.n_voxels // 8, self.n_voxels
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
self.voxels = self.voxels[keeps]
self.corner_indices = self.corner_indices[keeps]
self._update_voxel_indices_in_grid()
self._update_gi2vi()
# Prune voxel embeddings
self._update_voxel_embeddings(lambda val: val[keeps])
return keeps.size(0), keeps.sum().item()
def _update_voxel_embeddings(self, update_fn):
for name, emb in self.voxel_embeddings.items():
new_emb = self.set_voxel_embedding(update_fn(emb.weight), name)
self._update_optimizer(emb.weight, new_emb.weight, update_fn)
def _update_optimizer(self, old_param: nn.Parameter, new_param: nn.Parameter, update_fn):
optimizer = get_env()["trainer"].optimizer
if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
# Update related states in optimizer
if old_param in optimizer.state:
if new_param is not None:
# Transfer state from old parameter to new parameter
state = optimizer.state[old_param]
state.update({
key: update_fn(state[key])
for key in ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'] if key in state
})
optimizer.state[new_param] = state
# Remove state of old parameter
optimizer.state.pop(old_param)
# Update parameter list in optimizer
for group in optimizer.param_groups:
try:
if new_param is not None:
# Replace old parameter with new one
idx = group['params'].index(old_param)
group['params'][idx] = new_param
else:
# Or just remove old parameter if new parameter is not specified
group['params'].remove(old_param)
except Exception:
pass
def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)
return self.voxel_indices_in_grid[1:].reshape(*self.steps).ne(-1).sum(sum_dims)
def balance_cut(self, dim: int, n_parts: int) -> List[int]:
n_voxels_list = self.n_voxels_along_dim(dim)
......@@ -269,10 +391,11 @@ class Voxels(Space):
part = 1
offset = 0
for i in range(len(cdf)):
if cdf[i] >= part:
bins.append(i + 1 - offset)
offset = i + 1
if cdf[i] > part:
bins.append(i - offset)
offset = i
part = int(cdf[i]) + 1
bins.append(len(cdf) - offset)
return bins
def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -299,17 +422,16 @@ class Voxels(Space):
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_voxel_indices_in_grid(self):
def _update_gi2vi(self):
"""
Update voxel indices in grid.
"""
grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps)
self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1)
self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device)
gi = to_grid_indices(self.voxels, self.bbox, self.steps)
# Perserve the first element in voxel_indices_in_grid for 'invalid voxel'(-1)
self.voxel_indices_in_grid = gi.new_full([self.n_grids + 1], -1)
self.voxel_indices_in_grid[gi + 1] = torch.arange(self.n_voxels, device=self.device)
@torch.no_grad()
def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs):
def _before_load_state_dict(self, state_dict, prefix, *args):
# Handle buffers
for name, buffer in self.named_buffers(recurse=False):
if name in self._non_persistent_buffers_set:
......@@ -320,12 +442,17 @@ class Voxels(Space):
for name, module in self.named_modules():
if name.startswith('emb_'):
setattr(self, name, torch.nn.Embedding(self.n_corners, module.embedding_dim))
if name.startswith('vemb_'):
setattr(self, name, torch.nn.Embedding(self.n_voxels, module.embedding_dim))
def _after_load_state_dict(self):
self._update_gi2vi()
class Octree(Voxels):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def __init__(self, clone_src: "Octree" = None, **kwargs) -> None:
super().__init__(clone_src, **kwargs)
self.nodes_cached = None
self.tree_cached = None
......
MIT License
Copyright (c) 2020 bmild
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# NeRF: Neural Radiance Fields
### [Project Page](http://tancik.com/nerf) | [Video](https://youtu.be/JuH79E8rdKc) | [Paper](https://arxiv.org/abs/2003.08934) | [Data](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
[![Open Tiny-NeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb)<br>
Tensorflow implementation of optimizing a neural representation for a single scene and rendering new views.<br><br>
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
[Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*<sup>1</sup>,
[Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*<sup>1</sup>,
[Matthew Tancik](http://tancik.com/)\*<sup>1</sup>,
[Jonathan T. Barron](http://jonbarron.info/)<sup>2</sup>,
[Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)<sup>3</sup>,
[Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)<sup>1</sup> <br>
<sup>1</sup>UC Berkeley, <sup>2</sup>Google Research, <sup>3</sup>UC San Diego
\*denotes equal contribution
in ECCV 2020 (Oral Presentation, Best Paper Honorable Mention)
<img src='imgs/pipeline.jpg'/>
## TL;DR quickstart
To setup a conda environment, download example training data, begin the training process, and launch Tensorboard:
```
conda env create -f environment.yml
conda activate nerf
bash download_example_data.sh
python run_nerf.py --config config_fern.txt
tensorboard --logdir=logs/summaries --port=6006
```
If everything works without errors, you can now go to `localhost:6006` in your browser and watch the "Fern" scene train.
## Setup
Python 3 dependencies:
* Tensorflow 1.15
* matplotlib
* numpy
* imageio
* configargparse
The LLFF data loader requires ImageMagick.
We provide a conda environment setup file including all of the above dependencies. Create the conda environment `nerf` by running:
```
conda env create -f environment.yml
```
You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
## What is a NeRF?
A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views.
Optimizing a NeRF takes between a few hours and a day or two (depending on resolution) and only requires a single GPU. Rendering an image from an optimized NeRF takes somewhere between less than a second and ~30 seconds, again depending on resolution.
## Running code
Here we show how to run our code on two example scenes. You can download the rest of the synthetic and real data used in the paper [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
### Optimizing a NeRF
Run
```
bash download_example_data.sh
```
to get the our synthetic Lego dataset and the LLFF Fern dataset.
To optimize a low-res Fern NeRF:
```
python run_nerf.py --config config_fern.txt
```
After 200k iterations (about 15 hours), you should get a video like this at `logs/fern_test/fern_test_spiral_200000_rgb.mp4`:
![ferngif](https://people.eecs.berkeley.edu/~bmild/nerf/fern_200k_256w.gif)
To optimize a low-res Lego NeRF:
```
python run_nerf.py --config config_lego.txt
```
After 200k iterations, you should get a video like this:
![legogif](https://people.eecs.berkeley.edu/~bmild/nerf/lego_200k_256w.gif)
### Rendering a NeRF
Run
```
bash download_example_weights.sh
```
to get a pretrained high-res NeRF for the Fern dataset. Now you can use [`render_demo.ipynb`](https://github.com/bmild/nerf/blob/master/render_demo.ipynb) to render new views.
### Replicating the paper results
The example config files run at lower resolutions than the quantitative/qualitative results in the paper and video. To replicate the results from the paper, start with the config files in [`paper_configs/`](https://github.com/bmild/nerf/tree/master/paper_configs). Our synthetic Blender data and LLFF scenes are hosted [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and the DeepVoxels data is hosted by Vincent Sitzmann [here](https://drive.google.com/open?id=1lUvJWB6oFtT8EQ_NzBrXnmi25BufxRfl).
### Extracting geometry from a NeRF
Check out [`extract_mesh.ipynb`](https://github.com/bmild/nerf/blob/master/extract_mesh.ipynb) for an example of running marching cubes to extract a triangle mesh from a trained NeRF network. You'll need the install the [PyMCubes](https://github.com/pmneila/PyMCubes) package for marching cubes plus the [trimesh](https://github.com/mikedh/trimesh) and [pyrender](https://github.com/mmatl/pyrender) packages if you want to render the mesh inside the notebook:
```
pip install trimesh pyrender PyMCubes
```
## Generating poses for your own scenes
### Don't have poses?
We recommend using the `imgs2poses.py` script from the [LLFF code](https://github.com/fyusion/llff). Then you can pass the base scene directory into our code using `--datadir <myscene>` along with `-dataset_type llff`. You can take a look at the `config_fern.txt` config file for example settings to use for a forward facing scene. For a spherically captured 360 scene, we recomment adding the `--no_ndc --spherify --lindisp` flags.
### Already have poses!
In `run_nerf.py` and all other code, we use the same pose coordinate system as in OpenGL: the local camera coordinate system of an image is defined in a way that the X axis points to the right, the Y axis upwards, and the Z axis backwards as seen from the image.
Poses are stored as 3x4 numpy arrays that represent camera-to-world transformation matrices. The other data you will need is simple pinhole camera intrinsics (`hwf = [height, width, focal length]`) and near/far scene bounds. Take a look at [our data loading code](https://github.com/bmild/nerf/blob/master/run_nerf.py#L406) to see more.
## Citation
```
@inproceedings{mildenhall2020nerf,
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
year={2020},
booktitle={ECCV},
}
```
conda create --name nerf python=3.7
conda activate nerf
conda install cudatoolkit=10.0
conda install tensorflow-gpu=1.15
conda install numpy
conda install matplotlib
conda install imageio
conda install imageio-ffmpeg
conda install configargparse
\ No newline at end of file
expname = bar0514
basedir = ./logs
datadir = ./data/barbershop_2021.05.04
dataset_type = llff
factor = 1
llffhold = 20
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0
expname = bedroom_nerf_2021.01.18
basedir = ./logs
datadir = ./data/bedroom_nerf_2021.01.18
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
expname = gallery_nerf_2021.01.20
basedir = ./logs
datadir = ./data/gallery_nerf_2021.01.20
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
expname = gas_nerf_2021.01.17
basedir = ./logs
datadir = ./data/gas_nerf_2021.01.17
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
expname = lego_test
basedir = ./logs
datadir = ./data/nerf_synthetic/lego
dataset_type = blender
half_res = True
no_batching = True
N_samples = 64
N_importance = 64
use_viewdirs = True
white_bkgd = True
N_rand = 1024
\ No newline at end of file
expname = lobby_nerf_2021.01.20
basedir = ./logs
datadir = ./data/lobby_nerf_2021.01.20
dataset_type = llff
factor = 1
llffhold = 8
N_rand = 1024
N_samples = 16
N_importance = 0
netdepth = 4
netwidth = 128
use_viewdirs = True
raw_noise_std = 1e0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment