Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
from ..__common__ import *
__all__ = ["Field"]
class Field(nn.Module):
def __init__(self, x_chns: int, shape: list[int], skips: list[int] = [],
act: str = 'relu', with_ln: bool = False):
super().__init__({"x": x_chns}, {"f": shape[1]})
self.net = nn.FcBlock(x_chns, 0, *shape, skips, act, with_ln=with_ln)
# stub method for type hint
def __call__(self, x: torch.Tensor) -> torch.Tensor:
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
from ..__common__ import *
from .field import *
from .color_decoder import *
from .density_decoder import *
class FsNeRF(nn.Module):
def __init__(self, x_chns: int, color_chns: int, depth: int, width: int,
skips: list[int], act: str, ln: bool, n_samples: int, n_fields: int):
"""
Initialize a FS-NeRF core module.
:param x_chns `int`: channels of input positions (D_x)
:param d_chns `int`: channels of input directions (D_d)
:param color_chns `int`: channels of output colors (D_c)
:param depth `int`: number of layers in field network
:param width `int`: width of each layer in field network
:param skips `[int]`: skip connections from input to specific layers in field network
:param act `str`: activation function in field network and color decoder
:param ln `bool`: whether enable layer normalization in field network and color decoder
:param color_decoder_type `str`: type of color decoder
"""
super().__init__({"x": x_chns}, {"rgbd": 1 + color_chns})
self.n_fields = n_fields
self.samples_per_field = n_samples // n_fields
self.subnets = torch.nn.ModuleList()
for _ in range(n_fields):
field = Field(x_chns * self.samples_per_field, [depth, width], skips, act, ln)
density_decoder = DensityDecoder(field.out_chns, self.samples_per_field)
color_decoder = BasicColorDecoder(field.out_chns, color_chns * self.samples_per_field)
self.subnets.append(torch.nn.ModuleDict({
"field": field,
"density_decoder": density_decoder,
"color_decoder": color_decoder
}))
# stub method for type hint
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Inference colors and densities from input samples
:param x `Tensor(B..., P, D_x)`: input positions
:return `Tensor(B..., P, D_c + D_σ)`: output colors and densities
"""
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
densities = []
colors = []
for i in range(self.n_fields):
f = self.subnets[i]["field"](
x[..., i * self.samples_per_field:(i + 1) * self.samples_per_field, :].flatten(-2))
densities.append(self.subnets[i]["density_decoder"](f)
.unflatten(-1, (self.samples_per_field, -1)))
colors.append(self.subnets[i]["color_decoder"](f, None)
.unflatten(-1, (self.samples_per_field, -1)))
return torch.cat([torch.cat(colors, -2), torch.cat(densities, -2)], -1)
\ No newline at end of file
from ..__common__ import *
from .field import *
from .color_decoder import *
from .density_decoder import *
class NeRF(nn.Module):
def __init__(self, x_chns: int, d_chns: int, color_chns: int, depth: int, width: int,
skips: list[int], act: str, ln: bool, color_decoder_type: str):
"""
Initialize a NeRF core module.
:param x_chns `int`: channels of input positions (D_x)
:param d_chns `int`: channels of input directions (D_d)
:param color_chns `int`: channels of output colors (D_c)
:param depth `int`: number of layers in field network
:param width `int`: width of each layer in field network
:param skips `[int]`: skip connections from input to specific layers in field network
:param act `str`: activation function in field network and color decoder
:param ln `bool`: whether enable layer normalization in field network and color decoder
:param color_decoder_type `str`: type of color decoder
"""
super().__init__({"x": x_chns, "d": d_chns}, {"density": 1, "color": color_chns})
self.field = Field(x_chns, [depth, width], skips, act, ln)
self.density_decoder = DensityDecoder(self.field.out_chns, 1)
self.color_decoder = ColorDecoder.create(self.field.out_chns, d_chns, color_chns,
color_decoder_type, {"act": act, "with_ln": ln})
# stub method for type hint
def __call__(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
"""
Inference colors and densities from input samples
:param x `Tensor(B..., D_x)`: input positions
:param d `Tensor(B..., D_d)`: input directions
:return `Tensor(B..., D_c + D_σ)`: output colors and densities
"""
...
def forward(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
f = self.field(x)
densities = self.density_decoder(f)
colors = self.color_decoder(f, d)
return torch.cat([colors, densities], -1)
from typing import Tuple
import torch
from .__common__ import *
from .generic import *
from utils import math
from utils.module import Module
__all__ = ["InputEncoder", "LinearEncoder", "FreqEncoder"]
class InputEncoder(Module):
class InputEncoder(nn.Module):
"""
Base class for input encoder.
"""
def __init__(self, in_chns: int, out_chns: int):
super().__init__({"_": in_chns}, {"_": out_chns})
def __init__(self, chns, L, cat_input=False):
super().__init__()
emb = torch.exp(torch.arange(L, dtype=torch.float) * math.log(2.))
# stub method for type hint
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode the input tensor.
self.emb = nn.Parameter(emb, requires_grad=False)
self.in_dim = chns
self.out_dim = chns * (L * 2 + cat_input)
self.cat_input = cat_input
:param x `Tensor(N..., D)`: D-dim inputs
:return `Tensor(N..., E)`: encoded outputs
"""
...
def forward(self, x: torch.Tensor, angular=False):
sizes = x.size()
x0 = x
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
if angular:
x = torch.acos(x.clamp(-1, 1))
x = x[..., None] @ self.emb[None]
x = torch.cat([torch.sin(x), torch.cos(x)], -1)
x = x.flatten(-2)
if self.cat_input:
x = torch.cat([x0, x], -1)
@staticmethod
def create(chns: int, type: str, args: dict[str, Any]) -> "InputEncoder":
"""
Create an input encoder of `type` with `args`.
:param chns `int`: input channels
:param type `str`: type of input encoder, without suffix "Encoder"
:param args `{str:Any}`: arguments for initializing the input encoder
:return `InputEncoder`: the created input encoder
"""
return getattr(sys.modules[__name__], f"{type}Encoder")(chns, **args)
class LinearEncoder(InputEncoder):
"""
The linear encoder: D -> D.
"""
def __init__(self, chns):
super().__init__(chns, chns)
def forward(self, x: torch.Tensor):
return x
def extra_repr(self) -> str:
return f'in={self.in_dim}, out={self.out_dim}, cat_input={self.cat_input}'
class IntegratedPosEncoder(Module):
def __init__(self, chns, L, shape: str, cat_input=False):
super.__init__()
self.shape = shape
def _lift_gaussian(self, d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
r_var: torch.Tensor, diag: bool):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_sq = d**2
d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)
if diag:
d_outer_diag = d_sq
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def _conical_frustum_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, base_radius: float,
diag: bool, stable: bool = True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
return f"{self.in_chns} -> {self.out_chns}"
class FreqEncoder(InputEncoder):
"""
if stable:
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
(3 * mu**2 + hw**2)**2)
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def _cylinder_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
Args:
d: torch.float32 3-vector, the axis of the cylinder
t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
The frequency encoder introduced in [mildenhall2020nerf]: D -> 2LD[+D].
"""
t_mean = (t0 + t1) / 2
r_var = radius**2 / 4
t_var = (t1 - t0)**2 / 12
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def cast_rays(self, t_vals: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor,
rays_r: torch.Tensor, diag: bool = True):
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it.
Args:
t_vals: float array, the "fencepost" distances along the ray.
rays_o: float array, the ray origin coordinates.
rays_d: float array, the ray direction vectors.
radii: float array, the radii (base radii for cones) of the rays.
ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
diag: boolean, whether or not the covariance matrices should be diagonal.
Returns:
a tuple of arrays of means and covariances.
freq_bands: torch.Tensor
"""
t0 = t_vals[..., :-1]
t1 = t_vals[..., 1:]
if self.shape == 'cone':
gaussian_fn = self._conical_frustum_to_gaussian
elif self.shape == 'cylinder':
gaussian_fn = self._cylinder_to_gaussian
else:
assert False
means, covs = gaussian_fn(rays_d, t0, t1, rays_r, diag)
means = means + rays_o[..., None, :]
return means, covs
def integrated_pos_enc(x_coord: Tuple[torch.Tensor, torch.Tensor], min_deg: int, max_deg: int,
diag: bool = True):
"""Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].
Args:
x_coord: a tuple containing: x, torch.ndarray, variables to be encoded. Should
be in [-pi, pi]. x_cov, torch.ndarray, covariance matrices for `x`.
min_deg: int, the min degree of the encoding.
max_deg: int, the max degree of the encoding.
diag: bool, if true, expects input covariances to be diagonal (full
otherwise).
Returns:
encoded: torch.ndarray, encoded variables.
`Tensor(L)` Frequency bands (1, 2, ..., 2^(L-1))
"""
if diag:
x, x_cov_diag = x_coord
scales = torch.tensor([2**i for i in range(min_deg, max_deg)], device=x.device)[:, None]
shape = list(x.shape[:-1]) + [-1]
y = torch.reshape(x[..., None, :] * scales, shape)
y_var = torch.reshape(x_cov_diag[..., None, :] * scales**2, shape)
else:
x, x_cov = x_coord
num_dims = x.shape[-1]
basis = torch.cat([
2**i * torch.eye(num_dims, device=x.device)
for i in range(min_deg, max_deg)
], 1)
y = torch.matmul(x, basis)
# Get the diagonal of a covariance matrix (ie, variance). This is equivalent
# to jax.vmap(torch.diag)((basis.T @ covs) @ basis).
y_var = (torch.matmul(x_cov, basis) * basis).sum(-2)
return math.expected_sin(
torch.cat([y, y + 0.5 * math.pi], -1),
torch.cat([y_var] * 2, -1))[0]
def __init__(self, chns, freqs: int, include_input: bool):
super().__init__(chns, chns * (freqs * 2 + include_input))
self.include_input = include_input
self.freqs = freqs
self.register_temp("freq_bands", (2. ** torch.arange(freqs))[:, None].expand(-1, chns))
def forward(self, x: torch.Tensor):
x_ = x.unsqueeze(-2) * self.freq_bands
result = union(torch.sin(x_), torch.cos(x_)).flatten(-2)
return union(x, result) if self.include_input else result
def extra_repr(self) -> str:
return f"{self.in_chns} -> {self.out_chns}"\
f"(2x{self.freqs}x{self.in_chns}{f'+{self.in_chns}' * self.include_input})"
import torch
from itertools import cycle
from typing import Dict, Set, Tuple, Union
from .__common__ import *
import torch.nn.functional as F
from utils.type import NetInput, ReturnData
__all__ = ["density2energy", "density2alpha", "VolumnRenderer"]
from .generic import *
from model.base import BaseModel
from utils import math
from utils.module import Module
from utils.perf import checkpoint, perf
from utils.samples import Samples
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0) -> torch.Tensor:
"""
Calculate energies from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param densities `Tensor(N...)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: energies which block light rays
:return `Tensor(N...)`: energies which block light rays
"""
if raw_noise_std > 0:
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
densities = densities + torch.normal(0.0, raw_noise_std, densities.size())
return densities * dists[..., None]
densities = densities + torch.normal(0.0, raw_noise_std, densities.shape,
device=densities.device)
return F.relu(densities) * dists
def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
def energy2alpha(energies: torch.Tensor) -> torch.Tensor:
"""
Calculate alphas from densities inferred by model.
Convert energies to alphas.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: alphas
:param energies `Tensor(N...)`: energies (calculated from densities)
:return `Tensor(N...)`: alphas
"""
energies = density2energy(densities, dists, raw_noise_std)
return 1.0 - torch.exp(-energies)
class AlphaComposition(Module):
def __init__(self):
super().__init__()
def forward(self, colors, alphas, bg=None):
def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0) -> torch.Tensor:
"""
[summary]
Calculate alphas from densities inferred by model.
:param colors `Tensor(N, P, C)`: [description]
:param alphas `Tensor(N, P, 1)`: [description]
:param bg `Tensor([N, ]C)`: [description], defaults to None
:return `Tensor(N, C)`: [description]
:param densities `Tensor(N...)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to regularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N...)`: alphas
"""
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + math.tiny, dim=-2)
one_minus_alpha = torch.cat([
torch.ones_like(one_minus_alpha[..., :1, :]),
one_minus_alpha
], dim=-2)
weights = alphas * one_minus_alpha # (N, P, 1)
# (N, C), computed weighted color of each sample along each ray.
final_color = torch.sum(weights * colors, dim=-2)
# To composite onto a white background, use the accumulated alpha map.
if bg is not None:
# Sum of weights along each ray. This value is in [0, 1] up to numerical error.
acc_map = torch.sum(weights, -1)
final_color = final_color + bg * (1. - acc_map[..., None])
return {
'color': final_color,
'weights': weights,
}
return energy2alpha(density2energy(densities, dists, raw_noise_std))
class VolumnRenderer(Module):
class VolumnRenderer(nn.Module):
class States:
kernel: BaseModel
samples: Samples
early_stop_tolerance: float
outputs: Set[str]
hit_mask: torch.Tensor
N: int
P: int
device: torch.device
colors: torch.Tensor
densities: torch.Tensor
energies: torch.Tensor
weights: torch.Tensor
cum_energies: torch.Tensor
exp_energies: torch.Tensor
tot_evaluations: Dict[str, int]
chunk: Tuple[slice, slice]
cum_chunk: Tuple[slice, slice]
cum_last: Tuple[slice, slice]
chunk_id: int
@property
def start(self) -> int:
return self.chunk[1].start
@property
def end(self) -> int:
return self.chunk[1].stop
def __init__(self, kernel: BaseModel, samples: Samples, early_stop_tolerance: float,
outputs: Set[str]) -> None:
self.kernel = kernel
self.samples = samples
self.early_stop_tolerance = early_stop_tolerance
self.outputs = outputs
N, P = samples.size
self.device = self.samples.device
self.hit_mask = samples.voxel_indices != -1 # (N, P) | bool
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.densities = torch.zeros(N, P, 1, device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
self.tot_evaluations = {}
self.N, self.P = N, P
self.chunk_id = -1
def n_hits(self, index: Union[int, slice] = None) -> int:
if not isinstance(self.hit_mask, torch.Tensor):
if index is not None:
return self.N * self.colors[:, index].shape[1]
return self.N * self.P
if index is None:
return self.hit_mask.count_nonzero().item()
return self.hit_mask[:, index].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int):
if key not in self.tot_evaluations:
self.tot_evaluations[key] = 0
self.tot_evaluations[key] += n
def next_chunk(self, *, length=None, end=None):
start = 0 if not hasattr(self, "chunk") else self.end
length = length or self.P
end = min(end or start + length, self.P)
self.chunk = slice(None), slice(start, end)
self.cum_chunk = slice(None), slice(start + 1, end + 1)
self.cum_last = slice(None), slice(start, start + 1)
self.chunk_id += 1
return self
def put(self, key: str, values: torch.Tensor, indices: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[slice, slice]]):
if not hasattr(self, key):
new_tensor = torch.zeros(self.N, self.P, values.shape[-1], device=self.device)
setattr(self, key, new_tensor)
tensor: torch.Tensor = getattr(self, key)
# if isinstance(indices[0], torch.Tensor):
# tensor.index_put_(indices, values)
# else:
tensor[indices] = values
def __init__(self, **kwargs):
def __init__(self):
super().__init__()
@perf
def forward(self, kernel: BaseModel, samples: Samples, *outputs: str,
raymarching_early_stop_tolerance: float = 0,
raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
**kwargs) -> ReturnData:
# stub method
def __call__(self, samples: Samples, densities: torch.Tensor, colors: torch.Tensor, *outputs: str,
white_bg: bool, raw_noise_std: float) -> ReturnData:
"""
Perform volumn rendering.
:param kernel `BaseModel`: render kernel
:param samples `Samples(N, P)`: samples
:param samples `Samples(B, P)`: samples
:param rgbd `Tensor(B, P, C+1)`: colors and densities
:param outputs `str...`: items should be contained in the result dict.
Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
Should between 0 and 1 (0 means no early stop). Defaults to 0
:param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
0 and `None` means not splitting the raymarching process. Defaults to `None`
:return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
:return `ReturnData`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
"""
if samples.size[1] == 0:
print("VolumnRenderer.forward(): # of samples is zero")
return None
infer_outputs = set()
for key in outputs:
if key == "color":
infer_outputs.add("colors")
infer_outputs.add("densities")
elif key == "specular":
infer_outputs.add("speculars")
infer_outputs.add("densities")
elif key == "diffuse":
infer_outputs.add("diffuses")
infer_outputs.add("densities")
elif key == "depth":
infer_outputs.add("densities")
else:
infer_outputs.add(key)
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance, infer_outputs)
checkpoint("Prepare states object")
if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [
math.ceil(s.P / raymarching_chunk_size_or_sections)
]
if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections
for chunk_samples in cycle(chunk_sections):
self._forward_chunk(s.next_chunk(length=chunk_samples))
if s.end >= s.P:
break
else:
chunk_size = -raymarching_chunk_size_or_sections
chunk_hits = s.n_hits(0)
for i in range(1, s.P):
n_hits = s.n_hits(i)
if chunk_hits + n_hits > chunk_size:
self._forward_chunk(s.next_chunk(end=i))
n_hits = s.n_hits(i)
chunk_hits = 0
chunk_hits += n_hits
self._forward_chunk(s.next_chunk())
checkpoint("Run forward chunks")
ret = {}
for key in outputs:
if key == 'color':
ret['color'] = torch.sum(s.colors * s.weights, 1)
elif key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse' and hasattr(s, "diffuses"):
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular' and hasattr(s, "speculars"):
ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers':
ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states':
ret['states'] = s
else:
if hasattr(s, key):
ret[key] = getattr(s, key)
checkpoint("Set return data")
return ret
@perf
def _calc_weights(self, s: States):
"""
Calculate weights of samples in composited outputs
:param s `States`: states
:param start `int`: chunk's start
:param end `int`: chunk's end
"""
s.energies[s.chunk] = density2energy(s.densities[s.chunk], s.samples.dists[s.chunk])
s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
@perf
def _apply_early_stop(self, s: States):
"""
Stop rays whose accumulated opacity are larger than a threshold
:param s `States`: s
:param end `int`: chunk's end
"""
if s.end < s.P and s.early_stop_tolerance > 0 and isinstance(s.hit_mask, torch.Tensor):
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0
@perf
def _forward_chunk(self, s: States) -> int:
if isinstance(s.hit_mask, torch.Tensor):
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return
fi_idxs[1].add_(s.start)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
else:
fi_idxs = s.chunk
fi_outputs = s.kernel.infer(*s.outputs, samples=s.samples[fi_idxs], chunk_id=s.chunk_id)
for key, value in fi_outputs.items():
s.put(key, value, fi_idxs)
self._calc_weights(s)
self._apply_early_stop(s)
class DensityFirstVolumnRenderer(VolumnRenderer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _forward_chunk(self, s: VolumnRenderer.States) -> int:
fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X
density_inputs = s.kernel.input(fi_samples, "x", "f") # (N', Ex)
# Infer densities (shape)
density_outputs = s.kernel.infer('densities', 'features', samples=fi_samples,
inputs=density_inputs, chunk_id=s.chunk_id)
s.put('densities', density_outputs['densities'], fi_idxs)
s.accumulate_tot_evaluations("densities", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
# Remove samples whose weights are less than a threshold
s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0
# Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_samples = s.samples[fi_idxs] # N -> N"
fi_features = density_outputs['features'][fi_mask]
color_inputs = s.kernel.input(fi_samples, "d") # (N")
color_inputs.x = density_inputs.x[fi_mask]
# Infer colors (appearance)
outputs = s.outputs.copy()
if 'densities' in outputs:
outputs.remove('densities')
color_outputs = s.kernel.infer(*outputs, samples=fi_samples, inputs=color_inputs,
chunk_id=s.chunk_id, features=fi_features)
# if s.chunk_id == 0:
# fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1:
# fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
# elif s.chunk_id == 2:
# fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
# else:
# fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
for key, value in color_outputs.items():
s.put(key, value, fi_idxs)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
...
@profile
def forward(self, samples: Samples, rgbd: torch.Tensor, *outputs: str,
white_bg: bool, raw_noise_std: float) -> ReturnData:
energies = density2energy(rgbd[..., -1], samples.dists, raw_noise_std) # (B, P)
alphas = energy2alpha(energies) # (B, P)
weights = (alphas * torch.cumprod(union(1, 1. - alphas + 1e-10), -1)[..., :-1])[..., None]
output_fn = {
"color": lambda: torch.sum(weights * rgbd[..., :-1], -2) + (1. - torch.sum(weights, -2)
if white_bg else 0.),
"depth": lambda: torch.sum(weights * samples.depths[..., None], -2),
"colors": lambda: rgbd[..., :-1],
"densities": lambda: rgbd[..., -1:],
"alphas": lambda: alphas[..., None],
"energies": lambda: energies[..., None],
"weights": lambda: weights
}
return ReturnData({key: output_fn[key]() for key in outputs if key in output_fn})
import torch
from typing import Tuple
from .generic import *
from .__common__ import *
from .space import Space
from clib import *
from utils import device
from utils import sphere
from utils import misc
from utils import math
from utils.module import Module
from utils.samples import Samples
from utils.perf import perf, checkpoint
class Bins(object):
@property
def up(self):
return self.bounds[1:]
@property
def lo(self):
return self.bounds[:-1]
from utils.misc import grid2d
def __init__(self, vals: torch.Tensor):
self.vals = vals
self.bounds = torch.cat([
self.vals[:1],
0.5 * (self.vals[1:] + self.vals[:-1]),
self.vals[-1:]
])
__all__ = ["Sampler", "UniformSampler", "PdfSampler"]
@staticmethod
def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
return Bins(torch.linspace(*val_range, N, device=device))
def to(self, device: torch.device):
self.vals = self.vals.to(device)
self.bounds = self.bounds.to(device)
class Sampler(nn.Module):
_samples_indices_cached: torch.Tensor | None
class Sampler(Module):
def __init__(self, **kwargs):
def __init__(self, x_chns: int, d_chns: int):
"""
Initialize a Sampler module
"""
super().__init__()
super().__init__({}, {"x": x_chns, "d": d_chns})
self._samples_indices_cached = None
def _sample(self, range: Tuple[float, float], n_rays: int, n_samples: int, perturb: bool,
device: torch.device) -> torch.Tensor:
"""
[summary]
# stub method for type hint
def __call__(self, rays: Rays, space: Space, **kwargs) -> Samples:
...
:param t_range `float, float`: sampling range
:param n_rays `int`: number of rays (B)
:param n_samples `int`: number of samples per ray (P)
:param perturb `bool`: whether perturb sampling
:param device `torch.device`: the device used to create tensors
:return `Tensor(B, P+1)`: sampling bounds of t
def _get_samples_indices(self, pts: torch.Tensor) -> torch.Tensor:
"""
bounds = torch.linspace(*range, n_samples + 1, device=device) # (P+1)
if perturb:
rand_bounds = torch.cat([
bounds[:1],
0.5 * (bounds[1:] + bounds[:-1]),
bounds[-1:]
])
rand_vals = torch.rand(n_rays, n_samples + 1, device=device)
bounds = rand_bounds[:-1] * (1 - rand_vals) + rand_bounds[1:] * rand_vals
else:
bounds = bounds[None].expand(n_rays, -1)
return bounds
Get 2D indices of samples. The first value is the index of ray, while the second value is
the index of sample in a ray.
def _get_samples_indices(self, pts: torch.Tensor):
:param pts `Tensor(B, P, 3)`: the sample points
:return `Tensor(B, P)`: the 2D indices of samples
"""
if self._samples_indices_cached is None\
or self._samples_indices_cached.device != pts.device\
or self._samples_indices_cached.shape[0] < pts.shape[0]\
or self._samples_indices_cached.shape[1] < pts.shape[1]:
self._samples_indices_cached = misc.meshgrid(
*pts.shape[:2], swap_dim=True, device=pts.device)
self._samples_indices_cached = grid2d(*pts.shape[:2], indexing="ij", device=pts.device)
return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]]
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_: Space, *,
sample_range: Tuple[float, float], n_samples: int, lindisp: bool = False,
perturb_sample: bool = True, spherical: bool = False,
**kwargs) -> Tuple[Samples, torch.Tensor]:
def _get_samples(self, rays: Rays, space: Space, t_vals: torch.Tensor, mode: str) -> Samples:
"""
Sample points along rays.
Get samples along rays at sample steps specified by `t_vals`.
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param sample_range `float, float`: sampling range
:param n_samples `int`: number of samples per ray
:param lindisp `bool`: whether sample linearly in disparity space (1/depth)
:param perturb_sample `bool`: whether perturb sampling
:param t_vals `Tensor(B, P)`: sample steps
:param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
:return `Samples(B, P)`: samples
"""
if spherical:
t_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
rays_o.device)
t0, t1 = t_bounds[:, :-1], t_bounds[:, 1:] # (B, P)
t = (t0 + t1) * .5
p, z = sphere.ray_sphere_intersect(rays_o, rays_d, t.reciprocal())
p = sphere.cartesian2spherical(p, inverse_r=True)
vidxs = space_.get_voxel_indices(p)
return Samples(
pts=p,
dirs=rays_d[:, None].expand(-1, n_samples, -1),
depths=z,
dists=(t1 + math.tiny).reciprocal() - t0.reciprocal(),
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=t
)
if mode == "xyz":
z_vals = t_vals
pts = rays.get_points(z_vals)
elif mode == "xyz_disp":
z_vals = t_vals.reciprocal()
pts = rays.get_points(z_vals)
elif mode == "spherical":
z_vals = t_vals.reciprocal()
pts = sphere.cartesian2spherical(rays.get_points(z_vals), inverse_r=True)
elif mode == "spherical_radius":
z_vals = sphere.ray_sphere_intersect(rays, t_vals.reciprocal())
pts = sphere.cartesian2spherical(rays.get_points(z_vals), inverse_r=True)
else:
sample_range = (1 / sample_range[0], 1 / sample_range[1]) if lindisp else sample_range
z_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
rays_o.device)
if lindisp:
z_bounds = z_bounds.reciprocal()
z0, z1 = z_bounds[:, :-1], z_bounds[:, 1:] # (B, P)
z = (z0 + z1) * .5
p = rays_o[:, None] + rays_d[:, None] * z[..., None]
vidxs = space_.get_voxel_indices(p)
raise ValueError(f"Unknown mode: {mode}")
rays_d = rays.rays_d.unsqueeze(1) # (B, 1, 3)
dists = union(z_vals[..., 1:] - z_vals[..., :-1], math.huge) # (B, P)
dists *= torch.norm(rays_d, dim=-1)
return Samples(
pts=p,
dirs=rays_d[:, None].expand(-1, n_samples, -1),
depths=z,
dists=z1 - z0,
voxel_indices=vidxs,
indices=self._get_samples_indices(p),
t=z
pts=pts,
dirs=rays_d.expand(*pts.shape[:2], -1),
depths=z_vals,
t_vals=t_vals,
dists=dists,
voxel_indices=space.get_voxel_indices(pts) if space else 0,
indices=self._get_samples_indices(pts)
)
class PdfSampler(Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
spherical: bool, lindisp: bool, **kwargs):
class UniformSampler(Sampler):
"""
Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
This module expands NeRF's code of uniform sampling to support our spherical sampling and enable
the trace of samples' indices.
"""
super().__init__()
self.lindisp = lindisp
self.perturb_sample = perturb_sample
self.spherical = spherical
self.n_samples = n_samples
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
"""
Sample points along rays. return Spherical or Cartesian coordinates,
specified by `self.shperical`
def __init__(self):
super().__init__(3, 3)
:param rays_o `Tensor(B, 3)`: rays' origin
:param rays_d `Tensor(B, 3)`: rays' direction
:param weights `Tensor(B, M)`: weights of sample bins
:param s_vals `Tensor(B, M)`: (optional) center of sample bins
:param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array
:return `Tensor(B, N, 3)`: sampled points
:return `Tensor(B, N)`: corresponding depths along rays
def _sample(self, range: tuple[float, float], n_rays: int, n_samples: int, perturb: bool) -> torch.Tensor:
"""
if s_vals is None:
s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default())
s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample)
if include_s_vals:
s = torch.cat([s, s_vals], dim=-1)
s = torch.sort(s, descending=self.lindisp)[0]
z = torch.reciprocal(s) if self.lindisp else s
if self.spherical:
pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
return sphers, depths, s, pts
else:
return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s
def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True):
'''
:param bins `Tensor(..., M+1)`: bounds of bins
:param weights `Tensor(..., M)`: weights of bins
:param N `int`: # of samples along each ray
:param det `bool`: (default to `True`) perform deterministic sampling or not
:return `Tensor(..., N)`: samples
'''
# Get pdf
weights = weights + math.tiny # prevent nans
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
cdf = torch.cat([
torch.zeros_like(pdf[..., :1]),
torch.cumsum(pdf, dim=-1)
], dim=-1) # [..., M+1]
# Take uniform samples
dots_sh = list(weights.shape[:-1])
M = weights.shape[-1]
u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \
if det else torch.rand(dots_sh + [N], device=bins.device) # [..., N]
# Invert CDF
# [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,]
above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long()
Generate sample steps along rays in the specified range.
# random sample inside each bin
below_inds = torch.clamp(above_inds - 1, min=0)
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N, 2]
cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1]
cdf_g = torch.gather(cdf, dim=-1, index=inds_g) # [..., N, 2]
bins = bins[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1]
bins_g = torch.gather(bins, dim=-1, index=inds_g) # [..., N, 2]
# fix numeric issue
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N]
denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + math.tiny)
return samples
class VoxelSampler(Module):
def __init__(self, *, sample_step: float, **kwargs):
:param range `float, float`: sampling range
:param n_rays `int`: number of rays (B)
:param n_samples `int`: number of samples per ray (P)
:param perturb `bool`: whether perturb sampling
:return `Tensor(B, P)`: sampled "t"s along rays
"""
Initialize a VoxelSampler module
t_vals = torch.linspace(*range, n_samples, device=self.device) # (P)
if perturb:
mids = .5 * (t_vals[..., 1:] + t_vals[..., :-1])
upper = union(mids, t_vals[..., -1:])
lower = union(t_vals[..., :1], mids)
# stratified samples in those intervals
t_vals = t_vals.expand(n_rays, -1)
t_vals = lower + (upper - lower) * torch.rand_like(t_vals)
else:
t_vals = t_vals.expand(n_rays, -1)
return t_vals
:param perturb_sample: perturb the sample depths
:param step_size: step size
# stub method for type hint
def __call__(self, rays: Rays, space: Space, *,
range: tuple[float, float],
mode: str,
n_samples: int,
perturb: bool) -> Samples:
"""
super().__init__()
self.sample_step = sample_step
Sample points along rays.
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *,
perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
:param rays `Rays(B)`: rays
:param space `Space`: sample space
:param range `float, float`: sampling range
:param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
:param n_samples `int`: number of samples per ray
:param perturb `bool`: whether perturb sampling, defaults to `False`
:return `Samples(B, P)`: samples
"""
[summary]
...
:param rays_o `Tensor(N, 3)`: rays' origin positions
:param rays_d `Tensor(N, 3)`: rays' directions
:param step_size `float`: gap between samples along a ray
:return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
:return `Tensor(N)`: valid rays mask
"""
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
n_rays = rays_o.size(0)
ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N')
@profile
def forward(self, rays: Rays, space: Space, *,
range: tuple[float, float],
mode: str,
n_samples: int,
perturb: bool) -> Samples:
t_range = range if mode == "xyz" else (1. / range[0], 1. / range[1])
t_vals = self._sample(t_range, rays.shape[0], n_samples, perturb) # (B, P)
return self._get_samples(rays, space, t_vals, mode)
hits = intersections.hits
min_depths = intersections.min_depths
max_depths = intersections.max_depths
voxel_indices = intersections.voxel_indices
rays_near_depth = min_depths[:, :1] # (N', 1)
rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1)
rays_length = rays_far_depth - rays_near_depth
rays_steps = (rays_length / self.sample_step).ceil().long()
rays_step_size = rays_length / rays_steps
max_steps = rays_steps.max().item()
rays_step = torch.arange(max_steps, device=rays_o.device,
dtype=torch.float)[None].repeat(n_rays, 1) # (N', P)
invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[
ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth)
] # (N', P)
samples_depth[invalid_samples_mask] = math.huge
samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1
class PdfSampler(Sampler):
"""
Hierarchical sampling (section 5.2 of NeRF)
"""
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * samples_depth[..., None],
dirs=rays_d.expand(-1, max_steps, -1),
depths=samples_depth,
dists=samples_dist,
voxel_indices=samples_voxel_index
), valid_rays_mask
def __init__(self):
super().__init__(3, 3)
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
space: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
def _sample(self, t_vals: torch.Tensor, weights: torch.Tensor, n_importance: int,
perturb: bool, include_existed: bool, sort_descending: bool) -> torch.Tensor:
"""
[summary]
Generate sample steps by PDF according to existed sample steps and their weights.
:param rays_o `Tensor(N, 3)`: [description]
:param rays_d `Tensor(N, 3)`: [description]
:param step_size `float`: [description]
:return `Samples(N, P)`: [description]
:param t_vals `Tensor(B, P)`: existed sample steps
:param weights `Tensor(B, P)`: weights of existed sample steps
:param n_importance `int`: number of samples to generate for each ray
:param perturb `bool`: whether perturb sampling
:param include_existed `bool`: whether to include existed samples in the output
:return `Tensor(B, P'[+P])`: the output sample steps
"""
intersections = space.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
bins = .5 * (t_vals[..., 1:] + t_vals[..., :-1]) # (B, P - 1)
weights = weights[..., 1:-1] + math.tiny # (B, P - 2)
checkpoint("Ray intersect")
# Get PDF
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = union(0., torch.cumsum(pdf, -1)) # (B, P - 1)
if intersections.size == 0:
return None, valid_rays_mask
# Take uniform samples
if perturb:
u = torch.rand(*cdf.shape[:-1], n_importance, device=self.device)
else:
min_depth = intersections.min_depths
max_depth = intersections.max_depths
pts_idx = intersections.voxel_indices
dists = max_depth - min_depth
tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1)
probs = dists / tot_dists
steps = tot_dists[:, 0] / self.sample_step
# sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, math.huge)
u = torch.linspace(0., 1., steps=n_importance, device=self.device).\
expand(*cdf.shape[:-1], -1)
checkpoint("Inverse CDF sampling")
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * sampled_depths[..., None],
dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
depths=sampled_depths,
dists=sampled_dists,
voxel_indices=sampled_indices
), valid_rays_mask
# Invert CDF
u = u.contiguous() # (B, P')
inds = torch.searchsorted(cdf, u, right=True) # (B, P')
inds_g = torch.stack([
(inds - 1).clamp_min(0), # below
inds.clamp_max(cdf.shape[-1] - 1) # above
], -1) # (B, P', 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] # [B, P', P - 1]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) # (B, P', 2)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) # (B, P', 2)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
t_samples = (bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])).detach()
if include_existed:
return torch.sort(union(t_vals, t_samples), -1, descending=sort_descending)[0]
else:
return t_samples
# stub method for type hint
def __call__(self, rays: Rays, space: Space, t_vals: torch.Tensor, weights: torch.Tensor, *,
mode: str,
n_importance: int,
perturb: bool,
include_existed_samples: bool) -> Samples:
"""
Sample points along rays using PDF sampling based on existed samples.
:param rays `Rays(B)`: rays
:param space `Space`: sample space
:param t_vals `Tensor(B, P)`: existed sample steps
:param weights `Tensor(B, P)`: weights of existed sample steps
:param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
:param n_importance `int`: number of samples to generate using PDF sampling for each ray
:param perturb `bool`: whether perturb sampling, defaults to `False`
:param include_existed_samples `bool`: whether to include existed samples in the output,
defaults to `True`
:return `Samples(B, P'[+P])`: samples
"""
...
@profile
def forward(self, rays: Rays, space: Space, t_vals: torch.Tensor, weights: torch.Tensor, *,
mode: str,
n_importance: int,
perturb: bool,
include_existed_samples: bool) -> Samples:
t_vals = self._sample(t_vals, weights, n_importance, perturb, include_existed_samples,
mode != "xyz")
return self._get_samples(rays, space, t_vals, mode)
import torch
from typing import Dict, List, Optional, Tuple, Union
from .__common__ import *
from clib import *
from model.utils import load
from utils.module import Module
#from model.utils import load
from utils.nn import Parameter
from utils.geometry import *
from utils.voxels import *
from utils.perf import perf
from utils.env import get_env
__all__ = ["Space", "Voxels", "Octree"]
class Intersections:
......@@ -24,8 +22,8 @@ class Intersections:
"""`Tensor(N)` Number of hits"""
@property
def size(self):
return self.hits.size(0)
def shape(self):
return self.hits.shape
def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
......@@ -42,9 +40,9 @@ class Intersections:
hits=self.hits[index])
class Space(Module):
bbox: Optional[torch.Tensor]
"""`Tensor(2, 3)` Bounding box"""
class Space(nn.Module):
bbox: torch.Tensor | None
"""`Tensor(2, D)` Bounding box"""
@property
def dims(self) -> int:
......@@ -52,16 +50,18 @@ class Space(Module):
return self.bbox.shape[1] if self.bbox is not None else 3
@staticmethod
def create(args: dict) -> 'Space':
if 'space' not in args:
def create(type: str, args: dict[str, Any]) -> 'Space':
match type:
case "Space":
return Space(**args)
if args['space'] == 'octree':
case "Octree":
return Octree(**args)
if args['space'] == 'voxels':
case "Voxels":
return Voxels(**args)
return load(args['space']).space
case _:
return load(type).space
def __init__(self, clone_src: "Space" = None, *, bbox: List[float] = None, **kwargs):
def __init__(self, clone_src: "Space" = None, *, bbox: list[float] = None, **kwargs):
super().__init__()
if clone_src:
self.device = clone_src.device
......@@ -69,10 +69,30 @@ class Space(Module):
else:
self.register_temp('bbox', None if not bbox else torch.tensor(bbox).reshape(2, -1))
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
raise NotImplementedError
def ray_intersect_with_bbox(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> Intersections:
"""
[summary]
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
:param rays_o `Tensor(N..., D)`: rays' origin
:param rays_d `Tensor(N..., D)`: rays' direction
:param max_hits `int?`: max number of hits of each ray, have no effect for this method
:return `Intersect(N...)`: rays' intersection with the bounding box
"""
if self.bbox is None:
raise RuntimeError("The space has no bounding box")
inv_d = rays_d.reciprocal().unsqueeze(-2)
t = (self.bbox - rays_o.unsqueeze(-2)) * inv_d # (N..., 2, D)
t0 = t.min(dim=-2)[0].max(dim=-1, keepdim=True)[0].clamp(min=1e-4) # (N..., 1)
t1 = t.max(dim=-2)[0].min(dim=-1, keepdim=True)[0]
miss = t1 <= t0
t0[miss], t1[miss] = -1., -1.
hit = torch.logical_not(miss).long()
return Intersections(t0, t1, hit - 1, hit.squeeze(-1))
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, max_hits: int) -> Intersections:
return self.ray_intersect_with_bbox(rays_o, rays_d)
def get_voxel_indices(self, pts: torch.Tensor) -> int | torch.Tensor:
if self.bbox is None:
return 0
voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
......@@ -81,19 +101,22 @@ class Space(Module):
return voxel_indices
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
raise NotImplementedError()
@torch.no_grad()
def split(self) -> Tuple[int, int]:
def split(self) -> tuple[int, int]:
raise NotImplementedError()
@torch.no_grad()
def clone(self):
return Space(self)
return self.__class__(self)
class Voxels(Space):
bbox: torch.Tensor
"""`Tensor(2, D)` Bounding box"""
steps: torch.Tensor
"""`Tensor(3)` Steps along each dimension"""
......@@ -131,42 +154,43 @@ class Voxels(Space):
@property
def voxel_size(self) -> torch.Tensor:
"""`Tensor(3)` Voxel size"""
if self.bbox is None:
raise RuntimeError("Cannot get property 'voxel_size' of a space which "
"doesn't have bounding box")
return (self.bbox[1] - self.bbox[0]) / self.steps
@property
def corner_embeddings(self) -> Dict[str, torch.nn.Embedding]:
def corner_embeddings(self) -> dict[str, torch.nn.Embedding]:
return {name[4:]: emb for name, emb in self.named_modules() if name.startswith("emb_")}
@property
def voxel_embeddings(self) -> Dict[str, torch.nn.Embedding]:
def voxel_embeddings(self) -> dict[str, torch.nn.Embedding]:
return {name[5:]: emb for name, emb in self.named_modules() if name.startswith("vemb_")}
def __init__(self, clone_src: "Voxels" = None, *, bbox: List[float] = None,
voxel_size: float = None, steps: Union[torch.Tensor, Tuple[int, ...]] = None,
def __init__(self, clone_src: "Voxels" = None, *, bbox: list[float] = None,
voxel_size: float = None, steps: torch.Tensor | tuple[int, ...] = None,
**kwargs) -> None:
super().__init__(clone_src, bbox=bbox, **kwargs)
if clone_src:
super().__init__(clone_src)
self.register_buffer('steps', clone_src.steps)
self.register_buffer('voxels', clone_src.voxels)
self.register_buffer("corners", clone_src.corners)
self.register_buffer("corner_indices", clone_src.corner_indices)
self.register_buffer('voxel_indices_in_grid', clone_src.voxel_indices_in_grid)
else:
if self.bbox is None:
if bbox is None:
raise ValueError("Missing argument 'bbox'")
if voxel_size is not None:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
else:
super().__init__(bbox=bbox)
if steps is not None:
self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
else:
self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
self.register_buffer('voxel_indices_in_grid', torch.arange(-1, self.n_voxels))
def clone(self):
return Voxels(self)
def to_vi(self, gi: torch.Tensor) -> torch.Tensor:
return self.voxel_indices_in_grid[gi + 1]
......@@ -208,7 +232,7 @@ class Voxels(Space):
voxels = self.voxels[voxel_indices] # (N, 3)
corner_indices = self.corner_indices[voxel_indices] # (N, 8)
p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel
return trilinear_interp(p, emb(corner_indices))
return linear_interp(p, emb(corner_indices))
def create_voxel_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
"""
......@@ -245,7 +269,7 @@ class Voxels(Space):
raise KeyError(f"Embedding '{name}' doesn't exist")
return emb(voxel_indices)
@perf
@profile
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
"""
Calculate intersections of rays and voxels.
......@@ -277,7 +301,7 @@ class Voxels(Space):
hits=hits[0]
)
@perf
@profile
def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
"""
Get voxel indices of points.
......@@ -290,8 +314,8 @@ class Voxels(Space):
gi = to_grid_indices(pts, self.bbox, self.steps)
return self.to_vi(gi)
@perf
def get_corners(self, vidxs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@profile
def get_corners(self, vidxs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
vidxs = vidxs.unique()
if vidxs[0] == -1:
vidxs = vidxs[1:]
......@@ -303,7 +327,7 @@ class Voxels(Space):
return fi_corner_indices, fi_corners
@torch.no_grad()
def split(self) -> Tuple[int, int]:
def split(self) -> tuple[int, int]:
"""
Split voxels into smaller voxels with half size.
"""
......@@ -336,7 +360,7 @@ class Voxels(Space):
return self.n_voxels // 8, self.n_voxels
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
self.voxels = self.voxels[keeps]
self.corner_indices = self.corner_indices[keeps]
self._update_gi2vi()
......@@ -351,7 +375,7 @@ class Voxels(Space):
new_emb = self.set_voxel_embedding(update_fn(emb.weight), name)
self._update_optimizer(emb.weight, new_emb.weight, update_fn)
def _update_optimizer(self, old_param: nn.Parameter, new_param: nn.Parameter, update_fn):
def _update_optimizer(self, old_param: Parameter, new_param: Parameter, update_fn):
optimizer = get_env()["trainer"].optimizer
if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
# Update related states in optimizer
......@@ -384,7 +408,7 @@ class Voxels(Space):
sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid[1:].reshape(*self.steps).ne(-1).sum(sum_dims)
def balance_cut(self, dim: int, n_parts: int) -> List[int]:
def balance_cut(self, dim: int, n_parts: int) -> list[int]:
n_voxels_list = self.n_voxels_along_dim(dim)
cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
bins = []
......@@ -398,7 +422,7 @@ class Voxels(Space):
bins.append(len(cdf) - offset)
return bins
def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
"""
For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`.
......@@ -419,7 +443,7 @@ class Voxels(Space):
pts += (torch.rand_like(pts) - .5) * self.voxel_size / S
return pts.reshape(-1, 3), voxel_indices.flatten()
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_gi2vi(self):
......@@ -456,7 +480,7 @@ class Octree(Voxels):
self.nodes_cached = None
self.tree_cached = None
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
if self.nodes_cached is None:
self.nodes_cached, self.tree_cached = build_easy_octree(
self.voxels, 0.5 * self.voxel_size)
......@@ -477,7 +501,7 @@ class Octree(Voxels):
return ret
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
ret = super().prune(keeps)
self.clear()
return ret
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import plotly_express as px # import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"import pandas as pd\n",
"import numpy as np\n",
"import os\n",
"\n",
"name = [\"NeRF\", \"PANO\", \"OURS+\", \"OURS\"]\n",
"foveal = 8.0\n",
"mid = 2.8\n",
"far = 2.6\n",
"blend = 0.1\n",
"ours_full = 562\n",
"pano = 1e4\n",
"nerf = 9e4\n",
"\n",
"times = {\n",
" \"fovea_l\": 0,\n",
" \"mid_l\": 0,\n",
" \"far_l\": 0,\n",
" \"fovea_r\": 0,\n",
" \"mid_r\": 0,\n",
" \"far_r\": 0,\n",
" \"blend\": 0,\n",
" \"ours_full\": 0,\n",
" \"pano\": 0,\n",
" \"nerf\": 0,\n",
"}\n",
"clip = 0\n",
"frame_id = 0\n",
"\n",
"\n",
"def calc_total():\n",
" return [\n",
" times[\"nerf\"],\n",
" times[\"pano\"],\n",
" times[\"ours_full\"],\n",
" times[\"fovea_l\"] + times[\"mid_l\"] + times[\"far_l\"] +\n",
" times[\"fovea_r\"] + times[\"mid_r\"] + times[\"far_r\"] + times[\"blend\"]\n",
" ]\n",
"\n",
"\n",
"def draw_frame(*, xlim=None, **kwargs):\n",
" global frame_id\n",
" for key in kwargs:\n",
" times[key] = kwargs[key]\n",
" tot = calc_total()\n",
" data = {\n",
" \"fovea_l\": [0, 0, 0, times[\"fovea_l\"]],\n",
" \"mid_l\": [0, 0, 0, times[\"mid_l\"]],\n",
" \"far_l\": [0, 0, 0, times[\"far_l\"]],\n",
" \"fovea_r\": [0, 0, 0, times[\"fovea_r\"]],\n",
" \"mid_r\": [0, 0, 0, times[\"mid_r\"]],\n",
" \"far_r\": [0, 0, 0, times[\"far_r\"]],\n",
" \"blend\": [0, 0, 0, times[\"blend\"]],\n",
" \"ours_full\": [0, 0, times[\"ours_full\"], 0],\n",
" \"pano\": [0, times[\"pano\"], 0, 0],\n",
" \"nerf\": [times[\"nerf\"], 0, 0, 0],\n",
" }\n",
" if xlim is None or xlim < max(tot) * 1.1:\n",
" xlim = max(tot) * 1.1\n",
" \n",
" fig = go.Figure()\n",
" times_keys = list(times.keys())\n",
" for key in times_keys:\n",
" if key == times_keys[-1]:\n",
" fig.add_trace(go.Bar(\n",
" y=name,\n",
" x=data[key],\n",
" name=key,\n",
" orientation='h',\n",
" text=[\"\" if item == 0 else f\"{item:.1f}\" if item < 1000 else f\"{item:.1e}\" for item in tot],\n",
" textposition=\"outside\"\n",
" ))\n",
" else:\n",
" fig.add_trace(go.Bar(\n",
" y=name,\n",
" x=data[key],\n",
" name=key,\n",
" orientation='h',\n",
" ))\n",
" fig.update_traces(width=0.5)\n",
" fig.update_layout(barmode='stack', showlegend=False,\n",
" yaxis_visible=False, yaxis_showticklabels=False, xaxis_range=[0, xlim])\n",
" \n",
" # fig.show()\n",
" fig.write_image(f\"dynamic_bar/clip_{clip}/{frame_id:04d}.png\", width=1920 // 2, height=1080 // 2, scale=2)\n",
" frame_id = frame_id + 1\n",
"\n",
"def add_animation(*, frames, xlim=None, **kwargs):\n",
" if frames == 1:\n",
" draw_frame(**kwargs, xlim=xlim)\n",
" return\n",
" data = {\n",
" key: np.linspace(times[key], kwargs[key], frames)\n",
" for key in kwargs\n",
" }\n",
" for i in range(frames):\n",
" draw_frame(**{key: data[key][i] for key in data}, xlim=xlim)\n",
"\n",
"def new_clip():\n",
" global clip, frame_id\n",
" clip += 1\n",
" frame_id = 0\n",
" os.system(f\"mkdir dynamic_bar/clip_{clip}\")\n",
"\n",
"os.system('rm -f -r dynamic_bar')\n",
"os.system('mkdir dynamic_bar')\n",
"\n",
"# ours mono\n",
"new_clip()\n",
"add_animation(fovea_l=foveal, frames=48, xlim=30) # Step 1: grow foveal\n",
"add_animation(mid_l=mid, frames=16, xlim=30) # Step 2: grow mid\n",
"add_animation(far_l=far, frames=16, xlim=30) # Step 3: grow far\n",
"add_animation(blend=blend, frames=1, xlim=30) # Step 4: grow blend\n",
"\n",
"# ours stereo\n",
"new_clip()\n",
"add_animation(fovea_r=foveal, frames=24, xlim=30) # Step 1: grow foveal\n",
"add_animation(mid_r=mid, frames=8, xlim=30) # Step 2: grow mid\n",
"add_animation(far_r=far, frames=8, xlim=30) # Step 3: grow far\n",
"\n",
"# ours stereo adapt\n",
"new_clip()\n",
"add_animation(mid_r=0, far_r=0, frames=24, xlim=30)\n",
"\n",
"# other series\n",
"new_clip()\n",
"add_animation(ours_full=ours_full, frames=48, xlim=30)\n",
"new_clip()\n",
"add_animation(pano=pano, frames=48, xlim=30)\n",
"new_clip()\n",
"add_animation(nerf=nerf, frames=48, xlim=30)\n",
"\n",
"#os.system(f'ffmpeg -y -r 24 -i dynamic_bar/%04d.png dynamic_bar.avi')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = px.bar(\n",
" df1, # 绘图数据\n",
" x=list(times.keys()), # y轴\n",
" y=\"name\", # x轴\n",
" orientation='h', # 水平柱状图\n",
" #text=[[\"a\", \"tot\"], \"tot1\", \"tot2\", {\"fovea_l\": \"\", \"blend_r\": 13.5}] # 需要显示的数据\n",
")\n",
"fig.update_traces(textposition=\"outside\", showlegend=False, text=[[\"a\"]*11, [\"b\"]*11, [\"c\"]*11, [\"d\"]*11,[\"e\"]*11])\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>country</th>\n",
" <th>continent</th>\n",
" <th>year</th>\n",
" <th>lifeExp</th>\n",
" <th>pop</th>\n",
" <th>gdpPercap</th>\n",
" <th>iso_alpha</th>\n",
" <th>iso_num</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1952</td>\n",
" <td>28.801</td>\n",
" <td>8425333</td>\n",
" <td>779.445314</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1957</td>\n",
" <td>30.332</td>\n",
" <td>9240934</td>\n",
" <td>820.853030</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1962</td>\n",
" <td>31.997</td>\n",
" <td>10267083</td>\n",
" <td>853.100710</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1967</td>\n",
" <td>34.020</td>\n",
" <td>11537966</td>\n",
" <td>836.197138</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Afghanistan</td>\n",
" <td>Asia</td>\n",
" <td>1972</td>\n",
" <td>36.088</td>\n",
" <td>13079460</td>\n",
" <td>739.981106</td>\n",
" <td>AFG</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1699</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>1987</td>\n",
" <td>62.351</td>\n",
" <td>9216418</td>\n",
" <td>706.157306</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1700</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>1992</td>\n",
" <td>60.377</td>\n",
" <td>10704340</td>\n",
" <td>693.420786</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1701</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>1997</td>\n",
" <td>46.809</td>\n",
" <td>11404948</td>\n",
" <td>792.449960</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1702</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>2002</td>\n",
" <td>39.989</td>\n",
" <td>11926563</td>\n",
" <td>672.038623</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1703</th>\n",
" <td>Zimbabwe</td>\n",
" <td>Africa</td>\n",
" <td>2007</td>\n",
" <td>43.487</td>\n",
" <td>12311143</td>\n",
" <td>469.709298</td>\n",
" <td>ZWE</td>\n",
" <td>716</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1704 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" country continent year lifeExp pop gdpPercap iso_alpha \\\n",
"0 Afghanistan Asia 1952 28.801 8425333 779.445314 AFG \n",
"1 Afghanistan Asia 1957 30.332 9240934 820.853030 AFG \n",
"2 Afghanistan Asia 1962 31.997 10267083 853.100710 AFG \n",
"3 Afghanistan Asia 1967 34.020 11537966 836.197138 AFG \n",
"4 Afghanistan Asia 1972 36.088 13079460 739.981106 AFG \n",
"... ... ... ... ... ... ... ... \n",
"1699 Zimbabwe Africa 1987 62.351 9216418 706.157306 ZWE \n",
"1700 Zimbabwe Africa 1992 60.377 10704340 693.420786 ZWE \n",
"1701 Zimbabwe Africa 1997 46.809 11404948 792.449960 ZWE \n",
"1702 Zimbabwe Africa 2002 39.989 11926563 672.038623 ZWE \n",
"1703 Zimbabwe Africa 2007 43.487 12311143 469.709298 ZWE \n",
"\n",
" iso_num \n",
"0 4 \n",
"1 4 \n",
"2 4 \n",
"3 4 \n",
"4 4 \n",
"... ... \n",
"1699 716 \n",
"1700 716 \n",
"1701 716 \n",
"1702 716 \n",
"1703 716 \n",
"\n",
"[1704 rows x 8 columns]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = px.data.gapminder()\n",
"df"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import torch.nn.functional as nn_f\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../../')\n",
"sys.path.append(rootdir)\n",
"\n",
"from utils import img\n",
"from utils.view import *\n",
"\n",
"datadir = f\"{rootdir}/data/__thesis/__demo/compare\"\n",
"figs = ['fsnerf', 'gt', 'nerf']\n",
"crops = {\n",
" 'barbershop': [[406, 117, 100], [209, 170, 100]],\n",
" 'gas': [[195, 69, 100], [7, 305, 100]],\n",
" 'mc': [[395, 128, 100], [97, 391, 100]],\n",
" 'pabellon': [[208, 115, 100], [22, 378, 100]]\n",
"}\n",
"colors = torch.tensor([[0, 1, 0], [1, 1, 0]], dtype=torch.float)\n",
"border = 3\n",
"\n",
"for scene in crops:\n",
" images = img.load([f\"{datadir}/origin/{scene}_{fig}.png\" for fig in figs])\n",
" halfw = images.size(-1) // 2\n",
" halfh = images.size(-2) // 2\n",
" overlay = torch.zeros(1, 4, *images.shape[2:])\n",
" mask = torch.zeros(len(crops[scene]), *images.shape[2:], dtype=torch.bool)\n",
" for i, crop in enumerate(crops[scene]):\n",
" patches = images[..., crop[1]: crop[1] + crop[2], crop[0]: crop[0] + crop[2]].clone()\n",
" patches[..., :border, :] = colors[i, :, None, None]\n",
" patches[..., -border:, :] = colors[i, :, None, None]\n",
" patches[..., :, :border] = colors[i, :, None, None]\n",
" patches[..., :, -border:] = colors[i, :, None, None]\n",
" img.save(patches, [f\"{datadir}/crop/{scene}_{i}_{fig}.png\" for fig in figs])\n",
" mask[i,\n",
" crop[1] - border: crop[1] + crop[2] + border,\n",
" crop[0] - border: crop[0] + crop[2] + border] = True\n",
" mask[i,\n",
" crop[1]: crop[1] + crop[2],\n",
" crop[0]: crop[0] + crop[2]] = False\n",
" images[:, :, mask[i]] = colors[i, :, None]\n",
" img.save(images, [f\"{datadir}/overlay/{scene}_{fig}.png\" for fig in figs])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"orig_nbformat": 2,
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
......@@ -2,8 +2,11 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import sys\n",
"import os\n",
"import torch\n",
......@@ -12,25 +15,25 @@
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n",
"torch.cuda.set_device(0)\n",
"\n",
"torch.cuda.set_device(3)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n",
"from utils import img\n",
"from utils import device\n",
"import model\n",
"from data import Dataset\n",
"from utils import netio, img, device\n",
"from utils.view import *\n",
"from utils.type import PathLike\n",
"from components.fnr import FoveatedNeuralRenderer\n",
"from components.render import render\n",
"\n",
"\n",
"def load_net(path):\n",
" config = SphericalViewSynConfig()\n",
" config.from_id(os.path.splitext(path)[0])\n",
" config.sa['perturb_sample'] = False\n",
" net = config.create_net().to(device.default())\n",
" netio.load(path, net)\n",
" return net\n",
"def load_model(model_path: PathLike):\n",
" return model.deserialize(netio.load_checkpoint(model_path)[0],\n",
" raymarching_early_stop_tolerance=0.01,\n",
" raymarching_chunk_size_or_sections=None,\n",
" perturb_sample=False).eval().to(device.default())\n",
"\n",
"\n",
"def find_file(prefix):\n",
......@@ -40,6 +43,16 @@
" return None\n",
"\n",
"\n",
"def create_renderer(*nets, fov_scale=1.):\n",
" fov_list = [20, 45, 110]\n",
" for i in range(len(fov_list)):\n",
" fov_list[i] = length2fov(fov2length(fov_list[i]) * fov_scale)\n",
" res_list = [(256, 256), (256, 256), (256, 230)]\n",
" res_full = (1600, 1440)\n",
" return FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList(nets), res_full,\n",
" device=device.default())\n",
"\n",
"\n",
"def plot_images(images):\n",
" plt.figure(figsize=(12, 4))\n",
" plt.subplot(131)\n",
......@@ -49,64 +62,50 @@
" plt.subplot(133)\n",
" img.plot(images['layers_img'][2])\n",
" #plt.figure(figsize=(12, 12))\n",
" #img.plot(images['overlaid'])\n",
" # img.plot(images['overlaid'])\n",
" #plt.figure(figsize=(12, 12))\n",
" #img.plot(images['blended_raw'])\n",
" # img.plot(images['blended_raw'])\n",
" plt.figure(figsize=(12, 12))\n",
" img.plot(images['blended'])\n",
"\n",
"\n",
"def save_images(images, scene, i):\n",
" outputdir = '../__demo/mono/'\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" for layer in range(len(images[\"layers_img\"])):\n",
" img.save(images['layers_img'][layer], f'{outputdir}{scene}_{i:04d}({layer}).png')\n",
" img.save(images['blended'], f'{outputdir}{scene}_{i:04d}.png')\n",
" if \"overlaid\" in images:\n",
" img.save(images['overlaid'], f'{outputdir}{scene}_{i:04d}_overlaid.png')\n",
" if \"blended_raw\" in images:\n",
" img.save(images['blended_raw'], f'{outputdir}{scene}_{i:04d}_noCE.png')\n",
" if \"nerf\" in images:\n",
" img.save(images['nerf'], f'{outputdir}{scene}_{i:04d}_nerf.png')\n",
"\n",
"\n",
"scenes = {\n",
" 'classroom': 'classroom_all',\n",
" 'stones': 'stones_all',\n",
" 'barbershop': 'barbershop_all',\n",
" 'lobby': 'lobby_all'\n",
" 'classroom': '__new/classroom_all',\n",
" 'stones': '__new/stones_all',\n",
" 'barbershop': '__new/barbershop_all',\n",
" 'lobby': '__new/lobby_all',\n",
" \"bedroom2\": \"__captured/bedroom2\"\n",
"}\n",
"\n",
"fov_list = [20, 45, 110]\n",
"res_list = [(256, 256), (256, 256), (400, 360)]\n",
"res_full = (1600, 1440)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Set CUDA:0 as current device.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"scene = 'barbershop'\n",
"os.chdir(f'{rootdir}/data/__new/{scenes[scene]}')\n",
"\n",
"scene = \"bedroom2\"\n",
"os.chdir(f'{rootdir}/data/{scenes[scene]}')\n",
"print('Change working directory to ', os.getcwd())\n",
"\n",
"fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n",
"renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]),\n",
" res_full, device=device.default())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Change working directory to /home/dengnc/dvs/data/__new/barbershop_all\n",
"Load net from fovea200@snerffast4-rgb_e6_fc512x4_d1.20-6.00_s64_~p.pth ...\n",
"Load net from periph200@snerffast2-rgb_e6_fc256x4_d1.20-6.00_s32_~p.pth ...\n"
"fovea_net = load_model(find_file('fovea'))\n",
"periph_net = load_model(find_file('periph'))\n",
"nerf_net = load_model(find_file(\"nerf\"))"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"params = {\n",
" 'classroom': [\n",
......@@ -165,7 +164,7 @@
"\n",
"for i, param in enumerate(params[scene]):\n",
" view = Trans(torch.tensor(param[:3], device=device.default()),\n",
" torch.tensor(euler_to_matrix([-param[4], param[3], 0]), device=device.default()).view(3, 3))\n",
" torch.tensor(euler_to_matrix(-param[4], param[3], 0), device=device.default()).view(3, 3))\n",
" images = renderer(view, param[-2:], using_mask=False, ret_raw=True)\n",
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n",
......@@ -179,49 +178,45 @@
" #img.save(images['blended_raw'], f'{outputdir}{scene}_{i}.png')\n",
" else:\n",
" images = plot_images(images)\n"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_views(data_desc_file) -> Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n",
" view_centers = torch.tensor(\n",
" data_desc['view_centers'], device=device.default()).view(-1, 3)\n",
" view_rots = torch.tensor(\n",
" data_desc['view_rots'], device=device.default()).view(-1, 3, 3)\n",
" return Trans(view_centers, view_rots)\n",
"\n",
"\n",
"views = load_views('for_panorama_cvt.json')\n",
"print('Dataset loaded.')\n",
"for view_idx in range(views.size()[0]):\n",
" center = (0, 0)\n",
" images = renderer(views.get(view_idx), center, using_mask=True)\n",
" outputdir = 'panorama'\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Dataset loaded.\n"
"def load_views(data_desc_file) -> tuple[list[int], Trans]:\n",
" dataset = Dataset(data_desc_file)\n",
" return dataset.indices.tolist(),\\\n",
" Trans(dataset.centers, dataset.rots).to(device.default())\n",
"\n",
"\n",
"demos = [ # view_idx, center_x, center_y, fov_scale\n",
" [220, 30, 25, 0.7],\n",
" [235, 0, 130, 0.7],\n",
" [239, 70, 140, 0.7],\n",
" [841, -100, 160, 0.7]\n",
"]\n",
"indices, views = load_views('images.json')\n",
"for demo_idx in [0]:\n",
" view_idx = demos[demo_idx][0]\n",
" i = indices.index(view_idx)\n",
" center = tuple(demos[demo_idx][1:3])\n",
" renderer = create_renderer(fovea_net, periph_net, periph_net, fov_scale=demos[demo_idx][3])\n",
" images = renderer(views.get(i), center, using_mask=False)\n",
" #nerf_fovea = render(nerf_net, renderer.cam, views.get(i), None, batch_size=16384)[\"color\"]\n",
" #images[\"nerf\"] = nerf_fovea\n",
" plot_images(images)\n",
" #save_images(images, scene, view_idx)\n"
]
}
],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.5 64-bit ('base': conda)"
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
......@@ -233,15 +228,17 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.10.0"
},
"metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
}
},
"vscode": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
......
......@@ -2,14 +2,17 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:0 as current device.\n"
"Set CUDA:0 as current device.\n",
"Change working directory to /home/dengnc/Work/fov_nerf/data/__thesis/barbershop\n",
"Load model fovea.tar\n",
"Load model periph.tar\n"
]
}
],
......@@ -20,29 +23,25 @@
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"rootdir = os.path.abspath(sys.path[0] + '/../../')\n",
"sys.path.append(rootdir)\n",
"\n",
"torch.cuda.set_device(0)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"from data.spherical_view_syn import *\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n",
"from utils import img\n",
"from utils import device\n",
"from model import Model\n",
"from data import Dataset\n",
"from utils import netio, img, device\n",
"from utils.view import *\n",
"from utils.types import PathLike\n",
"from components.fnr import FoveatedNeuralRenderer\n",
"from components.render import render\n",
"\n",
"\n",
"def load_net(path):\n",
" config = SphericalViewSynConfig()\n",
" config.from_id(os.path.splitext(path)[0])\n",
" config.sa['perturb_sample'] = False\n",
" net = config.create_net().to(device.default())\n",
" netio.load(path, net)\n",
" return net\n",
"def load_model(model_path: PathLike):\n",
" print(\"Load model\", model_path)\n",
" return Model.load(model_path).eval().to(device.default())\n",
"\n",
"\n",
"def find_file(prefix):\n",
......@@ -52,6 +51,16 @@
" return None\n",
"\n",
"\n",
"def create_renderer(*nets, fov_scale=1.):\n",
" fov_list = [20, 45, 110]\n",
" for i in range(len(fov_list)):\n",
" fov_list[i] = length2fov(fov2length(fov_list[i]) * fov_scale)\n",
" res_list = [(256, 256), (256, 256), (256, 230)]\n",
" res_full = (1600, 1440)\n",
" return FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList(nets), res_full,\n",
" device=device.default())\n",
"\n",
"\n",
"def load_views(data_desc_file) -> Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n",
......@@ -124,38 +133,29 @@
"scenes = {\n",
" 'classroom': 'classroom_all',\n",
" 'stones': 'stones_all',\n",
" 'barbershop': 'barbershop_all',\n",
" 'barbershop': '__thesis/barbershop',\n",
" 'lobby': 'lobby_all'\n",
"}\n",
"\n",
"\n",
"fov_list = [20, 45, 110]\n",
"res_list = [(256, 256), (256, 256), (256, 230)]\n",
"res_full = (1600, 1440)\n"
"scene = \"barbershop\"\n",
"os.chdir(f'{rootdir}/data/{scenes[scene]}')\n",
"print('Change working directory to ', os.getcwd())\n",
"\n",
"fovea_net = load_model(find_file('fovea'))\n",
"periph_net = load_model(find_file('periph'))\n",
"renderer = create_renderer(fovea_net, periph_net, periph_net)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Change working directory to /home/dengnc/dvs/data/__new/barbershop_all\n",
"Load net from fovea200@snerffast4-rgb_e6_fc512x4_d1.20-6.00_s64_~p.pth ...\n",
"Load net from periph200@snerffast2-rgb_e6_fc256x4_d1.20-6.00_s32_~p.pth ...\n",
"barbershop 0 Saved\n",
"barbershop 1 Saved\n",
"barbershop 2 Saved\n",
"barbershop 3 Saved\n",
"barbershop 4 Saved\n",
"barbershop 0 Saved\n",
"barbershop 1 Saved\n",
"barbershop 2 Saved\n",
"barbershop 3 Saved\n",
"barbershop 4 Saved\n",
"barbershop 0 Saved\n",
"barbershop 1 Saved\n",
"barbershop 2 Saved\n",
......@@ -180,35 +180,21 @@
" [(0, 0, 0, 0, 0), (21, 150), (12, 150)]\n",
" ],\n",
" 'barbershop': [\n",
" #[(0, 0, 0, 0, 0), (106, -67), (90, -67)],\n",
" #[(-0.08247789757018531, -0.17165164843400083, -0.20644832805536045, -66.00281344384992, -4.354400833888114), (0, 0), (0, 0)],\n",
" #[(0, 0, 0, 0, 0), (-114, 10), (-126, 10)],\n",
" [(0, 0, 0, 25, 20), (189, -45), (173, -45)],\n",
" [(0, 0, 0, 25, 20), (-148, 130), (-163, 130)],\n",
" [(0.15, 0.15, 0, 43, 2), (9, 0), (-9, 0)],\n",
" [(0.15, 0, 0.15, -13, -5), (6, 0), (-6, 0)],\n",
" [(-0.15, 0.15, 0.15, -53, -21), (3, 0), (-3, 0)]\n",
" [(0, 0, 0, -25, 20), (189, -45), (173, -45)],\n",
" [(0, 0, 0, -25, 20), (-148, 130), (-163, 130)],\n",
" [(0.15, 0.15, 0, -43, 2), (9, 0), (-9, 0)],\n",
" [(0.15, 0, -0.15, 13, -5), (6, 0), (-6, 0)],\n",
" [(-0.15, 0.15, -0.15, 53, -21), (3, 0), (-3, 0)]\n",
" ]\n",
"}\n",
"\n",
"#for scene in ['classroom', 'lobby', 'barbershop']:\n",
"for scene in ['barbershop']:\n",
" os.chdir(f'{rootdir}/data/__new/{scenes[scene]}')\n",
" print('Change working directory to ', os.getcwd())\n",
"\n",
" fovea_net = load_net(find_file('fovea'))\n",
" periph_net = load_net(find_file('periph'))\n",
" renderer = FoveatedNeuralRenderer(fov_list, res_list,\n",
" nn.ModuleList([fovea_net, periph_net, periph_net]),\n",
" res_full, device=device.default())\n",
"\n",
" for mono_periph in range(0,4):\n",
"for mono_periph in range(3, 5):\n",
" for i, param in enumerate(params[scene]):\n",
" view = Trans(torch.tensor(param[0][:3], device=device.default()),\n",
" torch.tensor(euler_to_matrix([-param[0][4], param[0][3], 0]),\n",
" torch.tensor(euler_to_matrix(param[0][4], param[0][3], 0),\n",
" device=device.default()).view(3, 3))\n",
" eye_offset = torch.tensor([0.03, 0, 0], device=device.default())\n",
" left_view = Trans(view.trans_point(-eye_offset), view.r)\n",
" right_view = Trans(view.trans_point(eye_offset), view.r)\n",
" left_images, right_images = renderer(view, param[1], param[2],\n",
" stereo_disparity=0.06,\n",
" using_mask=True,\n",
......@@ -245,11 +231,8 @@
}
],
"metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
......@@ -263,7 +246,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.10.0"
},
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
......
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import sys\n",
"import os\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../../')\n",
"sys.path.append(rootdir)\n",
"\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"from model import Model\n",
"from data import Dataset\n",
"from utils import netio, img, device\n",
"from utils.view import *\n",
"from utils.types import *\n",
"from components.render import render\n",
"\n",
"\n",
"model: Model = None\n",
"dataset: Dataset = None\n",
"\n",
"\n",
"def load_model(path: PathLike):\n",
" ckpt_path = netio.find_checkpoint(Path(path))\n",
" ckpt = torch.load(ckpt_path)\n",
" model = Model.create(ckpt[\"args\"][\"model\"], ckpt[\"args\"][\"model_args\"])\n",
" model.load_state_dict(ckpt[\"states\"][\"model\"])\n",
" model.to(device.default()).eval()\n",
" return model\n",
"\n",
"\n",
"def load_dataset(path: PathLike):\n",
" return Dataset(path, color_mode=model.color, coord_sys=model.args.coord,\n",
" device=device.default())\n",
"\n",
"\n",
"def plot_images(images, rows, cols):\n",
" plt.figure(figsize=(20, int(20 / cols * rows)))\n",
" for r in range(rows):\n",
" for c in range(cols):\n",
" plt.subplot(rows, cols, r * cols + c + 1)\n",
" img.plot(images[r * cols + c])\n",
"\n",
"\n",
"def save_images(images, scene, i):\n",
" outputdir = f'{rootdir}/data/__demo/layers/'\n",
" os.makedirs(outputdir, exist_ok=True)\n",
" for layer in range(len(images)):\n",
" img.save(images[layer], f'{outputdir}{scene}_{i:04d}({layer}).png')\n",
"\n",
"scene = \"gas\"\n",
"model_path = f\"{rootdir}/data/__thesis/{scene}/_nets/train/snerf_fast\"\n",
"dataset_path = f\"{rootdir}/data/__thesis/{scene}/test.json\"\n",
"\n",
"\n",
"model = load_model(model_path)\n",
"dataset = load_dataset(dataset_path)\n",
"\n",
"\n",
"i = 6\n",
"cam = dataset.cam\n",
"view = Trans(dataset.centers[i], dataset.rots[i])\n",
"output = render(model, dataset.cam, view, \"colors\", \"weights\")\n",
"output_colors = output.colors * output.weights\n",
"\n",
"samples_per_layer = 4#model.core.samples_per_field\n",
"n_samples = model.args.n_samples\n",
"output_layers = [\n",
" output_colors[..., offset:offset+samples_per_layer, :].sum(-2)\n",
" for offset in range(0, n_samples, samples_per_layer)\n",
"]\n",
" \n",
"plot_images(output_layers, 8, 2)\n",
"#save_images(output_layers, scene, i)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.0 ('dvs')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
}
},
"vscode": {
"interpreter": {
"hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
......@@ -60,7 +60,7 @@
" enable_ce = True, output_res = None):\n",
" ipd = 0.06\n",
" layers_cam = [\n",
" CameraParam({\n",
" Camera({\n",
" 'fov': 110,\n",
" 'cx': 0.5,\n",
" 'cy': 0.5,\n",
......
......@@ -31,7 +31,7 @@
"from utils import device\n",
"from utils import view\n",
"from components.gen_final import GenFinal\n",
"from utils.perf import Perf\n",
"from utils.profile import Profiler\n",
"\n",
"\n",
"def load_net(path):\n",
......@@ -135,15 +135,15 @@
" torch.tensor([[0.0, 0.0, 0.0]], device=device.default()),\n",
" torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device.default())\n",
")\n",
"perf = Perf(True, True)\n",
"profile = Profiler(True, True)\n",
"rays_o, rays_d = gen.layer_cams[0].get_global_rays(test_view, True)\n",
"perf.checkpoint(\"GetRays\")\n",
"profile.checkpoint(\"GetRays\")\n",
"rays_o = rays_o.view(-1, 3)\n",
"rays_d = rays_d.view(-1, 3)\n",
"coords, pts, depths = fovea_net.sampler(rays_o, rays_d)\n",
"perf.checkpoint(\"Sample\")\n",
"profile.checkpoint(\"Sample\")\n",
"encoded = fovea_net.input_encoder(coords)\n",
"perf.checkpoint(\"Encode\")\n",
"profile.checkpoint(\"Encode\")\n",
"print(\"Rays:\", rays_d)\n",
"print(\"Spherical coords:\", coords)\n",
"print(\"Depths:\", depths)\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment