Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
from typing import Union
from pathlib import Path
def get_dataset_desc_path(path: Union[Path, str]) -> Path:
if isinstance(path, str):
path = Path(path)
if path.suffix != ".json":
path = Path(f"{path}.json")
return path
def get_data_path(dataset_desc_path: Path, path_pattern: str) -> str:
root = dataset_desc_path.parent
if "/" not in path_pattern:
path_pattern = f"{dataset_desc_path.stem}/{path_pattern}"
return str(root / path_pattern)
\ No newline at end of file
import torch
import torch.nn.functional as nn_f
from typing import Dict, Tuple, Union
from pathlib import Path
from utils import img
from utils import color
from .dataset import Dataset
class ViewDataset(Dataset):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam ```object```: camera intrinsic parameters\n
view_centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
view_images ```Tensor(N, 3, H, W)```: images of views\n
view_depths ```Tensor(N, H, W)```: depths of views\n
"""
class Chunk(object):
def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
color: int, **kwargs):
"""
[summary]
:param dataset `ViewDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self.id = id
self.dataset = dataset
self.indices = chunk_data['indices']
self.centers = chunk_data['centers']
self.rots = chunk_data['rots']
self.color = color
self.n_views = self.indices.size(0)
self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
self.colors = self.depths = self.bins = None
self.colors_cpu = self.depths_cpu = self.bins_cpu = None
self.loaded = False
def release(self):
self.colors = self.depths = self.bins = None
self.loaded = False
def load(self):
#print("chunk load")
try:
if self.dataset.image_path and self.colors_cpu is None:
images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.color)
if self.dataset.res != list(images.shape[-2:]):
images = nn_f.interpolate(images, self.dataset.res)
self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
if self.colors_cpu is not None:
self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True)
if self.dataset.depth_path and self.depths_cpu is None:
depths = self.dataset._decode_depth_images(
img.load(self.depth_path % i for i in self.indices))
if self.dataset.res != list(depths.shape[-2:]):
depths = nn_f.interpolate(depths, self.dataset.res)
self.depths_cpu = depths.flatten(0, 2)
if self.depths_cpu is not None:
self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True)
if self.dataset.bins_path and self.bins_cpu is None:
bins = img.load([self.dataset.bins_path % i for i in self.indices])
if self.dataset.res != list(bins.shape[-2:]):
bins = nn_f.interpolate(bins, self.dataset.res)
self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2)
if self.bins_cpu is not None:
self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True)
torch.cuda.current_stream(self.dataset.device).synchronize()
self.loaded = True
except Exception as ex:
print(ex)
exit(-1)
def __len__(self):
return self.n_views * self.n_pixels_per_view
def __getitem__(self, idx):
if not self.loaded:
self.load()
view_idx = idx // self.n_pixels_per_view
pix_idx = idx % self.n_pixels_per_view
global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
rays_o = self.centers[view_idx]
rays_d = self.dataset.cam_rays[pix_idx][:, None] # (N, 1, 3)
r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3)
rays_d = torch.matmul(rays_d, r)[:, 0] # (N, 3)
data = {
'idx': global_idx,
'rays_o': rays_o,
'rays_d': rays_d,
'level': self.dataset.level
}
if self.colors is not None:
data['color'] = self.colors[idx]
if self.depths is not None:
data['depth'] = self.depths[idx]
if self.bins is not None:
data['bin'] = self.bins[idx]
#data['view_idx'] = view_idx
#data['pix_idx'] = pix_idx
return data
def __init__(self, desc: dict, desc_path: Path, *,
load_images: bool = True,
load_depths: bool = False,
load_bins: bool = False,
res: Tuple[int, int] = None,
views_to_load: Union[range, torch.Tensor] = None,
device: torch.device = None,
**kwargs):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images
- cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
- view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param load_depths ```bool```: whether load depth images and return in __getitem__()
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
super().__init__(desc, desc_path, res=res, views_to_load=views_to_load, device=device,
load_images=load_images, load_depths=load_depths, load_bins=load_bins)
def _decode_depth_images(self, input):
disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
return torch.reciprocal(disp_val)
def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor],
load_images: bool, load_depths: bool, load_bins: bool):
super()._load_desc(res, views_to_load)
self.image_path = load_images and self._get_data_path("view")
self.depth_path = load_depths and self._get_data_path("depth")
self.bins_path = load_bins and self._get_data_path("bins")
self.cam_rays = self.cam.get_local_rays(flatten=True)
from math import ceil
cdf = [2.2, 3.5, 3.6, 3.7, 4.0]
bins = []
part = 1
offset = 0
for i in range(len(cdf)):
if cdf[i] >= part:
bins.append(i + 1 - offset)
offset = i + 1
part = int(cdf[i]) + 1
print(bins)
\ No newline at end of file
import itertools
import torch
from utils import math
from tqdm import tqdm
mar0 = 1. / 48.
mar_slope = 0.0275
weights = torch.tensor([1., .25, .25], device="cuda") # (L) Also define levels here
# VR configuration
res = (1440, 1600) # (hor, ver)
fov = 110 # degrees
distance = .5 * res[1] / math.tan(.5 * math.radians(fov))
ratio = res[0] / res[1] # hor / ver
K = 360. / math.pi / distance
L = len(weights)
min_sum = math.inf
x_of_min_sum = None
e_of_min_sum = None
s_of_min_sum = None
D_of_min_sum = None
for x1 in tqdm(itertools.product(*([range(1, res[0] - 2)] * (L - 3))),
total=int(math.pow(res[0] - 1, L - 3))):
if any([x1[i] <= x1[i - 1] for i in range(1, len(x1))]):
continue
if not x1:
x2 = torch.stack(torch.meshgrid(
[torch.arange(1, res[0], device="cuda")] * 2), -1).flatten(0, 1)
x = x2[(x2[:, 1:] > x2[:, :-1]).any(-1)]
else:
x2 = torch.stack(torch.meshgrid(
[torch.arange(x1[-1] + 1, res[0], device="cuda")] * 2), -1).flatten(0, 1)
x = torch.cat([
torch.tensor([x1], device="cuda").expand(x2.shape[0], -1),
x2[(x2[:, 1:] <= x2[:, :-1]).any(-1)]
], -1)
tan_e = x / distance # (N, L - 1)
e = tan_e.arctan().rad2deg() # (N, L - 1)
mar = mar0 + mar_slope * e # (N, L - 1)
s = torch.cat([e.new_ones(e.shape[0], 1), mar * (1. + tan_e.pow(2.)) / K], -1) # (N, L)
D = torch.cat([x * 2. / s[:, :-1], res[1] / s[:, -1:]], -1) # (N, L)
P = D * D
P[:, -1] *= ratio
weighted_sum = (P * weights).sum(-1)
min_value, min_indice = weighted_sum.min(0)
min_value = min_value.item()
min_indice = min_indice.item()
if min_value < min_sum:
min_sum = min_value
x_of_min_sum = x[min_indice]
e_of_min_sum = e[min_indice]
s_of_min_sum = s[min_indice]
D_of_min_sum = D[min_indice]
print(min_sum)
print("x:", x_of_min_sum)
print("e:", e_of_min_sum)
print("s:", s_of_min_sum)
print("D:", D_of_min_sum)
gt.png

81.3 KB

import sys
import tty
import termios
import select
import time
def readchar():
r, w, e = select.select([sys.stdin], [], [])
if sys.stdin in r:
ch = sys.stdin.read(1)
return ch
fd = sys.stdin.fileno()
oldtty = termios.tcgetattr(fd)
newtty = termios.tcgetattr(fd)
try:
termios.tcsetattr(fd, termios.TCSANOW, newtty)
tty.setraw(fd)
tty.setcbreak(fd)
while True:
print('Wait')
time.sleep(0.1)
key = readchar()
print('%d' % ord(key))
if key == 'w':
print('w')
if key == 'q':
break
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, oldtty)
\ No newline at end of file
import torch
from torch import Tensor
from typing import List, Dict, Union, Optional, Any
from configargparse import ArgumentParser
from modules import *
from utils.type import InputData, NetInput, NetOutput, ReturnData
from utils.perf import perf
from utils.samples import Samples
from utils import math, config, nn
from utils.types import *
from utils.profile import profile
\ No newline at end of file
import importlib
import os
from .utils import *
# Automatically import model files this directory
package_dir = os.path.dirname(__file__)
package = os.path.basename(package_dir)
for file in os.listdir(package_dir):
path = os.path.join(package_dir, file)
if file.startswith('_') or file.startswith('.') or file == "utils.py":
continue
if file.endswith('.py') or os.path.isdir(path):
model_name = file[:-3] if file.endswith('.py') else file
importlib.import_module(f'{package}.{model_name}')
import sys
from inspect import isclass
from .model import Model, model_classes
from .nerf import NeRF
from .fs_nerf import FsNeRF
__all__ = ["Model", "NeRF", "FsNeRF"]
# Register all model classes
for item in __all__:
var = getattr(sys.modules[__name__], item)
if isclass(var) and issubclass(var, Model):
model_classes[item] = var
\ No newline at end of file
import torch
from torch import Tensor
from modules import *
from utils.type import *
from utils.profile import profile
\ No newline at end of file
import importlib
import os
from .utils import *
# Automatically import model files this directory
package_dir = os.path.dirname(__file__)
package = os.path.basename(package_dir)
for file in os.listdir(package_dir):
path = os.path.join(package_dir, file)
if file.startswith('_') or file.startswith('.') or file == "utils.py":
continue
if file.endswith('.py') or os.path.isdir(path):
model_name = file[:-3] if file.endswith('.py') else file
importlib.import_module(f'{package}.{model_name}')
import json
from typing import Optional
from torch import Tensor
from utils import color
from utils.misc import print_and_log
from utils.samples import Samples
from utils.module import Module
from utils.type import NetInput, NetOutput, InputData, ReturnData
from utils.perf import perf
from utils.nn import Module
from utils.types import *
from utils.profile import profile
model_classes = {}
......@@ -39,11 +36,10 @@ class BaseModel(Module, metaclass=BaseModelMeta):
self._preprocess_args()
self._init_chns()
def chns(self, name: str, value: int = None) -> Optional[int]:
def chns(self, name: str, value: int = None) -> int:
if value is not None:
self._chns[name] = value
else:
return self._chns.get(name, 1)
return self._chns.get(name, 1)
def input(self, samples: Samples, *whats: str) -> NetInput:
all = ["x", "d", "f"]
......@@ -65,7 +61,7 @@ class BaseModel(Module, metaclass=BaseModelMeta):
"""
raise NotImplementedError()
@perf
@profile
def forward(self, data: InputData, *outputs: str, **extra_args) -> ReturnData:
"""
Perform rendering for given rays.
......@@ -82,7 +78,7 @@ class BaseModel(Module, metaclass=BaseModelMeta):
return ret
def print_config(self):
print_and_log(json.dumps(self.args))
return json.dumps(self.args)
def _preprocess_args(self):
pass
......@@ -93,7 +89,7 @@ class BaseModel(Module, metaclass=BaseModelMeta):
self._chns["color"] = color.chns(self.color)
self._chns.update(chns)
def _input(self, samples: Samples, what: str) -> Optional[Tensor]:
def _input(self, samples: Samples, what: str) -> Tensor | None:
raise NotImplementedError()
def _sample(self, data: InputData, **extra_args) -> Samples:
......
......@@ -4,7 +4,7 @@ from .base import BaseModel
from typing import Callable
from .nerf import NeRF
from utils.voxels import trilinear_interp
from utils.voxels import linear_interp
class CNeRF(BaseModel):
......@@ -19,17 +19,17 @@ class CNeRF(BaseModel):
self.corner_indices, self.corners = space.get_corners(vidxs)
self.feats_on_corners = feats_fn(self.corners)
@perf
@profile
def interp(self, samples: Samples) -> Tensor:
with perf("Prepare for coarse interpolation"):
with profile("Prepare for coarse interpolation"):
voxels = self.space.voxels[samples.interp_vidxs]
cidxs = self.corner_indices[samples.interp_vidxs] # (N, 8)
feats_on_corners = self.feats_on_corners[cidxs] # (N, 8, X)
# (N, 3) normed-coords in voxel
p = (samples.pts - voxels) / self.space.voxel_size + .5
with perf("Interpolate features"):
return trilinear_interp(p, feats_on_corners)
with profile("Interpolate features"):
return linear_interp(p, feats_on_corners)
@property
def stage(self):
......@@ -62,21 +62,22 @@ class CNeRF(BaseModel):
self.model(stage).space = self.model(stage - 1).space.clone()
self.args0["stage"] = stage
@perf
@profile
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
return self.model(samples.level).infer(*outputs, samples=samples, inputs=inputs, **kwargs)
def print_config(self):
s = f"{len(self.sub_models)} levels:\n"
for i, model in enumerate(self.sub_models):
print(f"Model {i} =====>")
model.print_config()
s += f"Model {i}: {model.print_config()}\n"
return s
@torch.no_grad()
def split(self):
return self.model(self.stage).split()
def _input(self, samples: Samples, what: str) -> Optional[Tensor]:
def _input(self, samples: Samples, what: str) -> Tensor | None:
if what == "f":
if samples.level == 0:
return None
......@@ -86,7 +87,7 @@ class CNeRF(BaseModel):
else:
return self.model(samples.level)._input(samples, what)
@perf
@profile
def _sample(self, data: InputData, **extra_args) -> Samples:
samples: Samples = self.model(data["level"])._sample(data, **extra_args)
samples.level = data["level"]
......@@ -95,7 +96,7 @@ class CNeRF(BaseModel):
# samples.voxel_indices, data["rays_d"])
return samples
@perf
@profile
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
self._prepare_interp(samples, on_coarse=self.args.get("interp_on_coarse"))
return self.model(samples.level).renderer(self, samples, *outputs, **{
......
......@@ -22,7 +22,7 @@ class MNeRF(NeRF):
in_chns = self.x_encoder.out_dim + core_params['nf']
return MultiNerf(nets)
@perf
@profile
def _sample(self, data: InputData, **extra_args) -> Samples:
samples = super()._sample(data, **extra_args)
samples.level = data["level"]
......
from .__common__ import *
from .mnerf import MNeRF
from utils.misc import merge
class MNeRFAdvance(MNeRF):
"""
......@@ -19,7 +17,7 @@ class MNeRFAdvance(MNeRF):
return super().split()
def _sample(self, data: InputData, **extra_args) -> Samples:
return super()._sample(data, **merge(extra_args, n_samples=self.n_samples(data["level"]) + 1))
return super()._sample(data, **(extra_args | {"n_samples": self.n_samples(data["level"]) + 1}))
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
L = samples.level
......
from pathlib import Path
from operator import itemgetter
from modules.input_encoder import FreqEncoder
from .__common__ import *
from .base import BaseModel
from utils import math
from utils.misc import masked_scatter
class NeRF(BaseModel):
TrainerClass = "TrainWithSpace"
SamplerClass = None
RendererClass = None
space: Space | Voxels | Octree
@property
def multi_nets(self) -> int:
return self.args.get("multi_nets", 1)
def __init__(self, args0: dict, args1: dict = None):
"""
Initialize a NeRF model
:param args0 `dict`: basic arguments
:param args1 `dict`: extra arguments, defaults to {}
"""
super().__init__(args0, args1)
# Initialize components
self._init_space()
self._init_encoders()
self._init_core()
self._init_sampler()
self._init_renderer()
@profile
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
if len(self.cores) == 1:
return self.cores[0](inputs, *outputs, samples=samples, **kwargs)
return self._multi_infer(inputs, *outputs, samples=samples, **kwargs)
@torch.no_grad()
def split(self):
ret = self.space.split()
if 'n_samples' in self.args0:
self.args0['n_samples'] *= 2
if 'voxel_size' in self.args0:
self.args0['voxel_size'] /= 2
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] \
* self.args0["sample_step_ratio"]
if 'sample_step' in self.args0:
self.args0['sample_step'] /= 2
return ret
def export_onnx(self, path: str | Path, batch_size: int = None):
self.cores[0].get_exporter().export_onnx(path / "core_0.onnx", batch_size)
def _preprocess_args(self):
if "sample_step_ratio" in self.args0:
self.args1["sample_step"] = self.args0["voxel_size"] * self.args0["sample_step_ratio"]
if self.args0.get("spherical"):
sample_range = [
1 / self.args0['depth_range'][0],
1 / self.args0['depth_range'][1]
] if 'depth_range' in self.args0 else [1, 0]
rot_range = [[-180, -90], [180, 90]]
self.args1['bbox'] = [
[sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
[sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
]
self.args1['sample_range'] = sample_range
if not self.args.get("multi_nets"):
if not self.args.get("net_bounds"):
self.register_temp("net_bounds", None)
self.args1["multi_nets"] = 1
else:
self.register_temp("net_bounds", torch.tensor(self.args["net_bounds"]))
self.args1["multi_nets"] = self.net_bounds.size(0)
def _init_chns(self, **chns):
super()._init_chns(**{
"x": self.args.get('n_featdim') or 3,
"d": 3 if self.args.get('encode_d') else 0,
**chns
})
def _init_space(self):
self.space = Space.create(self.args)
if self.args.get('n_featdim'):
self.space.create_embedding(self.args['n_featdim'])
def _init_encoders(self):
if isinstance(self.args["encode_x"], list):
self.x_encoder = InputEncoder.create(self.chns("x"), *self.args["encode_x"])
else:
self.x_encoder = FreqEncoder(self.chns("x"), self.args['encode_x'], cat_input=True)
if self.args.get("encode_d"):
if isinstance(self.args["encode_d"], list):
self.d_encoder = InputEncoder.create(self.chns("d"), *self.args["encode_d"])
else:
self.d_encoder = FreqEncoder(self.chns("d"), self.args['encode_d'], angular=True)
else:
self.d_encoder = None
def _init_core(self):
self.cores = self.create_multiple(self._create_core_unit, self.args.get("multi_nets", 1))
def _init_sampler(self):
if self.SamplerClass is None:
SamplerClass = Sampler
else:
SamplerClass = self.SamplerClass
self.sampler = SamplerClass(**self.args)
def _init_renderer(self):
if self.RendererClass is None:
if self.args.get("core") == "nerfadv":
RendererClass = DensityFirstVolumnRenderer
else:
RendererClass = VolumnRenderer
else:
RendererClass = self.RendererClass
self.renderer = RendererClass(**self.args)
def _create_core_unit(self, core_params: dict = None, **args):
core_params = core_params or self.args["core_params"]
if self.args.get("core") == "nerfadv":
return NerfAdvCore(**{
"x_chns": self.x_encoder.out_dim,
"d_chns": self.d_encoder.out_dim,
"density_chns": self.chns('density'),
"color_chns": self.chns('color'),
**core_params,
**args
})
else:
return NerfCore(**{
"x_chns": self.x_encoder.out_dim,
"density_chns": self.chns('density'),
"color_chns": self.chns('color'),
"d_chns": self.d_encoder.out_dim if self.d_encoder else 0,
**core_params,
**args
})
@profile
def _sample(self, data: InputData, **extra_args) -> Samples:
return self.sampler(*itemgetter("rays_o", "rays_d")(data), self.space,
**self.args | extra_args)
@profile
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
if len(samples.size) == 1:
return self.infer(*outputs, samples=samples)
return self.renderer(self, samples, *outputs, **self.args | extra_args)
def _input(self, samples: Samples, what: str) -> torch.Tensor | None:
if what == "x":
if self.args.get('n_featdim'):
return self._encode("emb", self.space.extract_embedding(
samples.pts, samples.voxel_indices))
else:
return self._encode("x", samples.pts)
elif what == "d":
if self.d_encoder and samples.dirs is not None:
return self._encode("d", samples.dirs)
else:
return None
elif what == "f":
return None
else:
ValueError(f"Don't know how to process input \"{what}\"")
def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor:
if what == "x":
# Normalize x according to the encoder's range requirement using space's bounding box
if self.space.bbox is not None:
val = (val - self.space.bbox[0]) / (self.space.bbox[1] - self.space.bbox[0])
val = val * (self.x_encoder.in_range[1] - self.x_encoder.in_range[0])\
+ self.x_encoder.in_range[0]
return self.x_encoder(val)
elif what == "emb":
return self.x_encoder(val)
elif what == "d":
return self.d_encoder(val)
else:
ValueError(f"Don't know how to encode \"{what}\"")
@profile
def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
ret: NetOutput = {}
for i, core in enumerate(self.cores):
selector = (samples.pts >= self.net_bounds[i, 0]
and samples.pts < self.net_bounds[i, 1]).all(-1)
partial_ret: NetOutput = core(inputs[selector], *outputs, samples=samples[selector],
**kwargs)
for key, value in partial_ret.items():
if key not in ret:
ret[key] = value.new_zeros(*inputs.shape, value.shape[-1])
ret[key] = masked_scatter(selector, value, ret[key])
return ret
class NSVF(NeRF):
SamplerClass = VoxelSampler
from .__common__ import *
from .nerf import NeRF
from utils.misc import merge
class SnerfFast(NeRF):
def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, chunk_id: int, **kwargs) -> NetOutput:
inputs = inputs or self.input(samples)
ret = self.cores[chunk_id](inputs, *outputs)
return {
key: value.reshape(*inputs.shape, -1)
for key, value in ret.items()
key: value.reshape(*samples.size, -1)
for key, value in self.cores[chunk_id](inputs, *outputs).items()
}
def _preprocess_args(self):
......@@ -28,7 +25,7 @@ class SnerfFast(NeRF):
density_chns=self.chns('density') * self.samples_per_part,
color_chns=self.chns('color') * self.samples_per_part)
def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]:
def _input(self, samples: Samples, what: str) -> torch.Tensor | None:
if what == "x":
return self._encode("x", samples.pts[..., -self.chns("x"):]).flatten(1, 2)
elif what == "d":
......@@ -37,17 +34,25 @@ class SnerfFast(NeRF):
else:
return super()._input(samples, what)
def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor:
if what == "x":
# Normalize x according to the encoder's range requirement using space's bounding box
bbox = self.space.bbox[:, -self.chns("x"):]
val = (val - bbox[0]) / (bbox[1] - bbox[0])
val = val * (self.x_encoder.in_range[1] - self.x_encoder.in_range[0])\
+ self.x_encoder.in_range[0]
return self.x_encoder(val)
return super()._encode(what, val)
def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
return self._render(samples, *outputs,
**merge(extra_args,
raymarching_chunk_size_or_sections=[self.samples_per_part]))
return super()._render(samples, *outputs,
**extra_args |
{"raymarching_chunk_size_or_sections", [self.samples_per_part]})
def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, chunk_id: int, **kwargs) -> NetOutput:
ret = self.cores[chunk_id](inputs, *outputs)
return {
key: value.reshape(*samples.size, -1)
for key, value in ret.items()
}
def _sample(self, data: InputData, **extra_args) -> Samples:
samples = super()._sample(data, **extra_args)
samples.voxel_indices = 0
return samples
class SnerfFastExport(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment