Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
......@@ -2,7 +2,6 @@ from .__common__ import *
from .nerf import NeRF
from .utils import load
from utils.misc import merge
class SNeRFX(NeRF):
......@@ -23,12 +22,12 @@ class SNeRFX(NeRF):
self.args0['net_samples'] = [val * 2 for val in self.args0['net_samples']]
return ret
@perf
@profile
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
return super()._render(samples, *outputs,
**merge(extra_args,
raymarching_chunk_size_or_sections=self.args["net_samples"]))
**extra_args |
{"raymarching_chunk_size_or_sections": self.args["net_samples"]})
@perf
@profile
def _multi_infer(self, inputs: NetInput, *outputs: str, chunk_id: int, **kwargs) -> NetOutput:
return self.cores[chunk_id](inputs, *outputs, **kwargs)
from pathlib import Path
from typing import Optional, Union
from .base import model_classes, BaseModel
from utils import netio
def get_class(model_class_name: str) -> Optional[type]:
def get_class(model_class_name: str) -> type | None:
return model_classes.get(model_class_name)
......@@ -31,5 +30,5 @@ def serialize(model: BaseModel) -> dict:
}
def load(path: Union[str, Path]) -> BaseModel:
def load(path: str | Path) -> BaseModel:
return deserialize(netio.load_checkpoint(path)[0])
......@@ -14,7 +14,7 @@ class VNeRF(NeRF):
def _create_core_unit(self):
return super()._create_core_unit(x_chns=self.x_encoder.out_dim + self.args['n_featdim'])
def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]:
def _input(self, samples: Samples, what: str) -> torch.Tensor | None:
if what == "x":
return torch.cat([
self.space.extract_voxel_embedding(samples.voxel_indices),
......
from utils.misc import dump_tensors_to_csv
from .__common__ import *
from .base import BaseModel
from typing import Callable
from .nerf import NeRF
from utils.voxels import linear_interp
class CNeRF(BaseModel):
TrainerClass = "TrainMultiScale"
class InterpSpace(object):
def __init__(self, space: Voxels, vidxs: Tensor, feats_fn: Callable[[Any], Tensor]) -> None:
super().__init__()
self.space = space
self.corner_indices, self.corners = space.get_corners(vidxs)
self.feats_on_corners = feats_fn(self.corners)
@profile
def interp(self, samples: Samples) -> Tensor:
with profile("Prepare for coarse interpolation"):
voxels = self.space.voxels[samples.interp_vidxs]
cidxs = self.corner_indices[samples.interp_vidxs] # (N, 8)
feats_on_corners = self.feats_on_corners[cidxs] # (N, 8, X)
# (N, 3) normed-coords in voxel
p = (samples.pts - voxels) / self.space.voxel_size + .5
with profile("Interpolate features"):
return linear_interp(p, feats_on_corners)
@property
def stage(self):
return self.args.get("stage", 0)
def __init__(self, args0: dict, args1: dict = None):
super().__init__(args0, args1)
self.sub_models = []
args0_for_submodel = {
key: value for key, value in args0.items()
if key != "sub_models" and key != "interp_on_coarse"
}
for i in range(len(self.args["sub_models"])):
self.args["sub_models"][i] = {
**args0_for_submodel,
**self.args["sub_models"][i]
}
self.sub_models.append(NeRF(self.args["sub_models"][i], args1))
self.sub_models = torch.nn.ModuleList(self.sub_models)
for i in range(self.stage):
print(f"__init__: freeze model {i}")
self.model(i).freeze()
def model(self, level: int) -> NeRF:
return self.sub_models[level]
def trigger_stage(self, stage: int):
print(f"trigger_stage: freeze model {stage - 1}")
self.model(stage - 1).freeze()
self.model(stage).space = self.model(stage - 1).space.clone()
self.args0["stage"] = stage
@profile
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
return self.model(samples.level).infer(*outputs, samples=samples, inputs=inputs, **kwargs)
def print_config(self):
s = f"{len(self.sub_models)} levels:\n"
for i, model in enumerate(self.sub_models):
s += f"Model {i}: {model.print_config()}\n"
return s
@torch.no_grad()
def split(self):
return self.model(self.stage).split()
def _input(self, samples: Samples, what: str) -> Tensor | None:
if what == "f":
if samples.level == 0:
return None
if samples.interp_space is None:
return self._infer_features(pts=samples.pts, level=samples.level - 1)
return samples.interp_space.interp(samples)
else:
return self.model(samples.level)._input(samples, what)
@profile
def _sample(self, data: InputData, **extra_args) -> Samples:
samples: Samples = self.model(data["level"])._sample(data, **extra_args)
samples.level = data["level"]
# TODO remove below
#dump_tensors_to_csv(f"/home/dengnc/dvs/data/classroom/_nets/ms_train_t0.8/_cnerf_ioc/{'train' if self.training else 'test'}.csv",
# samples.voxel_indices, data["rays_d"])
return samples
@profile
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
self._prepare_interp(samples, on_coarse=self.args.get("interp_on_coarse"))
return self.model(samples.level).renderer(self, samples, *outputs, **{
**self.model(samples.level).args, **extra_args})
def _infer_features(self, samples: Samples = None, **sample_data) -> NetOutput:
samples = samples or Samples(**sample_data)
if self.args.get("interp_on_coarse"):
self._prepare_interp(samples, on_coarse=True)
inputs = self.input(samples, "x", "f")
return self.infer("features", samples=samples, inputs=inputs)["features"]
def _prepare_interp(self, samples: Samples, on_coarse: bool):
if samples.level == 0:
return
if on_coarse:
interp_space = self.model(samples.level - 1).space
samples.interp_vidxs = interp_space.get_voxel_indices(samples.pts)
else:
interp_space = self.model(samples.level).space
samples.interp_vidxs = samples.voxel_indices
samples.interp_space = CNeRF.InterpSpace(interp_space, samples.interp_vidxs,
lambda corners: self._infer_features(
pts=corners, level=samples.level - 1))
def _after_load_state_dict(self) -> None:
a: torch.Tensor = None
return
print(list(self.model(0).named_parameters())[2])
exit()
import torch
from .__common__ import *
from .nerf import NeRF
class MNeRF(NeRF):
"""
Multi-scale NeRF
"""
TrainerClass = "TrainMultiScale"
def freeze(self, level: int):
for core in self.cores:
core.set_frozen(level, True)
def _create_core_unit(self):
nets = []
in_chns = self.x_encoder.out_dim
for core_params in self.args['core_params']:
nets.append(super()._create_core_unit(core_params, x_chns=in_chns))
in_chns = self.x_encoder.out_dim + core_params['nf']
return MultiNerf(nets)
@profile
def _sample(self, data: InputData, **extra_args) -> Samples:
samples = super()._sample(data, **extra_args)
samples.level = data["level"]
return samples
from .__common__ import *
from .mnerf import MNeRF
class MNeRFAdvance(MNeRF):
"""
Advanced Multi-scale NeRF
"""
TrainerClass = "TrainMultiScale"
def n_samples(self, level: int = -1) -> int:
return self.args["n_samples_list"][level]
def split(self):
self.args0["n_samples_list"] = [val * 2 for val in self.args["n_samples_list"]]
return super().split()
def _sample(self, data: InputData, **extra_args) -> Samples:
return super()._sample(data, **(extra_args | {"n_samples": self.n_samples(data["level"]) + 1}))
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
L = samples.level
steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][0]
curr_samples = samples[:, ::steps]
curr_samples.level = 0
curr_samples.features = None
for i in range(L):
render_out = super()._render(curr_samples, 'features', **extra_args)
next_steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][i + 1]
next_samples = samples[:, ::next_steps]
features = curr_samples.interpolate(next_samples, render_out['features'])
curr_samples = next_samples
curr_samples.level = i + 1
curr_samples.features = features
return super()._render(curr_samples, *outputs, **extra_args)
from .__common__ import *
from .nerf import NeRF
from .utils import load
class SNeRFX(NeRF):
def _preprocess_args(self):
self.args0["spherical"] = True
super()._preprocess_args()
if "net_samples" not in self.args:
n_nets = self.args.get("multi_nets", 1)
cut_by_space = load(self.args['cut_by']).space if "cut_by" in self.args else self.space
k = self.args["n_samples"] // cut_by_space.steps[0].item()
self.args0["net_samples"] = [val * k for val in cut_by_space.balance_cut(0, n_nets)]
self.args1["multi_nets"] = len(self.args["net_samples"])
@torch.no_grad()
def split(self):
ret = super().split()
self.args0['net_samples'] = [val * 2 for val in self.args0['net_samples']]
return ret
@profile
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
return super()._render(samples, *outputs,
**extra_args |
{"raymarching_chunk_size_or_sections": self.args["net_samples"]})
@profile
def _multi_infer(self, inputs: NetInput, *outputs: str, chunk_id: int, **kwargs) -> NetOutput:
return self.cores[chunk_id](inputs, *outputs, **kwargs)
from .__common__ import *
from .nerf import NeRF
class VNeRF(NeRF):
def _init_chns(self):
super()._init_chns(x=3)
def _init_space(self):
self.space = Space.create(self.args)
self.space.create_voxel_embedding(self.args['n_featdim'])
def _create_core_unit(self):
return super()._create_core_unit(x_chns=self.x_encoder.out_dim + self.args['n_featdim'])
def _input(self, samples: Samples, what: str) -> torch.Tensor | None:
if what == "x":
return torch.cat([
self.space.extract_voxel_embedding(samples.voxel_indices),
self._encode("x", samples.pts)
], dim=-1)
else:
return super()._input(samples, what)
from .__common__ import *
from .model import Model
class FsNeRF(Model):
class Args(Model.Args):
n_samples: int = 64
perturb_sampling: bool = False
with_radius: bool = False
n_fields: int = 1
depth: int = 8
width: int = 256
skips: list[int] = [4]
act: str = "relu"
ln: bool = False
xfreqs: int = 6
raw_noise_std: float = 0.
near: float = 1.
far: float = 10.
white_bg: bool = False
args: Args
def __init__(self, args: Args):
"""
Initialize a FS-NeRF model
:param args `Args`: arguments
"""
super().__init__(args)
# Initialize components
self._init_sampler()
self._init_encoders()
self._init_core()
self._init_renderer()
@profile
def forward(self, rays: Rays, *outputs: str, **args) -> ReturnData:
samples = self.sample(rays, **args)
x = self.encode(samples)
rgbd = self.infer(x)
return self.render(samples, rgbd, *outputs, **args)
def sample(self, rays: Rays, **kwargs) -> Samples:
args = self.args.merge_with(kwargs)
return self.sampler(rays, None, range=(args.near, args.far), mode="spherical_radius",
n_samples=args.n_samples,
perturb=args.perturb_sampling if self.training else False)
def encode(self, samples: Samples) -> torch.Tensor:
return self.x_encoder(samples.pts[..., -self.x_encoder.in_chns:])
def infer(self, x: torch.Tensor) -> torch.Tensor:
return self.core(x)
def render(self, samples: Samples, rgbd: torch.Tensor, *outputs: str, **kwargs) -> ReturnData:
args = self.args.merge_with(kwargs)
return self.renderer(samples, rgbd, *outputs, white_bg=args.white_bg,
raw_noise_std=args.raw_noise_std if self.training else 0.)
def _init_encoders(self):
self.x_encoder = FreqEncoder(self.sampler.out_chns["x"] - (not self.args.with_radius),
self.args.xfreqs, False)
def _init_core(self):
self.core = core.FsNeRF(self.x_encoder.out_chns, self.color.chns,
self.args.depth, self.args.width, self.args.skips,
self.args.act, self.args.ln, self.args.n_samples, self.args.n_fields)
def _init_sampler(self):
self.sampler = UniformSampler()
def _init_renderer(self):
self.renderer = VolumnRenderer()
from operator import itemgetter
from .__common__ import *
from utils import netio
from utils.args import BaseArgs
model_classes: dict[str, "Model"] = {}
class Model(nn.Module):
class Args(BaseArgs):
color: str = "rgb"
coord: str = "gl"
args: Args
color: Color
def __init__(self, args: Args):
super().__init__()
self.args = args
self.color = Color[self.args.color]
# stub method
def __call__(self, rays: Rays, *outputs: str, **args) -> ReturnData:
...
def forward(self, rays: Rays, *outputs: str, **args) -> ReturnData:
raise NotImplementedError()
@staticmethod
def get_class(typename: str) -> Type["Model"] | None:
return model_classes.get(typename)
@staticmethod
def create(typename: str, args: dict|Args) -> "Model":
ModelCls = Model.get_class(typename)
if ModelCls is None:
raise ValueError(f"Model {typename} is not found")
if isinstance(args, dict):
args = ModelCls.Args(**args)
return ModelCls(args)
@staticmethod
def load(path: PathLike) -> "Model":
ckpt = netio.load_checkpoint(Path(path))[0]
model_type, model_args = itemgetter("model", "model_args")(ckpt["args"])
model = Model.create(model_type, model_args)
model.load_state_dict(ckpt["states"]["model"])
return model
from .__common__ import *
from .base import BaseModel
from operator import itemgetter
from utils import math
from utils.misc import masked_scatter, merge
class NeRF(BaseModel):
TrainerClass = "TrainWithSpace"
SamplerClass = None
RendererClass = None
space: Union[Space, Voxels, Octree]
@property
def multi_nets(self) -> int:
return self.args.get("multi_nets", 1)
def __init__(self, args0: dict, args1: dict = None):
from .model import Model
class NeRF(Model):
class Args(Model.Args):
n_samples: int = 64
sample_mode: str = "xyz"
perturb_sampling: bool = False
depth: int = 8
width: int = 256
skips: list[int] = [4]
act: str = "relu"
ln: bool = False
color_decoder: str = "NeRF"
n_importance: int = 0
fine_depth: int = 8
fine_width: int = 256
fine_skips: list[int] = [4]
xfreqs: int = 10
dfreqs: int = 4
raw_noise_std: float = 0.
near: float = 1.
far: float = 10.
white_bg: bool = False
args: Args
def __init__(self, args: Args):
"""
Initialize a NeRF model
:param args0 `dict`: basic arguments
:param args1 `dict`: extra arguments, defaults to {}
:param args `dict`: arguments
"""
super().__init__(args0, args1)
super().__init__(args)
if args.sample_mode == "xyz" or args.sample_mode == "xyz_disp":
args.near = 0.1
# Initialize components
self._init_space()
self._init_sampler()
self._init_encoders()
self._init_core()
self._init_sampler()
self._init_renderer()
@perf
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
if len(self.cores) == 1:
return self.cores[0](inputs, *outputs, samples=samples, **kwargs)
return self._multi_infer(inputs, *outputs, samples=samples, **kwargs)
@torch.no_grad()
def split(self):
ret = self.space.split()
if 'n_samples' in self.args0:
self.args0['n_samples'] *= 2
if 'voxel_size' in self.args0:
self.args0['voxel_size'] /= 2
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] \
* self.args0["sample_step_ratio"]
if 'sample_step' in self.args0:
self.args0['sample_step'] /= 2
return ret
def _preprocess_args(self):
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] * self.args0["sample_step_ratio"]
if self.args0.get("spherical"):
sample_range = [1 / self.args0['depth_range'][0], 1 / self.args0['depth_range'][1]] \
if self.args0.get('depth_range') else [1, 0]
rot_range = [[-180, -90], [180, 90]]
self.args1['bbox'] = [
[sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
[sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
]
self.args1['sample_range'] = sample_range
if not self.args.get("net_bounds"):
self.register_temp("net_bounds", None)
self.args1["multi_nets"] = 1
else:
self.register_temp("net_bounds", torch.tensor(self.args["net_bounds"]))
self.args1["multi_nets"] = self.net_bounds.size(0)
def _init_chns(self, **chns):
super()._init_chns(**{
"x": self.args.get('n_featdim') or 3,
"d": 3 if self.args.get('encode_d') else 0,
**chns
})
def _init_space(self):
self.space = Space.create(self.args)
if self.args.get('n_featdim'):
self.space.create_embedding(self.args['n_featdim'])
if self.args.n_importance > 0:
self._init_cascade()
@profile
def forward(self, rays: Rays, *outputs: str, **args) -> ReturnData:
samples = self.sample(rays, **args)
x, d = self.encode(samples)
rgbd = self.infer(x, d)
return self.render(rays, samples, rgbd, *outputs, cascade=True, **args)
def sample(self, rays: Rays, **kwargs) -> Samples:
args = self.args.merge_with(kwargs)
return self.sampler(rays, None, range=(args.near, args.far),
mode=args.sample_mode, n_samples=args.n_samples,
perturb=args.perturb_sampling if self.training else False)
def encode(self, samples: Samples) -> tuple[torch.Tensor, torch.Tensor]:
return self.x_encoder(samples.pts), self.d_encoder(math.normalize(samples.dirs))
def infer(self, x: torch.Tensor, d: torch.Tensor, *, fine: bool = False) -> torch.Tensor:
if self.args.n_importance > 0 and fine:
return self.fine_core(x, d)
return self.core(x, d)
def render(self, rays: Rays, samples: Samples, rgbd: torch.Tensor, *outputs: str,
cascade: bool = False, **kwargs) -> ReturnData:
args = self.args.merge_with(kwargs)
if args.n_importance > 0 and cascade:
coarse_outputs = [item[7:] for item in outputs if item.startswith("coarse_")]
coarse_ret = self.renderer(samples, rgbd, "weights", *coarse_outputs,
white_bg=args.white_bg,
raw_noise_std=args.raw_noise_std if self.training else 0.)
samples = self.pdf_sampler(rays, None, samples.t_vals, coarse_ret["weights"][..., 0],
mode=args.sample_mode,
n_importance=args.n_importance,
perturb=args.perturb_sampling if self.training else False,
include_existed_samples=True)
x, d = self.encode(samples)
fine_rgbd = self.infer(x, d, fine=True)
return self.renderer(samples, fine_rgbd, *outputs, white_bg=args.white_bg,
raw_noise_std=args.raw_noise_std if self.training else 0.) | {
f"coarse_{key}": coarse_ret[key]
for key in coarse_outputs
if key in coarse_ret
}
return self.renderer(samples, rgbd, *outputs, white_bg=args.white_bg,
raw_noise_std=args.raw_noise_std if self.training else 0.)
def _init_encoders(self):
self.x_encoder = InputEncoder(self.chns("x"), self.args['encode_x'], cat_input=True)
self.d_encoder = InputEncoder(self.chns("d"), self.args['encode_d'])\
if self.chns("d") > 0 else None
self.x_encoder = FreqEncoder(self.sampler.out_chns["x"], self.args.xfreqs, True)
self.d_encoder = FreqEncoder(self.sampler.out_chns["d"], self.args.dfreqs, True)
def _init_core(self):
self.cores = self.create_multiple(self._create_core_unit, self.args.get("multi_nets", 1))
self.core = core.NeRF(self.x_encoder.out_chns, self.d_encoder.out_chns, self.color.chns,
self.args.depth, self.args.width, self.args.skips,
self.args.act, self.args.ln, self.args.color_decoder)
def _init_sampler(self):
if self.SamplerClass is None:
SamplerClass = Sampler
else:
SamplerClass = self.SamplerClass
self.sampler = SamplerClass(**self.args)
self.sampler = UniformSampler()
def _init_cascade(self):
self.pdf_sampler = PdfSampler()
self.fine_core = core.NeRF(self.x_encoder.out_chns, self.d_encoder.out_chns, self.color.chns,
self.args.fine_depth, self.args.fine_width,
self.args.fine_skips, self.args.act, self.args.ln,
self.args.color_decoder)
def _init_renderer(self):
if self.RendererClass is None:
if self.args.get("core") == "nerfadv":
RendererClass = DensityFirstVolumnRenderer
else:
RendererClass = VolumnRenderer
else:
RendererClass = self.RendererClass
self.renderer = RendererClass(**self.args)
def _create_core_unit(self, core_params: dict = None, **args):
core_params = core_params or self.args["core_params"]
if self.args.get("core") == "nerfadv":
return NerfAdvCore(**{
"x_chns": self.x_encoder.out_dim,
"d_chns": self.d_encoder.out_dim,
"density_chns": self.chns('density'),
"color_chns": self.chns('color'),
**core_params,
**args
})
else:
return NerfCore(**{
"x_chns": self.x_encoder.out_dim,
"density_chns": self.chns('density'),
"color_chns": self.chns('color'),
"d_chns": self.d_encoder.out_dim if self.d_encoder else 0,
**core_params,
**args
})
@perf
def _sample(self, data: InputData, **extra_args) -> Samples:
return self.sampler(*itemgetter("rays_o", "rays_d")(data), self.space,
**merge(self.args, extra_args))
@perf
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
if len(samples.size) == 1:
return self.infer(*outputs, samples=samples)
return self.renderer(self, samples, *outputs, **merge(self.args, extra_args))
def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]:
if what == "x":
if self.args.get('n_featdim'):
return self._encode("emb", self.space.extract_embedding(
samples.pts, samples.voxel_indices))
else:
return self._encode("x", samples.pts)
elif what == "d":
if self.d_encoder and samples.dirs is not None:
return self._encode("d", samples.dirs)
else:
return None
elif what == "f":
return None
else:
ValueError(f"Don't know how to process input \"{what}\"")
def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor:
if what == "x":
if self.args.get("spherical"):
sr = self.args['sample_range']
# scale val.r: [sr[0], sr[1]] -> [-PI/2, PI/2]
val = val.clone()
val[..., 0] = ((val[..., 0] - sr[0]) / (sr[1] - sr[0]) - .5) * math.pi
return self.x_encoder(val)
else:
return self.x_encoder(val * math.pi)
elif what == "emb":
return self.x_encoder(val * math.pi)
elif what == "d":
return self.d_encoder(val, angular=True)
else:
ValueError(f"Don't know how to encode \"{what}\"")
@perf
def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
ret: NetOutput = {}
for i, core in enumerate(self.cores):
selector = (samples.pts >= self.net_bounds[i, 0]
and samples.pts < self.net_bounds[i, 1]).all(-1)
partial_ret: NetOutput = core(inputs[selector], *outputs, samples=samples[selector],
**kwargs)
for key, value in partial_ret.items():
if key not in ret:
ret[key] = value.new_zeros(*inputs.shape, value.shape[-1])
ret[key] = masked_scatter(selector, value, ret[key])
return ret
class NSVF(NeRF):
SamplerClass = VoxelSampler
self.renderer = VolumnRenderer()
import sys
from utils import math, nn
from utils.types import *
from utils.misc import union, split
from utils.profile import profile
from .sampler import Sampler, PdfSampler, VoxelSampler
from .input_encoder import InputEncoder, IntegratedPosEncoder
from .renderer import VolumnRenderer, DensityFirstVolumnRenderer
from .space import Space, Voxels, Octree
from .core import NerfCore, NerfAdvCore, MultiNerf
\ No newline at end of file
from .core import *
from .sampler import *
from .input_encoder import *
from .renderer import *
from .space import *
import re
import torch
from typing import Iterable, Tuple
from .generic import *
from utils.misc import union, split
from utils.type import NetInput, NetOutput
from utils.module import Module
from utils.samples import Samples
class NerfCore(Module):
class NeRF(Module):
def __init__(self, *, x_chns, density_chns, color_chns, nf, n_layers,
d_chns=0, d_nf=0, act='relu', skips=[], with_layer_norm=False,
density_out_act='relu', color_out_act='sigmoid', f_chns=0):
density_out_act='relu', color_out_act='sigmoid', feature_layer=False):
super().__init__()
self.input_f = f_chns > 0
self.core_field = FcBlock(in_chns=x_chns + f_chns, out_chns=None, nf=nf, n_layers=n_layers,
self.x_chns = x_chns
self.d_chns = d_chns
self.field = FcBlock(in_chns=x_chns, out_chns=None, nf=nf, n_layers=n_layers,
skips=skips, act=act, with_ln=with_layer_norm)
self.density_out = FcLayer(nf, density_chns, density_out_act, with_ln=False) \
if density_chns > 0 else None
if color_chns == 0:
self.color_out = None
elif d_chns > 0:
self.feature_layer = feature_layer and FcLayer(nf, nf, with_ln=False)
self.color_out = FcBlock(in_chns=nf + d_chns, out_chns=color_chns,
nf=d_nf or nf // 2, n_layers=1,
act=act, out_act=color_out_act, with_ln=with_layer_norm)
......@@ -31,20 +22,30 @@ class NerfCore(Module):
self.color_out = FcLayer(nf, color_chns, color_out_act, with_ln=False)
self.with_dir = False
def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput:
ret = {}
if features is None:
features = self.core_field(union(inputs.x, inputs.f) if self.input_f else inputs.x)
if 'features' in outputs:
ret['features'] = features
def forward(self, inputs: NetInput, *outputs: str, field_out: torch.Tensor = None, **kwargs) -> NetOutput:
ret = NetOutput()
if field_out is None:
field_out = self.field(inputs.x)
if 'field_out' in outputs:
ret.field_out = field_out
if 'densities' in outputs and self.density_out:
ret['densities'] = self.density_out(features)
ret.densities = self.density_out(field_out)
if 'colors' in outputs and self.color_out:
if self.with_dir:
features = union(features, inputs.d)
ret['colors'] = self.color_out(features)
if self.feature_layer:
h = self.feature_layer(field_out)
h = union(h, inputs.d)
else:
h = field_out
ret.colors = self.color_out(h)
return ret
def get_exporter(self):
return ModelExporter(self.infer, "densities", "colors", x=[self.x_chns], d=[self.d_chns])
def infer(self, x: torch.Tensor, d: torch.Tensor = None, f: torch.Tensor = None):
return tuple(self._forward(NetInput(x, d, f), "colors", "densities").values())
class NerfAdvCore(Module):
......@@ -129,7 +130,7 @@ class NerfAdvCore(Module):
if 'colors' in outputs or 'specluars' in outputs or 'diffuses' in outputs:
if 'densities' in ret:
valid_mask = ret['densities'][..., 0].detach() >= 1e-4
indices: Tuple[torch.Tensor, ...] = valid_mask.nonzero(as_tuple=True)
indices: tuple[torch.Tensor, ...] = valid_mask.nonzero(as_tuple=True)
inputs, features = inputs[indices], features[indices]
else:
indices = None
......
class IntegratedPosEncoder(InputEncoder):
def __init__(self, chns, L, shape: str, cat_input=False):
super().__init__(chns)
self.shape = shape
def _lift_gaussian(self, d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
r_var: torch.Tensor, diag: bool):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_sq = d**2
d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)
if diag:
d_outer_diag = d_sq
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def _conical_frustum_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, base_radius: float,
diag: bool, stable: bool = True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
"""
if stable:
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
(3 * mu**2 + hw**2)**2)
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def _cylinder_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
Args:
d: torch.float32 3-vector, the axis of the cylinder
t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
"""
t_mean = (t0 + t1) / 2
r_var = radius**2 / 4
t_var = (t1 - t0)**2 / 12
return self._lift_gaussian(d, t_mean, t_var, r_var, diag)
def cast_rays(self, t_vals: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor,
rays_r: torch.Tensor, diag: bool = True):
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it.
Args:
t_vals: float array, the "fencepost" distances along the ray.
rays_o: float array, the ray origin coordinates.
rays_d: float array, the ray direction vectors.
radii: float array, the radii (base radii for cones) of the rays.
ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
diag: boolean, whether or not the covariance matrices should be diagonal.
Returns:
a tuple of arrays of means and covariances.
"""
t0 = t_vals[..., :-1]
t1 = t_vals[..., 1:]
if self.shape == 'cone':
gaussian_fn = self._conical_frustum_to_gaussian
elif self.shape == 'cylinder':
gaussian_fn = self._cylinder_to_gaussian
else:
assert False
means, covs = gaussian_fn(rays_d, t0, t1, rays_r, diag)
means = means + rays_o[..., None, :]
return means, covs
def integrated_pos_enc(x_coord: tuple[torch.Tensor, torch.Tensor], min_deg: int, max_deg: int,
diag: bool = True):
"""Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].
Args:
x_coord: a tuple containing: x, torch.ndarray, variables to be encoded. Should
be in [-pi, pi]. x_cov, torch.ndarray, covariance matrices for `x`.
min_deg: int, the min degree of the encoding.
max_deg: int, the max degree of the encoding.
diag: bool, if true, expects input covariances to be diagonal (full
otherwise).
Returns:
encoded: torch.ndarray, encoded variables.
"""
if diag:
x, x_cov_diag = x_coord
scales = torch.tensor([2**i for i in range(min_deg, max_deg)], device=x.device)[:, None]
shape = list(x.shape[:-1]) + [-1]
y = torch.reshape(x[..., None, :] * scales, shape)
y_var = torch.reshape(x_cov_diag[..., None, :] * scales**2, shape)
else:
x, x_cov = x_coord
num_dims = x.shape[-1]
basis = torch.cat([
2**i * torch.eye(num_dims, device=x.device)
for i in range(min_deg, max_deg)
], 1)
y = torch.matmul(x, basis)
# Get the diagonal of a covariance matrix (ie, variance). This is equivalent
# to jax.vmap(torch.diag)((basis.T @ covs) @ basis).
y_var = (torch.matmul(x_cov, basis) * basis).sum(-2)
return math.expected_sin(
torch.cat([y, y + 0.5 * math.pi], -1),
torch.cat([y_var] * 2, -1))[0]
# @torch.jit.script
def intepolate_calc_weight(x, corners):
return x * corners * 2 - x - corners + 1
class MultiresHashEncoder(InputEncoder):
fast_op = True
t_ind: torch.dtype
layers: int
coarse_levels: int
layers_hashsize: list[int]
layers_res: torch.Tensor
"""Tensor(L, D)"""
local_corners: torch.Tensor
"""Tensor(C, D)"""
layers_hashoffset: torch.Tensor
"""Tensor(L+1)"""
hashtable: torch.nn.parameter.Parameter
"""Parameter(T, F)"""
def __init__(self, chns: int, layers: int, log2_hashsize: int, features: int,
res0: int | list[int], scale_up: float = 2.0):
super().__init__(chns, layers * features, (0., 1.))
res0 = torch.tensor([res0] * chns if isinstance(res0, int) else res0)
self.layers = layers
self.features = features
self.scale_up = scale_up
self.max_hashsize = 2 ** log2_hashsize
self.t_ind = torch.int if self.fast_op else torch.long
layers_res: list[torch.Tensor] = []
self.layers_hashsize: list[int] = []
self.coarse_levels = 0
layers_hashoffset: list[int] = [0]
for i in range(layers):
layers_res.append((res0 * scale_up ** i).to(self.t_ind))
if layers_res[-1].max() > self.max_hashsize ** (1 / 3)\
or layers_res[-1].prod() > self.max_hashsize:
self.layers_hashsize.append(self.max_hashsize)
else:
self.layers_hashsize.append(layers_res[-1].prod().item())
self.coarse_levels = i + 1
layers_hashoffset.append(layers_hashoffset[-1] + self.layers_hashsize[-1])
self.register_temp("layers_res", torch.stack(layers_res, 0))
self.register_temp("layers_hashoffset", torch.tensor(layers_hashoffset, dtype=self.t_ind))
self.register_temp("local_corners", split_voxels_local(1, 2, dims=chns) + .5)
# Initialize the hash table entries using the uniform distribution U(−10^−4, 10^−4) to provide
# a small amount of randomness while encouraging initial predictions close to zero [muller2022instant]
self.hashtable = torch.nn.parameter.Parameter(
(torch.rand(layers_hashoffset[-1], features, device=self.device) - .5))
@profile
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode inputs using multi-resolution hash encoder [muller2022instant]
:param x `Tensor(N..., D)`: D-dim inputs
:return `Tensor(N..., LF)`: encoded outputs
"""
if self.fast_op:
N_, D = x.shape[:-1], x.shape[-1]
return multires_hash_encode(self.layers, self.coarse_levels, self.layers_res,
self.layers_hashoffset, x.reshape(-1, D), self.hashtable)\
.transpose(0, 1).reshape(*N_, -1)
@profile("Calculate corners")
def calc_corners(x) -> tuple[torch.Tensor, torch.Tensor]:
grid_pos = x.unsqueeze(-2) * (self.layers_res - 1) # (N..., L, D)
grid_pos.unsqueeze_(-2) # (N..., L, 1, D)
grid_lo = torch.floor(grid_pos)
grid_pos.sub_(grid_lo)
corners = (grid_lo + self.local_corners).long().min(self.layers_res.unsqueeze(-2) - 1)
# (N..., L, C, D)
return grid_pos, corners
grid_pos, corners = calc_corners(x)
# (N..., L, 1, D), (N..., L, C, D)
@profile("Calculate encoded")
def calc_encoded(level: int) -> torch.Tensor:
if level < self.coarse_levels:
idx = to_flat_indices(corners[..., level, :, :], self.layers_res[level, None])
else:
idx = self._fast_hash(corners[..., level, :, :]) % self.max_hashsize
idx.add_(self.layers_hashoffset[level, None])
return self._linear_interp(grid_pos[..., level, :, :], self.hashtable[idx])
result = torch.stack([calc_encoded(level) for level in range(self.layers)], dim=-2)
# (N..., L, X)
return result.flatten(-2)
def _linear_interp(self, x: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
"""
[summary]
:param x `Tensor(N..., L, 1, D)`: [description]
:param corner_values `Tensor(N..., L, C, X)`: [description]
:return `Tensor(N..., L, X): [description]
:rtype: [type]
"""
weights = (x * self.local_corners * 2 - x - self.local_corners + 1).prod(-1, keepdim=True)
# (N..., L, C, 1)
return (weights * corner_values).sum(-2) # (N..., L, X)
def extra_repr(self) -> str:
return f"{self.in_chns} -> {self.out_chns}({self.layers}x{self.features})"\
f", resolution={self.layers_res[0].tolist()}*{self.scale_up}^L"\
f", max_hashsize={self.max_hashsize}"
@profile
def _fast_hash(self, grid_pos: torch.Tensor) -> torch.Tensor:
"""
Perform fast hash according to instant-ngp
:param grid_pos `Tensor(N..., D)`: integer grid positions
:return `Tensor(N...)`: hash values
"""
if grid_pos.shape[-1] > 7:
raise ValueError("fast_hash can only hash up to 7 dimensions.")
# While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
# and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
# coordinates. [muller2022instant]
primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737]
result = grid_pos[..., 0] * primes[0]
for i in range(1, grid_pos.shape[-1]):
result.bitwise_xor_(grid_pos[..., i] * primes[i])
return result
class LayeredMultiresHashEncoder(InputEncoder):
use_cpp = False
layers: int
coarse_levels: int
layers_res: torch.Tensor
"""Tensor(L, D)"""
local_corners: torch.Tensor
"""Tensor(C, D)"""
layers_hashsize: list[int]
layers_hashoffset: list[int]
t_ind: torch.dtype
def __init__(self, chns: int, layers: int, log2_hashsize: int, features: int,
res0: int | list[int], scale_up: float = 2.0, parts: int = 64):
super().__init__(chns, layers * features, (0., 1.))
res0 = torch.tensor([res0] * chns if isinstance(res0, int) else res0)
self.layers = layers
self.features = features
self.scale_up = scale_up
self.max_hashsize = 2 ** log2_hashsize // parts
self.t_ind = torch.int if self.use_cpp else torch.long
layers_res: list[torch.Tensor] = []
self.layers_hashsize: list[int] = []
self.layers_usehash: list[bool] = []
self.coarse_levels = 0
layers_hashoffset: list[int] = [0]
for i in range(layers):
layers_res.append(res0 if i == 0 else (layers_res[-1] * scale_up).to(self.t_ind))
if layers_res[-1].max() > self.max_hashsize ** (1 / 3)\
or layers_res[-1].prod() > self.max_hashsize:
self.layers_hashsize.append(self.max_hashsize)
else:
self.layers_hashsize.append(layers_res[-1].prod().item())
self.coarse_levels = i + 1
layers_hashoffset.append(layers_hashoffset[-1] + self.layers_hashsize[-1])
self.register_temp("layers_res", torch.stack(layers_res, 0))
self.register_temp("layers_hashoffset", torch.tensor(layers_hashoffset, dtype=self.t_ind))
# Initialize the hash table entries using the uniform distribution U(−10^−4, 10^−4) to provide
# a small amount of randomnesddaddadwss while encouraging initial predictions close to zero [muller2022instant]
self.hashtable = torch.nn.parameter.Parameter(
(torch.rand(parts, layers_hashoffset[-1], features, device=self.device) - .5))
self.register_temp("local_corners", split_voxels_local(1, 2, dims=chns) + .5)
@profile
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode inputs using multi-resolution hash encoder [muller2022instant]
:param x `Tensor(N..., P, D)`: D-dim inputs
:return `Tensor(N..., P, LF)`: encoded outputs
"""
if self.use_cpp:
N_, P, D = x.shape[:-2], x.shape[-2], x.shape[-1]
return torch.stack([
multires_hash_encode(self.layers, self.coarse_levels, self.layers_res,
self.layers_hashoffset, x[..., i, :].reshape(-1, D),
self.hashtable[i]).reshape(*N_, -1)
for i in range(P)
], dim=-2)
@profile("Calculate corners")
def calc_corners(x) -> tuple[torch.Tensor, torch.Tensor]:
grid_pos = x.unsqueeze(-2) * (self.layers_res - 1) # (N..., P, L, D)
grid_pos.unsqueeze_(-2) # (N..., P, L, 1, D)
grid_lo = torch.floor(grid_pos)
grid_pos.sub_(grid_lo)
corners = (grid_lo + self.local_corners).long().min(self.layers_res.unsqueeze(-2) - 1)
# (N..., L, C, D)
return grid_pos, corners
grid_pos, corners = calc_corners(x)
# (N..., P, L, 1, D), (N..., P, L, C, D)
@profile("Calculate encoded")
def calc_encoded(level: int) -> torch.Tensor:
if level < self.coarse_levels:
idx = to_flat_indices(corners[..., level, :, :], self.layers_res[level, None])
else:
idx = self._fast_hash(corners[..., level, :, :]) % self.max_hashsize
idx.add_(self.layers_hashoffset[level, None])
part_idx = torch.arange(x.shape[-2], device=x.device)[:, None].broadcast_to(idx.shape)
return self._linear_interp(grid_pos[..., level, :, :], self.hashtable[part_idx, idx])
result = torch.stack([calc_encoded(level) for level in range(self.layers)], dim=-2)
# (N..., L, X)
return result.flatten(-2)
def _linear_interp(self, x: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
"""
[summary]
:param x `Tensor(N..., L, 1, D)`: [description]
:param corner_values `Tensor(N..., L, C, X)`: [description]
:return `Tensor(N..., L, X): [description]
:rtype: [type]
"""
weights = (x * self.local_corners * 2 - x - self.local_corners + 1).prod(-1, keepdim=True)
# (N..., L, C, 1)
return (weights * corner_values).sum(-2) # (N..., L, X)
def extra_repr(self) -> str:
return f"{self.in_chns} -> {self.out_chns}({self.layers}x{self.features})"\
f", resolution={self.layers_res[0].tolist()}*{self.scale_up}^L"\
f", max_hashsize={self.max_hashsize}"
@profile
def _fast_hash(self, grid_pos: torch.Tensor) -> torch.Tensor:
"""
Perform fast hash according to instant-ngp
:param grid_pos `Tensor(N..., D)`: integer grid positions
:return `Tensor(N...)`: hash values
"""
if grid_pos.shape[-1] > 7:
raise ValueError("fast_hash can only hash up to 7 dimensions.")
# While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
# and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
# coordinates. [muller2022instant]
primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737]
result = grid_pos[..., 0] * primes[0]
for i in range(1, grid_pos.shape[-1]):
result.bitwise_xor_(grid_pos[..., i] * primes[i])
return result
from itertools import cycle
from typing import Set
from ..__common__ import *
from model.model import BaseModel
def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
"""
Calculate energies from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: energies which block light rays
"""
if raw_noise_std > 0:
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
densities = densities + torch.normal(0.0, raw_noise_std, densities.size())
return densities * dists[..., None]
def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
"""
Calculate alphas from densities inferred by model.
:param densities `Tensor(N..., 1)`: model's output densities
:param dists `Tensor(N...)`: integration times
:param raw_noise_std `float`: the noise std used to egularize network during training (prevents
floater artifacts), defaults to 0, means no noise is added
:return `Tensor(N..., 1)`: alphas
"""
energies = density2energy(densities, dists, raw_noise_std)
return 1.0 - torch.exp(-energies)
class AlphaComposition(Module):
def __init__(self):
super().__init__()
def forward(self, colors, alphas, bg=None):
"""
[summary]
:param colors `Tensor(N, P, C)`: [description]
:param alphas `Tensor(N, P, 1)`: [description]
:param bg `Tensor([N, ]C)`: [description], defaults to None
:return `Tensor(N, C)`: [description]
"""
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + math.tiny, dim=-2)
one_minus_alpha = torch.cat([
torch.ones_like(one_minus_alpha[..., :1, :]),
one_minus_alpha
], dim=-2)
weights = alphas * one_minus_alpha # (N, P, 1)
# (N, C), computed weighted color of each sample along each ray.
final_color = torch.sum(weights * colors, dim=-2)
# To composite onto a white background, use the accumulated alpha map.
if bg is not None:
# Sum of weights along each ray. This value is in [0, 1] up to numerical error.
acc_map = torch.sum(weights, -1)
final_color = final_color + bg * (1. - acc_map[..., None])
return {
'color': final_color,
'weights': weights,
}
class VolumnRenderer(Module):
class States:
kernel: BaseModel
samples: Samples
early_stop_tolerance: float
outputs: Set[str]
hit_mask: torch.Tensor
N: int
P: int
device: torch.device
colors: torch.Tensor
densities: torch.Tensor
energies: torch.Tensor
weights: torch.Tensor
cum_energies: torch.Tensor
exp_energies: torch.Tensor
tot_evaluations: dict[str, int]
chunk: tuple[slice, slice]
cum_chunk: tuple[slice, slice]
cum_last: tuple[slice, slice]
chunk_id: int
@property
def start(self) -> int:
return self.chunk[1].start
@property
def end(self) -> int:
return self.chunk[1].stop
def __init__(self, kernel: BaseModel, samples: Samples, early_stop_tolerance: float,
outputs: Set[str]) -> None:
self.kernel = kernel
self.samples = samples
self.early_stop_tolerance = early_stop_tolerance
self.outputs = outputs
N, P = samples.size
self.device = self.samples.device
self.hit_mask = samples.voxel_indices != -1 # (N, P) | bool
self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
self.densities = torch.zeros(N, P, 1, device=samples.device)
self.energies = torch.zeros(N, P, 1, device=samples.device)
self.weights = torch.zeros(N, P, 1, device=samples.device)
self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
self.tot_evaluations = {}
self.N, self.P = N, P
self.chunk_id = -1
def n_hits(self, index: int | slice = None) -> int:
if not isinstance(self.hit_mask, torch.Tensor):
if index is not None:
return self.N * self.colors[:, index].shape[1]
return self.N * self.P
if index is None:
return self.hit_mask.count_nonzero().item()
return self.hit_mask[:, index].count_nonzero().item()
def accumulate_tot_evaluations(self, key: str, n: int):
if key not in self.tot_evaluations:
self.tot_evaluations[key] = 0
self.tot_evaluations[key] += n
def next_chunk(self, *, length=None, end=None):
start = 0 if not hasattr(self, "chunk") else self.end
length = length or self.P
end = min(end or start + length, self.P)
self.chunk = slice(None), slice(start, end)
self.cum_chunk = slice(None), slice(start + 1, end + 1)
self.cum_last = slice(None), slice(start, start + 1)
self.chunk_id += 1
return self
def put(self, key: str, values: torch.Tensor, indices: tuple[torch.Tensor, torch.Tensor] | tuple[slice, slice]):
if not hasattr(self, key):
new_tensor = torch.zeros(self.N, self.P, values.shape[-1], device=self.device)
setattr(self, key, new_tensor)
tensor: torch.Tensor = getattr(self, key)
# if isinstance(indices[0], torch.Tensor):
# tensor.index_put_(indices, values)
# else:
tensor[indices] = values
def __init__(self, **kwargs):
super().__init__()
@profile
def forward(self, kernel: BaseModel, samples: Samples, *outputs: str,
raymarching_early_stop_tolerance: float = 0,
raymarching_chunk_size_or_sections: int | list[int] = None,
**kwargs) -> ReturnData:
"""
Perform volumn rendering.
:param kernel `BaseModel`: render kernel
:param samples `Samples(N, P)`: samples
:param outputs `str...`: items should be contained in the result dict.
Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
:param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
Should between 0 and 1 (0 means no early stop). Defaults to 0
:param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
0 and `None` means not splitting the raymarching process. Defaults to `None`
:return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
"""
if samples.size[1] == 0:
print("VolumnRenderer.forward(): # of samples is zero")
return None
infer_outputs = set()
for key in outputs:
if key == "color":
infer_outputs.add("colors")
infer_outputs.add("densities")
elif key == "specular":
infer_outputs.add("speculars")
infer_outputs.add("densities")
elif key == "diffuse":
infer_outputs.add("diffuses")
infer_outputs.add("densities")
elif key == "depth":
infer_outputs.add("densities")
else:
infer_outputs.add(key)
with profile("Prepare states object"):
s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance,
infer_outputs)
if not raymarching_chunk_size_or_sections:
raymarching_chunk_size_or_sections = [s.P]
elif isinstance(raymarching_chunk_size_or_sections, int) and \
raymarching_chunk_size_or_sections > 0:
raymarching_chunk_size_or_sections = [
math.ceil(s.P / raymarching_chunk_size_or_sections)
]
with profile("Run forward chunks"):
if isinstance(raymarching_chunk_size_or_sections, list):
chunk_sections = raymarching_chunk_size_or_sections
for chunk_samples in cycle(chunk_sections):
self._forward_chunk(s.next_chunk(length=chunk_samples))
if s.end >= s.P:
break
else:
chunk_size = -raymarching_chunk_size_or_sections
chunk_hits = s.n_hits(0)
for i in range(1, s.P):
n_hits = s.n_hits(i)
if chunk_hits + n_hits > chunk_size:
self._forward_chunk(s.next_chunk(end=i))
n_hits = s.n_hits(i)
chunk_hits = 0
chunk_hits += n_hits
self._forward_chunk(s.next_chunk())
with profile("Set return data"):
ret = {}
for key in outputs:
if key == 'color':
ret['color'] = torch.sum(s.colors * s.weights, 1)
elif key == 'depth':
ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
elif key == 'diffuse' and hasattr(s, "diffuses"):
ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
elif key == 'specular' and hasattr(s, "speculars"):
ret['specular'] = torch.sum(s.speculars * s.weights, 1)
elif key == 'layers':
ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
elif key == 'states':
ret['states'] = s
else:
if hasattr(s, key):
ret[key] = getattr(s, key)
return ret
@profile
def _calc_weights(self, s: States):
"""
Calculate weights of samples in composited outputs
:param s `States`: states
:param start `int`: chunk's start
:param end `int`: chunk's end
"""
s.energies[s.chunk] = density2energy(s.densities[s.chunk], s.samples.dists[s.chunk])
s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+ s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
@profile
def _apply_early_stop(self, s: States):
"""
Stop rays whose accumulated opacity are larger than a threshold
:param s `States`: s
:param end `int`: chunk's end
"""
if s.end < s.P and s.early_stop_tolerance > 0 and isinstance(s.hit_mask, torch.Tensor):
rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
s.hit_mask[rays_to_stop, s.end:] = 0
@profile
def _forward_chunk(self, s: States) -> int:
if isinstance(s.hit_mask, torch.Tensor):
fi_idxs: tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return
fi_idxs[1].add_(s.start)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
else:
fi_idxs = s.chunk
fi_outputs = s.kernel.infer(*s.outputs, samples=s.samples[fi_idxs], chunk_id=s.chunk_id)
for key, value in fi_outputs.items():
s.put(key, value, fi_idxs)
self._calc_weights(s)
self._apply_early_stop(s)
class DensityFirstVolumnRenderer(VolumnRenderer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _forward_chunk(self, s: VolumnRenderer.States) -> int:
fi_idxs: tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N')
fi_idxs[1].add_(s.start)
if fi_idxs[0].size(0) == 0:
s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
return
# fi_* means "filtered" by hit mask
fi_samples = s.samples[fi_idxs] # N -> N'
# For all valid samples: encode X
density_inputs = s.kernel.input(fi_samples, "x", "f") # (N', Ex)
# Infer densities (shape)
density_outputs = s.kernel.infer('densities', 'features', samples=fi_samples,
inputs=density_inputs, chunk_id=s.chunk_id)
s.put('densities', density_outputs['densities'], fi_idxs)
s.accumulate_tot_evaluations("densities", fi_idxs[0].size(0))
self._calc_weights(s)
self._apply_early_stop(s)
# Remove samples whose weights are less than a threshold
s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0
# Update "filtered" tensors
fi_mask = s.hit_mask[fi_idxs]
fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N"
fi_samples = s.samples[fi_idxs] # N -> N"
fi_features = density_outputs['features'][fi_mask]
color_inputs = s.kernel.input(fi_samples, "d") # (N")
color_inputs.x = density_inputs.x[fi_mask]
# Infer colors (appearance)
outputs = s.outputs.copy()
if 'densities' in outputs:
outputs.remove('densities')
color_outputs = s.kernel.infer(*outputs, samples=fi_samples, inputs=color_inputs,
chunk_id=s.chunk_id, features=fi_features)
# if s.chunk_id == 0:
# fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
# elif s.chunk_id == 1:
# fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
# elif s.chunk_id == 2:
# fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
# else:
# fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
for key, value in color_outputs.items():
s.put(key, value, fi_idxs)
s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
class VoxelSampler(Module):
def __init__(self, *, sample_step: float, **kwargs):
"""
Initialize a VoxelSampler module
:param perturb_sample: perturb the sample depths
:param step_size: step size
"""
super().__init__()
self.sample_step = sample_step
def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *,
perturb_sample: bool, **kwargs) -> tuple[Samples, torch.Tensor]:
"""
[summary]
:param rays_o `Tensor(N, 3)`: rays' origin positions
:param rays_d `Tensor(N, 3)`: rays' directions
:param step_size `float`: gap between samples along a ray
:return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
:return `Tensor(N)`: valid rays mask
"""
intersections = space_module.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
n_rays = rays_o.size(0)
ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N')
hits = intersections.hits
min_depths = intersections.min_depths
max_depths = intersections.max_depths
voxel_indices = intersections.voxel_indices
rays_near_depth = min_depths[:, :1] # (N', 1)
rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1)
rays_length = rays_far_depth - rays_near_depth
rays_steps = (rays_length / self.sample_step).ceil().long()
rays_step_size = rays_length / rays_steps
max_steps = rays_steps.max().item()
rays_step = torch.arange(max_steps, device=rays_o.device,
dtype=torch.float)[None].repeat(n_rays, 1) # (N', P)
invalid_samples_mask = rays_step >= rays_steps
samples_min_depth = rays_near_depth + rays_step * rays_step_size
samples_depth = samples_min_depth + rays_step_size \
* (torch.rand_like(samples_min_depth) if perturb_sample else 0.5) # (N', P)
samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P)
samples_voxel_index = voxel_indices[
ray_index_list[:, None],
torch.searchsorted(max_depths, samples_depth)
] # (N', P)
samples_depth[invalid_samples_mask] = math.huge
samples_dist[invalid_samples_mask] = 0
samples_voxel_index[invalid_samples_mask] = -1
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * samples_depth[..., None],
dirs=rays_d.expand(-1, max_steps, -1),
depths=samples_depth,
dists=samples_dist,
voxel_indices=samples_voxel_index
), valid_rays_mask
@profile
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
space: Space, *, perturb_sample: bool, **kwargs) -> tuple[Samples, torch.Tensor]:
"""
[summary]
:param rays_o `Tensor(N, 3)`: [description]
:param rays_d `Tensor(N, 3)`: [description]
:param step_size `float`: [description]
:return `Samples(N, P)`: [description]
"""
with profile("Ray intersect"):
intersections = space.ray_intersect(rays_o, rays_d, 100)
valid_rays_mask = intersections.hits > 0
rays_o = rays_o[valid_rays_mask]
rays_d = rays_d[valid_rays_mask]
intersections = intersections[valid_rays_mask] # (N) -> (N')
if intersections.size == 0:
return None, valid_rays_mask
else:
with profile("Inverse CDF sampling"):
min_depth = intersections.min_depths
max_depth = intersections.max_depths
pts_idx = intersections.voxel_indices
dists = max_depth - min_depth
tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1)
probs = dists / tot_dists
steps = tot_dists[:, 0] / self.sample_step
# sample points and use middle point approximation
sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
sampled_indices = sampled_indices.long()
invalid_idx_mask = sampled_indices.eq(-1)
sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
sampled_depths.masked_fill_(invalid_idx_mask, math.huge)
rays_o, rays_d = rays_o[:, None], rays_d[:, None]
return Samples(
pts=rays_o + rays_d * sampled_depths[..., None],
dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
depths=sampled_depths,
dists=sampled_dists,
voxel_indices=sampled_indices
), valid_rays_mask
from .nerf import NeRF
from .fs_nerf import FsNeRF
__all__ = ["NeRF", "FsNeRF"]
from ..__common__ import *
__all__ = ["ColorDecoder", "BasicColorDecoder", "NeRFColorDecoder"]
class ColorDecoder(nn.Module):
def __init__(self, f_chns: int, d_chns: int, color_chns: int):
super().__init__({"f": f_chns, "d": d_chns}, {"color": color_chns})
# stub method for type hint
def __call__(self, f: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
...
def forward(self, f: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@staticmethod
def create(f_chns: int, d_chns: int, color_chns: int, type: str, args: dict[str, Any]) -> "ColorDecoder":
return getattr(sys.modules[__name__], f"{type}ColorDecoder")(
f_chns=f_chns, d_chns=d_chns, color_chns=color_chns, **args)
class BasicColorDecoder(ColorDecoder):
def __init__(self, f_chns: int, color_chns: int, out_act: str = "sigmoid", **kwargs):
super().__init__(f_chns, 0, color_chns)
self.net = nn.FcLayer(f_chns, color_chns, out_act)
def forward(self, f: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
return self.net(f)
class NeRFColorDecoder(ColorDecoder):
def __init__(self, f_chns: int, d_chns: int, color_chns: int, act: str = "relu",
out_act: str = "sigmoid", with_ln: bool = False, **kwargs):
super().__init__(f_chns, d_chns, color_chns)
self.feature_layer = nn.FcLayer(f_chns, f_chns, with_ln=with_ln)
self.net = nn.FcBlock(f_chns + d_chns, color_chns, 1,
f_chns // 2, [], act, out_act, with_ln)
def forward(self, f: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
return self.net(union(self.feature_layer(f), d))
from ..__common__ import *
__all__ = ["DensityDecoder"]
class DensityDecoder(nn.Module):
def __init__(self, f_chns: int, density_chns: int, **kwargs):
super().__init__({"f": f_chns}, {"density": density_chns})
self.net = nn.FcLayer(f_chns, density_chns)
# stub method for type hint
def __call__(self, f: torch.Tensor) -> torch.Tensor:
...
def forward(self, f: torch.Tensor) -> torch.Tensor:
return self.net(f)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment