Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
from ..__common__ import *
__all__ = ["Field"]
class Field(nn.Module):
def __init__(self, x_chns: int, shape: list[int], skips: list[int] = [],
act: str = 'relu', with_ln: bool = False):
super().__init__({"x": x_chns}, {"f": shape[1]})
self.net = nn.FcBlock(x_chns, 0, *shape, skips, act, with_ln=with_ln)
# stub method for type hint
def __call__(self, x: torch.Tensor) -> torch.Tensor:
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
from ..__common__ import *
from .field import *
from .color_decoder import *
from .density_decoder import *
class FsNeRF(nn.Module):
def __init__(self, x_chns: int, color_chns: int, depth: int, width: int,
skips: list[int], act: str, ln: bool, n_samples: int, n_fields: int):
"""
Initialize a FS-NeRF core module.
:param x_chns `int`: channels of input positions (D_x)
:param d_chns `int`: channels of input directions (D_d)
:param color_chns `int`: channels of output colors (D_c)
:param depth `int`: number of layers in field network
:param width `int`: width of each layer in field network
:param skips `[int]`: skip connections from input to specific layers in field network
:param act `str`: activation function in field network and color decoder
:param ln `bool`: whether enable layer normalization in field network and color decoder
:param color_decoder_type `str`: type of color decoder
"""
super().__init__({"x": x_chns}, {"rgbd": 1 + color_chns})
self.n_fields = n_fields
self.samples_per_field = n_samples // n_fields
self.subnets = torch.nn.ModuleList()
for _ in range(n_fields):
field = Field(x_chns * self.samples_per_field, [depth, width], skips, act, ln)
density_decoder = DensityDecoder(field.out_chns, self.samples_per_field)
color_decoder = BasicColorDecoder(field.out_chns, color_chns * self.samples_per_field)
self.subnets.append(torch.nn.ModuleDict({
"field": field,
"density_decoder": density_decoder,
"color_decoder": color_decoder
}))
# stub method for type hint
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Inference colors and densities from input samples
:param x `Tensor(B..., P, D_x)`: input positions
:return `Tensor(B..., P, D_c + D_σ)`: output colors and densities
"""
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
densities = []
colors = []
for i in range(self.n_fields):
f = self.subnets[i]["field"](
x[..., i * self.samples_per_field:(i + 1) * self.samples_per_field, :].flatten(-2))
densities.append(self.subnets[i]["density_decoder"](f)
.unflatten(-1, (self.samples_per_field, -1)))
colors.append(self.subnets[i]["color_decoder"](f, None)
.unflatten(-1, (self.samples_per_field, -1)))
return torch.cat([torch.cat(colors, -2), torch.cat(densities, -2)], -1)
\ No newline at end of file
from ..__common__ import *
from .field import *
from .color_decoder import *
from .density_decoder import *
class NeRF(nn.Module):
def __init__(self, x_chns: int, d_chns: int, color_chns: int, depth: int, width: int,
skips: list[int], act: str, ln: bool, color_decoder_type: str):
"""
Initialize a NeRF core module.
:param x_chns `int`: channels of input positions (D_x)
:param d_chns `int`: channels of input directions (D_d)
:param color_chns `int`: channels of output colors (D_c)
:param depth `int`: number of layers in field network
:param width `int`: width of each layer in field network
:param skips `[int]`: skip connections from input to specific layers in field network
:param act `str`: activation function in field network and color decoder
:param ln `bool`: whether enable layer normalization in field network and color decoder
:param color_decoder_type `str`: type of color decoder
"""
super().__init__({"x": x_chns, "d": d_chns}, {"density": 1, "color": color_chns})
self.field = Field(x_chns, [depth, width], skips, act, ln)
self.density_decoder = DensityDecoder(self.field.out_chns, 1)
self.color_decoder = ColorDecoder.create(self.field.out_chns, d_chns, color_chns,
color_decoder_type, {"act": act, "with_ln": ln})
# stub method for type hint
def __call__(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
"""
Inference colors and densities from input samples
:param x `Tensor(B..., D_x)`: input positions
:param d `Tensor(B..., D_d)`: input directions
:return `Tensor(B..., D_c + D_σ)`: output colors and densities
"""
...
def forward(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
f = self.field(x)
densities = self.density_decoder(f)
colors = self.color_decoder(f, d)
return torch.cat([colors, densities], -1)
from typing import Tuple from .__common__ import *
import torch
from .generic import * __all__ = ["InputEncoder", "LinearEncoder", "FreqEncoder"]
from utils import math
from utils.module import Module
class InputEncoder(Module): class InputEncoder(nn.Module):
"""
Base class for input encoder.
"""
def __init__(self, chns, L, cat_input=False): def __init__(self, in_chns: int, out_chns: int):
super().__init__() super().__init__({"_": in_chns}, {"_": out_chns})
emb = torch.exp(torch.arange(L, dtype=torch.float) * math.log(2.))
self.emb = nn.Parameter(emb, requires_grad=False) # stub method for type hint
self.in_dim = chns def __call__(self, x: torch.Tensor) -> torch.Tensor:
self.out_dim = chns * (L * 2 + cat_input) """
self.cat_input = cat_input Encode the input tensor.
def forward(self, x: torch.Tensor, angular=False): :param x `Tensor(N..., D)`: D-dim inputs
sizes = x.size() :return `Tensor(N..., E)`: encoded outputs
x0 = x """
...
if angular: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.acos(x.clamp(-1, 1)) raise NotImplementedError()
x = x[..., None] @ self.emb[None]
x = torch.cat([torch.sin(x), torch.cos(x)], -1)
x = x.flatten(-2)
if self.cat_input:
x = torch.cat([x0, x], -1)
return x
def extra_repr(self) -> str: @staticmethod
return f'in={self.in_dim}, out={self.out_dim}, cat_input={self.cat_input}' def create(chns: int, type: str, args: dict[str, Any]) -> "InputEncoder":
class IntegratedPosEncoder(Module):
def __init__(self, chns, L, shape: str, cat_input=False):
super.__init__()
self.shape = shape
def _lift_gaussian(self, d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
r_var: torch.Tensor, diag: bool):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_sq = d**2
d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)
if diag:
d_outer_diag = d_sq
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def _conical_frustum_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, base_radius: float,
diag: bool, stable: bool = True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
""" """
if stable: Create an input encoder of `type` with `args`.
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2 :param chns `int`: input channels
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2) :param type `str`: type of input encoder, without suffix "Encoder"
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) / :param args `{str:Any}`: arguments for initializing the input encoder
(3 * mu**2 + hw**2)**2) :return `InputEncoder`: the created input encoder
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def _cylinder_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
Args:
d: torch.float32 3-vector, the axis of the cylinder
t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
"""
t_mean = (t0 + t1) / 2
r_var = radius**2 / 4
t_var = (t1 - t0)**2 / 12
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def cast_rays(self, t_vals: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor,
rays_r: torch.Tensor, diag: bool = True):
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it.
Args:
t_vals: float array, the "fencepost" distances along the ray.
rays_o: float array, the ray origin coordinates.
rays_d: float array, the ray direction vectors.
radii: float array, the radii (base radii for cones) of the rays.
ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
diag: boolean, whether or not the covariance matrices should be diagonal.
Returns:
a tuple of arrays of means and covariances.
"""
t0 = t_vals[..., :-1]
t1 = t_vals[..., 1:]
if self.shape == 'cone':
gaussian_fn = self._conical_frustum_to_gaussian
elif self.shape == 'cylinder':
gaussian_fn = self._cylinder_to_gaussian
else:
assert False
means, covs = gaussian_fn(rays_d, t0, t1, rays_r, diag)
means = means + rays_o[..., None, :]
return means, covs
def integrated_pos_enc(x_coord: Tuple[torch.Tensor, torch.Tensor], min_deg: int, max_deg: int,
diag: bool = True):
"""Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].
Args:
x_coord: a tuple containing: x, torch.ndarray, variables to be encoded. Should
be in [-pi, pi]. x_cov, torch.ndarray, covariance matrices for `x`.
min_deg: int, the min degree of the encoding.
max_deg: int, the max degree of the encoding.
diag: bool, if true, expects input covariances to be diagonal (full
otherwise).
Returns:
encoded: torch.ndarray, encoded variables.
""" """
if diag: return getattr(sys.modules[__name__], f"{type}Encoder")(chns, **args)
x, x_cov_diag = x_coord
scales = torch.tensor([2**i for i in range(min_deg, max_deg)], device=x.device)[:, None]
shape = list(x.shape[:-1]) + [-1] class LinearEncoder(InputEncoder):
y = torch.reshape(x[..., None, :] * scales, shape) """
y_var = torch.reshape(x_cov_diag[..., None, :] * scales**2, shape) The linear encoder: D -> D.
else: """
x, x_cov = x_coord
num_dims = x.shape[-1] def __init__(self, chns):
basis = torch.cat([ super().__init__(chns, chns)
2**i * torch.eye(num_dims, device=x.device)
for i in range(min_deg, max_deg) def forward(self, x: torch.Tensor):
], 1) return x
y = torch.matmul(x, basis)
# Get the diagonal of a covariance matrix (ie, variance). This is equivalent def extra_repr(self) -> str:
# to jax.vmap(torch.diag)((basis.T @ covs) @ basis). return f"{self.in_chns} -> {self.out_chns}"
y_var = (torch.matmul(x_cov, basis) * basis).sum(-2)
return math.expected_sin( class FreqEncoder(InputEncoder):
torch.cat([y, y + 0.5 * math.pi], -1), """
torch.cat([y_var] * 2, -1))[0] The frequency encoder introduced in [mildenhall2020nerf]: D -> 2LD[+D].
"""
freq_bands: torch.Tensor
"""
`Tensor(L)` Frequency bands (1, 2, ..., 2^(L-1))
"""
def __init__(self, chns, freqs: int, include_input: bool):
super().__init__(chns, chns * (freqs * 2 + include_input))
self.include_input = include_input
self.freqs = freqs
self.register_temp("freq_bands", (2. ** torch.arange(freqs))[:, None].expand(-1, chns))
def forward(self, x: torch.Tensor):
x_ = x.unsqueeze(-2) * self.freq_bands
result = union(torch.sin(x_), torch.cos(x_)).flatten(-2)
return union(x, result) if self.include_input else result
def extra_repr(self) -> str:
return f"{self.in_chns} -> {self.out_chns}"\
f"(2x{self.freqs}x{self.in_chns}{f'+{self.in_chns}' * self.include_input})"
import torch from .__common__ import *
from itertools import cycle import torch.nn.functional as F
from typing import Dict, Set, Tuple, Union
from utils.type import NetInput, ReturnData __all__ = ["density2energy", "density2alpha", "VolumnRenderer"]
from .generic import *
from model.base import BaseModel
from utils import math
from utils.module import Module
from utils.perf import checkpoint, perf
from utils.samples import Samples
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0) -> torch.Tensor:
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
""" """
Calculate energies from densities inferred by model. Calculate energies from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities :param densities `Tensor(N...)`: model's output densities
:param dists `Tensor(N...)`: integration times :param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents :param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: energies which block light rays :return `Tensor(N...)`: energies which block light rays
""" """
if raw_noise_std > 0: if raw_noise_std > 0:
# Add noise to model's predictions for density. Can be used to # Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts). # regularize network during training (prevents floater artifacts).
densities = densities + torch.normal(0.0, raw_noise_std, densities.size()) densities = densities + torch.normal(0.0, raw_noise_std, densities.shape,
return densities * dists[..., None] device=densities.device)
return F.relu(densities) * dists
def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): def energy2alpha(energies: torch.Tensor) -> torch.Tensor:
""" """
Calculate alphas from densities inferred by model. Convert energies to alphas.
:param densities `Tensor(N..., 1)`: model's output densities :param energies `Tensor(N...)`: energies (calculated from densities)
:param dists `Tensor(N...)`: integration times :return `Tensor(N...)`: alphas
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: alphas
""" """
energies = density2energy(densities, dists, raw_noise_std)
return 1.0 - torch.exp(-energies) return 1.0 - torch.exp(-energies)
class AlphaComposition(Module): def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0) -> torch.Tensor:
"""
def __init__(self): Calculate alphas from densities inferred by model.
super().__init__()
def forward(self, colors, alphas, bg=None):
"""
[summary]
:param colors `Tensor(N, P, C)`: [description]
:param alphas `Tensor(N, P, 1)`: [description]
:param bg `Tensor([N, ]C)`: [description], defaults to None
:return `Tensor(N, C)`: [description]
"""
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + math.tiny, dim=-2)
one_minus_alpha = torch.cat([
torch.ones_like(one_minus_alpha[..., :1, :]),
one_minus_alpha
], dim=-2)
weights = alphas * one_minus_alpha # (N, P, 1)
# (N, C), computed weighted color of each sample along each ray.
final_color = torch.sum(weights * colors, dim=-2)
# To composite onto a white background, use the accumulated alpha map.
if bg is not None:
# Sum of weights along each ray. This value is in [0, 1] up to numerical error.
acc_map = torch.sum(weights, -1)
final_color = final_color + bg * (1. - acc_map[..., None])
return {
'color': final_color,
'weights': weights,
}
class VolumnRenderer(Module):
class States:
kernel: BaseModel
samples: Samples
early_stop_tolerance: float
outputs: Set[str]
hit_mask: torch.Tensor
N: int
P: int
device: torch.device
colors: torch.Tensor
densities: torch.Tensor
energies: torch.Tensor
weights: torch.Tensor
cum_energies: torch.Tensor
exp_energies: torch.Tensor
tot_evaluations: Dict[str, int]
chunk: Tuple[slice, slice]
cum_chunk: Tuple[slice, slice]
cum_last: Tuple[slice, slice]
chunk_id: int
@property
def start(self) -> int:
return self.chunk[1].start
@property
def end(self) -> int:
return self.chunk[1].stop
def __init__(self, kernel: BaseModel, samples: Samples, early_stop_tolerance: float,
outputs: Set[str]) -> None:
self.kernel = kernel
self.samples = samples
self.early_stop_tolerance = early_stop_tolerance
self.outputs = outputs
N, P = samples.size
self.device = self.samples.device
self.hit_mask = samples.voxel_indices != -1 # (N, P) | bool
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.densities = torch.zeros(N, P, 1, device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
self.tot_evaluations = {}
self.N, self.P = N, P
self.chunk_id = -1
def n_hits(self, index: Union[int, slice] = None) -> int:
if not isinstance(self.hit_mask, torch.Tensor):
if index is not None:
return self.N * self.colors[:, index].shape[1]
return self.N * self.P
if index is None:
return self.hit_mask.count_nonzero().item()
return self.hit_mask[:, index].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int): :param densities `Tensor(N...)`: model's output densities
if key not in self.tot_evaluations: :param dists `Tensor(N...)`: integration times
self.tot_evaluations[key] = 0 :param raw_noise_std `float`: the noise std used to regularize network during training (prevents
self.tot_evaluations[key] += n floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N...)`: alphas
"""
return energy2alpha(density2energy(densities, dists, raw_noise_std))
def next_chunk(self, *, length=None, end=None):
start = 0 if not hasattr(self, "chunk") else self.end
length = length or self.P
end = min(end or start + length, self.P)
self.chunk = slice(None), slice(start, end)
self.cum_chunk = slice(None), slice(start + 1, end + 1)
self.cum_last = slice(None), slice(start, start + 1)
self.chunk_id += 1
return self
def put(self, key: str, values: torch.Tensor, indices: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[slice, slice]]): class VolumnRenderer(nn.Module):
if not hasattr(self, key):
new_tensor = torch.zeros(self.N, self.P, values.shape[-1], device=self.device)
setattr(self, key, new_tensor)
tensor: torch.Tensor = getattr(self, key)
# if isinstance(indices[0], torch.Tensor):
# tensor.index_put_(indices, values)
# else:
tensor[indices] = values
def __init__(self, **kwargs): def __init__(self):
super().__init__() super().__init__()
@perf # stub method
def forward(self, kernel: BaseModel, samples: Samples, *outputs: str, def __call__(self, samples: Samples, densities: torch.Tensor, colors: torch.Tensor, *outputs: str,
raymarching_early_stop_tolerance: float = 0, white_bg: bool, raw_noise_std: float) -> ReturnData:
raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
**kwargs) -> ReturnData:
""" """
Perform volumn rendering. Perform volumn rendering.
:param kernel `BaseModel`: render kernel :param samples `Samples(B, P)`: samples
:param samples `Samples(N, P)`: samples :param rgbd `Tensor(B, P, C+1)`: colors and densities
:param outputs `str...`: items should be contained in the result dict. :param outputs `str...`: items should be contained in the result dict.
Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to [] Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop. :return `ReturnData`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
Should between 0 and 1 (0 means no early stop). Defaults to 0
:param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
0 and `None` means not splitting the raymarching process. Defaults to `None`
:return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
""" """
if samples.size[1] == 0: ...
print("VolumnRenderer.forward(): # of samples is zero")
return None @profile
def forward(self, samples: Samples, rgbd: torch.Tensor, *outputs: str,
infer_outputs = set() white_bg: bool, raw_noise_std: float) -> ReturnData:
for key in outputs: energies = density2energy(rgbd[..., -1], samples.dists, raw_noise_std) # (B, P)
if key == "color": alphas = energy2alpha(energies) # (B, P)
infer_outputs.add("colors") weights = (alphas * torch.cumprod(union(1, 1. - alphas + 1e-10), -1)[..., :-1])[..., None]
infer_outputs.add("densities") output_fn = {
elif key == "specular": "color": lambda: torch.sum(weights * rgbd[..., :-1], -2) + (1. - torch.sum(weights, -2)
infer_outputs.add("speculars") if white_bg else 0.),
infer_outputs.add("densities") "depth": lambda: torch.sum(weights * samples.depths[..., None], -2),
elif key == "diffuse": "colors": lambda: rgbd[..., :-1],
infer_outputs.add("diffuses") "densities": lambda: rgbd[..., -1:],
infer_outputs.add("densities") "alphas": lambda: alphas[..., None],
elif key == "depth": "energies": lambda: energies[..., None],
infer_outputs.add("densities") "weights": lambda: weights
else: }
infer_outputs.add(key) return ReturnData({key: output_fn[key]() for key in outputs if key in output_fn})
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance, infer_outputs)
checkpoint("Prepare states object")
if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [
math.ceil(s.P / raymarching_chunk_size_or_sections)
]
if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections
for chunk_samples in cycle(chunk_sections):
self._forward_chunk(s.next_chunk(length=chunk_samples))
if s.end >= s.P:
break
else:
chunk_size = -raymarching_chunk_size_or_sections
chunk_hits = s.n_hits(0)
for i in range(1, s.P):
n_hits = s.n_hits(i)
if chunk_hits + n_hits > chunk_size:
self._forward_chunk(s.next_chunk(end=i))
n_hits = s.n_hits(i)
chunk_hits = 0
chunk_hits += n_hits
self._forward_chunk(s.next_chunk())
checkpoint("Run forward chunks")
ret = {}
for key in outputs:
if key == 'color':
ret['color'] = torch.sum(s.colors * s.weights, 1)
elif key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse' and hasattr(s, "diffuses"):
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular' and hasattr(s, "speculars"):
ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers':
ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states':
ret['states'] = s
else:
if hasattr(s, key):
ret[key] = getattr(s, key)
checkpoint("Set return data")
return ret
@perf
def _calc_weights(self, s: States):
"""
Calculate weights of samples in composited outputs
:param s `States`: states
:param start `int`: chunk's start
:param end `int`: chunk's end
"""
s.energies[s.chunk] = density2energy(s.densities[s.chunk], s.samples.dists[s.chunk])
s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
@perf
def _apply_early_stop(self, s: States):
"""
Stop rays whose accumulated opacity are larger than a threshold
:param s `States`: s
:param end `int`: chunk's end
"""
if s.end < s.P and s.early_stop_tolerance > 0 and isinstance(s.hit_mask, torch.Tensor):
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0
@perf
def _forward_chunk(self, s: States) -> int:
if isinstance(s.hit_mask, torch.Tensor):
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return
fi_idxs[1].add_(s.start)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
else:
fi_idxs = s.chunk
fi_outputs = s.kernel.infer(*s.outputs, samples=s.samples[fi_idxs], chunk_id=s.chunk_id)
for key, value in fi_outputs.items():
s.put(key, value, fi_idxs)
self._calc_weights(s)
self._apply_early_stop(s)
class DensityFirstVolumnRenderer(VolumnRenderer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _forward_chunk(self, s: VolumnRenderer.States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X
density_inputs = s.kernel.input(fi_samples, "x", "f") # (N', Ex)
# Infer densities (shape)
density_outputs = s.kernel.infer('densities', 'features', samples=fi_samples,
inputs=density_inputs, chunk_id=s.chunk_id)
s.put('densities', density_outputs['densities'], fi_idxs)
s.accumulate_tot_evaluations("densities", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
# Remove samples whose weights are less than a threshold
s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0
# Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_samples = s.samples[fi_idxs] # N -> N"
fi_features = density_outputs['features'][fi_mask]
color_inputs = s.kernel.input(fi_samples, "d") # (N")
color_inputs.x = density_inputs.x[fi_mask]
# Infer colors (appearance)
outputs = s.outputs.copy()
if 'densities' in outputs:
outputs.remove('densities')
color_outputs = s.kernel.infer(*outputs, samples=fi_samples, inputs=color_inputs,
chunk_id=s.chunk_id, features=fi_features)
# if s.chunk_id == 0:
# fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1:
# fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
# elif s.chunk_id == 2:
# fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
# else:
# fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
for key, value in color_outputs.items():
s.put(key, value, fi_idxs)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
import torch from .__common__ import *
from typing import Tuple
from .generic import *
from .space import Space from .space import Space
from clib import * from clib import *
from utils import device
from utils import sphere from utils import sphere
from utils import misc from utils.misc import grid2d
from utils import math
from utils.module import Module
from utils.samples import Samples
from utils.perf import perf, checkpoint
class Bins(object):
@property
def up(self):
return self.bounds[1:]
@property __all__ = ["Sampler", "UniformSampler", "PdfSampler"]
def lo(self):
return self.bounds[:-1]
def __init__(self, vals: torch.Tensor):
self.vals = vals
self.bounds = torch.cat([
self.vals[:1],
0.5 * (self.vals[1:] + self.vals[:-1]),
self.vals[-1:]
])
@staticmethod class Sampler(nn.Module):
def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None): _samples_indices_cached: torch.Tensor | None
return Bins(torch.linspace(*val_range, N, device=device))
def to(self, device: torch.device): def __init__(self, x_chns: int, d_chns: int):
self.vals = self.vals.to(device)
self.bounds = self.bounds.to(device)
class Sampler(Module):
def __init__(self, **kwargs):
""" """
Initialize a Sampler module Initialize a Sampler module
""" """
super().__init__() super().__init__({}, {"x": x_chns, "d": d_chns})
self._samples_indices_cached = None self._samples_indices_cached = None
def _sample(self, range: Tuple[float, float], n_rays: int, n_samples: int, perturb: bool, # stub method for type hint
device: torch.device) -> torch.Tensor: def __call__(self, rays: Rays, space: Space, **kwargs) -> Samples:
""" ...
[summary]
:param t_range `float, float`: sampling range def _get_samples_indices(self, pts: torch.Tensor) -> torch.Tensor:
:param n_rays `int`: number of rays (B)
:param n_samples `int`: number of samples per ray (P)
:param perturb `bool`: whether perturb sampling
:param device `torch.device`: the device used to create tensors
:return `Tensor(B, P+1)`: sampling bounds of t
""" """
bounds = torch.linspace(*range, n_samples + 1, device=device) # (P+1) Get 2D indices of samples. The first value is the index of ray, while the second value is
if perturb: the index of sample in a ray.
rand_bounds = torch.cat([
bounds[:1],
0.5 * (bounds[1:] + bounds[:-1]),
bounds[-1:]
])
rand_vals = torch.rand(n_rays, n_samples + 1, device=device)
bounds = rand_bounds[:-1] * (1 - rand_vals) + rand_bounds[1:] * rand_vals
else:
bounds = bounds[None].expand(n_rays, -1)
return bounds
def _get_samples_indices(self, pts: torch.Tensor): :param pts `Tensor(B, P, 3)`: the sample points
:return `Tensor(B, P)`: the 2D indices of samples
"""
if self._samples_indices_cached is None\ if self._samples_indices_cached is None\
or self._samples_indices_cached.device != pts.device\
or self._samples_indices_cached.shape[0] < pts.shape[0]\ or self._samples_indices_cached.shape[0] < pts.shape[0]\
or self._samples_indices_cached.shape[1] < pts.shape[1]: or self._samples_indices_cached.shape[1] < pts.shape[1]:
self._samples_indices_cached = misc.meshgrid( self._samples_indices_cached = grid2d(*pts.shape[:2], indexing="ij", device=pts.device)
*pts.shape[:2], swap_dim=True, device=pts.device)
return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]] return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]]
@perf def _get_samples(self, rays: Rays, space: Space, t_vals: torch.Tensor, mode: str) -> Samples:
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_: Space, *,
sample_range: Tuple[float, float], n_samples: int, lindisp: bool = False,
perturb_sample: bool = True, spherical: bool = False,
**kwargs) -> Tuple[Samples, torch.Tensor]:
""" """
Sample points along rays. Get samples along rays at sample steps specified by `t_vals`.
:param rays_o `Tensor(B, 3)`: rays' origin :param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction :param rays_d `Tensor(B, 3)`: rays' direction
:param sample_range `float, float`: sampling range :param t_vals `Tensor(B, P)`: sample steps
:param n_samples `int`: number of samples per ray :param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
:param lindisp `bool`: whether sample linearly in disparity space (1/depth)
:param perturb_sample `bool`: whether perturb sampling
:return `Samples(B, P)`: samples :return `Samples(B, P)`: samples
""" """
if spherical: if mode == "xyz":
t_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample, z_vals = t_vals
rays_o.device) pts = rays.get_points(z_vals)
t0, t1 = t_bounds[:, :-1], t_bounds[:, 1:] # (B, P) elif mode == "xyz_disp":
t = (t0 + t1) * .5 z_vals = t_vals.reciprocal()
pts = rays.get_points(z_vals)
p, z = sphere.ray_sphere_intersect(rays_o, rays_d, t.reciprocal()) elif mode == "spherical":
p = sphere.cartesian2spherical(p, inverse_r=True) z_vals = t_vals.reciprocal()
vidxs = space_.get_voxel_indices(p) pts = sphere.cartesian2spherical(rays.get_points(z_vals), inverse_r=True)
return Samples( elif mode == "spherical_radius":
pts=p, z_vals = sphere.ray_sphere_intersect(rays, t_vals.reciprocal())
dirs=rays_d[:, None].expand(-1, n_samples, -1), pts = sphere.cartesian2spherical(rays.get_points(z_vals), inverse_r=True)
depths=z,
dists=(t1 + math.tiny).reciprocal() - t0.reciprocal(),
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=t
)
else: else:
sample_range = (1 / sample_range[0], 1 / sample_range[1]) if lindisp else sample_range raise ValueError(f"Unknown mode: {mode}")
z_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
rays_o.device)
if lindisp:
z_bounds = z_bounds.reciprocal()
z0, z1 = z_bounds[:, :-1], z_bounds[:, 1:] # (B, P)
z = (z0 + z1) * .5
p = rays_o[:, None] + rays_d[:, None] * z[..., None]
vidxs = space_.get_voxel_indices(p)
return Samples(
pts=p,
dirs=rays_d[:, None].expand(-1, n_samples, -1),
depths=z,
dists=z1 - z0,
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=z
)
rays_d = rays.rays_d.unsqueeze(1) # (B, 1, 3)
dists = union(z_vals[..., 1:] - z_vals[..., :-1], math.huge) # (B, P)
dists *= torch.norm(rays_d, dim=-1)
return Samples(
pts=pts,
dirs=rays_d.expand(*pts.shape[:2], -1),
depths=z_vals,
t_vals=t_vals,
dists=dists,
voxel_indices=space.get_voxel_indices(pts) if space else 0,
indices=self._get_samples_indices(pts)
)
class PdfSampler(Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, class UniformSampler(Sampler):
spherical: bool, lindisp: bool, **kwargs): """
""" This module expands NeRF's code of uniform sampling to support our spherical sampling and enable
Initialize a Sampler module the trace of samples' indices.
"""
:param depth_range: depth range for sampler def __init__(self):
:param n_samples: count to sample along ray super().__init__(3, 3)
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super().__init__()
self.lindisp = lindisp
self.perturb_sample = perturb_sample
self.spherical = spherical
self.n_samples = n_samples
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs): def _sample(self, range: tuple[float, float], n_rays: int, n_samples: int, perturb: bool) -> torch.Tensor:
""" """
Sample points along rays. return Spherical or Cartesian coordinates, Generate sample steps along rays in the specified range.
specified by `self.shperical`
:param rays_o `Tensor(B, 3)`: rays' origin :param range `float, float`: sampling range
:param rays_d `Tensor(B, 3)`: rays' direction :param n_rays `int`: number of rays (B)
:param weights `Tensor(B, M)`: weights of sample bins :param n_samples `int`: number of samples per ray (P)
:param s_vals `Tensor(B, M)`: (optional) center of sample bins :param perturb `bool`: whether perturb sampling
:param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array :return `Tensor(B, P)`: sampled "t"s along rays
:return `Tensor(B, N, 3)`: sampled points
:return `Tensor(B, N)`: corresponding depths along rays
""" """
if s_vals is None: t_vals = torch.linspace(*range, n_samples, device=self.device) # (P)
s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default()) if perturb:
s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample) mids = .5 * (t_vals[..., 1:] + t_vals[..., :-1])
if include_s_vals: upper = union(mids, t_vals[..., -1:])
s = torch.cat([s, s_vals], dim=-1) lower = union(t_vals[..., :1], mids)
s = torch.sort(s, descending=self.lindisp)[0] # stratified samples in those intervals
z = torch.reciprocal(s) if self.lindisp else s t_vals = t_vals.expand(n_rays, -1)
if self.spherical: t_vals = lower + (upper - lower) * torch.rand_like(t_vals)
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
return sphers, depths, s, pts
else: else:
return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s t_vals = t_vals.expand(n_rays, -1)
return t_vals
def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True):
'''
:param bins `Tensor(..., M+1)`: bounds of bins
:param weights `Tensor(..., M)`: weights of bins
:param N `int`: # of samples along each ray
:param det `bool`: (default to `True`) perform deterministic sampling or not
:return `Tensor(..., N)`: samples
'''
# Get pdf
weights = weights + math.tiny # prevent nans
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
cdf = torch.cat([
torch.zeros_like(pdf[..., :1]),
torch.cumsum(pdf, dim=-1)
], dim=-1) # [..., M+1]
# Take uniform samples
dots_sh = list(weights.shape[:-1])
M = weights.shape[-1]
u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \ # stub method for type hint
if det else torch.rand(dots_sh + [N], device=bins.device) # [..., N] def __call__(self, rays: Rays, space: Space, *,
range: tuple[float, float],
# Invert CDF mode: str,
# [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,] n_samples: int,
above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long() perturb: bool) -> Samples:
# random sample inside each bin
below_inds = torch.clamp(above_inds - 1, min=0)
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N, 2]
cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1]
cdf_g = torch.gather(cdf, dim=-1, index=inds_g) # [..., N, 2]
bins = bins[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1]
bins_g = torch.gather(bins, dim=-1, index=inds_g) # [..., N, 2]
# fix numeric issue
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N]
denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + math.tiny)
return samples
class VoxelSampler(Module):
def __init__(self, *, sample_step: float, **kwargs):
"""
Initialize a VoxelSampler module
:param perturb_sample: perturb the sample depths
:param step_size: step size
""" """
super().__init__() Sample points along rays.
self.sample_step = sample_step
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *, :param rays `Rays(B)`: rays
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]: :param space `Space`: sample space
:param range `float, float`: sampling range
:param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
:param n_samples `int`: number of samples per ray
:param perturb `bool`: whether perturb sampling, defaults to `False`
:return `Samples(B, P)`: samples
""" """
[summary] ...
:param rays_o `Tensor(N, 3)`: rays' origin positions @profile
:param rays_d `Tensor(N, 3)`: rays' directions def forward(self, rays: Rays, space: Space, *,
:param step_size `float`: gap between samples along a ray range: tuple[float, float],
:return `Samples(N', P)`: samples along valid rays (which hit at least one voxel) mode: str,
:return `Tensor(N)`: valid rays mask n_samples: int,
""" perturb: bool) -> Samples:
intersections = space_module.ray_intersect(rays_o, rays_d, 100) t_range = range if mode == "xyz" else (1. / range[0], 1. / range[1])
valid_rays_mask = intersections.hits > 0 t_vals = self._sample(t_range, rays.shape[0], n_samples, perturb) # (B, P)
rays_o = rays_o[valid_rays_mask] return self._get_samples(rays, space, t_vals, mode)
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
n_rays = rays_o.size(0)
ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N')
hits = intersections.hits
min_depths = intersections.min_depths
max_depths = intersections.max_depths
voxel_indices = intersections.voxel_indices
rays_near_depth = min_depths[:, :1] # (N', 1) class PdfSampler(Sampler):
rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1) """
rays_length = rays_far_depth - rays_near_depth Hierarchical sampling (section 5.2 of NeRF)
rays_steps = (rays_length / self.sample_step).ceil().long() """
rays_step_size = rays_length / rays_steps
max_steps = rays_steps.max().item()
rays_step = torch.arange(max_steps, device=rays_o.device,
dtype=torch.float)[None].repeat(n_rays, 1) # (N', P)
invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[
ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth)
] # (N', P)
samples_depth[invalid_samples_mask] = math.huge
samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1
rays_o, rays_d = rays_o[:, None], rays_d[:, None] def __init__(self):
return Samples( super().__init__(3, 3)
pts=rays_o + rays_d * samples_depth[..., None],
dirs=rays_d.expand(-1, max_steps, -1),
depths=samples_depth,
dists=samples_dist,
voxel_indices=samples_voxel_index
), valid_rays_mask
@perf def _sample(self, t_vals: torch.Tensor, weights: torch.Tensor, n_importance: int,
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, perturb: bool, include_existed: bool, sort_descending: bool) -> torch.Tensor:
space: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
""" """
[summary] Generate sample steps by PDF according to existed sample steps and their weights.
:param rays_o `Tensor(N, 3)`: [description] :param t_vals `Tensor(B, P)`: existed sample steps
:param rays_d `Tensor(N, 3)`: [description] :param weights `Tensor(B, P)`: weights of existed sample steps
:param step_size `float`: [description] :param n_importance `int`: number of samples to generate for each ray
:return `Samples(N, P)`: [description] :param perturb `bool`: whether perturb sampling
:param include_existed `bool`: whether to include existed samples in the output
:return `Tensor(B, P'[+P])`: the output sample steps
""" """
intersections = space.ray_intersect(rays_o, rays_d, 100) bins = .5 * (t_vals[..., 1:] + t_vals[..., :-1]) # (B, P - 1)
valid_rays_mask = intersections.hits > 0 weights = weights[..., 1:-1] + math.tiny # (B, P - 2)
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
checkpoint("Ray intersect") # Get PDF
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = union(0., torch.cumsum(pdf, -1)) # (B, P - 1)
if intersections.size == 0: # Take uniform samples
return None, valid_rays_mask if perturb:
u = torch.rand(*cdf.shape[:-1], n_importance, device=self.device)
else: else:
min_depth = intersections.min_depths u = torch.linspace(0., 1., steps=n_importance, device=self.device).\
max_depth = intersections.max_depths expand(*cdf.shape[:-1], -1)
pts_idx = intersections.voxel_indices
dists = max_depth - min_depth
tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1)
probs = dists / tot_dists
steps = tot_dists[:, 0] / self.sample_step
# sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, math.huge)
checkpoint("Inverse CDF sampling") # Invert CDF
u = u.contiguous() # (B, P')
rays_o, rays_d = rays_o[:, None], rays_d[:, None] inds = torch.searchsorted(cdf, u, right=True) # (B, P')
return Samples( inds_g = torch.stack([
pts=rays_o + rays_d * sampled_depths[..., None], (inds - 1).clamp_min(0), # below
dirs=rays_d.expand(-1, sampled_depths.size(1), -1), inds.clamp_max(cdf.shape[-1] - 1) # above
depths=sampled_depths, ], -1) # (B, P', 2)
dists=sampled_dists,
voxel_indices=sampled_indices matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] # [B, P', P - 1]
), valid_rays_mask cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) # (B, P', 2)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) # (B, P', 2)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
t_samples = (bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])).detach()
if include_existed:
return torch.sort(union(t_vals, t_samples), -1, descending=sort_descending)[0]
else:
return t_samples
# stub method for type hint
def __call__(self, rays: Rays, space: Space, t_vals: torch.Tensor, weights: torch.Tensor, *,
mode: str,
n_importance: int,
perturb: bool,
include_existed_samples: bool) -> Samples:
"""
Sample points along rays using PDF sampling based on existed samples.
:param rays `Rays(B)`: rays
:param space `Space`: sample space
:param t_vals `Tensor(B, P)`: existed sample steps
:param weights `Tensor(B, P)`: weights of existed sample steps
:param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
:param n_importance `int`: number of samples to generate using PDF sampling for each ray
:param perturb `bool`: whether perturb sampling, defaults to `False`
:param include_existed_samples `bool`: whether to include existed samples in the output,
defaults to `True`
:return `Samples(B, P'[+P])`: samples
"""
...
@profile
def forward(self, rays: Rays, space: Space, t_vals: torch.Tensor, weights: torch.Tensor, *,
mode: str,
n_importance: int,
perturb: bool,
include_existed_samples: bool) -> Samples:
t_vals = self._sample(t_vals, weights, n_importance, perturb, include_existed_samples,
mode != "xyz")
return self._get_samples(rays, space, t_vals, mode)
import torch from .__common__ import *
from typing import Dict, List, Optional, Tuple, Union
from clib import * from clib import *
from model.utils import load #from model.utils import load
from utils.module import Module from utils.nn import Parameter
from utils.geometry import * from utils.geometry import *
from utils.voxels import * from utils.voxels import *
from utils.perf import perf
from utils.env import get_env __all__ = ["Space", "Voxels", "Octree"]
class Intersections: class Intersections:
...@@ -24,8 +22,8 @@ class Intersections: ...@@ -24,8 +22,8 @@ class Intersections:
"""`Tensor(N)` Number of hits""" """`Tensor(N)` Number of hits"""
@property @property
def size(self): def shape(self):
return self.hits.size(0) return self.hits.shape
def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor, def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
voxel_indices: torch.Tensor, hits: torch.Tensor) -> None: voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
...@@ -42,9 +40,9 @@ class Intersections: ...@@ -42,9 +40,9 @@ class Intersections:
hits=self.hits[index]) hits=self.hits[index])
class Space(Module): class Space(nn.Module):
bbox: Optional[torch.Tensor] bbox: torch.Tensor | None
"""`Tensor(2, 3)` Bounding box""" """`Tensor(2, D)` Bounding box"""
@property @property
def dims(self) -> int: def dims(self) -> int:
...@@ -52,16 +50,18 @@ class Space(Module): ...@@ -52,16 +50,18 @@ class Space(Module):
return self.bbox.shape[1] if self.bbox is not None else 3 return self.bbox.shape[1] if self.bbox is not None else 3
@staticmethod @staticmethod
def create(args: dict) -> 'Space': def create(type: str, args: dict[str, Any]) -> 'Space':
if 'space' not in args: match type:
return Space(**args) case "Space":
if args['space'] == 'octree': return Space(**args)
return Octree(**args) case "Octree":
if args['space'] == 'voxels': return Octree(**args)
return Voxels(**args) case "Voxels":
return load(args['space']).space return Voxels(**args)
case _:
def __init__(self, clone_src: "Space" = None, *, bbox: List[float] = None, **kwargs): return load(type).space
def __init__(self, clone_src: "Space" = None, *, bbox: list[float] = None, **kwargs):
super().__init__() super().__init__()
if clone_src: if clone_src:
self.device = clone_src.device self.device = clone_src.device
...@@ -69,10 +69,30 @@ class Space(Module): ...@@ -69,10 +69,30 @@ class Space(Module):
else: else:
self.register_temp('bbox', None if not bbox else torch.tensor(bbox).reshape(2, -1)) self.register_temp('bbox', None if not bbox else torch.tensor(bbox).reshape(2, -1))
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: def ray_intersect_with_bbox(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> Intersections:
raise NotImplementedError """
[summary]
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: :param rays_o `Tensor(N..., D)`: rays' origin
:param rays_d `Tensor(N..., D)`: rays' direction
:param max_hits `int?`: max number of hits of each ray, have no effect for this method
:return `Intersect(N...)`: rays' intersection with the bounding box
"""
if self.bbox is None:
raise RuntimeError("The space has no bounding box")
inv_d = rays_d.reciprocal().unsqueeze(-2)
t = (self.bbox - rays_o.unsqueeze(-2)) * inv_d # (N..., 2, D)
t0 = t.min(dim=-2)[0].max(dim=-1, keepdim=True)[0].clamp(min=1e-4) # (N..., 1)
t1 = t.max(dim=-2)[0].min(dim=-1, keepdim=True)[0]
miss = t1 <= t0
t0[miss], t1[miss] = -1., -1.
hit = torch.logical_not(miss).long()
return Intersections(t0, t1, hit - 1, hit.squeeze(-1))
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, max_hits: int) -> Intersections:
return self.ray_intersect_with_bbox(rays_o, rays_d)
def get_voxel_indices(self, pts: torch.Tensor) -> int | torch.Tensor:
if self.bbox is None: if self.bbox is None:
return 0 return 0
voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long) voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
...@@ -81,19 +101,22 @@ class Space(Module): ...@@ -81,19 +101,22 @@ class Space(Module):
return voxel_indices return voxel_indices
@torch.no_grad() @torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
raise NotImplementedError() raise NotImplementedError()
@torch.no_grad() @torch.no_grad()
def split(self) -> Tuple[int, int]: def split(self) -> tuple[int, int]:
raise NotImplementedError() raise NotImplementedError()
@torch.no_grad() @torch.no_grad()
def clone(self): def clone(self):
return Space(self) return self.__class__(self)
class Voxels(Space): class Voxels(Space):
bbox: torch.Tensor
"""`Tensor(2, D)` Bounding box"""
steps: torch.Tensor steps: torch.Tensor
"""`Tensor(3)` Steps along each dimension""" """`Tensor(3)` Steps along each dimension"""
...@@ -131,42 +154,43 @@ class Voxels(Space): ...@@ -131,42 +154,43 @@ class Voxels(Space):
@property @property
def voxel_size(self) -> torch.Tensor: def voxel_size(self) -> torch.Tensor:
"""`Tensor(3)` Voxel size""" """`Tensor(3)` Voxel size"""
if self.bbox is None:
raise RuntimeError("Cannot get property 'voxel_size' of a space which "
"doesn't have bounding box")
return (self.bbox[1] - self.bbox[0]) / self.steps return (self.bbox[1] - self.bbox[0]) / self.steps
@property @property
def corner_embeddings(self) -> Dict[str, torch.nn.Embedding]: def corner_embeddings(self) -> dict[str, torch.nn.Embedding]:
return {name[4:]: emb for name, emb in self.named_modules() if name.startswith("emb_")} return {name[4:]: emb for name, emb in self.named_modules() if name.startswith("emb_")}
@property @property
def voxel_embeddings(self) -> Dict[str, torch.nn.Embedding]: def voxel_embeddings(self) -> dict[str, torch.nn.Embedding]:
return {name[5:]: emb for name, emb in self.named_modules() if name.startswith("vemb_")} return {name[5:]: emb for name, emb in self.named_modules() if name.startswith("vemb_")}
def __init__(self, clone_src: "Voxels" = None, *, bbox: List[float] = None, def __init__(self, clone_src: "Voxels" = None, *, bbox: list[float] = None,
voxel_size: float = None, steps: Union[torch.Tensor, Tuple[int, ...]] = None, voxel_size: float = None, steps: torch.Tensor | tuple[int, ...] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__(clone_src, bbox=bbox, **kwargs)
if clone_src: if clone_src:
super().__init__(clone_src)
self.register_buffer('steps', clone_src.steps) self.register_buffer('steps', clone_src.steps)
self.register_buffer('voxels', clone_src.voxels) self.register_buffer('voxels', clone_src.voxels)
self.register_buffer("corners", clone_src.corners) self.register_buffer("corners", clone_src.corners)
self.register_buffer("corner_indices", clone_src.corner_indices) self.register_buffer("corner_indices", clone_src.corner_indices)
self.register_buffer('voxel_indices_in_grid', clone_src.voxel_indices_in_grid) self.register_buffer('voxel_indices_in_grid', clone_src.voxel_indices_in_grid)
else: else:
if self.bbox is None: if bbox is None:
raise ValueError("Missing argument 'bbox'") raise ValueError("Missing argument 'bbox'")
if voxel_size is not None: super().__init__(bbox=bbox)
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size)) if steps is not None:
else:
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long)) self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
else:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
self.register_buffer('voxels', init_voxels(self.bbox, self.steps)) self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps) corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners) self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices) self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(-1, self.n_voxels)) self.register_buffer('voxel_indices_in_grid', torch.arange(-1, self.n_voxels))
def clone(self):
return Voxels(self)
def to_vi(self, gi: torch.Tensor) -> torch.Tensor: def to_vi(self, gi: torch.Tensor) -> torch.Tensor:
return self.voxel_indices_in_grid[gi + 1] return self.voxel_indices_in_grid[gi + 1]
...@@ -208,7 +232,7 @@ class Voxels(Space): ...@@ -208,7 +232,7 @@ class Voxels(Space):
voxels = self.voxels[voxel_indices] # (N, 3) voxels = self.voxels[voxel_indices] # (N, 3)
corner_indices = self.corner_indices[voxel_indices] # (N, 8) corner_indices = self.corner_indices[voxel_indices] # (N, 8)
p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel
return trilinear_interp(p, emb(corner_indices)) return linear_interp(p, emb(corner_indices))
def create_voxel_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: def create_voxel_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
""" """
...@@ -245,7 +269,7 @@ class Voxels(Space): ...@@ -245,7 +269,7 @@ class Voxels(Space):
raise KeyError(f"Embedding '{name}' doesn't exist") raise KeyError(f"Embedding '{name}' doesn't exist")
return emb(voxel_indices) return emb(voxel_indices)
@perf @profile
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
""" """
Calculate intersections of rays and voxels. Calculate intersections of rays and voxels.
...@@ -277,7 +301,7 @@ class Voxels(Space): ...@@ -277,7 +301,7 @@ class Voxels(Space):
hits=hits[0] hits=hits[0]
) )
@perf @profile
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
""" """
Get voxel indices of points. Get voxel indices of points.
...@@ -290,8 +314,8 @@ class Voxels(Space): ...@@ -290,8 +314,8 @@ class Voxels(Space):
gi = to_grid_indices(pts, self.bbox, self.steps) gi = to_grid_indices(pts, self.bbox, self.steps)
return self.to_vi(gi) return self.to_vi(gi)
@perf @profile
def get_corners(self, vidxs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def get_corners(self, vidxs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
vidxs = vidxs.unique() vidxs = vidxs.unique()
if vidxs[0] == -1: if vidxs[0] == -1:
vidxs = vidxs[1:] vidxs = vidxs[1:]
...@@ -303,7 +327,7 @@ class Voxels(Space): ...@@ -303,7 +327,7 @@ class Voxels(Space):
return fi_corner_indices, fi_corners return fi_corner_indices, fi_corners
@torch.no_grad() @torch.no_grad()
def split(self) -> Tuple[int, int]: def split(self) -> tuple[int, int]:
""" """
Split voxels into smaller voxels with half size. Split voxels into smaller voxels with half size.
""" """
...@@ -336,7 +360,7 @@ class Voxels(Space): ...@@ -336,7 +360,7 @@ class Voxels(Space):
return self.n_voxels // 8, self.n_voxels return self.n_voxels // 8, self.n_voxels
@torch.no_grad() @torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
self.voxels = self.voxels[keeps] self.voxels = self.voxels[keeps]
self.corner_indices = self.corner_indices[keeps] self.corner_indices = self.corner_indices[keeps]
self._update_gi2vi() self._update_gi2vi()
...@@ -351,7 +375,7 @@ class Voxels(Space): ...@@ -351,7 +375,7 @@ class Voxels(Space):
new_emb = self.set_voxel_embedding(update_fn(emb.weight), name) new_emb = self.set_voxel_embedding(update_fn(emb.weight), name)
self._update_optimizer(emb.weight, new_emb.weight, update_fn) self._update_optimizer(emb.weight, new_emb.weight, update_fn)
def _update_optimizer(self, old_param: nn.Parameter, new_param: nn.Parameter, update_fn): def _update_optimizer(self, old_param: Parameter, new_param: Parameter, update_fn):
optimizer = get_env()["trainer"].optimizer optimizer = get_env()["trainer"].optimizer
if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
# Update related states in optimizer # Update related states in optimizer
...@@ -384,7 +408,7 @@ class Voxels(Space): ...@@ -384,7 +408,7 @@ class Voxels(Space):
sum_dims = [val for val in range(self.dims) if val != dim] sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid[1:].reshape(*self.steps).ne(-1).sum(sum_dims) return self.voxel_indices_in_grid[1:].reshape(*self.steps).ne(-1).sum(sum_dims)
def balance_cut(self, dim: int, n_parts: int) -> List[int]: def balance_cut(self, dim: int, n_parts: int) -> list[int]:
n_voxels_list = self.n_voxels_along_dim(dim) n_voxels_list = self.n_voxels_along_dim(dim)
cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist() cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
bins = [] bins = []
...@@ -398,7 +422,7 @@ class Voxels(Space): ...@@ -398,7 +422,7 @@ class Voxels(Space):
bins.append(len(cdf) - offset) bins.append(len(cdf) - offset)
return bins return bins
def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
""" """
For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`. For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`.
...@@ -419,7 +443,7 @@ class Voxels(Space): ...@@ -419,7 +443,7 @@ class Voxels(Space):
pts += (torch.rand_like(pts) - .5) * self.voxel_size / S pts += (torch.rand_like(pts) - .5) * self.voxel_size / S
return pts.reshape(-1, 3), voxel_indices.flatten() return pts.reshape(-1, 3), voxel_indices.flatten()
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d) return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_gi2vi(self): def _update_gi2vi(self):
...@@ -456,7 +480,7 @@ class Octree(Voxels): ...@@ -456,7 +480,7 @@ class Octree(Voxels):
self.nodes_cached = None self.nodes_cached = None
self.tree_cached = None self.tree_cached = None
def get(self) -> Tuple[torch.Tensor, torch.Tensor]: def get(self) -> tuple[torch.Tensor, torch.Tensor]:
if self.nodes_cached is None: if self.nodes_cached is None:
self.nodes_cached, self.tree_cached = build_easy_octree( self.nodes_cached, self.tree_cached = build_easy_octree(
self.voxels, 0.5 * self.voxel_size) self.voxels, 0.5 * self.voxel_size)
...@@ -477,7 +501,7 @@ class Octree(Voxels): ...@@ -477,7 +501,7 @@ class Octree(Voxels):
return ret return ret
@torch.no_grad() @torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
ret = super().prune(keeps) ret = super().prune(keeps)
self.clear() self.clear()
return ret return ret
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import plotly_express as px # import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"import pandas as pd\n",
"import numpy as np\n",
"import os\n",
"\n",
"name = [\"NeRF\", \"PANO\", \"OURS+\", \"OURS\"]\n",
"foveal = 8.0\n",
"mid = 2.8\n",
"far = 2.6\n",
"blend = 0.1\n",
"ours_full = 562\n",
"pano = 1e4\n",
"nerf = 9e4\n",
"\n",
"times = {\n",
" \"fovea_l\": 0,\n",
" \"mid_l\": 0,\n",
" \"far_l\": 0,\n",
" \"fovea_r\": 0,\n",
" \"mid_r\": 0,\n",
" \"far_r\": 0,\n",
" \"blend\": 0,\n",
" \"ours_full\": 0,\n",
" \"pano\": 0,\n",
" \"nerf\": 0,\n",
"}\n",
"clip = 0\n",
"frame_id = 0\n",
"\n",
"\n",
"def calc_total():\n",
" return [\n",
" times[\"nerf\"],\n",
" times[\"pano\"],\n",
" times[\"ours_full\"],\n",
" times[\"fovea_l\"] + times[\"mid_l\"] + times[\"far_l\"] +\n",
" times[\"fovea_r\"] + times[\"mid_r\"] + times[\"far_r\"] + times[\"blend\"]\n",
" ]\n",
"\n",
"\n",
"def draw_frame(*, xlim=None, **kwargs):\n",
" global frame_id\n",
" for key in kwargs:\n",
" times[key] = kwargs[key]\n",
" tot = calc_total()\n",
" data = {\n",
" \"fovea_l\": [0, 0, 0, times[\"fovea_l\"]],\n",
" \"mid_l\": [0, 0, 0, times[\"mid_l\"]],\n",
" \"far_l\": [0, 0, 0, times[\"far_l\"]],\n",
" \"fovea_r\": [0, 0, 0, times[\"fovea_r\"]],\n",
" \"mid_r\": [0, 0, 0, times[\"mid_r\"]],\n",
" \"far_r\": [0, 0, 0, times[\"far_r\"]],\n",
" \"blend\": [0, 0, 0, times[\"blend\"]],\n",
" \"ours_full\": [0, 0, times[\"ours_full\"], 0],\n",
" \"pano\": [0, times[\"pano\"], 0, 0],\n",
" \"nerf\": [times[\"nerf\"], 0, 0, 0],\n",
" }\n",
" if xlim is None or xlim < max(tot) * 1.1:\n",
" xlim = max(tot) * 1.1\n",
" \n",
" fig = go.Figure()\n",
" times_keys = list(times.keys())\n",
" for key in times_keys:\n",
" if key == times_keys[-1]:\n",
" fig.add_trace(go.Bar(\n",
" y=name,\n",
" x=data[key],\n",
" name=key,\n",
" orientation='h',\n",
" text=[\"\" if item == 0 else f\"{item:.1f}\" if item < 1000 else f\"{item:.1e}\" for item in tot],\n",
" textposition=\"outside\"\n",
" ))\n",
" else:\n",
" fig.add_trace(go.Bar(\n",
" y=name,\n",
" x=data[key],\n",
" name=key,\n",
" orientation='h',\n",
" ))\n",
" fig.update_traces(width=0.5)\n",
" fig.update_layout(barmode='stack', showlegend=False,\n",
" yaxis_visible=False, yaxis_showticklabels=False, xaxis_range=[0, xlim])\n",
" \n",
" # fig.show()\n",
" fig.write_image(f\"dynamic_bar/clip_{clip}/{frame_id:04d}.png\", width=1920 // 2, height=1080 // 2, scale=2)\n",
" frame_id = frame_id + 1\n",
"\n",
"def add_animation(*, frames, xlim=None, **kwargs):\n",
" if frames == 1:\n",
" draw_frame(**kwargs, xlim=xlim)\n",
" return\n",
" data = {\n",
" key: np.linspace(times[key], kwargs[key], frames)\n",
" for key in kwargs\n",
" }\n",
" for i in range(frames):\n",
" draw_frame(**{key: data[key][i] for key in data}, xlim=xlim)\n",
"\n",
"def new_clip():\n",
" global clip, frame_id\n",
" clip += 1\n",
" frame_id = 0\n",
" os.system(f\"mkdir dynamic_bar/clip_{clip}\")\n",
"\n",
"os.system('rm -f -r dynamic_bar')\n",
"os.system('mkdir dynamic_bar')\n",
"\n",
"# ours mono\n",
"new_clip()\n",
"add_animation(fovea_l=foveal, frames=48, xlim=30) # Step 1: grow foveal\n",
"add_animation(mid_l=mid, frames=16, xlim=30) # Step 2: grow mid\n",
"add_animation(far_l=far, frames=16, xlim=30) # Step 3: grow far\n",
"add_animation(blend=blend, frames=1, xlim=30) # Step 4: grow blend\n",
"\n",
"# ours stereo\n",
"new_clip()\n",
"add_animation(fovea_r=foveal, frames=24, xlim=30) # Step 1: grow foveal\n",
"add_animation(mid_r=mid, frames=8, xlim=30) # Step 2: grow mid\n",
"add_animation(far_r=far, frames=8, xlim=30) # Step 3: grow far\n",
"\n",
"# ours stereo adapt\n",
"new_clip()\n",
"add_animation(mid_r=0, far_r=0, frames=24, xlim=30)\n",
"\n",
"# other series\n",
"new_clip()\n",
"add_animation(ours_full=ours_full, frames=48, xlim=30)\n",
"new_clip()\n",
"add_animation(pano=pano, frames=48, xlim=30)\n",
"new_clip()\n",
"add_animation(nerf=nerf, frames=48, xlim=30)\n",
"\n",
"#os.system(f'ffmpeg -y -r 24 -i dynamic_bar/%04d.png dynamic_bar.avi')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = px.bar(\n",
" df1, # 绘图数据\n",
" x=list(times.keys()), # y轴\n",
" y=\"name\", # x轴\n",
" orientation='h', # 水平柱状图\n",
" #text=[[\"a\", \"tot\"], \"tot1\", \"tot2\", {\"fovea_l\": \"\", \"blend_r\": 13.5}] # 需要显示的数据\n",
")\n",
"fig.update_traces(textposition=\"outside\", showlegend=False, text=[[\"a\"]*11, [\"b\"]*11, [\"c\"]*11, [\"d\"]*11,[\"e\"]*11])\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>country</th>\n",
" <th>continent</th>\n",
" <th>year</th>\n",
" <th>lifeExp</th>\n",
" <th>pop</th>\n",
" <th>gdpPercap</th>\n",
" <th>iso_alpha</th>\n",
" <th>iso_num</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1952</td>\n",
" <td>28.801</td>\n",
" <td>8425333</td>\n",
" <td>779.445314</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1957</td>\n",
" <td>30.332</td>\n",
" <td>9240934</td>\n",
" <td>820.853030</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1962</td>\n",
" <td>31.997</td>\n",
" <td>10267083</td>\n",
" <td>853.100710</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1967</td>\n",
" <td>34.020</td>\n",
" <td>11537966</td>\n",
" <td>836.197138</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1972</td>\n",
" <td>36.088</td>\n",
" <td>13079460</td>\n",
" <td>739.981106</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1699</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>1987</td>\n",
" <td>62.351</td>\n",
" <td>9216418</td>\n",
" <td>706.157306</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1700</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>1992</td>\n",
" <td>60.377</td>\n",
" <td>10704340</td>\n",
" <td>693.420786</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1701</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>1997</td>\n",
" <td>46.809</td>\n",
" <td>11404948</td>\n",
" <td>792.449960</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1702</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>2002</td>\n",
" <td>39.989</td>\n",
" <td>11926563</td>\n",
" <td>672.038623</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1703</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>2007</td>\n",
" <td>43.487</td>\n",
" <td>12311143</td>\n",
" <td>469.709298</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1704 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" country continent year lifeExp pop gdpPercap iso_alpha \\\n",
"0 Afghanistan Asia 1952 28.801 8425333 779.445314 AFG \n",
"1 Afghanistan Asia 1957 30.332 9240934 820.853030 AFG \n",
"2 Afghanistan Asia 1962 31.997 10267083 853.100710 AFG \n",
"3 Afghanistan Asia 1967 34.020 11537966 836.197138 AFG \n",
"4 Afghanistan Asia 1972 36.088 13079460 739.981106 AFG \n",
"... ... ... ... ... ... ... ... \n",
"1699 Zimbabwe Africa 1987 62.351 9216418 706.157306 ZWE \n",
"1700 Zimbabwe Africa 1992 60.377 10704340 693.420786 ZWE \n",
"1701 Zimbabwe Africa 1997 46.809 11404948 792.449960 ZWE \n",
"1702 Zimbabwe Africa 2002 39.989 11926563 672.038623 ZWE \n",
"1703 Zimbabwe Africa 2007 43.487 12311143 469.709298 ZWE \n",
"\n",
" iso_num \n",
"0 4 \n",
"1 4 \n",
"2 4 \n",
"3 4 \n",
"4 4 \n",
"... ... \n",
"1699 716 \n",
"1700 716 \n",
"1701 716 \n",
"1702 716 \n",
"1703 716 \n",
"\n",
"[1704 rows x 8 columns]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = px.data.gapminder()\n",
"df"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import torch.nn.functional as nn_f\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../../')\n",
"sys.path.append(rootdir)\n",
"\n",
"from utils import img\n",
"from utils.view import *\n",
"\n",
"datadir = f\"{rootdir}/data/__thesis/__demo/compare\"\n",
"figs = ['fsnerf', 'gt', 'nerf']\n",
"crops = {\n",
" 'barbershop': [[406, 117, 100], [209, 170, 100]],\n",
" 'gas': [[195, 69, 100], [7, 305, 100]],\n",
" 'mc': [[395, 128, 100], [97, 391, 100]],\n",
" 'pabellon': [[208, 115, 100], [22, 378, 100]]\n",
"}\n",
"colors = torch.tensor([[0, 1, 0], [1, 1, 0]], dtype=torch.float)\n",
"border = 3\n",
"\n",
"for scene in crops:\n",
" images = img.load([f\"{datadir}/origin/{scene}_{fig}.png\" for fig in figs])\n",
" halfw = images.size(-1) // 2\n",
" halfh = images.size(-2) // 2\n",
" overlay = torch.zeros(1, 4, *images.shape[2:])\n",
" mask = torch.zeros(len(crops[scene]), *images.shape[2:], dtype=torch.bool)\n",
" for i, crop in enumerate(crops[scene]):\n",
" patches = images[..., crop[1]: crop[1] + crop[2], crop[0]: crop[0] + crop[2]].clone()\n",
" patches[..., :border, :] = colors[i, :, None, None]\n",
" patches[..., -border:, :] = colors[i, :, None, None]\n",
" patches[..., :, :border] = colors[i, :, None, None]\n",
" patches[..., :, -border:] = colors[i, :, None, None]\n",
" img.save(patches, [f\"{datadir}/crop/{scene}_{i}_{fig}.png\" for fig in figs])\n",
" mask[i,\n",
" crop[1] - border: crop[1] + crop[2] + border,\n",
" crop[0] - border: crop[0] + crop[2] + border] = True\n",
" mask[i,\n",
" crop[1]: crop[1] + crop[2],\n",
" crop[0]: crop[0] + crop[2]] = False\n",
" images[:, :, mask[i]] = colors[i, :, None]\n",
" img.save(images, [f\"{datadir}/overlay/{scene}_{fig}.png\" for fig in figs])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"orig_nbformat": 2,
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"%matplotlib inline\n",
"import sys\n", "import sys\n",
"import os\n", "import os\n",
"import torch\n", "import torch\n",
...@@ -12,25 +15,25 @@ ...@@ -12,25 +15,25 @@
"\n", "\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n", "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n", "sys.path.append(rootdir)\n",
"torch.cuda.set_device(0)\n", "\n",
"torch.cuda.set_device(3)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n", "torch.autograd.set_grad_enabled(False)\n",
"\n", "\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n", "import model\n",
"from utils import netio\n", "from data import Dataset\n",
"from utils import img\n", "from utils import netio, img, device\n",
"from utils import device\n",
"from utils.view import *\n", "from utils.view import *\n",
"from utils.type import PathLike\n",
"from components.fnr import FoveatedNeuralRenderer\n", "from components.fnr import FoveatedNeuralRenderer\n",
"from components.render import render\n",
"\n", "\n",
"\n", "\n",
"def load_net(path):\n", "def load_model(model_path: PathLike):\n",
" config = SphericalViewSynConfig()\n", " return model.deserialize(netio.load_checkpoint(model_path)[0],\n",
" config.from_id(os.path.splitext(path)[0])\n", " raymarching_early_stop_tolerance=0.01,\n",
" config.sa['perturb_sample'] = False\n", " raymarching_chunk_size_or_sections=None,\n",
" net = config.create_net().to(device.default())\n", " perturb_sample=False).eval().to(device.default())\n",
" netio.load(path, net)\n",
" return net\n",
"\n", "\n",
"\n", "\n",
"def find_file(prefix):\n", "def find_file(prefix):\n",
...@@ -40,6 +43,16 @@ ...@@ -40,6 +43,16 @@
" return None\n", " return None\n",
"\n", "\n",
"\n", "\n",
"def create_renderer(*nets, fov_scale=1.):\n",
" fov_list = [20, 45, 110]\n",
" for i in range(len(fov_list)):\n",
" fov_list[i] = length2fov(fov2length(fov_list[i]) * fov_scale)\n",
" res_list = [(256, 256), (256, 256), (256, 230)]\n",
" res_full = (1600, 1440)\n",
" return FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList(nets), res_full,\n",
" device=device.default())\n",
"\n",
"\n",
"def plot_images(images):\n", "def plot_images(images):\n",
" plt.figure(figsize=(12, 4))\n", " plt.figure(figsize=(12, 4))\n",
" plt.subplot(131)\n", " plt.subplot(131)\n",
...@@ -49,64 +62,50 @@ ...@@ -49,64 +62,50 @@
" plt.subplot(133)\n", " plt.subplot(133)\n",
" img.plot(images['layers_img'][2])\n", " img.plot(images['layers_img'][2])\n",
" #plt.figure(figsize=(12, 12))\n", " #plt.figure(figsize=(12, 12))\n",
" #img.plot(images['overlaid'])\n", " # img.plot(images['overlaid'])\n",
" #plt.figure(figsize=(12, 12))\n", " #plt.figure(figsize=(12, 12))\n",
" #img.plot(images['blended_raw'])\n", " # img.plot(images['blended_raw'])\n",
" plt.figure(figsize=(12, 12))\n", " plt.figure(figsize=(12, 12))\n",
" img.plot(images['blended'])\n", " img.plot(images['blended'])\n",
"\n", "\n",
"\n", "\n",
"def save_images(images, scene, i):\n",
" outputdir = '../__demo/mono/'\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" for layer in range(len(images[\"layers_img\"])):\n",
" img.save(images['layers_img'][layer], f'{outputdir}{scene}_{i:04d}({layer}).png')\n",
" img.save(images['blended'], f'{outputdir}{scene}_{i:04d}.png')\n",
" if \"overlaid\" in images:\n",
" img.save(images['overlaid'], f'{outputdir}{scene}_{i:04d}_overlaid.png')\n",
" if \"blended_raw\" in images:\n",
" img.save(images['blended_raw'], f'{outputdir}{scene}_{i:04d}_noCE.png')\n",
" if \"nerf\" in images:\n",
" img.save(images['nerf'], f'{outputdir}{scene}_{i:04d}_nerf.png')\n",
"\n",
"\n",
"scenes = {\n", "scenes = {\n",
" 'classroom': 'classroom_all',\n", " 'classroom': '__new/classroom_all',\n",
" 'stones': 'stones_all',\n", " 'stones': '__new/stones_all',\n",
" 'barbershop': 'barbershop_all',\n", " 'barbershop': '__new/barbershop_all',\n",
" 'lobby': 'lobby_all'\n", " 'lobby': '__new/lobby_all',\n",
" \"bedroom2\": \"__captured/bedroom2\"\n",
"}\n", "}\n",
"\n", "\n",
"fov_list = [20, 45, 110]\n", "\n",
"res_list = [(256, 256), (256, 256), (400, 360)]\n", "scene = \"bedroom2\"\n",
"res_full = (1600, 1440)" "os.chdir(f'{rootdir}/data/{scenes[scene]}')\n",
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Set CUDA:0 as current device.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"scene = 'barbershop'\n",
"os.chdir(f'{rootdir}/data/__new/{scenes[scene]}')\n",
"print('Change working directory to ', os.getcwd())\n", "print('Change working directory to ', os.getcwd())\n",
"\n", "\n",
"fovea_net = load_net(find_file('fovea'))\n", "fovea_net = load_model(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n", "periph_net = load_model(find_file('periph'))\n",
"renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]),\n", "nerf_net = load_model(find_file(\"nerf\"))"
" res_full, device=device.default())" ]
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Change working directory to /home/dengnc/dvs/data/__new/barbershop_all\n",
"Load net from fovea200@snerffast4-rgb_e6_fc512x4_d1.20-6.00_s64_~p.pth ...\n",
"Load net from periph200@snerffast2-rgb_e6_fc256x4_d1.20-6.00_s32_~p.pth ...\n"
]
}
],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"params = {\n", "params = {\n",
" 'classroom': [\n", " 'classroom': [\n",
...@@ -165,7 +164,7 @@ ...@@ -165,7 +164,7 @@
"\n", "\n",
"for i, param in enumerate(params[scene]):\n", "for i, param in enumerate(params[scene]):\n",
" view = Trans(torch.tensor(param[:3], device=device.default()),\n", " view = Trans(torch.tensor(param[:3], device=device.default()),\n",
" torch.tensor(euler_to_matrix([-param[4], param[3], 0]), device=device.default()).view(3, 3))\n", " torch.tensor(euler_to_matrix(-param[4], param[3], 0), device=device.default()).view(3, 3))\n",
" images = renderer(view, param[-2:], using_mask=False, ret_raw=True)\n", " images = renderer(view, param[-2:], using_mask=False, ret_raw=True)\n",
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n", " images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n", " if True:\n",
...@@ -179,49 +178,45 @@ ...@@ -179,49 +178,45 @@
" #img.save(images['blended_raw'], f'{outputdir}{scene}_{i}.png')\n", " #img.save(images['blended_raw'], f'{outputdir}{scene}_{i}.png')\n",
" else:\n", " else:\n",
" images = plot_images(images)\n" " images = plot_images(images)\n"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"def load_views(data_desc_file) -> Trans:\n", "def load_views(data_desc_file) -> tuple[list[int], Trans]:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n", " dataset = Dataset(data_desc_file)\n",
" data_desc = json.loads(file.read())\n", " return dataset.indices.tolist(),\\\n",
" view_centers = torch.tensor(\n", " Trans(dataset.centers, dataset.rots).to(device.default())\n",
" data_desc['view_centers'], device=device.default()).view(-1, 3)\n", "\n",
" view_rots = torch.tensor(\n", "\n",
" data_desc['view_rots'], device=device.default()).view(-1, 3, 3)\n", "demos = [ # view_idx, center_x, center_y, fov_scale\n",
" return Trans(view_centers, view_rots)\n", " [220, 30, 25, 0.7],\n",
"\n", " [235, 0, 130, 0.7],\n",
"\n", " [239, 70, 140, 0.7],\n",
"views = load_views('for_panorama_cvt.json')\n", " [841, -100, 160, 0.7]\n",
"print('Dataset loaded.')\n", "]\n",
"for view_idx in range(views.size()[0]):\n", "indices, views = load_views('images.json')\n",
" center = (0, 0)\n", "for demo_idx in [0]:\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n", " view_idx = demos[demo_idx][0]\n",
" outputdir = 'panorama'\n", " i = indices.index(view_idx)\n",
" os.makedirs(outputdir, exist_ok=True)\n", " center = tuple(demos[demo_idx][1:3])\n",
" img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')" " renderer = create_renderer(fovea_net, periph_net, periph_net, fov_scale=demos[demo_idx][3])\n",
], " images = renderer(views.get(i), center, using_mask=False)\n",
"outputs": [ " #nerf_fovea = render(nerf_net, renderer.cam, views.get(i), None, batch_size=16384)[\"color\"]\n",
{ " #images[\"nerf\"] = nerf_fovea\n",
"output_type": "stream", " plot_images(images)\n",
"name": "stdout", " #save_images(images, scene, view_idx)\n"
"text": [ ]
"Dataset loaded.\n"
]
}
],
"metadata": {}
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"name": "python3", "display_name": "Python 3.10.0 ('dvs')",
"display_name": "Python 3.8.5 64-bit ('base': conda)" "language": "python",
"name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
...@@ -233,17 +228,19 @@ ...@@ -233,17 +228,19 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.10.0"
}, },
"metadata": { "metadata": {
"interpreter": { "interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" "hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
} }
}, },
"interpreter": { "vscode": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" "interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }
\ No newline at end of file
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Set CUDA:0 as current device.\n" "Set CUDA:0 as current device.\n",
"Change working directory to /home/dengnc/Work/fov_nerf/data/__thesis/barbershop\n",
"Load model fovea.tar\n",
"Load model periph.tar\n"
] ]
} }
], ],
...@@ -20,29 +23,25 @@ ...@@ -20,29 +23,25 @@
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n", "rootdir = os.path.abspath(sys.path[0] + '/../../')\n",
"sys.path.append(rootdir)\n", "sys.path.append(rootdir)\n",
"\n", "\n",
"torch.cuda.set_device(0)\n", "torch.cuda.set_device(0)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n", "torch.autograd.set_grad_enabled(False)\n",
"\n", "\n",
"from data.spherical_view_syn import *\n", "from model import Model\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n", "from data import Dataset\n",
"from utils import netio\n", "from utils import netio, img, device\n",
"from utils import img\n",
"from utils import device\n",
"from utils.view import *\n", "from utils.view import *\n",
"from utils.types import PathLike\n",
"from components.fnr import FoveatedNeuralRenderer\n", "from components.fnr import FoveatedNeuralRenderer\n",
"from components.render import render\n",
"\n", "\n",
"\n", "\n",
"def load_net(path):\n", "def load_model(model_path: PathLike):\n",
" config = SphericalViewSynConfig()\n", " print(\"Load model\", model_path)\n",
" config.from_id(os.path.splitext(path)[0])\n", " return Model.load(model_path).eval().to(device.default())\n",
" config.sa['perturb_sample'] = False\n",
" net = config.create_net().to(device.default())\n",
" netio.load(path, net)\n",
" return net\n",
"\n", "\n",
"\n", "\n",
"def find_file(prefix):\n", "def find_file(prefix):\n",
...@@ -52,6 +51,16 @@ ...@@ -52,6 +51,16 @@
" return None\n", " return None\n",
"\n", "\n",
"\n", "\n",
"def create_renderer(*nets, fov_scale=1.):\n",
" fov_list = [20, 45, 110]\n",
" for i in range(len(fov_list)):\n",
" fov_list[i] = length2fov(fov2length(fov_list[i]) * fov_scale)\n",
" res_list = [(256, 256), (256, 256), (256, 230)]\n",
" res_full = (1600, 1440)\n",
" return FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList(nets), res_full,\n",
" device=device.default())\n",
"\n",
"\n",
"def load_views(data_desc_file) -> Trans:\n", "def load_views(data_desc_file) -> Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n", " with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n", " data_desc = json.loads(file.read())\n",
...@@ -124,38 +133,29 @@ ...@@ -124,38 +133,29 @@
"scenes = {\n", "scenes = {\n",
" 'classroom': 'classroom_all',\n", " 'classroom': 'classroom_all',\n",
" 'stones': 'stones_all',\n", " 'stones': 'stones_all',\n",
" 'barbershop': 'barbershop_all',\n", " 'barbershop': '__thesis/barbershop',\n",
" 'lobby': 'lobby_all'\n", " 'lobby': 'lobby_all'\n",
"}\n", "}\n",
"\n", "\n",
"\n", "\n",
"fov_list = [20, 45, 110]\n", "scene = \"barbershop\"\n",
"res_list = [(256, 256), (256, 256), (256, 230)]\n", "os.chdir(f'{rootdir}/data/{scenes[scene]}')\n",
"res_full = (1600, 1440)\n" "print('Change working directory to ', os.getcwd())\n",
"\n",
"fovea_net = load_model(find_file('fovea'))\n",
"periph_net = load_model(find_file('periph'))\n",
"renderer = create_renderer(fovea_net, periph_net, periph_net)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Change working directory to /home/dengnc/dvs/data/__new/barbershop_all\n",
"Load net from fovea200@snerffast4-rgb_e6_fc512x4_d1.20-6.00_s64_~p.pth ...\n",
"Load net from periph200@snerffast2-rgb_e6_fc256x4_d1.20-6.00_s32_~p.pth ...\n",
"barbershop 0 Saved\n",
"barbershop 1 Saved\n",
"barbershop 2 Saved\n",
"barbershop 3 Saved\n",
"barbershop 4 Saved\n",
"barbershop 0 Saved\n",
"barbershop 1 Saved\n",
"barbershop 2 Saved\n",
"barbershop 3 Saved\n",
"barbershop 4 Saved\n",
"barbershop 0 Saved\n", "barbershop 0 Saved\n",
"barbershop 1 Saved\n", "barbershop 1 Saved\n",
"barbershop 2 Saved\n", "barbershop 2 Saved\n",
...@@ -180,60 +180,46 @@ ...@@ -180,60 +180,46 @@
" [(0, 0, 0, 0, 0), (21, 150), (12, 150)]\n", " [(0, 0, 0, 0, 0), (21, 150), (12, 150)]\n",
" ],\n", " ],\n",
" 'barbershop': [\n", " 'barbershop': [\n",
" #[(0, 0, 0, 0, 0), (106, -67), (90, -67)],\n", " #[(-0.08247789757018531, -0.17165164843400083, -0.20644832805536045, -66.00281344384992, -4.354400833888114), (0, 0), (0, 0)],\n",
" #[(0, 0, 0, 0, 0), (-114, 10), (-126, 10)],\n", " #[(0, 0, 0, 0, 0), (-114, 10), (-126, 10)],\n",
" [(0, 0, 0, 25, 20), (189, -45), (173, -45)],\n", " [(0, 0, 0, -25, 20), (189, -45), (173, -45)],\n",
" [(0, 0, 0, 25, 20), (-148, 130), (-163, 130)],\n", " [(0, 0, 0, -25, 20), (-148, 130), (-163, 130)],\n",
" [(0.15, 0.15, 0, 43, 2), (9, 0), (-9, 0)],\n", " [(0.15, 0.15, 0, -43, 2), (9, 0), (-9, 0)],\n",
" [(0.15, 0, 0.15, -13, -5), (6, 0), (-6, 0)],\n", " [(0.15, 0, -0.15, 13, -5), (6, 0), (-6, 0)],\n",
" [(-0.15, 0.15, 0.15, -53, -21), (3, 0), (-3, 0)]\n", " [(-0.15, 0.15, -0.15, 53, -21), (3, 0), (-3, 0)]\n",
" ]\n", " ]\n",
"}\n", "}\n",
"\n", "\n",
"#for scene in ['classroom', 'lobby', 'barbershop']:\n", "for mono_periph in range(3, 5):\n",
"for scene in ['barbershop']:\n", " for i, param in enumerate(params[scene]):\n",
" os.chdir(f'{rootdir}/data/__new/{scenes[scene]}')\n", " view = Trans(torch.tensor(param[0][:3], device=device.default()),\n",
" print('Change working directory to ', os.getcwd())\n", " torch.tensor(euler_to_matrix(param[0][4], param[0][3], 0),\n",
"\n", " device=device.default()).view(3, 3))\n",
" fovea_net = load_net(find_file('fovea'))\n", " left_images, right_images = renderer(view, param[1], param[2],\n",
" periph_net = load_net(find_file('periph'))\n", " stereo_disparity=0.06,\n",
" renderer = FoveatedNeuralRenderer(fov_list, res_list,\n", " using_mask=True,\n",
" nn.ModuleList([fovea_net, periph_net, periph_net]),\n", " mono_periph_mode=mono_periph,\n",
" res_full, device=device.default())\n", " ret_raw=False)\n",
"\n", " if True:\n",
" for mono_periph in range(0,4):\n", " outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n",
" for i, param in enumerate(params[scene]):\n", " os.makedirs(outputdir, exist_ok=True)\n",
" view = Trans(torch.tensor(param[0][:3], device=device.default()),\n", " img.save(torch.cat([\n",
" torch.tensor(euler_to_matrix([-param[0][4], param[0][3], 0]),\n", " left_images['blended'],\n",
" device=device.default()).view(3, 3))\n", " right_images['blended']\n",
" eye_offset = torch.tensor([0.03, 0, 0], device=device.default())\n", " ], dim=-1), '%s/%s_%d.png' % (outputdir, scene, i))\n",
" left_view = Trans(view.trans_point(-eye_offset), view.r)\n", " img.save(left_images['blended'], '%s/%s_%d_l.png' % (outputdir, scene, i))\n",
" right_view = Trans(view.trans_point(eye_offset), view.r)\n", " img.save(right_images['blended'], '%s/%s_%d_r.png' % (outputdir, scene, i))\n",
" left_images, right_images = renderer(view, param[1], param[2],\n", " stereo_overlap = torch.cat([\n",
" stereo_disparity=0.06,\n", " left_images['blended'][:, 0:1],\n",
" using_mask=True,\n", " right_images['blended'][:, 1:3]\n",
" mono_periph_mode=mono_periph,\n", " ], dim=1)\n",
" ret_raw=False)\n", " img.save(stereo_overlap, '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n",
" if True:\n", " #os.makedirs(outputdir + '/mid', exist_ok=True)\n",
" outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n", " #img.save(left_images['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n",
" os.makedirs(outputdir, exist_ok=True)\n", " #img.save(right_images['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n",
" img.save(torch.cat([\n", " print(\"%s %d Saved\" % (scene, i))\n",
" left_images['blended'],\n", " else:\n",
" right_images['blended']\n", " plot_figures(left_images, right_images, param[1], param[2])\n"
" ], dim=-1), '%s/%s_%d.png' % (outputdir, scene, i))\n",
" img.save(left_images['blended'], '%s/%s_%d_l.png' % (outputdir, scene, i))\n",
" img.save(right_images['blended'], '%s/%s_%d_r.png' % (outputdir, scene, i))\n",
" stereo_overlap = torch.cat([\n",
" left_images['blended'][:, 0:1],\n",
" right_images['blended'][:, 1:3]\n",
" ], dim=1)\n",
" img.save(stereo_overlap, '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n",
" #os.makedirs(outputdir + '/mid', exist_ok=True)\n",
" #img.save(left_images['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n",
" #img.save(right_images['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n",
" print(\"%s %d Saved\" % (scene, i))\n",
" else:\n",
" plot_figures(left_images, right_images, param[1], param[2])\n"
] ]
}, },
{ {
...@@ -245,11 +231,8 @@ ...@@ -245,11 +231,8 @@
} }
], ],
"metadata": { "metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
},
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3.10.0 ('dvs')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
...@@ -263,7 +246,12 @@ ...@@ -263,7 +246,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.10.0"
},
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
} }
}, },
"nbformat": 4, "nbformat": 4,
......
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import sys\n",
"import os\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../../')\n",
"sys.path.append(rootdir)\n",
"\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"from model import Model\n",
"from data import Dataset\n",
"from utils import netio, img, device\n",
"from utils.view import *\n",
"from utils.types import *\n",
"from components.render import render\n",
"\n",
"\n",
"model: Model = None\n",
"dataset: Dataset = None\n",
"\n",
"\n",
"def load_model(path: PathLike):\n",
" ckpt_path = netio.find_checkpoint(Path(path))\n",
" ckpt = torch.load(ckpt_path)\n",
" model = Model.create(ckpt[\"args\"][\"model\"], ckpt[\"args\"][\"model_args\"])\n",
" model.load_state_dict(ckpt[\"states\"][\"model\"])\n",
" model.to(device.default()).eval()\n",
" return model\n",
"\n",
"\n",
"def load_dataset(path: PathLike):\n",
" return Dataset(path, color_mode=model.color, coord_sys=model.args.coord,\n",
" device=device.default())\n",
"\n",
"\n",
"def plot_images(images, rows, cols):\n",
" plt.figure(figsize=(20, int(20 / cols * rows)))\n",
" for r in range(rows):\n",
" for c in range(cols):\n",
" plt.subplot(rows, cols, r * cols + c + 1)\n",
" img.plot(images[r * cols + c])\n",
"\n",
"\n",
"def save_images(images, scene, i):\n",
" outputdir = f'{rootdir}/data/__demo/layers/'\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" for layer in range(len(images)):\n",
" img.save(images[layer], f'{outputdir}{scene}_{i:04d}({layer}).png')\n",
"\n",
"scene = \"gas\"\n",
"model_path = f\"{rootdir}/data/__thesis/{scene}/_nets/train/snerf_fast\"\n",
"dataset_path = f\"{rootdir}/data/__thesis/{scene}/test.json\"\n",
"\n",
"\n",
"model = load_model(model_path)\n",
"dataset = load_dataset(dataset_path)\n",
"\n",
"\n",
"i = 6\n",
"cam = dataset.cam\n",
"view = Trans(dataset.centers[i], dataset.rots[i])\n",
"output = render(model, dataset.cam, view, \"colors\", \"weights\")\n",
"output_colors = output.colors * output.weights\n",
"\n",
"samples_per_layer = 4#model.core.samples_per_field\n",
"n_samples = model.args.n_samples\n",
"output_layers = [\n",
" output_colors[..., offset:offset+samples_per_layer, :].sum(-2)\n",
" for offset in range(0, n_samples, samples_per_layer)\n",
"]\n",
" \n",
"plot_images(output_layers, 8, 2)\n",
"#save_images(output_layers, scene, i)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
}
},
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
...@@ -60,7 +60,7 @@ ...@@ -60,7 +60,7 @@
" enable_ce = True, output_res = None):\n", " enable_ce = True, output_res = None):\n",
" ipd = 0.06\n", " ipd = 0.06\n",
" layers_cam = [\n", " layers_cam = [\n",
" CameraParam({\n", " Camera({\n",
" 'fov': 110,\n", " 'fov': 110,\n",
" 'cx': 0.5,\n", " 'cx': 0.5,\n",
" 'cy': 0.5,\n", " 'cy': 0.5,\n",
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
"from utils import device\n", "from utils import device\n",
"from utils import view\n", "from utils import view\n",
"from components.gen_final import GenFinal\n", "from components.gen_final import GenFinal\n",
"from utils.perf import Perf\n", "from utils.profile import Profiler\n",
"\n", "\n",
"\n", "\n",
"def load_net(path):\n", "def load_net(path):\n",
...@@ -135,15 +135,15 @@ ...@@ -135,15 +135,15 @@
" torch.tensor([[0.0, 0.0, 0.0]], device=device.default()),\n", " torch.tensor([[0.0, 0.0, 0.0]], device=device.default()),\n",
" torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device.default())\n", " torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device.default())\n",
")\n", ")\n",
"perf = Perf(True, True)\n", "profile = Profiler(True, True)\n",
"rays_o, rays_d = gen.layer_cams[0].get_global_rays(test_view, True)\n", "rays_o, rays_d = gen.layer_cams[0].get_global_rays(test_view, True)\n",
"perf.checkpoint(\"GetRays\")\n", "profile.checkpoint(\"GetRays\")\n",
"rays_o = rays_o.view(-1, 3)\n", "rays_o = rays_o.view(-1, 3)\n",
"rays_d = rays_d.view(-1, 3)\n", "rays_d = rays_d.view(-1, 3)\n",
"coords, pts, depths = fovea_net.sampler(rays_o, rays_d)\n", "coords, pts, depths = fovea_net.sampler(rays_o, rays_d)\n",
"perf.checkpoint(\"Sample\")\n", "profile.checkpoint(\"Sample\")\n",
"encoded = fovea_net.input_encoder(coords)\n", "encoded = fovea_net.input_encoder(coords)\n",
"perf.checkpoint(\"Encode\")\n", "profile.checkpoint(\"Encode\")\n",
"print(\"Rays:\", rays_d)\n", "print(\"Rays:\", rays_d)\n",
"print(\"Spherical coords:\", coords)\n", "print(\"Spherical coords:\", coords)\n",
"print(\"Depths:\", depths)\n", "print(\"Depths:\", depths)\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment