Commit 5699ccbf authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 338ae906
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn.modules.linear import Identity
from utils.constants import *
from .generic import *
from .sampler import * from .sampler import *
from .input_encoder import * from .input_encoder import *
from .renderer import * from .renderer import *
from .space import *
from .core import *
class NerfCore(nn.Module): \ No newline at end of file
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
dir_chns=0, dir_nf=0, activation='relu', skips=[]):
super().__init__()
self.core = FcNet(in_chns=coord_chns, out_chns=0, nf=core_nf, n_layers=core_layers,
skips=skips, activation=activation)
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
if color_chns == 0:
self.feature_out = None
self.color_out = None
elif dir_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf)
self.color_out = nn.Sequential(
FcLayer(core_nf + dir_chns, dir_nf, activation),
FcLayer(dir_nf, color_chns)
)
else:
self.feature_out = Identity()
self.color_out = FcLayer(core_nf, color_chns)
def forward(self, coord: torch.Tensor, dir: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
core_output = self.core(coord)
density = self.density_out(core_output) if self.density_out is not None else None
if self.color_out is None:
color = None
else:
feature = self.feature_out(core_output)
if dir is not None:
feature = torch.cat([feature, dir], dim=-1)
color = torch.sigmoid(self.color_out(feature))
return color, density
\ No newline at end of file
from .generic import *
from typing import Dict
class NerfCore(nn.Module):
def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
dir_chns=0, dir_nf=0, act='relu', skips=[]):
super().__init__()
self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers,
skips=skips, act=act)
self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
if color_chns == 0:
self.feature_out = None
self.color_out = None
elif dir_chns > 0:
self.feature_out = FcLayer(core_nf, core_nf)
self.color_out = nn.Sequential(
FcLayer(core_nf + dir_chns, dir_nf, act),
FcLayer(dir_nf, color_chns)
)
else:
self.feature_out = torch.nn.Identity()
self.color_out = FcLayer(core_nf, color_chns)
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]:
ret = {}
core_output = self.core(x)
if 'density' in outputs:
ret['density'] = torch.relu(self.density_out(core_output)) \
if self.density_out is not None else None
if 'color' in outputs:
if self.color_out is None:
ret['color'] = None
else:
feature = self.feature_out(core_output)
if dir is not None:
feature = torch.cat([feature, d], dim=-1)
ret['color'] = self.color_out(feature).sigmoid()
for key in outputs:
if key == 'density' or key == 'color':
continue
ret[key] = None
return ret
class NerfAdvCore(nn.Module):
def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
density_net_params: dict, color_net_params: dict,
specular_net_params: dict = None,
appearance="decomposite",
density_color_connection=False):
"""
Create a NeRF-Adv Core Net.
Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
Other parameters will be properly set automatically.
:param x_chns `int`: the channels of input "position"
:param d_chns `int`: the channels of input "direction"
:param density_chns `int`: the channels of output "density"
:param color_chns `int`: the channels of output "color"
:param density_net_params `dict`: parameters for the density net
:param color_net_params `dict`: parameters for the color net
:param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
:param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
:param density_color_connection `bool`: (optional) whether to add connections between
density net and color net, defaults to False
"""
super().__init__()
self.density_chns = density_chns
self.color_chns = color_chns
self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0
self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0
self.appearance = appearance
self.density_color_connection = density_color_connection
self.density_net = FcNet(**density_net_params,
in_chns=x_chns,
out_chns=self.density_chns + self.color_feature_chns,
out_act='relu')
if self.appearance == "newtype":
self.specular_feature_chns = d_chns * 3
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
self.specular_net = "Placeholder"
else:
if self.appearance == "decomposite":
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.color_chns + self.specular_feature_chns)
else:
if specular_net_params:
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + self.color_feature_chns,
out_chns=self.specular_feature_chns)
else:
self.color_net = FcNet(**color_net_params,
in_chns=x_chns + d_chns + self.color_feature_chns,
out_chns=self.color_chns)
self.specular_net = FcNet(**specular_net_params,
in_chns=d_chns + self.specular_feature_chns,
out_chns=self.color_chns) if specular_net_params else None
def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *,
color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]:
input_shape = x.shape[:-1]
if len(input_shape) > 1:
x = x.flatten(0, -2)
d = d.flatten(0, -2)
n = x.shape[0]
c = self.color_chns
ret: Dict[str, torch.Tensor] = {}
if 'density' in outputs:
density_net_out: torch.Tensor = self.density_net(x)
ret['density'] = density_net_out[:, :self.density_chns]
color_feats = density_net_out[:, self.density_chns:]
if 'color_feat' in outputs:
ret['color_feat'] = color_feats
if 'color' in outputs or 'specluar' in outputs:
if 'density' in ret:
valid_mask = ret['density'][:, 0].detach() >= 1e-4
indices = valid_mask.nonzero()[:, 0]
x, d, color_feats = x[indices], d[indices], color_feats[indices]
else:
indices = None
speculars = None
color_net_in = [x]
if not self.specular_net:
color_net_in.append(d)
if self.density_color_connection:
color_net_in.append(color_feats)
color_net_in = torch.cat(color_net_in, -1)
color_net_out: torch.Tensor = self.color_net(color_net_in)
diffuses = color_net_out[:, :c]
specular_features = color_net_out[:, -self.specular_feature_chns:]
if self.appearance == "newtype":
speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]),
d[..., None])[..., 0]
# TODO relu or not?
diffuses = diffuses.relu()
speculars = speculars.relu()
colors = diffuses + speculars
else:
if not self.specular_net:
colors = diffuses
diffuses = None
else:
specular_net_in = torch.cat([d, specular_features], -1)
specular_net_out = self.specular_net(specular_net_in)
if self.appearance == "decomposite":
speculars = specular_net_out
colors = diffuses + speculars
else:
diffuses = None
colors = specular_net_out
colors = torch.sigmoid(colors) # TODO indent or not?
if 'color' in outputs:
ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \
if indices else colors
if 'diffuse' in outputs:
ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \
if indices is not None and diffuses is not None else diffuses
if 'specular' in outputs:
ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \
if indices is not None and speculars is not None else speculars
if len(input_shape) > 1:
ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()}
return ret
...@@ -34,7 +34,7 @@ class Sine(nn.Module): ...@@ -34,7 +34,7 @@ class Sine(nn.Module):
class FcLayer(nn.Module): class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, activation: str = 'linear', skip_chns: int = 0): def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0):
super().__init__() super().__init__()
nls_and_inits = { nls_and_inits = {
'sine': (Sine(), sine_init), 'sine': (Sine(), sine_init),
...@@ -48,7 +48,7 @@ class FcLayer(nn.Module): ...@@ -48,7 +48,7 @@ class FcLayer(nn.Module):
'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init), 'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init),
'linear': (None, None) 'linear': (None, None)
} }
nl, nl_weight_init = nls_and_inits[activation] nl, nl_weight_init = nls_and_inits[act]
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(in_chns + skip_chns, out_chns), nn.Linear(in_chns + skip_chns, out_chns),
...@@ -59,7 +59,7 @@ class FcLayer(nn.Module): ...@@ -59,7 +59,7 @@ class FcLayer(nn.Module):
if nl_weight_init is not None: if nl_weight_init is not None:
nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0]) nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0])
else: else:
self.init_params(activation) self.init_params(act)
def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor: def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
return self.net(torch.cat([x0, x], dim=-1) if self.skip else x) return self.net(torch.cat([x0, x], dim=-1) if self.skip else x)
...@@ -68,9 +68,9 @@ class FcLayer(nn.Module): ...@@ -68,9 +68,9 @@ class FcLayer(nn.Module):
linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0] linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0]
return linear_net.weight, linear_net.bias return linear_net.weight, linear_net.bias
def init_params(self, activation): def init_params(self, act):
weight, bias = self.get_params() weight, bias = self.get_params()
nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(activation)) nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act))
nn.init.zeros_(bias) nn.init.zeros_(bias)
def copy_to(self, layer): def copy_to(self, layer):
...@@ -83,7 +83,7 @@ class FcLayer(nn.Module): ...@@ -83,7 +83,7 @@ class FcLayer(nn.Module):
class FcNet(nn.Module): class FcNet(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int, def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
skips: List[int] = [], activation: str = 'relu'): skips: List[int] = [], act: str = 'relu', out_act = 'linear'):
""" """
Initialize a full-connection net Initialize a full-connection net
...@@ -95,12 +95,12 @@ class FcNet(nn.Module): ...@@ -95,12 +95,12 @@ class FcNet(nn.Module):
""" """
super().__init__() super().__init__()
self.layers = [FcLayer(in_chns, nf, activation)] + [ self.layers = [FcLayer(in_chns, nf, act)] + [
FcLayer(nf, nf, activation, skip_chns=in_chns if i in skips else 0) FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0)
for i in range(n_layers - 1) for i in range(n_layers - 1)
] ]
if out_chns > 0: if out_chns:
self.layers.append(FcLayer(nf, out_chns)) self.layers.append(FcLayer(nf, out_chns, out_act))
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
self.add_module(f"layer{i}", layer) self.add_module(f"layer{i}", layer)
......
from itertools import cycle
from math import ceil
from typing import Dict, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as nn_f
from utils.constants import * from utils.constants import *
from utils.perf import perf
from .generic import * from .generic import *
from .sampler import Samples
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
"""
Calculate energies from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: energies which block light rays
"""
if raw_noise_std > 0:
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
densities = densities + torch.normal(0.0, raw_noise_std, densities.size())
return densities * dists[..., None]
def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
"""
Calculate alphas from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: alphas
"""
energies = density2energy(densities, dists, raw_noise_std)
return 1.0 - torch.exp(-energies)
class AlphaComposition(nn.Module): class AlphaComposition(nn.Module):
...@@ -11,18 +47,26 @@ class AlphaComposition(nn.Module): ...@@ -11,18 +47,26 @@ class AlphaComposition(nn.Module):
super().__init__() super().__init__()
def forward(self, colors, alphas, bg=None): def forward(self, colors, alphas, bg=None):
"""
[summary]
:param colors `Tensor(N, P, C)`: [description]
:param alphas `Tensor(N, P, 1)`: [description]
:param bg `Tensor([N, ]C)`: [description], defaults to None
:return `Tensor(N, C)`: [description]
"""
# Compute weight for RGB of each sample along each ray. A cumprod() is # Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this # used to express the idea of the ray not having reflected up to this
# sample yet. # sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1] + TINY_FLOAT, dim=-1) one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2)
one_minus_alpha = torch.cat([ one_minus_alpha = torch.cat([
torch.ones_like(one_minus_alpha[..., 0:1]), torch.ones_like(one_minus_alpha[..., :1, :]),
one_minus_alpha one_minus_alpha
], dim=-1) ], dim=-2)
weights = alphas * one_minus_alpha # (N_rays, N) weights = alphas * one_minus_alpha # (N, P, 1)
# (N_rays, 1|3), computed weighted color of each sample along each ray. # (N, C), computed weighted color of each sample along each ray.
final_color = torch.sum(weights[..., None] * colors, dim=-2) final_color = torch.sum(weights * colors, dim=-2)
# To composite onto a white background, use the accumulated alpha map. # To composite onto a white background, use the accumulated alpha map.
if bg is not None: if bg is not None:
...@@ -38,58 +82,290 @@ class AlphaComposition(nn.Module): ...@@ -38,58 +82,290 @@ class AlphaComposition(nn.Module):
class VolumnRenderer(nn.Module): class VolumnRenderer(nn.Module):
def __init__(self, *, raw_noise_std=0.0, sigma_as_density=True): class States:
""" kernel: nn.Module
Initialize a Rendering module samples: Samples
""" hit_mask: torch.Tensor
early_stop_tolerance: float
N: int
P: int
colors: torch.Tensor
diffuses: torch.Tensor
speculars: torch.Tensor
energies: torch.Tensor
weights: torch.Tensor
cum_energies: torch.Tensor
exp_energies: torch.Tensor
tot_evaluations: Dict[str, int]
chunk: Tuple[slice, slice]
cum_chunk: Tuple[slice, slice]
cum_last: Tuple[slice, slice]
chunk_id: int
@property
def start(self) -> int:
return self.chunk[1].start
@property
def end(self) -> int:
return self.chunk[1].stop
def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None:
self.kernel = kernel
self.samples = samples
self.early_stop_tolerance = early_stop_tolerance
N, P = samples.size
self.hit_mask = samples.voxel_indices != -1 # (N, P)
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
self.tot_evaluations = {}
self.N, self.P = N, P
self.chunk_id = -1
def n_hits(self, start: int = None, end: int = None) -> int:
if start is None:
return self.hit_mask.count_nonzero().item()
if end is None:
return self.hit_mask[:, start].count_nonzero().item()
return self.hit_mask[:, start:end].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int):
if key not in self.tot_evaluations:
self.tot_evaluations[key] = 0
self.tot_evaluations[key] += n
def next_chunk(self, *, length=None, end=None):
start = 0 if not hasattr(self, "chunk") else self.end
length = length or self.P
end = min(end or start + length, self.P)
self.chunk = slice(None), slice(start, end)
self.cum_chunk = slice(None), slice(start + 1, end + 1)
self.cum_last = slice(None), slice(start, start + 1)
self.chunk_id += 1
return self
def __init__(self, **kwargs):
super().__init__() super().__init__()
self.alpha_composition = AlphaComposition()
self.sigma_as_density = sigma_as_density @perf
self.raw_noise_std = raw_noise_std def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *,
raymarching_early_stop_tolerance: float = 0,
def forward(self, colors, sigmas, z_vals, bg_color=None, ret_depth=False, debug=False): raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
"""Transforms model's predictions to semantically meaningful values. **kwargs):
"""
Args: Perform volumn rendering.
color: [num_rays, num_samples along ray, 1|3]. Predicted color from model.
density: [num_rays, num_samples along ray]. Predicted density from model. :param kernel: render kernel
z_vals: [num_rays, num_samples along ray]. Integration time. :param samples `Samples(N, P)`: samples
:param extra_outputs `list[str]`: extra items should be contained in the result dict.
Returns: Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
rgb_map: [num_rays, 1|3]. Estimated RGB color of a ray. :param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
disp_map: [num_rays]. Disparity map. Inverse of depth map. Should between 0 and 1 (0 means no early stop). Defaults to 0
acc_map: [num_rays]. Sum of weights along each ray. :param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
weights: [num_rays, num_samples]. Weights assigned to each sampled color. Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
depth_map: [num_rays]. Estimated distance to object. Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
0 and `None` means not splitting the raymarching process. Defaults to `None`
:return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
""" """
alphas = self.density2alpha(sigmas, z_vals) if self.sigma_as_density \ if samples.size[1] == 0:
else nn_f.sigmoid(sigmas) print("VolumnRenderer.forward(): # of samples is zero")
ret = self.alpha_composition(colors, alphas, bg_color) return None
if ret_depth:
ret['depth'] = torch.sum(ret['weights'] * z_vals, dim=-1) s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance)
if debug:
ret['layers'] = torch.cat([colors, alphas[..., None]], dim=-1) if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)]
if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections
for chunk_samples in cycle(chunk_sections):
self._forward_chunk(s.next_chunk(length=chunk_samples))
if s.end >= s.P:
break
else:
chunk_size = -raymarching_chunk_size_or_sections
chunk_hits = s.n_hits(0)
for i in range(1, s.P):
n_hits = s.n_hits(i)
if chunk_hits + n_hits > chunk_size:
self._forward_chunk(s.next_chunk(end=i))
n_hits = s.n_hits(i)
chunk_hits = 0
chunk_hits += n_hits
self._forward_chunk(s.next_chunk())
ret = {
'color': torch.sum(s.colors * s.weights, 1),
'tot_evaluations': s.tot_evaluations
}
for key in extra_outputs:
if key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse':
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular':
ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers':
ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states':
ret['states'] = s
else:
ret[key] = getattr(s, key)
return ret return ret
def density2alpha(self, densities: torch.Tensor, z_vals: torch.Tensor): # if raymarching_chunk_size == 0:
# raymarching_chunk_samples = 1
# if raymarching_chunk_samples != 0:
# if isinstance(raymarching_chunk_samples, int):
# raymarching_chunk_samples = repeat(raymarching_chunk_samples,
# ceil(s.P / raymarching_chunk_samples))
# chunk_offset = 0
# for chunk_samples in raymarching_chunk_samples:
# start, end = chunk_offset, chunk_offset + chunk_samples
# n_hits = self._forward_chunk(s, start, end)
# if n_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0
# chunk_offset += chunk_samples
# elif raymarching_chunk_size > 0:
# chunk_offset, chunk_hits = 0, s.n_hits(0)
# for i in range(1, s.P):
# n_hits = s.n_hits(i)
# if chunk_hits + n_hits > raymarching_chunk_size:
# self._forward_chunk(s, chunk_offset, i, chunk_hits)
# if chunk_hits > 0 and tolerance > 0: # Early stop
# s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0
# n_hits = s.n_hits(i)
# chunk_hits, chunk_offset = 0, i
# chunk_hits += n_hits
# self._forward_chunk(s, chunk_offset, s.P, chunk_hits)
# else:
# self._forward_chunk(s, 0, s.P)
# return self._composite(s, extra_outputs)
# original_depth = samples.get('original_point_depth', None)
# if original_depth is not None:
# results['z'] = (original_depth * probs).sum(-1)
# if getattr(input_fn, "track_max_probs", False) and (not self.training):
# input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs'])
def _calc_weights(self, s: States):
"""
Calculate weights of samples in composited outputs
:param s `States`: states
:param start `int`: chunk's start
:param end `int`: chunk's end
""" """
Raw value inferred from model to color and alpha s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
:param densities `Tensor(N.rays, N.samples)`: model's output density def _apply_early_stop(self, s: States):
:param z_vals `Tensor(N.rays, N.samples)`: integration time
:return `Tensor(N.rays, N.samples)`: alpha
""" """
Stop rays whose accumulated opacity are larger than a threshold
:param s `States`: s
:param end `int`: chunk's end
"""
if s.end < s.P and s.early_stop_tolerance > 0:
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0
def _forward_chunk(self, s: States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# Infer densities and colors
fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse',
chunk_id=s.chunk_id)
s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
class DensityFirstVolumnRenderer(VolumnRenderer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _forward_chunk(self, s: VolumnRenderer.States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return 0
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X
fi_encoded_x = s.kernel.encode_x(fi_samples) # (N', Ex)
# Infer densities (shape)
fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat',
chunk_id=s.chunk_id)
s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
s.accumulate_tot_evaluations("density", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
# Remove samples whose weights are less than a threshold
s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0
# Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_encoded_x = fi_encoded_x[fi_mask] # (N", Ex)
fi_color_feats = fi_outputs['color_feat'][fi_mask]
# For all valid samples: encode D
fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs]) # (N", Ed)
# Compute 'distance' (in time) between each integration time along a ray. # Infer colors (appearance)
# The 'distance' from the last integration time is infinity. fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse',
# dists: (N_rays, N) chunk_id=s.chunk_id,
dists = z_vals[..., 1:] - z_vals[..., :-1] extras={"color_feats": fi_color_feats})
last_dist = torch.zeros_like(z_vals[..., 0:1]) + TINY_FLOAT # if s.chunk_id == 0:
dists = torch.cat([dists, last_dist], -1) # fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1:
if self.raw_noise_std > 0.: # fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
# Add noise to model's predictions for density. Can be used to # elif s.chunk_id == 2:
# regularize network during training (prevents floater artifacts). # fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
noise = torch.normal(0.0, self.raw_noise_std, densities.size()) # else:
densities = densities + noise # fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
return -torch.exp(-torch.relu(densities) * dists) + 1.0 s.colors.index_put_(fi_idxs, fi_outputs['color'])
if fi_outputs['specular'] is not None:
s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
if fi_outputs['diffuse'] is not None:
s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
from typing import Tuple from .space import Space, Voxels
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple
from utils import device from utils import device
from utils import sphere from utils import sphere
from utils.constants import * from utils.constants import *
from utils.perf import perf, checkpoint
from .generic import * from .generic import *
from clib import *
class Bins(object): class Bins(object):
@property
def up(self):
return self.bounds[1:]
@property
def lo(self):
return self.bounds[:-1]
def __init__(self, vals: torch.Tensor): def __init__(self, vals: torch.Tensor):
self.vals = vals self.vals = vals
self.bounds = torch.cat([ self.bounds = torch.cat([
...@@ -16,8 +28,6 @@ class Bins(object): ...@@ -16,8 +28,6 @@ class Bins(object):
0.5 * (self.vals[1:] + self.vals[:-1]), 0.5 * (self.vals[1:] + self.vals[:-1]),
self.vals[-1:] self.vals[-1:]
]) ])
self.up = self.bounds[1:]
self.lo = self.bounds[:-1]
@staticmethod @staticmethod
def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None): def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
...@@ -26,14 +36,60 @@ class Bins(object): ...@@ -26,14 +36,60 @@ class Bins(object):
def to(self, device: torch.device): def to(self, device: torch.device):
self.vals = self.vals.to(device) self.vals = self.vals.to(device)
self.bounds = self.bounds.to(device) self.bounds = self.bounds.to(device)
self.up = self.bounds[1:]
self.lo = self.bounds[:-1]
class Samples:
pts: torch.Tensor
"""`Tensor(N[, P], 3)`"""
dirs: torch.Tensor
"""`Tensor(N[, P], 3)`"""
depths: torch.Tensor
"""`Tensor(N[, P])`"""
dists: torch.Tensor
"""`Tensor(N[, P])`"""
voxel_indices: torch.Tensor
"""`Tensor(N[, P])`"""
@property
def size(self):
return self.pts.size()[:-1]
@property
def device(self):
return self.pts.device
def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor,
dists: torch.Tensor, voxel_indices: torch.Tensor) -> None:
self.pts = pts
self.dirs = dirs
self.depths = depths
self.dists = dists
self.voxel_indices = voxel_indices
def __getitem__(self, index):
return Samples(
pts=self.pts[index],
dirs=self.dirs[index],
depths=self.depths[index],
dists=self.dists[index],
voxel_indices=self.voxel_indices[index])
def reshape(self, *shape: int):
return Samples(
pts=self.pts.reshape(*shape, 3),
dirs=self.dirs.reshape(*shape, 3),
depths=self.depths.reshape(*shape),
dists=self.dists.reshape(*shape),
voxel_indices=self.voxel_indices.reshape(*shape))
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs):
perturb_sample: bool, spherical: bool, lindisp: bool):
""" """
Initialize a Sampler module Initialize a Sampler module
...@@ -44,37 +100,81 @@ class Sampler(nn.Module): ...@@ -44,37 +100,81 @@ class Sampler(nn.Module):
""" """
super().__init__() super().__init__()
self.lindisp = lindisp self.lindisp = lindisp
self.spherical = spherical
self.perturb_sample = perturb_sample
s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
if s_range[1] > s_range[0]:
s_range[0] += 1e-4
s_range[1] -= 1e-4
else:
s_range[0] -= 1e-4
s_range[1] += 1e-4
self.bins = Bins.linspace(s_range, n_samples, device=device.default()) self.bins = Bins.linspace(s_range, n_samples, device=device.default())
def forward(self, rays_o, rays_d): @perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
""" """
Sample points along rays. return Spherical or Cartesian coordinates, Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical` specified by `self.shperical`
:param rays_o `Tensor(B, 3)`: rays' origin :param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction :param rays_d `Tensor(N, 3)`: rays' direction
:return `Tensor(B, N, 3)`: sampled points :return `Samples(N, P)`: samples
:return `Tensor(B, N)`: corresponding depths along rays
""" """
s = self.bins.vals.expand(rays_o.size(0), -1) s = self.bins.vals.expand(rays_o.size(0), -1)
if self.perturb_sample: if perturb_sample:
s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s) s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
pts, depths = self._get_sample_points(rays_o, rays_d, s)
voxel_indices = space_module.get_voxel_indices(pts)
valid_rays_mask = voxel_indices.ne(-1).any(dim=-1)
return Samples(
pts=pts,
dirs=rays_d[:, None].expand(-1, depths.size(1), -1),
depths=depths,
dists=self._calc_dists(depths),
voxel_indices=voxel_indices
)[valid_rays_mask], valid_rays_mask
def _get_sample_points(self, rays_o, rays_d, s):
z = torch.reciprocal(s) if self.lindisp else s z = torch.reciprocal(s) if self.lindisp else s
if self.spherical: pts = rays_o[:, None] + rays_d[:, None] * z[..., None]
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z) depths = z
sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp) return pts, depths
return sphers, depths, s, pts
else: def _calc_dists(self, vals):
return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None # Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N)
dists = vals[..., 1:] - vals[..., :-1]
last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT
return torch.cat([dists, last_dist], -1)
class SphericalSampler(Sampler):
def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, **kwargs):
"""
Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super().__init__(sample_range=sample_range, n_samples=n_samples,
perturb_sample=perturb_sample, lindisp=False)
def _get_sample_points(self, rays_o, rays_d, s):
r = torch.reciprocal(s)
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r)
pts = sphere.cartesian2spherical(pts, inverse_r=True)
return pts, depths
class PdfSampler(nn.Module): class PdfSampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
spherical: bool, lindisp: bool): spherical: bool, lindisp: bool, **kwargs):
""" """
Initialize a Sampler module Initialize a Sampler module
...@@ -90,7 +190,7 @@ class PdfSampler(nn.Module): ...@@ -90,7 +190,7 @@ class PdfSampler(nn.Module):
self.n_samples = n_samples self.n_samples = n_samples
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False): def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
""" """
Sample points along rays. return Spherical or Cartesian coordinates, Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical` specified by `self.shperical`
...@@ -166,22 +266,116 @@ class PdfSampler(nn.Module): ...@@ -166,22 +266,116 @@ class PdfSampler(nn.Module):
class VoxelSampler(nn.Module): class VoxelSampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs):
lindisp: bool, space):
""" """
Initialize a Sampler module Initialize a VoxelSampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths :param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth :param step_size: step size
""" """
super().__init__() super().__init__()
self.lindisp = lindisp
self.perturb_sample = perturb_sample self.perturb_sample = perturb_sample
self.n_samples = n_samples self.sample_step = sample_step
self.space = space
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
"""
[summary]
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False): :param rays_o `Tensor(N, 3)`: rays' origin positions
:param rays_d `Tensor(N, 3)`: rays' directions
\ No newline at end of file :param step_size `float`: gap between samples along a ray
:return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
:return `Tensor(N)`: valid rays mask
"""
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
n_rays = rays_o.size(0)
ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N')
hits = intersections.hits
min_depths = intersections.min_depths
max_depths = intersections.max_depths
voxel_indices = intersections.voxel_indices
rays_near_depth = min_depths[:, :1] # (N', 1)
rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1)
rays_length = rays_far_depth - rays_near_depth
rays_steps = (rays_length / self.sample_step).ceil().long()
rays_step_size = rays_length / rays_steps
max_steps = rays_steps.max().item()
rays_step = torch.arange(max_steps, device=rays_o.device,
dtype=torch.float)[None].repeat(n_rays, 1) # (N', P)
invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[
ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth)
] # (N', P)
samples_depth[invalid_samples_mask] = HUGE_FLOAT
samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * samples_depth[..., None],
dirs=rays_d.expand(-1, max_steps, -1),
depths=samples_depth,
dists=samples_dist,
voxel_indices=samples_voxel_index
), valid_rays_mask
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
**kwargs) -> Tuple[Samples, torch.Tensor]:
"""
[summary]
:param rays_o `Tensor(N, 3)`: [description]
:param rays_d `Tensor(N, 3)`: [description]
:param step_size `float`: [description]
:return `Samples(N, P)`: [description]
"""
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
checkpoint("Ray intersect")
if intersections.size == 0:
return None, valid_rays_mask
else:
min_depth = intersections.min_depths
max_depth = intersections.max_depths
pts_idx = intersections.voxel_indices
dists = max_depth - min_depth
tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1)
probs = dists / tot_dists
steps = tot_dists[:, 0] / self.sample_step
# sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample)
sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT)
checkpoint("Inverse CDF sampling")
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * sampled_depths[..., None],
dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
depths=sampled_depths,
dists=sampled_dists,
voxel_indices=sampled_indices
), valid_rays_mask
from math import ceil
import torch
import numpy as np
from typing import List, NoReturn, Tuple, Union
from torch import nn
from plyfile import PlyData, PlyElement
from utils.geometry import *
from utils.constants import *
from utils.voxels import *
from utils.perf import perf
from clib import *
class Intersections:
min_depths: torch.Tensor
"""`Tensor(N, P)` Min ray depths of intersected voxels"""
max_depths: torch.Tensor
"""`Tensor(N, P)` Max ray depths of intersected voxels"""
voxel_indices: torch.Tensor
"""`Tensor(N, P)` Indices of intersected voxels"""
hits: torch.Tensor
"""`Tensor(N)` Number of hits"""
@property
def size(self):
return self.hits.size(0)
def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
self.min_depths = min_depths
self.max_depths = max_depths
self.voxel_indices = voxel_indices
self.hits = hits
def __getitem__(self, index):
return Intersections(
min_depths=self.min_depths[index],
max_depths=self.max_depths[index],
voxel_indices=self.voxel_indices[index],
hits=self.hits[index])
class Space(nn.Module):
bbox: Union[torch.Tensor, None]
"""`Tensor(2, 3)` Bounding box"""
def __init__(self, *, bbox: List[float] = None, **kwargs):
super().__init__()
if bbox is None:
self.bbox = None
else:
self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False)
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
raise NotImplementedError
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
raise NotImplementedError
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
raise NotImplementedError
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
if self.bbox is not None:
out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1) # (N...)
voxel_indices[out_bbox] = -1
return voxel_indices
@torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
raise NotImplementedError()
class Voxels(Space):
steps: torch.Tensor
"""`Tensor(3)` Steps along each dimension"""
corners: torch.Tensor
"""`Tensor(C, 3)` Corner positions"""
voxels: torch.Tensor
"""`Tensor(M, 3)` Voxel centers"""
corner_indices: torch.Tensor
"""`Tensor(M, 8)` Voxel corner indices"""
voxel_indices_in_grid: torch.Tensor
"""`Tensor(G)` Indices in voxel list or -1 for pruned space"""
@property
def dims(self) -> int:
"""`int` Number of dimensions"""
return self.steps.size(0)
@property
def n_voxels(self) -> int:
"""`int` Number of voxels"""
return self.voxels.size(0)
@property
def n_corner(self) -> int:
"""`int` Number of corners"""
return self.corners.size(0)
@property
def voxel_size(self) -> torch.Tensor:
"""`Tensor(3)` Voxel size"""
return (self.bbox[1] - self.bbox[0]) / self.steps
@property
def device(self) -> torch.device:
return self.voxels.device
def __init__(self, *, voxel_size: float = None,
steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None:
super().__init__(**kwargs)
if self.bbox is None:
raise ValueError("Missing argument 'bbox'")
if voxel_size is not None:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
else:
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels))
self._register_load_state_dict_pre_hook(self._before_load_state_dict)
def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
"""
Create a embedding on voxel corners.
:param name `str`: embedding name
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
"""
name = f'emb_{name}'
self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims))
return self.__getattr__(name)
def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
return getattr(self, f'emb_{name}')
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
"""
Extract embedding values at given points using trilinear interpolation.
:param pts `Tensor(N, 3)`: points to extract values
:param voxel_indices `Tensor(N)`: corresponding voxel indices
:param name `str`: embedding name, default to 'default'
:return `Tensor(N, X)`: extracted values
"""
emb = self.get_embedding(name)
if emb is None:
raise KeyError(f"Embedding '{name}' doesn't exist")
voxels = self.voxels[voxel_indices] # (N, 3)
corner_indices = self.corner_indices[voxel_indices] # (N, 8)
p = (pts - voxels) / self.voxel_size + 0.5 # (N, 3) normed-coords in voxel
features = emb(corner_indices).reshape(pts.size(0), 8, -1) # (N, 8, X)
return trilinear_interp(p, features)
@perf
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
"""
Calculate intersections of rays and voxels.
:param rays_o `Tensor(N, 3)`: rays' origin
:param rays_d `Tensor(N, 3)`: rays' direction
:param n_max_hits `int`: maximum number of hits (for allocating enough space)
:return `Intersection`: intersections of rays and voxels
"""
# Prepend a dim to meet the requirement of external call
rays_o = rays_o[None].contiguous()
rays_d = rays_d[None].contiguous()
voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits)
invalid_voxel_mask = voxel_indices.eq(-1)
hits = n_max_hits - invalid_voxel_mask.sum(-1)
# Sort intersections according to their depths
min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
min_depths, sorted_idx = min_depths.sort(dim=-1)
max_depths = max_depths.gather(-1, sorted_idx)
voxel_indices = voxel_indices.gather(-1, sorted_idx)
return Intersections(
min_depths=min_depths[0],
max_depths=max_depths[0],
voxel_indices=voxel_indices[0],
hits=hits[0]
)
@perf
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
"""
Get voxel indices of points.
If a point is not in any valid voxels, its corresponding voxel index is -1.
:param pts `Tensor(N..., 3)`: points
:return `Tensor(N...)`: corresponding voxel indices
"""
grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps)
grid_indices[out_mask] = 0
voxel_indices = self.voxel_indices_in_grid[grid_indices]
voxel_indices[out_mask] = -1
return voxel_indices
@torch.no_grad()
def splitting(self) -> None:
"""
Split voxels into smaller voxels with half size.
"""
n_voxels_before = self.n_voxels
self.steps *= 2
self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
.reshape(-1, 3)
self._update_corners()
self._update_voxel_indices_in_grid()
return n_voxels_before, self.n_voxels
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
self.voxels = self.voxels[keeps]
self.corner_indices = self.corner_indices[keeps]
self._update_voxel_indices_in_grid()
return keeps.size(0), keeps.sum().item()
@torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5) -> None:
scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0]) # (M)
return self.prune(scores > threshold)
def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)
def balance_cut(self, dim: int, n_parts: int) -> List[int]:
n_voxels_list = self.n_voxels_along_dim(dim)
cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
bins = []
part = 1
offset = 0
for i in range(len(cdf)):
if cdf[i] >= part:
bins.append(i + 1 - offset)
offset = i + 1
part = int(cdf[i]) + 1
return bins
def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits)
sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand(
*sampled_xyz.shape[:2])
sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten()
@torch.no_grad()
def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor:
def get_scores_once(pts, idxs):
scores = score_fn(pts, idxs).reshape(-1, bits ** 3) # (B, P)
if reduce_fn is not None:
scores = reduce_fn(scores) # (B[, ...])
return scores
sampled_xyz, sampled_idx = self.sample(bits)
chunk_size = 64
return torch.cat([
get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
for i in range(0, self.voxels.size(0), chunk_size)
], 0) # (M[, ...])
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_corners(self):
"""
Update voxel corners.
"""
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
def _update_voxel_indices_in_grid(self):
"""
Update voxel indices in grid.
"""
grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps)
self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1)
self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device)
@torch.no_grad()
def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs):
# Handle buffers
for name, buffer in self.named_buffers(recurse=False):
if name in self._non_persistent_buffers_set:
continue
buffer.resize_as_(state_dict[prefix + name])
# Handle embeddings
for name, module in self.named_modules():
if name.startswith('emb_'):
setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim))
class Octree(Voxels):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.nodes_cached = None
self.tree_cached = None
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
if self.nodes_cached is None:
self.nodes_cached, self.tree_cached = build_easy_octree(
self.voxels, 0.5 * self.voxel_size)
return self.nodes_cached, self.tree_cached
def clear(self):
self.nodes_cached = None
self.tree_cached = None
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int):
nodes, tree = self.get()
return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)
@torch.no_grad()
def splitting(self):
ret = super().splitting()
self.clear()
return ret
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
ret = super().prune(keeps)
self.clear()
return ret
nerf++ @ a30f1a5a
Subproject commit a30f1a5ad116e43aad90c426a966b2a3fcedaf7e
import torch
import torch.nn as nn
from modules import *
from utils import color
class Nerf(nn.Module):
def __init__(self, fc_params, sampler_params, *,
c: int = color.RGB,
n_pos_encode: int = 0,
n_dir_encode: int = None,
coarse_net=None, **kwargs):
"""
Initialize a NeRF unit
:param fc_params `dict`: parameters for full-connection network
:param sampler_params `dict`: parameters for sampler
:param c `int`: color mode
:param n_pos_encode `int`: encode position to number of dimensions
:param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
:param coarse_net `NerfUnit`: optional coarse net
"""
super().__init__()
self.coarse_net = coarse_net
self.color = c
self.coord_chns = 3
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
if n_dir_encode is not None:
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
else:
self.dir_chns = 0
self.dir_encoder = None
self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
density_chns=1,
color_chns=self.color_chns,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=fc_params['nf'] // 2,
activation=fc_params['activation'],
skips=fc_params['skips'])
sampler_params['spherical'] = False
self.sampler = PdfSampler(**sampler_params) if self.coarse_net is not None \
else Sampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
ret_depth=False, debug=False) -> torch.Tensor:
"""
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param prev_ret `Mapping`:
:param ret_depth `bool`:
:return: `Tensor(B, C)``, inferred images/pixels
"""
if self.coarse_net is not None:
coarse_ret = self.coarse_net(rays_o, rays_d, ret_depth=ret_depth, debug=debug)
coords, depths, s_vals, _ = self.sampler(rays_o, rays_d, coarse_ret['sample'],
coarse_ret['weight'])
else:
coords, depths, s_vals, _ = self.sampler(rays_o, rays_d)
coords_encoded = self.pos_encoder(coords)
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, s_vals.size(-1), -1) \
if self.dir_encoder is not None else None
colors, densities = self.core(coords_encoded, dirs_encoded)
ret = self.rendering(colors, densities[..., 0], depths, ret_depth=ret_depth, debug=debug)
ret['sample'] = s_vals
if self.coarse_net is not None:
ret['coarse'] = coarse_ret
return ret
import torch
import torch.nn as nn
from modules import *
from utils import color
class NSVF(nn.Module):
def __init__(self, fc_params, sampler_params, *,
c: int = color.RGB,
n_featdim: int = 32,
n_pos_encode: int = 0,
n_dir_encode: int = None,
**kwargs):
"""
Initialize a NSVF model
:param fc_params `dict`: parameters for full-connection network
:param sampler_params `dict`: parameters for sampler
:param c `int`: color mode
:param n_pos_encode `int`: encode position to number of dimensions
:param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
:param coarse_net `NerfUnit`: optional coarse net
"""
super().__init__()
self.color = c
self.coord_chns = n_featdim
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
if n_dir_encode is not None:
self.dir_chns = 3
self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
else:
self.dir_chns = 0
self.dir_encoder = None
self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
density_chns=1,
color_chns=self.color_chns,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
dir_nf=fc_params['nf'] // 2,
activation=fc_params['activation'],
skips=fc_params['skips'])
self.space = OctTreeSpace()
sampler_params['space'] = self.space
self.sampler = VoxelSampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
ret_depth=False, debug=False) -> torch.Tensor:
"""
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param prev_ret `Mapping`:
:param ret_depth `bool`:
:return: `Tensor(B, C)``, inferred images/pixels
"""
feats, dirs, z_s, dz_s = self.sampler(rays_o, rays_d)
feats_encoded = self.pos_encoder(feats)
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, z_s.size(-1), -1) \
if self.dir_encoder is not None else None
colors, densities = self.core(feats_encoded, dirs_encoded)
ret = self.rendering(colors, densities[..., 0], z_s, dz_s, ret_depth=ret_depth, debug=debug)
return ret
import torch
import torch.nn as nn
from modules import *
from utils import sphere
from utils import color
class Snerf(nn.Module):
def __init__(self, fc_params, sampler_params, *,
n_parts: int = 1,
c: int = color.RGB,
pos_encode: int = 10,
dir_encode: int = None,
spherical_dir: bool = False, **kwargs):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param c: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.color = c
self.spherical_dir = spherical_dir
self.n_samples = sampler_params['n_samples']
self.n_parts = n_parts
self.samples_per_part = self.n_samples // self.n_parts
self.coord_chns = 3
self.color_chns = color.chns(self.color)
self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns)
if dir_encode is not None:
self.dir_encoder = InputEncoder.Get(dir_encode, 2 if self.spherical_dir else 3)
self.dir_chns_encoded = self.dir_encoder.out_dim
else:
self.dir_encoder = None
self.dir_chns_encoded = 0
self.nets = nn.ModuleList(
NerfCore(coord_chns=self.pos_encoder.out_dim,
density_chns=1,
color_chns=self.color_chns,
core_nf=fc_params['nf'],
core_layers=fc_params['n_layers'],
dir_chns=self.dir_chns_encoded,
dir_nf=fc_params['nf'] // 2,
activation=fc_params['activation'])
for _ in range(self.n_parts))
sampler_params['spherical'] = True
self.sampler = Sampler(**sampler_params)
self.rendering = VolumnRenderer()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth=False, debug=False) -> torch.Tensor:
"""
rays -> colors
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:return: `Tensor(B, C)``, inferred images/pixels
"""
n_rays = rays_o.size(0)
coords, depths, _, pts = self.sampler(rays_o, rays_d)
coords_encoded = self.pos_encoder(coords)
if self.dir_encoder is not None:
if self.spherical_dir:
dirs_encoded = self.dir_encoder(sphere.calc_local_dir(rays_d, coords, pts))
else:
dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, self.n_samples, -1)
else:
dirs_encoded = None
densities = torch.empty(n_rays, self.n_samples, device=device.default())
colors = torch.empty(n_rays, self.n_samples, self.color_chns, device=device.default())
for i, net in enumerate(self.nets):
s = slice(i * self.samples_per_part, (i + 1) * self.samples_per_part)
c, d = net(coords_encoded[:, s],
dirs_encoded[:, s] if dirs_encoded is not None else None)
colors[:, s] = c
densities[:, s] = d
ret = self.rendering(colors.view(-1, self.n_samples, self.color_chns),
densities, depths, ret_depth=ret_depth, debug=debug)
if debug:
ret['sample_densities'] = densities
ret['sample_depths'] = depths
return ret
class SnerfExport(nn.Module):
def __init__(self, net: Snerf):
super().__init__()
self.net = net
def forward(self, coords_encoded, z_vals):
colors = []
densities = []
for i in range(self.net.n_parts):
s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part)
mlp = self.net.nets[i] if self.net.nets is not None else self.net.net
c, d = mlp(coords_encoded[:, s].flatten(1, 2))
colors.append(c.view(-1, self.net.samples_per_part, self.net.color_chns))
densities.append(d)
colors = torch.cat(colors, 1)
densities = torch.cat(densities, 1)
alphas = self.net.rendering.density2alpha(densities, z_vals)
return torch.cat([colors, alphas[..., None]], -1)
...@@ -3,13 +3,11 @@ ...@@ -3,13 +3,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import sys\n", "import sys\n",
"import os\n", "import os\n",
"import torch\n", "import torch\n",
"import torch.nn as nn\n", "import torch.nn.functional as nn_f\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n", "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
...@@ -18,21 +16,16 @@ ...@@ -18,21 +16,16 @@
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n", "torch.autograd.set_grad_enabled(False)\n",
"\n", "\n",
"from data.spherical_view_syn import *\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n",
"from utils import img\n", "from utils import img\n",
"from utils import device\n",
"from utils.view import *\n", "from utils.view import *\n",
"from components.fnr import FoveatedNeuralRenderer\n",
"\n", "\n",
"datadir = f\"{rootdir}/data/__new/__demo/for_crop\"\n", "datadir = f\"{rootdir}/data/__new/__demo/for_crop\"\n",
"figs = ['our', 'gt', 'nerf', 'fgt']\n", "figs = ['our', 'gt', 'nerf', 'fgt']\n",
"crops = {\n", "crops = {\n",
" 'classroom_0': [[720, 800, 128], [1097, 982, 256]],\n", " 'classroom_0': [[720, 790, 100], [370, 1160, 200]],\n",
" 'lobby_1': [[570, 1000, 100], [1049, 1049, 256]],\n", " 'lobby_1': [[570, 1000, 100], [1300, 1000, 200]],\n",
" 'stones_2': [[720, 800, 100], [680, 1317, 256]],\n", " 'stones_2': [[720, 800, 100], [680, 1317, 200]],\n",
" 'barbershop_3': [[745, 810, 100], [1135, 627, 256]]\n", " 'barbershop_3': [[745, 810, 100], [950, 900, 200]]\n",
"}\n", "}\n",
"colors = torch.tensor([[0, 1, 0, 1], [1, 1, 0, 1]], dtype=torch.float)\n", "colors = torch.tensor([[0, 1, 0, 1], [1, 1, 0, 1]], dtype=torch.float)\n",
"border = 10\n", "border = 10\n",
...@@ -78,16 +71,18 @@ ...@@ -78,16 +71,18 @@
" img.save(torch.cat([fovea_patches, periph_patches], dim=-1),\n", " img.save(torch.cat([fovea_patches, periph_patches], dim=-1),\n",
" [f\"{datadir}/patch/{scene}_{fig}.png\" for fig in figs])\n", " [f\"{datadir}/patch/{scene}_{fig}.png\" for fig in figs])\n",
" img.save(overlay, f\"{datadir}/overlay/{scene}.png\")\n" " img.save(overlay, f\"{datadir}/overlay/{scene}.png\")\n"
] ],
"outputs": [],
"metadata": {}
} }
], ],
"metadata": { "metadata": {
"interpreter": { "interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" "hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3.8.5 64-bit ('base': conda)", "name": "python3",
"name": "python3" "display_name": "Python 3.8.5 64-bit ('base': conda)"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
......
...@@ -170,7 +170,7 @@ ...@@ -170,7 +170,7 @@
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n", " images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n", " if True:\n",
" outputdir = '../__demo/mono/'\n", " outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n", " os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n", " img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
" img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n", " img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
" img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n", " img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
...@@ -203,7 +203,7 @@ ...@@ -203,7 +203,7 @@
" center = (0, 0)\n", " center = (0, 0)\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n", " images = renderer(views.get(view_idx), center, using_mask=True)\n",
" outputdir = 'panorama'\n", " outputdir = 'panorama'\n",
" misc.create_dir(outputdir)\n", " os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')" " img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')"
], ],
"outputs": [ "outputs": [
......
...@@ -216,7 +216,7 @@ ...@@ -216,7 +216,7 @@
" ret_raw=False)\n", " ret_raw=False)\n",
" if True:\n", " if True:\n",
" outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n", " outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n",
" misc.create_dir(outputdir)\n", " os.makedirs(outputdir, exist_ok=True)\n",
" img.save(torch.cat([\n", " img.save(torch.cat([\n",
" left_images['blended'],\n", " left_images['blended'],\n",
" right_images['blended']\n", " right_images['blended']\n",
...@@ -228,7 +228,7 @@ ...@@ -228,7 +228,7 @@
" right_images['blended'][:, 1:3]\n", " right_images['blended'][:, 1:3]\n",
" ], dim=1)\n", " ], dim=1)\n",
" img.save(stereo_overlap, '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n", " img.save(stereo_overlap, '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n",
" #misc.create_dir(outputdir + '/mid')\n", " #os.makedirs(outputdir + '/mid', exist_ok=True)\n",
" #img.save(left_images['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n", " #img.save(left_images['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n",
" #img.save(right_images['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n", " #img.save(right_images['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n",
" print(\"%s %d Saved\" % (scene, i))\n", " print(\"%s %d Saved\" % (scene, i))\n",
......
...@@ -110,7 +110,7 @@ ...@@ -110,7 +110,7 @@
" #plot_figures(images, center)\n", " #plot_figures(images, center)\n",
"\n", "\n",
" outputdir = '../__1_eval/output_mono_periph/ref_as_right_eye/%s/' % scene\n", " outputdir = '../__1_eval/output_mono_periph/ref_as_right_eye/%s/' % scene\n",
" misc.create_dir(outputdir)\n", " os.makedirs(outputdir, exist_ok=True)\n",
" #for key in images:\n", " #for key in images:\n",
" key = 'blended'\n", " key = 'blended'\n",
" img.save(images[key], outputdir + 'view%04d_%s.png' % (view_idx, key))\n" " img.save(images[key], outputdir + 'view%04d_%s.png' % (view_idx, key))\n"
...@@ -131,7 +131,7 @@ ...@@ -131,7 +131,7 @@
" images = gen.gen(center, test_view, True)\n", " images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n", " #plot_figures(images, center)\n",
"\n", "\n",
" misc.create_dir('output/eval_gaze')\n", " os.makedirs('output/eval_gaze', exist_ok=True)\n",
" out_path = 'output/eval_gaze/gaze%03d_%d,%d.png' % (gaze_idx, x, y)\n", " out_path = 'output/eval_gaze/gaze%03d_%d,%d.png' % (gaze_idx, x, y)\n",
" img.save(images['blended'], out_path)\n", " img.save(images['blended'], out_path)\n",
" print('Output ' + out_path)\n", " print('Output ' + out_path)\n",
......
...@@ -130,7 +130,7 @@ ...@@ -130,7 +130,7 @@
" images = gen.gen(center, test_view, True)\n", " images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n", " #plot_figures(images, center)\n",
"\n", "\n",
" misc.create_dir('output/teasers')\n", " os.makedirs('output/teasers', exist_ok=True)\n",
" for key in images:\n", " for key in images:\n",
" img.save(\n", " img.save(\n",
" images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n" " images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
......
...@@ -150,7 +150,7 @@ ...@@ -150,7 +150,7 @@
"print(\"Encoded:\", encoded)\n", "print(\"Encoded:\", encoded)\n",
"#plot_figures(images, center)\n", "#plot_figures(images, center)\n",
"\n", "\n",
"#misc.create_dir('output/teasers')\n", "#os.makedirs('output/teasers', exist_ok=True)\n",
"#for key in images:\n", "#for key in images:\n",
"# img.save(\n", "# img.save(\n",
"# images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n" "# images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
......
...@@ -188,7 +188,7 @@ ...@@ -188,7 +188,7 @@
"\n", "\n",
"#plot_figures(left_images, right_images, centers[set_id][0], centers[set_id][1])\n", "#plot_figures(left_images, right_images, centers[set_id][0], centers[set_id][1])\n",
"\n", "\n",
"misc.create_dir('output')\n", "os.makedirs('output', exist_ok=True)\n",
"for key in left_images:\n", "for key in left_images:\n",
" img.save(\n", " img.save(\n",
" left_images[key], 'output/set%d_%s_l.png' % (set_id, key))\n", " left_images[key], 'output/set%d_%s_l.png' % (set_id, key))\n",
......
...@@ -117,7 +117,7 @@ ...@@ -117,7 +117,7 @@
" left_images = gen.gen(left_center, left_view, mono_trans=mono_trans)\n", " left_images = gen.gen(left_center, left_view, mono_trans=mono_trans)\n",
" right_images = gen.gen(right_center, right_view, mono_trans=mono_trans)\n", " right_images = gen.gen(right_center, right_view, mono_trans=mono_trans)\n",
" \n", " \n",
" misc.create_dir('output/video_frames/hmd2')\n", " os.makedirs('output/video_frames/hmd2', exist_ok=True)\n",
" img.save(torch.cat([left_images['blended'], right_images['blended']], -1),\n", " img.save(torch.cat([left_images['blended'], right_images['blended']], -1),\n",
" 'output/video_frames/hmd2/view%04d.png' % view_idx)\n", " 'output/video_frames/hmd2/view%04d.png' % view_idx)\n",
" print('Frame %d saved' % view_idx)\n" " print('Frame %d saved' % view_idx)\n"
......
...@@ -155,7 +155,7 @@ ...@@ -155,7 +155,7 @@
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n", " images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n", " if True:\n",
" outputdir = '../__demo/mono/'\n", " outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n", " os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n", " img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
" img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n", " img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
" img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n", " img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
...@@ -196,7 +196,7 @@ ...@@ -196,7 +196,7 @@
" center = (0, 0)\n", " center = (0, 0)\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n", " images = renderer(views.get(view_idx), center, using_mask=True)\n",
" outputdir = 'nerf_our'\n", " outputdir = 'nerf_our'\n",
" misc.create_dir(outputdir)\n", " os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')" " img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')"
] ]
} }
......
...@@ -101,7 +101,7 @@ ...@@ -101,7 +101,7 @@
"gaze = [37.55656052, 20.7297554]\n", "gaze = [37.55656052, 20.7297554]\n",
"images = renderer(view, gaze, using_mask=False, ret_raw=True)\n", "images = renderer(view, gaze, using_mask=False, ret_raw=True)\n",
"outputdir = '../__demo/mono_f60&m110/'\n", "outputdir = '../__demo/mono_f60&m110/'\n",
"misc.create_dir(outputdir)\n", "os.makedirs(outputdir, exist_ok=True)\n",
"img.save(images['layers_img'][0], f'{outputdir}{scene}_fovea.png')\n", "img.save(images['layers_img'][0], f'{outputdir}{scene}_fovea.png')\n",
"img.save(images['blended'], f'{outputdir}{scene}.png')\n", "img.save(images['blended'], f'{outputdir}{scene}.png')\n",
"img.save(images['blended_raw'], f'{outputdir}{scene}_noCE.png')" "img.save(images['blended_raw'], f'{outputdir}{scene}_noCE.png')"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment