Commit 5699ccbf authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 338ae906
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn.modules.linear import Identity
from utils.constants import *
from .generic import *
from .sampler import *
from .input_encoder import *
from .renderer import *
class NerfCore(nn.Module):
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
dir_chns=0, dir_nf=0, activation='relu', skips=[]):
super().__init__()
self.core = FcNet(in_chns=coord_chns, out_chns=0, nf=core_nf, n_layers=core_layers,
skips=skips, activation=activation)
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
if color_chns == 0:
self.feature_out = None
self.color_out = None
elif dir_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf)
self.color_out = nn.Sequential(
FcLayer(core_nf + dir_chns, dir_nf, activation),
FcLayer(dir_nf, color_chns)
)
else:
self.feature_out = Identity()
self.color_out = FcLayer(core_nf, color_chns)
def forward(self, coord: torch.Tensor, dir: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
core_output = self.core(coord)
density = self.density_out(core_output) if self.density_out is not None else None
if self.color_out is None:
color = None
else:
feature = self.feature_out(core_output)
if dir is not None:
feature = torch.cat([feature, dir], dim=-1)
color = torch.sigmoid(self.color_out(feature))
return color, density
\ No newline at end of file
from .space import *
from .core import *
\ No newline at end of file
from .generic import *
from typing import Dict
class NerfCore(nn.Module):
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
dir_chns=0, dir_nf=0, act='relu', skips=[]):
super().__init__()
self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers,
skips=skips, act=act)
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
if color_chns == 0:
self.feature_out = None
self.color_out = None
elif dir_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf)
self.color_out = nn.Sequential(
FcLayer(core_nf + dir_chns, dir_nf, act),
FcLayer(dir_nf, color_chns)
)
else:
self.feature_out = torch.nn.Identity()
self.color_out = FcLayer(core_nf, color_chns)
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]:
ret = {}
core_output = self.core(x)
if 'density' in outputs:
ret['density'] = torch.relu(self.density_out(core_output)) \
if self.density_out is not None else None
if 'color' in outputs:
if self.color_out is None:
ret['color'] = None
else:
feature = self.feature_out(core_output)
if dir is not None:
feature = torch.cat([feature, d], dim=-1)
ret['color'] = self.color_out(feature).sigmoid()
for key in outputs:
if key == 'density' or key == 'color':
continue
ret[key] = None
return ret
class NerfAdvCore(nn.Module):
def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
density_net_params: dict, color_net_params: dict,
specular_net_params: dict = None,
appearance="decomposite",
density_color_connection=False):
"""
Create a NeRF-Adv Core Net.
Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
Other parameters will be properly set automatically.
:param x_chns `int`: the channels of input "position"
:param d_chns `int`: the channels of input "direction"
:param density_chns `int`: the channels of output "density"
:param color_chns `int`: the channels of output "color"
:param density_net_params `dict`: parameters for the density net
:param color_net_params `dict`: parameters for the color net
:param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
:param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
:param density_color_connection `bool`: (optional) whether to add connections between
density net and color net, defaults to False
"""
super().__init__()
self.density_chns = density_chns
self.color_chns = color_chns
self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0
self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0
self.appearance = appearance
self.density_color_connection = density_color_connection
self.density_net = FcNet(**density_net_params,
in_chns=x_chns,
out_chns=self.density_chns + self.color_feature_chns,
out_act='relu')
if self.appearance == "newtype":
self.specular_feature_chns = d_chns * 3
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
self.specular_net = "Placeholder"
else:
if self.appearance == "decomposite":
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
else:
if specular_net_params:
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.specular_feature_chns)
else:
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + d_chns + self.color_feature_chns,
out_chns=self.color_chns)
self.specular_net = FcNet(**specular_net_params,
in_chns=d_chns + self.specular_feature_chns,
out_chns=self.color_chns) if specular_net_params else None
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *,
color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]:
input_shape = x.shape[:-1]
if len(input_shape) > 1:
x = x.flatten(0, -2)
d = d.flatten(0, -2)
n = x.shape[0]
c = self.color_chns
ret: Dict[str, torch.Tensor] = {}
if 'density' in outputs:
density_net_out: torch.Tensor = self.density_net(x)
ret['density'] = density_net_out[:, :self.density_chns]
color_feats = density_net_out[:, self.density_chns:]
if 'color_feat' in outputs:
ret['color_feat'] = color_feats
if 'color' in outputs or 'specluar' in outputs:
if 'density' in ret:
valid_mask = ret['density'][:, 0].detach() >= 1e-4
indices = valid_mask.nonzero()[:, 0]
x, d, color_feats = x[indices], d[indices], color_feats[indices]
else:
indices = None
speculars = None
color_net_in = [x]
if not self.specular_net:
color_net_in.append(d)
if self.density_color_connection:
color_net_in.append(color_feats)
color_net_in = torch.cat(color_net_in, -1)
color_net_out: torch.Tensor = self.color_net(color_net_in)
diffuses = color_net_out[:, :c]
specular_features = color_net_out[:, -self.specular_feature_chns:]
if self.appearance == "newtype":
speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]),
d[..., None])[..., 0]
# TODO relu or not?
diffuses = diffuses.relu()
speculars = speculars.relu()
colors = diffuses + speculars
else:
if not self.specular_net:
colors = diffuses
diffuses = None
else:
specular_net_in = torch.cat([d, specular_features], -1)
specular_net_out = self.specular_net(specular_net_in)
if self.appearance == "decomposite":
speculars = specular_net_out
colors = diffuses + speculars
else:
diffuses = None
colors = specular_net_out
colors = torch.sigmoid(colors) # TODO indent or not?
if 'color' in outputs:
ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \
if indices else colors
if 'diffuse' in outputs:
ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \
if indices is not None and diffuses is not None else diffuses
if 'specular' in outputs:
ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \
if indices is not None and speculars is not None else speculars
if len(input_shape) > 1:
ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()}
return ret
......@@ -34,7 +34,7 @@ class Sine(nn.Module):
class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, activation: str = 'linear', skip_chns: int = 0):
def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0):
super().__init__()
nls_and_inits = {
'sine': (Sine(), sine_init),
......@@ -48,7 +48,7 @@ class FcLayer(nn.Module):
'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init),
'linear': (None, None)
}
nl, nl_weight_init = nls_and_inits[activation]
nl, nl_weight_init = nls_and_inits[act]
self.net = nn.Sequential(
nn.Linear(in_chns + skip_chns, out_chns),
......@@ -59,7 +59,7 @@ class FcLayer(nn.Module):
if nl_weight_init is not None:
nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0])
else:
self.init_params(activation)
self.init_params(act)
def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
return self.net(torch.cat([x0, x], dim=-1) if self.skip else x)
......@@ -68,9 +68,9 @@ class FcLayer(nn.Module):
linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0]
return linear_net.weight, linear_net.bias
def init_params(self, activation):
def init_params(self, act):
weight, bias = self.get_params()
nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(activation))
nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act))
nn.init.zeros_(bias)
def copy_to(self, layer):
......@@ -83,7 +83,7 @@ class FcLayer(nn.Module):
class FcNet(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
skips: List[int] = [], activation: str = 'relu'):
skips: List[int] = [], act: str = 'relu', out_act = 'linear'):
"""
Initialize a full-connection net
......@@ -95,12 +95,12 @@ class FcNet(nn.Module):
"""
super().__init__()
self.layers = [FcLayer(in_chns, nf, activation)] + [
FcLayer(nf, nf, activation, skip_chns=in_chns if i in skips else 0)
self.layers = [FcLayer(in_chns, nf, act)] + [
FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0)
for i in range(n_layers - 1)
]
if out_chns > 0:
self.layers.append(FcLayer(nf, out_chns))
if out_chns:
self.layers.append(FcLayer(nf, out_chns, out_act))
for i, layer in enumerate(self.layers):
self.add_module(f"layer{i}", layer)
......
from itertools import cycle
from math import ceil
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as nn_f
from utils.constants import *
from utils.perf import perf
from .generic import *
from .sampler import Samples
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
"""
Calculate energies from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: energies which block light rays
"""
if raw_noise_std > 0:
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
densities = densities + torch.normal(0.0, raw_noise_std, densities.size())
return densities * dists[..., None]
def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
"""
Calculate alphas from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: alphas
"""
energies = density2energy(densities, dists, raw_noise_std)
return 1.0 - torch.exp(-energies)
class AlphaComposition(nn.Module):
......@@ -11,18 +47,26 @@ class AlphaComposition(nn.Module):
super().__init__()
def forward(self, colors, alphas, bg=None):
"""
[summary]
:param colors `Tensor(N, P, C)`: [description]
:param alphas `Tensor(N, P, 1)`: [description]
:param bg `Tensor([N, ]C)`: [description], defaults to None
:return `Tensor(N, C)`: [description]
"""
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1] + TINY_FLOAT, dim=-1)
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2)
one_minus_alpha = torch.cat([
torch.ones_like(one_minus_alpha[..., 0:1]),
torch.ones_like(one_minus_alpha[..., :1, :]),
one_minus_alpha
], dim=-1)
weights = alphas * one_minus_alpha # (N_rays, N)
], dim=-2)
weights = alphas * one_minus_alpha # (N, P, 1)
# (N_rays, 1|3), computed weighted color of each sample along each ray.
final_color = torch.sum(weights[..., None] * colors, dim=-2)
# (N, C), computed weighted color of each sample along each ray.
final_color = torch.sum(weights * colors, dim=-2)
# To composite onto a white background, use the accumulated alpha map.
if bg is not None:
......@@ -38,58 +82,290 @@ class AlphaComposition(nn.Module):
class VolumnRenderer(nn.Module):
def __init__(self, *, raw_noise_std=0.0, sigma_as_density=True):
"""
Initialize a Rendering module
"""
class States:
kernel: nn.Module
samples: Samples
hit_mask: torch.Tensor
early_stop_tolerance: float
N: int
P: int
colors: torch.Tensor
diffuses: torch.Tensor
speculars: torch.Tensor
energies: torch.Tensor
weights: torch.Tensor
cum_energies: torch.Tensor
exp_energies: torch.Tensor
tot_evaluations: Dict[str, int]
chunk: Tuple[slice, slice]
cum_chunk: Tuple[slice, slice]
cum_last: Tuple[slice, slice]
chunk_id: int
@property
def start(self) -> int:
return self.chunk[1].start
@property
def end(self) -> int:
return self.chunk[1].stop
def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None:
self.kernel = kernel
self.samples = samples
self.early_stop_tolerance = early_stop_tolerance
N, P = samples.size
self.hit_mask = samples.voxel_indices != -1 # (N, P)
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
self.tot_evaluations = {}
self.N, self.P = N, P
self.chunk_id = -1
def n_hits(self, start: int = None, end: int = None) -> int:
if start is None:
return self.hit_mask.count_nonzero().item()
if end is None:
return self.hit_mask[:, start].count_nonzero().item()
return self.hit_mask[:, start:end].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int):
if key not in self.tot_evaluations:
self.tot_evaluations[key] = 0
self.tot_evaluations[key] += n
def next_chunk(self, *, length=None, end=None):
start = 0 if not hasattr(self, "chunk") else self.end
length = length or self.P
end = min(end or start + length, self.P)
self.chunk = slice(None), slice(start, end)
self.cum_chunk = slice(None), slice(start + 1, end + 1)
self.cum_last = slice(None), slice(start, start + 1)
self.chunk_id += 1
return self
def __init__(self, **kwargs):
super().__init__()
self.alpha_composition = AlphaComposition()
self.sigma_as_density = sigma_as_density
self.raw_noise_std = raw_noise_std
def forward(self, colors, sigmas, z_vals, bg_color=None, ret_depth=False, debug=False):
"""Transforms model's predictions to semantically meaningful values.
Args:
color: [num_rays, num_samples along ray, 1|3]. Predicted color from model.
density: [num_rays, num_samples along ray]. Predicted density from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
Returns:
rgb_map: [num_rays, 1|3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
@perf
def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *,
raymarching_early_stop_tolerance: float = 0,
raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
**kwargs):
"""
Perform volumn rendering.
:param kernel: render kernel
:param samples `Samples(N, P)`: samples
:param extra_outputs `list[str]`: extra items should be contained in the result dict.
Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
Should between 0 and 1 (0 means no early stop). Defaults to 0
:param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
0 and `None` means not splitting the raymarching process. Defaults to `None`
:return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
"""
alphas = self.density2alpha(sigmas, z_vals) if self.sigma_as_density \
else nn_f.sigmoid(sigmas)
ret = self.alpha_composition(colors, alphas, bg_color)
if ret_depth:
ret['depth'] = torch.sum(ret['weights'] * z_vals, dim=-1)
if debug:
ret['layers'] = torch.cat([colors, alphas[..., None]], dim=-1)
if samples.size[1] == 0:
print("VolumnRenderer.forward(): # of samples is zero")
return None
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance)
if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)]
if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections
for chunk_samples in cycle(chunk_sections):
self._forward_chunk(s.next_chunk(length=chunk_samples))
if s.end >= s.P:
break
else:
chunk_size = -raymarching_chunk_size_or_sections
chunk_hits = s.n_hits(0)
for i in range(1, s.P):
n_hits = s.n_hits(i)
if chunk_hits + n_hits > chunk_size:
self._forward_chunk(s.next_chunk(end=i))
n_hits = s.n_hits(i)
chunk_hits = 0
chunk_hits += n_hits
self._forward_chunk(s.next_chunk())
ret = {
'color': torch.sum(s.colors * s.weights, 1),
'tot_evaluations': s.tot_evaluations
}
for key in extra_outputs:
if key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse':
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular':
ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers':
ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states':
ret['states'] = s
else:
ret[key] = getattr(s, key)
return ret
def density2alpha(self, densities: torch.Tensor, z_vals: torch.Tensor):
# if raymarching_chunk_size == 0:
# raymarching_chunk_samples = 1
# if raymarching_chunk_samples != 0:
# if isinstance(raymarching_chunk_samples, int):
# raymarching_chunk_samples = repeat(raymarching_chunk_samples,
# ceil(s.P / raymarching_chunk_samples))
# chunk_offset = 0
# for chunk_samples in raymarching_chunk_samples:
# start, end = chunk_offset, chunk_offset + chunk_samples
# n_hits = self._forward_chunk(s, start, end)
# if n_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0
# chunk_offset += chunk_samples
# elif raymarching_chunk_size > 0:
# chunk_offset, chunk_hits = 0, s.n_hits(0)
# for i in range(1, s.P):
# n_hits = s.n_hits(i)
# if chunk_hits + n_hits > raymarching_chunk_size:
# self._forward_chunk(s, chunk_offset, i, chunk_hits)
# if chunk_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0
# n_hits = s.n_hits(i)
# chunk_hits, chunk_offset = 0, i
# chunk_hits += n_hits
# self._forward_chunk(s, chunk_offset, s.P, chunk_hits)
# else:
# self._forward_chunk(s, 0, s.P)
# return self._composite(s, extra_outputs)
# original_depth = samples.get('original_point_depth', None)
# if original_depth is not None:
# results['z'] = (original_depth * probs).sum(-1)
# if getattr(input_fn, "track_max_probs", False) and (not self.training):
# input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs'])
def _calc_weights(self, s: States):
"""
Raw value inferred from model to color and alpha
Calculate weights of samples in composited outputs
:param densities `Tensor(N.rays, N.samples)`: model's output density
:param z_vals `Tensor(N.rays, N.samples)`: integration time
:return `Tensor(N.rays, N.samples)`: alpha
:param s `States`: states
:param start `int`: chunk's start
:param end `int`: chunk's end
"""
s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N)
dists = z_vals[..., 1:] - z_vals[..., :-1]
last_dist = torch.zeros_like(z_vals[..., 0:1]) + TINY_FLOAT
dists = torch.cat([dists, last_dist], -1)
def _apply_early_stop(self, s: States):
"""
Stop rays whose accumulated opacity are larger than a threshold
if self.raw_noise_std > 0.:
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
noise = torch.normal(0.0, self.raw_noise_std, densities.size())
densities = densities + noise
return -torch.exp(-torch.relu(densities) * dists) + 1.0
:param s `States`: s
:param end `int`: chunk's end
"""
if s.end < s.P and s.early_stop_tolerance > 0:
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0
def _forward_chunk(self, s: States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# Infer densities and colors
fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse',
chunk_id=s.chunk_id)
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
class DensityFirstVolumnRenderer(VolumnRenderer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _forward_chunk(self, s: VolumnRenderer.States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X
fi_encoded_x = s.kernel.encode_x(fi_samples) # (N', Ex)
# Infer densities (shape)
fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat',
chunk_id=s.chunk_id)
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("density", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
# Remove samples whose weights are less than a threshold
s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0
# Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_encoded_x = fi_encoded_x[fi_mask] # (N", Ex)
fi_color_feats = fi_outputs['color_feat'][fi_mask]
# For all valid samples: encode D
fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs]) # (N", Ed)
# Infer colors (appearance)
fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse',
chunk_id=s.chunk_id,
extras={"color_feats": fi_color_feats})
# if s.chunk_id == 0:
# fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1:
# fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
# elif s.chunk_id == 2:
# fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
# else:
# fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
from typing import Tuple
from .space import Space, Voxels
import torch
import torch.nn as nn
from typing import Tuple
from utils import device
from utils import sphere
from utils.constants import *
from utils.perf import perf, checkpoint
from .generic import *
from clib import *
class Bins(object):
@property
def up(self):
return self.bounds[1:]
@property
def lo(self):
return self.bounds[:-1]
def __init__(self, vals: torch.Tensor):
self.vals = vals
self.bounds = torch.cat([
......@@ -16,8 +28,6 @@ class Bins(object):
0.5 * (self.vals[1:] + self.vals[:-1]),
self.vals[-1:]
])
self.up = self.bounds[1:]
self.lo = self.bounds[:-1]
@staticmethod
def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
......@@ -26,14 +36,60 @@ class Bins(object):
def to(self, device: torch.device):
self.vals = self.vals.to(device)
self.bounds = self.bounds.to(device)
self.up = self.bounds[1:]
self.lo = self.bounds[:-1]
class Samples:
pts: torch.Tensor
"""`Tensor(N[, P], 3)`"""
dirs: torch.Tensor
"""`Tensor(N[, P], 3)`"""
depths: torch.Tensor
"""`Tensor(N[, P])`"""
dists: torch.Tensor
"""`Tensor(N[, P])`"""
voxel_indices: torch.Tensor
"""`Tensor(N[, P])`"""
@property
def size(self):
return self.pts.size()[:-1]
@property
def device(self):
return self.pts.device
def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor,
dists: torch.Tensor, voxel_indices: torch.Tensor) -> None:
self.pts = pts
self.dirs = dirs
self.depths = depths
self.dists = dists
self.voxel_indices = voxel_indices
def __getitem__(self, index):
return Samples(
pts=self.pts[index],
dirs=self.dirs[index],
depths=self.depths[index],
dists=self.dists[index],
voxel_indices=self.voxel_indices[index])
def reshape(self, *shape: int):
return Samples(
pts=self.pts.reshape(*shape, 3),
dirs=self.dirs.reshape(*shape, 3),
depths=self.depths.reshape(*shape),
dists=self.dists.reshape(*shape),
voxel_indices=self.voxel_indices.reshape(*shape))
class Sampler(nn.Module):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, spherical: bool, lindisp: bool):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs):
"""
Initialize a Sampler module
......@@ -44,37 +100,81 @@ class Sampler(nn.Module):
"""
super().__init__()
self.lindisp = lindisp
self.spherical = spherical
self.perturb_sample = perturb_sample
s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
if s_range[1] > s_range[0]:
s_range[0] += 1e-4
s_range[1] -= 1e-4
else:
s_range[0] -= 1e-4
s_range[1] += 1e-4
self.bins = Bins.linspace(s_range, n_samples, device=device.default())
def forward(self, rays_o, rays_d):
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
"""
Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical`
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:return `Tensor(B, N, 3)`: sampled points
:return `Tensor(B, N)`: corresponding depths along rays
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:return `Samples(N, P)`: samples
"""
s = self.bins.vals.expand(rays_o.size(0), -1)
if self.perturb_sample:
if perturb_sample:
s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
pts, depths = self._get_sample_points(rays_o, rays_d, s)
voxel_indices = space_module.get_voxel_indices(pts)
valid_rays_mask = voxel_indices.ne(-1).any(dim=-1)
return Samples(
pts=pts,
dirs=rays_d[:, None].expand(-1, depths.size(1), -1),
depths=depths,
dists=self._calc_dists(depths),
voxel_indices=voxel_indices
)[valid_rays_mask], valid_rays_mask
def _get_sample_points(self, rays_o, rays_d, s):
z = torch.reciprocal(s) if self.lindisp else s
if self.spherical:
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
return sphers, depths, s, pts
else:
return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None
pts = rays_o[:, None] + rays_d[:, None] * z[..., None]
depths = z
return pts, depths
def _calc_dists(self, vals):
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N)
dists = vals[..., 1:] - vals[..., :-1]
last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT
return torch.cat([dists, last_dist], -1)
class SphericalSampler(Sampler):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, **kwargs):
"""
Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super().__init__(sample_range=sample_range, n_samples=n_samples,
perturb_sample=perturb_sample, lindisp=False)
def _get_sample_points(self, rays_o, rays_d, s):
r = torch.reciprocal(s)
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r)
pts = sphere.cartesian2spherical(pts, inverse_r=True)
return pts, depths
class PdfSampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
spherical: bool, lindisp: bool):
spherical: bool, lindisp: bool, **kwargs):
"""
Initialize a Sampler module
......@@ -90,7 +190,7 @@ class PdfSampler(nn.Module):
self.n_samples = n_samples
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
"""
Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical`
......@@ -166,22 +266,116 @@ class PdfSampler(nn.Module):
class VoxelSampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
lindisp: bool, space):
def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs):
"""
Initialize a Sampler module
Initialize a VoxelSampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
:param step_size: step size
"""
super().__init__()
self.lindisp = lindisp
self.perturb_sample = perturb_sample
self.n_samples = n_samples
self.space = space
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
self.sample_step = sample_step
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
"""
[summary]
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
:param rays_o `Tensor(N, 3)`: rays' origin positions
:param rays_d `Tensor(N, 3)`: rays' directions
:param step_size `float`: gap between samples along a ray
:return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
:return `Tensor(N)`: valid rays mask
"""
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
n_rays = rays_o.size(0)
ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N')
hits = intersections.hits
min_depths = intersections.min_depths
max_depths = intersections.max_depths
voxel_indices = intersections.voxel_indices
rays_near_depth = min_depths[:, :1] # (N', 1)
rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1)
rays_length = rays_far_depth - rays_near_depth
rays_steps = (rays_length / self.sample_step).ceil().long()
rays_step_size = rays_length / rays_steps
max_steps = rays_steps.max().item()
rays_step = torch.arange(max_steps, device=rays_o.device,
dtype=torch.float)[None].repeat(n_rays, 1) # (N', P)
invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[
ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth)
] # (N', P)
samples_depth[invalid_samples_mask] = HUGE_FLOAT
samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * samples_depth[..., None],
dirs=rays_d.expand(-1, max_steps, -1),
depths=samples_depth,
dists=samples_dist,
voxel_indices=samples_voxel_index
), valid_rays_mask
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
"""
[summary]
:param rays_o `Tensor(N, 3)`: [description]
:param rays_d `Tensor(N, 3)`: [description]
:param step_size `float`: [description]
:return `Samples(N, P)`: [description]
"""
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
checkpoint("Ray intersect")
if intersections.size == 0:
return None, valid_rays_mask
else:
min_depth = intersections.min_depths
max_depth = intersections.max_depths
pts_idx = intersections.voxel_indices
dists = max_depth - min_depth
tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1)
probs = dists / tot_dists
steps = tot_dists[:, 0] / self.sample_step
# sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample)
sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT)
checkpoint("Inverse CDF sampling")
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * sampled_depths[..., None],
dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
depths=sampled_depths,
dists=sampled_dists,
voxel_indices=sampled_indices
), valid_rays_mask
from math import ceil
import torch
import numpy as np
from typing import List, NoReturn, Tuple, Union
from torch import nn
from plyfile import PlyData, PlyElement
from utils.geometry import *
from utils.constants import *
from utils.voxels import *
from utils.perf import perf
from clib import *
class Intersections:
min_depths: torch.Tensor
"""`Tensor(N, P)` Min ray depths of intersected voxels"""
max_depths: torch.Tensor
"""`Tensor(N, P)` Max ray depths of intersected voxels"""
voxel_indices: torch.Tensor
"""`Tensor(N, P)` Indices of intersected voxels"""
hits: torch.Tensor
"""`Tensor(N)` Number of hits"""
@property
def size(self):
return self.hits.size(0)
def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
self.min_depths = min_depths
self.max_depths = max_depths
self.voxel_indices = voxel_indices
self.hits = hits
def __getitem__(self, index):
return Intersections(
min_depths=self.min_depths[index],
max_depths=self.max_depths[index],
voxel_indices=self.voxel_indices[index],
hits=self.hits[index])
class Space(nn.Module):
bbox: Union[torch.Tensor, None]
"""`Tensor(2, 3)` Bounding box"""
def __init__(self, *, bbox: List[float] = None, **kwargs):
super().__init__()
if bbox is None:
self.bbox = None
else:
self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False)
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
raise NotImplementedError
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
raise NotImplementedError
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
raise NotImplementedError
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
if self.bbox is not None:
out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1) # (N...)
voxel_indices[out_bbox] = -1
return voxel_indices
@torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
raise NotImplementedError()
class Voxels(Space):
steps: torch.Tensor
"""`Tensor(3)` Steps along each dimension"""
corners: torch.Tensor
"""`Tensor(C, 3)` Corner positions"""
voxels: torch.Tensor
"""`Tensor(M, 3)` Voxel centers"""
corner_indices: torch.Tensor
"""`Tensor(M, 8)` Voxel corner indices"""
voxel_indices_in_grid: torch.Tensor
"""`Tensor(G)` Indices in voxel list or -1 for pruned space"""
@property
def dims(self) -> int:
"""`int` Number of dimensions"""
return self.steps.size(0)
@property
def n_voxels(self) -> int:
"""`int` Number of voxels"""
return self.voxels.size(0)
@property
def n_corner(self) -> int:
"""`int` Number of corners"""
return self.corners.size(0)
@property
def voxel_size(self) -> torch.Tensor:
"""`Tensor(3)` Voxel size"""
return (self.bbox[1] - self.bbox[0]) / self.steps
@property
def device(self) -> torch.device:
return self.voxels.device
def __init__(self, *, voxel_size: float = None,
steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None:
super().__init__(**kwargs)
if self.bbox is None:
raise ValueError("Missing argument 'bbox'")
if voxel_size is not None:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
else:
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels))
self._register_load_state_dict_pre_hook(self._before_load_state_dict)
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
"""
Create a embedding on voxel corners.
:param name `str`: embedding name
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
"""
name = f'emb_{name}'
self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims))
return self.__getattr__(name)
def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
return getattr(self, f'emb_{name}')
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
"""
Extract embedding values at given points using trilinear interpolation.
:param pts `Tensor(N, 3)`: points to extract values
:param voxel_indices `Tensor(N)`: corresponding voxel indices
:param name `str`: embedding name, default to 'default'
:return `Tensor(N, X)`: extracted values
"""
emb = self.get_embedding(name)
if emb is None:
raise KeyError(f"Embedding '{name}' doesn't exist")
voxels = self.voxels[voxel_indices] # (N, 3)
corner_indices = self.corner_indices[voxel_indices] # (N, 8)
p = (pts - voxels) / self.voxel_size + 0.5 # (N, 3) normed-coords in voxel
features = emb(corner_indices).reshape(pts.size(0), 8, -1) # (N, 8, X)
return trilinear_interp(p, features)
@perf
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
"""
Calculate intersections of rays and voxels.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param n_max_hits `int`: maximum number of hits (for allocating enough space)
:return `Intersection`: intersections of rays and voxels
"""
# Prepend a dim to meet the requirement of external call
rays_o = rays_o[None].contiguous()
rays_d = rays_d[None].contiguous()
voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits)
invalid_voxel_mask = voxel_indices.eq(-1)
hits = n_max_hits - invalid_voxel_mask.sum(-1)
# Sort intersections according to their depths
min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
min_depths, sorted_idx = min_depths.sort(dim=-1)
max_depths = max_depths.gather(-1, sorted_idx)
voxel_indices = voxel_indices.gather(-1, sorted_idx)
return Intersections(
min_depths=min_depths[0],
max_depths=max_depths[0],
voxel_indices=voxel_indices[0],
hits=hits[0]
)
@perf
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
"""
Get voxel indices of points.
If a point is not in any valid voxels, its corresponding voxel index is -1.
:param pts `Tensor(N..., 3)`: points
:return `Tensor(N...)`: corresponding voxel indices
"""
grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps)
grid_indices[out_mask] = 0
voxel_indices = self.voxel_indices_in_grid[grid_indices]
voxel_indices[out_mask] = -1
return voxel_indices
@torch.no_grad()
def splitting(self) -> None:
"""
Split voxels into smaller voxels with half size.
"""
n_voxels_before = self.n_voxels
self.steps *= 2
self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
.reshape(-1, 3)
self._update_corners()
self._update_voxel_indices_in_grid()
return n_voxels_before, self.n_voxels
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
self.voxels = self.voxels[keeps]
self.corner_indices = self.corner_indices[keeps]
self._update_voxel_indices_in_grid()
return keeps.size(0), keeps.sum().item()
@torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5) -> None:
scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0]) # (M)
return self.prune(scores > threshold)
def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)
def balance_cut(self, dim: int, n_parts: int) -> List[int]:
n_voxels_list = self.n_voxels_along_dim(dim)
cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
bins = []
part = 1
offset = 0
for i in range(len(cdf)):
if cdf[i] >= part:
bins.append(i + 1 - offset)
offset = i + 1
part = int(cdf[i]) + 1
return bins
def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits)
sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand(
*sampled_xyz.shape[:2])
sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten()
@torch.no_grad()
def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor:
def get_scores_once(pts, idxs):
scores = score_fn(pts, idxs).reshape(-1, bits ** 3) # (B, P)
if reduce_fn is not None:
scores = reduce_fn(scores) # (B[, ...])
return scores
sampled_xyz, sampled_idx = self.sample(bits)
chunk_size = 64
return torch.cat([
get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
for i in range(0, self.voxels.size(0), chunk_size)
], 0) # (M[, ...])
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_corners(self):
"""
Update voxel corners.
"""
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
def _update_voxel_indices_in_grid(self):
"""
Update voxel indices in grid.
"""
grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps)
self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1)
self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device)
@torch.no_grad()
def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs):
# Handle buffers
for name, buffer in self.named_buffers(recurse=False):
if name in self._non_persistent_buffers_set:
continue
buffer.resize_as_(state_dict[prefix + name])
# Handle embeddings
for name, module in self.named_modules():
if name.startswith('emb_'):
setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim))
class Octree(Voxels):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.nodes_cached = None
self.tree_cached = None
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
if self.nodes_cached is None:
self.nodes_cached, self.tree_cached = build_easy_octree(
self.voxels, 0.5 * self.voxel_size)
return self.nodes_cached, self.tree_cached
def clear(self):
self.nodes_cached = None
self.tree_cached = None
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int):
nodes, tree = self.get()
return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)
@torch.no_grad()
def splitting(self):
ret = super().splitting()
self.clear()
return ret
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
ret = super().prune(keeps)
self.clear()
return ret
nerf++ @ a30f1a5a
Subproject commit a30f1a5ad116e43aad90c426a966b2a3fcedaf7e
import torch
import torch.nn as nn
from modules import *
from utils import color
class Nerf(nn.Module):
def __init__(self, fc_params, sampler_params, *,
c: int = color.RGB,
n_pos_encode: int = 0,
n_dir_encode: int = None,
coarse_net=None, **kwargs):
"""
Initialize a NeRF unit
:param fc_params `dict`: parameters for full-connection network
:param sampler_params `dict`: parameters for sampler
:param c `int`: color mode
:param n_pos_encode `int`: encode position to number of dimensions
:param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
:param coarse_net `NerfUnit`: optional coarse net
"""
super().__init__()
self.coarse_net = coarse_net
self.color = c
self.coord_chns = 3
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
if n_dir_encode is not None:
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
else:
self.dir_chns = 0
self.dir_encoder = None
self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
density_chns=1,
color_chns=self.color_chns,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=fc_params['nf'] // 2,
activation=fc_params['activation'],
skips=fc_params['skips'])
sampler_params['spherical'] = False
self.sampler = PdfSampler(**sampler_params) if self.coarse_net is not None \
else Sampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
ret_depth=False, debug=False) -> torch.Tensor:
"""
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param prev_ret `Mapping`:
:param ret_depth `bool`:
:return: `Tensor(B, C)``, inferred images/pixels
"""
if self.coarse_net is not None:
coarse_ret = self.coarse_net(rays_o, rays_d, ret_depth=ret_depth, debug=debug)
coords, depths, s_vals, _ = self.sampler(rays_o, rays_d, coarse_ret['sample'],
coarse_ret['weight'])
else:
coords, depths, s_vals, _ = self.sampler(rays_o, rays_d)
coords_encoded = self.pos_encoder(coords)
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, s_vals.size(-1), -1) \
if self.dir_encoder is not None else None
colors, densities = self.core(coords_encoded, dirs_encoded)
ret = self.rendering(colors, densities[..., 0], depths, ret_depth=ret_depth, debug=debug)
ret['sample'] = s_vals
if self.coarse_net is not None:
ret['coarse'] = coarse_ret
return ret
import torch
import torch.nn as nn
from modules import *
from utils import color
class NSVF(nn.Module):
def __init__(self, fc_params, sampler_params, *,
c: int = color.RGB,
n_featdim: int = 32,
n_pos_encode: int = 0,
n_dir_encode: int = None,
**kwargs):
"""
Initialize a NSVF model
:param fc_params `dict`: parameters for full-connection network
:param sampler_params `dict`: parameters for sampler
:param c `int`: color mode
:param n_pos_encode `int`: encode position to number of dimensions
:param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
:param coarse_net `NerfUnit`: optional coarse net
"""
super().__init__()
self.color = c
self.coord_chns = n_featdim
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
if n_dir_encode is not None:
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
else:
self.dir_chns = 0
self.dir_encoder = None
self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
density_chns=1,
color_chns=self.color_chns,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=fc_params['nf'] // 2,
activation=fc_params['activation'],
skips=fc_params['skips'])
self.space = OctTreeSpace()
sampler_params['space'] = self.space
self.sampler = VoxelSampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
ret_depth=False, debug=False) -> torch.Tensor:
"""
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param prev_ret `Mapping`:
:param ret_depth `bool`:
:return: `Tensor(B, C)``, inferred images/pixels
"""
feats, dirs, z_s, dz_s = self.sampler(rays_o, rays_d)
feats_encoded = self.pos_encoder(feats)
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, z_s.size(-1), -1) \
if self.dir_encoder is not None else None
colors, densities = self.core(feats_encoded, dirs_encoded)
ret = self.rendering(colors, densities[..., 0], z_s, dz_s, ret_depth=ret_depth, debug=debug)
return ret
import torch
import torch.nn as nn
from modules import *
from utils import sphere
from utils import color
class Snerf(nn.Module):
def __init__(self, fc_params, sampler_params, *,
n_parts: int = 1,
c: int = color.RGB,
pos_encode: int = 10,
dir_encode: int = None,
spherical_dir: bool = False, **kwargs):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.color = c
self.spherical_dir = spherical_dir
self.n_samples = sampler_params['n_samples']
self.n_parts = n_parts
self.samples_per_part = self.n_samples // self.n_parts
self.coord_chns = 3
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns)
if dir_encode is not None:
self.dir_encoder = InputEncoder.Get(dir_encode, 2 if self.spherical_dir else 3)
self.dir_chns_encoded = self.dir_encoder.out_dim
else:
self.dir_encoder = None
self.dir_chns_encoded = 0
self.nets = nn.ModuleList(
NerfCore(coord_chns=self.pos_encoder.out_dim,
density_chns=1,
color_chns=self.color_chns,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
dir_chns=self.dir_chns_encoded,
dir_nf=fc_params['nf'] // 2,
activation=fc_params['activation'])
for _ in range(self.n_parts))
sampler_params['spherical'] = True
self.sampler = Sampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth=False, debug=False) -> torch.Tensor:
"""
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:return: `Tensor(B, C)``, inferred images/pixels
"""
n_rays = rays_o.size(0)
coords, depths, _, pts = self.sampler(rays_o, rays_d)
coords_encoded = self.pos_encoder(coords)
if self.dir_encoder is not None:
if self.spherical_dir:
dirs_encoded = self.dir_encoder(sphere.calc_local_dir(rays_d, coords, pts))
else:
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, self.n_samples, -1)
else:
dirs_encoded = None
densities = torch.empty(n_rays, self.n_samples, device=device.default())
colors = torch.empty(n_rays, self.n_samples, self.color_chns, device=device.default())
for i, net in enumerate(self.nets):
s = slice(i * self.samples_per_part, (i + 1) * self.samples_per_part)
c, d = net(coords_encoded[:, s],
dirs_encoded[:, s] if dirs_encoded is not None else None)
colors[:, s] = c
densities[:, s] = d
ret = self.rendering(colors.view(-1, self.n_samples, self.color_chns),
densities, depths, ret_depth=ret_depth, debug=debug)
if debug:
ret['sample_densities'] = densities
ret['sample_depths'] = depths
return ret
class SnerfExport(nn.Module):
def __init__(self, net: Snerf):
super().__init__()
self.net = net
def forward(self, coords_encoded, z_vals):
colors = []
densities = []
for i in range(self.net.n_parts):
s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part)
mlp = self.net.nets[i] if self.net.nets is not None else self.net.net
c, d = mlp(coords_encoded[:, s].flatten(1, 2))
colors.append(c.view(-1, self.net.samples_per_part, self.net.color_chns))
densities.append(d)
colors = torch.cat(colors, 1)
densities = torch.cat(densities, 1)
alphas = self.net.rendering.density2alpha(densities, z_vals)
return torch.cat([colors, alphas[..., None]], -1)
......@@ -3,13 +3,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as nn_f\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
......@@ -18,21 +16,16 @@
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"from data.spherical_view_syn import *\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n",
"from utils import img\n",
"from utils import device\n",
"from utils.view import *\n",
"from components.fnr import FoveatedNeuralRenderer\n",
"\n",
"datadir = f\"{rootdir}/data/__new/__demo/for_crop\"\n",
"figs = ['our', 'gt', 'nerf', 'fgt']\n",
"crops = {\n",
" 'classroom_0': [[720, 800, 128], [1097, 982, 256]],\n",
" 'lobby_1': [[570, 1000, 100], [1049, 1049, 256]],\n",
" 'stones_2': [[720, 800, 100], [680, 1317, 256]],\n",
" 'barbershop_3': [[745, 810, 100], [1135, 627, 256]]\n",
" 'classroom_0': [[720, 790, 100], [370, 1160, 200]],\n",
" 'lobby_1': [[570, 1000, 100], [1300, 1000, 200]],\n",
" 'stones_2': [[720, 800, 100], [680, 1317, 200]],\n",
" 'barbershop_3': [[745, 810, 100], [950, 900, 200]]\n",
"}\n",
"colors = torch.tensor([[0, 1, 0, 1], [1, 1, 0, 1]], dtype=torch.float)\n",
"border = 10\n",
......@@ -78,16 +71,18 @@
" img.save(torch.cat([fovea_patches, periph_patches], dim=-1),\n",
" [f\"{datadir}/patch/{scene}_{fig}.png\" for fig in figs])\n",
" img.save(overlay, f\"{datadir}/overlay/{scene}.png\")\n"
]
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
"hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
},
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('base': conda)",
"name": "python3"
"name": "python3",
"display_name": "Python 3.8.5 64-bit ('base': conda)"
},
"language_info": {
"codemirror_mode": {
......
......@@ -170,7 +170,7 @@
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n",
" outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
" img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
" img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
......@@ -203,7 +203,7 @@
" center = (0, 0)\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n",
" outputdir = 'panorama'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')"
],
"outputs": [
......
......@@ -216,7 +216,7 @@
" ret_raw=False)\n",
" if True:\n",
" outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" img.save(torch.cat([\n",
" left_images['blended'],\n",
" right_images['blended']\n",
......@@ -228,7 +228,7 @@
" right_images['blended'][:, 1:3]\n",
" ], dim=1)\n",
" img.save(stereo_overlap, '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n",
" #misc.create_dir(outputdir + '/mid')\n",
" #os.makedirs(outputdir + '/mid', exist_ok=True)\n",
" #img.save(left_images['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n",
" #img.save(right_images['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n",
" print(\"%s %d Saved\" % (scene, i))\n",
......
......@@ -110,7 +110,7 @@
" #plot_figures(images, center)\n",
"\n",
" outputdir = '../__1_eval/output_mono_periph/ref_as_right_eye/%s/' % scene\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" #for key in images:\n",
" key = 'blended'\n",
" img.save(images[key], outputdir + 'view%04d_%s.png' % (view_idx, key))\n"
......@@ -131,7 +131,7 @@
" images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n",
"\n",
" misc.create_dir('output/eval_gaze')\n",
" os.makedirs('output/eval_gaze', exist_ok=True)\n",
" out_path = 'output/eval_gaze/gaze%03d_%d,%d.png' % (gaze_idx, x, y)\n",
" img.save(images['blended'], out_path)\n",
" print('Output ' + out_path)\n",
......
......@@ -130,7 +130,7 @@
" images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n",
"\n",
" misc.create_dir('output/teasers')\n",
" os.makedirs('output/teasers', exist_ok=True)\n",
" for key in images:\n",
" img.save(\n",
" images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
......
......@@ -150,7 +150,7 @@
"print(\"Encoded:\", encoded)\n",
"#plot_figures(images, center)\n",
"\n",
"#misc.create_dir('output/teasers')\n",
"#os.makedirs('output/teasers', exist_ok=True)\n",
"#for key in images:\n",
"# img.save(\n",
"# images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
......
......@@ -188,7 +188,7 @@
"\n",
"#plot_figures(left_images, right_images, centers[set_id][0], centers[set_id][1])\n",
"\n",
"misc.create_dir('output')\n",
"os.makedirs('output', exist_ok=True)\n",
"for key in left_images:\n",
" img.save(\n",
" left_images[key], 'output/set%d_%s_l.png' % (set_id, key))\n",
......
......@@ -117,7 +117,7 @@
" left_images = gen.gen(left_center, left_view, mono_trans=mono_trans)\n",
" right_images = gen.gen(right_center, right_view, mono_trans=mono_trans)\n",
" \n",
" misc.create_dir('output/video_frames/hmd2')\n",
" os.makedirs('output/video_frames/hmd2', exist_ok=True)\n",
" img.save(torch.cat([left_images['blended'], right_images['blended']], -1),\n",
" 'output/video_frames/hmd2/view%04d.png' % view_idx)\n",
" print('Frame %d saved' % view_idx)\n"
......
......@@ -155,7 +155,7 @@
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n",
" outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
" img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
" img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
......@@ -196,7 +196,7 @@
" center = (0, 0)\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n",
" outputdir = 'nerf_our'\n",
" misc.create_dir(outputdir)\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')"
]
}
......
......@@ -101,7 +101,7 @@
"gaze = [37.55656052, 20.7297554]\n",
"images = renderer(view, gaze, using_mask=False, ret_raw=True)\n",
"outputdir = '../__demo/mono_f60&m110/'\n",
"misc.create_dir(outputdir)\n",
"os.makedirs(outputdir, exist_ok=True)\n",
"img.save(images['layers_img'][0], f'{outputdir}{scene}_fovea.png')\n",
"img.save(images['blended'], f'{outputdir}{scene}.png')\n",
"img.save(images['blended_raw'], f'{outputdir}{scene}_noCE.png')"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment