Commit 6294701e authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 2824f796
from typing import Union
from .train_with_space import TrainWithSpace
import torch
from pathlib import Path
from model.cnerf import CNeRF
from data.loader import DataLoader, MultiScaleDataLoader
from modules import Voxels
from utils.misc import print_and_log
class TrainMultiScale(TrainWithSpace):
model: CNeRF
data_loader: Union[DataLoader, MultiScaleDataLoader]
def __init__(self, model: CNeRF, run_dir: Path, states: dict) -> None:
super().__init__(model, run_dir, states)
self.freeze_epochs = self._arg("freeze_epochs", [])
self.level_by_level = True#self._arg("level_by_level", False)
def _train_epoch(self):
l = self._check_epoch_matches(self.freeze_epochs, self.epoch)
if l >= 0:
self.model.trigger_stage(l + 1)
cur_level = self.model.stage
if isinstance(self.data_loader, MultiScaleDataLoader):
if self.level_by_level:
self.data_loader.set_active_sub_loaders(cur_level)
else:
self.data_loader.set_active_sub_loaders(slice(cur_level, None))
self._split()
space: Voxels = self.model.model(self.model.stage).space
if self._check_epoch_matches(self.prune_epochs, self.epoch + 1) >= 0:
self.voxel_access = torch.zeros(space.n_voxels, dtype=torch.long, device=space.device)
super(TrainWithSpace, self)._train_epoch()
if self.voxel_access is not None:
before, after = space.prune(self.voxel_access > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
self.voxel_access = None
from modules.sampler import Samples
from modules.space import Octree, Voxels
from .train import Train
import sys
import torch
from pathlib import Path
from typing import List
from modules import Voxels
from model.base import BaseModel
from data.loader import DataLoader
from utils.samples import Samples
from utils.mem_profiler import MemProfiler
from utils.misc import print_and_log
from .base import *
from utils.type import InputData, ReturnData
class TrainWithSpace(Train):
def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000,
**kwargs) -> None:
super().__init__(model, **kwargs)
self.pruning_loop = pruning_loop
self.splitting_loop = splitting_loop
def __init__(self, model: BaseModel, run_dir: Path, states: dict) -> None:
super().__init__(model, run_dir, states)
self.prune_epochs = [] if self.perf_mode else self._arg("prune_epochs", [])
self.split_epochs = [] if self.perf_mode else self._arg("split_epochs", [])
self.voxel_access = None
#MemProfiler.enable = True
def _train_epoch(self):
if not self.perf_mode:
if self.epoch != 1:
if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1:
try:
with torch.no_grad():
before, after = self.model.split()
print_and_log(f"Splitting done: {before} -> {after}")
except NotImplementedError:
print_and_log(
"Note: The space does not support splitting operation. Just skip it.")
if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
try:
with torch.no_grad():
# self._prune_voxels_by_densities()
self._prune_voxels_by_weights()
except NotImplementedError:
print_and_log(
"Note: The space does not support pruning operation. Just skip it.")
self._split()
space: Voxels = self.model.space
if self._check_epoch_matches(self.prune_epochs, self.epoch + 1) >= 0:
self.voxel_access = torch.zeros(space.n_voxels, dtype=torch.long, device=space.device)
super()._train_epoch()
if self.voxel_access is not None:
before, after = space.prune(self.voxel_access > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
self.voxel_access = None
# self._prune()
def _forward(self, data: InputData) -> ReturnData:
if self.voxel_access is None:
return super()._forward(data)
out = self.model(data, 'color', 'energies', 'speculars', 'weights', "samples")
with torch.no_grad():
access_voxels = out['samples'].voxel_indices[out['weights'][..., 0] > 0.01]
self.voxel_access.index_add_(0, access_voxels, torch.ones_like(access_voxels))
return out
@torch.no_grad()
def _split(self):
if self._check_epoch_matches(self.split_epochs) < 0:
return
try:
before, after = self.model.split()
print_and_log(f"Splitting done: {before} -> {after}")
except NotImplementedError:
print_and_log("The space does not support splitting operation. Just skip it.")
@torch.no_grad()
def _prune(self):
if not self._check_epoch_matches(self.prune_epochs):
return
try:
# self._prune_voxels_by_densities()
self._prune_voxels_by_weights()
except NotImplementedError:
print_and_log("The space does not support pruning operation. Just skip it.")
def _check_epoch_matches(self, key_epochs: List[int], epoch: int = None):
epoch = epoch if epoch is not None else self.epoch
if epoch == 0 or len(key_epochs) == 0:
return -1
if len(key_epochs) == 1:
return 0 if epoch % key_epochs[0] == 0 else -1
try:
return key_epochs.index(epoch)
except ValueError:
return -1
def _prune_voxels_by_densities(self):
space: Voxels = self.model.space
......@@ -43,7 +81,7 @@ class TrainWithSpace(Train):
@torch.no_grad()
def get_scores(sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
densities = self.model.render(
densities = self.model.render_pass(
Samples(sampled_points, None, None, None, sampled_voxel_indices),
'density')
return 1 - (-densities).exp()
......@@ -57,28 +95,38 @@ class TrainWithSpace(Train):
], 0) # (M[, ...])
return space.prune(scores > threshold)
def _prune_voxels_by_weights(self):
def _prune_voxels_by_weights(self, data_loader: DataLoader = None):
space: Voxels = self.model.space
voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
device=space.voxels.device)
data_loader = data_loader or self.data_loader
batch_size = data_loader.batch_size
data_loader.batch_size = 2 ** 12
# Note:
# The first element of `voxel_access_counts` and `vidx_map` is perserved for 'invalid voxel'(-1)
# So the index should be offset by 1 when querying these variables
voxel_access_counts = torch.zeros(space.n_voxels + 1, dtype=torch.long, device=space.device)
vidx_map = torch.arange(-1, space.n_voxels, device=space.device)
iters_in_epoch = 0
batch_size = self.data_loader.batch_size
self.data_loader.batch_size = 2 ** 14
for _, rays_o, rays_d, _ in self.data_loader:
ret = self.model(rays_o, rays_d,
raymarching_early_stop_tolerance=0,
raymarching_chunk_size_or_sections=None,
perturb_sample=False,
extra_outputs=['weights'])
for data in data_loader:
samples = self.model.sample(data, perturb_sample=False)
rays_mask = vidx_map[samples.voxel_indices + 1].ne(-1).any(dim=-1) # (N)
samples = samples[rays_mask] # (N, P) -> (N', P)
ret = self.model.render(samples, 'weights',
raymarching_early_stop_tolerance=0,
raymarching_chunk_size_or_sections=None)
valid_mask = ret['weights'][..., 0] > 0.01
accessed_voxels = ret['samples'].voxel_indices[valid_mask]
voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels))
accessed_voxels = samples.voxel_indices[valid_mask]
voxel_access_counts.index_add_(0, accessed_voxels + 1, torch.ones_like(accessed_voxels))
# Filter out accessed voxels to speed-up pruning
vidx_map[voxel_access_counts.ne(0)] = -1
iters_in_epoch += 1
percent = iters_in_epoch / len(self.data_loader) * 100
sys.stdout.write(f'Pruning by weights...{percent:.1f}% \r')
self.data_loader.batch_size = batch_size
before, after = space.prune(voxel_access_counts > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
percent = iters_in_epoch / len(data_loader) * 100
n_voxel_accessed = voxel_access_counts.count_nonzero()
sys.stdout.write(f'Pruning by weights...{percent: .1f}% '
f'(Accessed {n_voxel_accessed}, Tot {space.n_voxels}) \r')
data_loader.batch_size = batch_size
before, after = space.prune(voxel_access_counts[1:] > 0)
print_and_log(f"Prune by weights: {before} -> {after} ")
def _prune_voxels_by_voxel_weights(self):
space: Voxels = self.model.space
......@@ -88,12 +136,14 @@ class TrainWithSpace(Train):
batch_size = self.data_loader.batch_size
self.data_loader.batch_size = 2 ** 14
iters_in_epoch = 0
for _, rays_o, rays_d, _ in self.data_loader:
ret = self.model(rays_o, rays_d,
raymarching_early_stop_tolerance=0,
raymarching_chunk_size_or_sections=None,
perturb_sample=False,
extra_outputs=['weights'])
for data in self.data_loader:
ret = self.model(**data,
extra_outputs=['weights'],
extra_args={
'raymarching_early_stop_tolerance': 0,
'raymarching_chunk_size_or_sections': None,
'perturb_sample': False
})
self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0],
voxel_access_counts)
iters_in_epoch += 1
......
import os
import sys
import argparse
import shutil
from typing import Mapping
from utils.constants import TINY_FLOAT
import torch
import torch.optim
import math
import time
from tensorboardX import SummaryWriter
from torch import nn
from numpy.core.numeric import NaN
parser = argparse.ArgumentParser()
# Arguments for train >>>
......@@ -49,10 +44,11 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from utils import netio
from utils import misc
from utils import math
from utils import device
from utils import img
from utils import interact
from utils import color
from utils.progress_bar import progress_bar
from utils.perf import Perf
from data.spherical_view_syn import *
......@@ -341,8 +337,8 @@ def test():
out['bins'] = out['bins'].permute(0, 3, 1, 2)
if args.output_flags['perf']:
perf_errors = torch.ones(n) * NaN
perf_ssims = torch.ones(n) * NaN
perf_errors = torch.ones(n) * math.nan
perf_ssims = torch.ones(n) * math.nan
if dataset.view_images != None:
for i in range(n):
perf_errors[i] = loss_func(dataset.view_images[i], out['color'][i]).item()
......
from utils import netio
from pathlib import Path
dir = "/home/dengnc/dvs/data/classroom/_nets/ms_train_t0.8/_cnerfadv_ioc/"
for epochs in range(1, 151):
path = f"{dir}checkpoint_{epochs}.tar"
if not Path(path).exists():
continue
print(f"Update epoch {epochs}")
s = netio.load_checkpoint(path)[0]
args0 = s["args"]
args0_for_submodel = {
key: value for key, value in args0.items()
if key != "sub_models" and key != "interp_on_coarse"
}
for i in range(len(args0["sub_models"])):
args0["sub_models"][i] = {**args0_for_submodel, **args0["sub_models"][i]}
if epochs >= 30:
args0["sub_models"][0]["n_samples"] = 64
elif epochs >= 10:
args0["sub_models"][0]["n_samples"] = 32
if epochs >= 70:
args0["sub_models"][1]["n_samples"] = 128
if epochs >= 120:
args0["sub_models"][2]["n_samples"] = 256
netio.save_checkpoint(s, dir, epochs)
\ No newline at end of file
import math
HUGE_FLOAT = 1e10
TINY_FLOAT = 1e-6
PI = math.pi
NAN = math.nan
E = math.e
\ No newline at end of file
env = None
def get_env():
return env
def set_env(new_env: dict):
global env
env = new_env
......@@ -7,8 +7,7 @@ import numpy as np
import torch.nn.functional as nn_f
from typing import List, Tuple, Union
from . import misc
from .constants import *
from . import math
def is_image_file(filename):
"""
......@@ -186,7 +185,7 @@ def translate(input: torch.Tensor, offset: Tuple[float, float]) -> torch.Tensor:
def mse2psnr(x):
logfunc = torch.log if isinstance(x, torch.Tensor) else np.log
return -10. * logfunc(x + TINY_FLOAT) / np.log(10.)
return -10. * logfunc(x + math.tiny) / np.log(10.)
def colorize_depthmap(depthmap: torch.Tensor, depth_range, inverse=True, colormap='binary'):
......
import torch
from math import *
huge = 1e10
tiny = 1e-6
def expected_sin(x: torch.Tensor, x_var: torch.Tensor):
"""Estimates mean and variance of sin(z), z ~ N(x, var)."""
# When the variance is wide, shrink sin towards zero.
y = (-.5 * x_var).exp() * x.sin()
y_var = torch.clamp_min(.5 * (1 - (-2 * x_var).exp() * (2 * x).cos()) - y**2, 0)
return y, y_var
def lift_gaussian(d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
r_var: torch.Tensor, diag: bool):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_sq = d**2
d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)
if diag:
d_outer_diag = d_sq
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def conical_frustum_to_gaussian(d: torch.Tensor, t0: float, t1: float, base_radius: float,
diag: bool, stable: bool = True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
"""
if stable:
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
(3 * mu**2 + hw**2)**2)
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return lift_gaussian(d, t_mean, t_var, r_var, diag)
def cylinder_to_gaussian(d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
Args:
d: torch.float32 3-vector, the axis of the cylinder
t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
"""
t_mean = (t0 + t1) / 2
r_var = radius**2 / 4
t_var = (t1 - t0)**2 / 12
return lift_gaussian(d, t_mean, t_var, r_var, diag)
......@@ -7,9 +7,9 @@ import torch
import glm
import csv
import numpy as np
from typing import List, Union
from typing import List, Tuple, Union
from torch.types import Number
from .constants import *
from . import math
from .device import *
......@@ -46,7 +46,7 @@ def glm2torch(val) -> torch.Tensor:
return torch.from_numpy(np.array(val))
def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> torch.Tensor:
def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False, device: torch.device = None) -> torch.Tensor:
"""
Generate a mesh grid
......@@ -57,7 +57,8 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
"""
if len(size) == 1:
size = (size[0], size[0])
y, x = torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]), indexing='ij')
y, x = torch.meshgrid(torch.arange(size[0], device=device),
torch.arange(size[1], device=device))
if normalize:
x.div_(size[1] - 1.)
y.div_(size[0] - 1.)
......@@ -65,7 +66,7 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
angle = -torch.atan(x / y) - (y < 0) * PI + 0.5 * PI
angle = -torch.atan(x / y) - (y < 0) * math.pi + 0.5 * math.pi
return angle
......@@ -167,13 +168,6 @@ def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: Union[torch
return initial.masked_scatter(mask.reshape(*M_, *repeat(1, len(D_))), value)
def list_epochs(dir: Path, pattern: str) -> List[int]:
prefix = pattern.split("*")[0]
epoch_list = [int(str(path.stem)[len(prefix):]) for path in dir.glob(pattern)]
epoch_list.sort()
return epoch_list
def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int):
start, end = re.search(r'%0\dd', file_pattern).span()
prefix, suffix = start, len(file_pattern) - end
......@@ -185,3 +179,43 @@ def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int):
seqs.sort(reverse=offset > 0)
for i in seqs:
(dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset)))
def merge(*args: dict, **kwargs) -> dict:
ret_args = {}
for arg in args:
ret_args.update(arg)
ret_args.update(kwargs)
return ret_args
def union(*tensors: torch.Tensor) -> torch.Tensor:
return torch.cat(tensors, dim=-1)
def split(tensor: torch.Tensor, *sizes: int) -> Tuple[torch.Tensor, ...]:
sizes = list(sizes)
tot_size = sum(sizes)
for i in range(len(sizes)):
if sizes[i] == -1:
sizes[i] = tensor.shape[-1] - tot_size - 1
tot_size = tensor.shape[-1]
break
if tot_size > tensor.shape[-1]:
raise ValueError("The total number of sizes is larger than the last dim of input tensor")
if any([size < 0 for size in sizes]):
raise ValueError("Only one element in sizes could be -1")
if tot_size < tensor.shape[-1]:
sizes = [*sizes, tensor.shape[-1] - tot_size]
return torch.split(tensor, sizes, -1)[:-1]
else:
return torch.split(tensor, sizes, -1)
def dump_tensors_to_csv(path, *tensors: torch.Tensor, open_mode = "w"):
for i in range(len(tensors)):
if len(tensors[i].shape) == 1:
tensors[i] = tensors[i][:, None]
elif len(tensors[i].shape) > 2:
tensors[i] = tensors[i].flatten(1, -1)
with open(path, open_mode) as fp:
csv_writer = csv.writer(fp)
csv_writer.writerows(torch.cat(tensors, -1).tolist())
\ No newline at end of file
import torch
from collections import OrderedDict
from typing import Any, Optional, Union
class Module(torch.nn.Module):
@property
def cls(self) -> str:
return self._get_name()
def __init__(self):
super().__init__()
self.device = torch.device("cpu")
self._temp = OrderedDict()
self._register_load_state_dict_pre_hook(self._before_load_state_dict)
@staticmethod
def create_multiple(fn, n: int) -> torch.nn.ModuleList:
return torch.nn.ModuleList([fn() for _ in range(n)])
def register_temp(self, name: str, value: Union[torch.Tensor, torch.nn.Module]):
temp = self.__dict__.get('_temp')
if temp is None:
raise AttributeError("cannot assign temp before Module.__init__() call")
if '.' in name:
raise KeyError("temp name can't contain \".\"")
if name == '':
raise KeyError("temp name can't be empty string \"\"")
if hasattr(self, name) and name not in temp:
raise KeyError(f"attribute '{name}' already exists")
if not isinstance(value, (type(None), torch.Tensor, torch.nn.Module)):
raise TypeError(f"cannot assign '{torch.typename(value)}' object to temp '{name}' "
"(torch Tensor, Module or None required)")
if value is not None:
value = value.to(self.device)
temp[name] = value
def freeze(self):
for parameter in self.parameters():
parameter.requires_grad = False
return self
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
if isinstance(module, torch.nn.Module):
module = module.to(self.device)
return super().add_module(name, module)
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
if isinstance(param, torch.nn.Parameter):
param = param.to(self.device)
return super().register_parameter(name, param)
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
if isinstance(tensor, torch.Tensor):
tensor = tensor.to(self.device)
return super().register_buffer(name, tensor, persistent=persistent)
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
target_device = None
try:
target_device = torch.device(kwargs['device'] if 'device' in kwargs else args[0])
except Exception:
pass
if target_device is not None:
def move_to_device(m):
if isinstance(m, Module):
m.device = target_device
self.apply(move_to_device)
return self
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
ret = super().load_state_dict(state_dict, strict=strict)
def fn(module):
if isinstance(module, Module):
module._after_load_state_dict()
self.apply(fn)
return ret
def __getattr__(self, name: str) -> Union[torch.Tensor, torch.nn.Module, Any]:
temp = self.__dict__.get('_temp')
if temp is not None and name in temp:
return temp[name]
return super().__getattr__(name)
def __setattr__(self, name: str, value: Union[torch.Tensor, torch.nn.Module, Any]) -> None:
if isinstance(value, (torch.Tensor, torch.nn.Module)):
value = value.to(self.device)
temp = self.__dict__.get('_temp')
if temp is not None and name in temp:
if not isinstance(value, (type(None), torch.Tensor, torch.nn.Module)):
raise TypeError("cannot assign '{}' object to temp '{}' "
"(torch Tensor, Module or None required)"
.format(torch.typename(value), name))
temp[name] = value
else:
super().__setattr__(name, value)
def __delattr__(self, name) -> None:
temp = self.__dict__.get('_temp')
if temp is not None and name in temp:
del self._temp[name]
else:
super().__delattr__(name)
def _apply(self, fn):
for key, value in self._temp.items():
if value is not None:
self._temp[key] = fn(value)
return super()._apply(fn)
def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs):
pass
def _after_load_state_dict(self) -> None:
pass
from typing import List, Tuple, Union
import torch
import numpy as np
from utils import device
from pathlib import Path
checkpoint_file_prefix = "checkpoint_"
checkpoint_file_suffix = ".tar"
def get_checkpoint_filename(epoch):
return f"{checkpoint_file_prefix}{epoch}{checkpoint_file_suffix}"
def list_epochs(directory: Union[str, Path]) -> List[int]:
directory = Path(directory)
epoch_list = [
int(file_path.stem[len(checkpoint_file_prefix):])
for file_path in directory.glob(get_checkpoint_filename("*"))
]
epoch_list.sort()
return epoch_list
def load_checkpoint(path: Union[str, Path]) -> Tuple[dict, Path]:
path = Path(path)
if path.suffix != checkpoint_file_suffix:
existed_epochs = list_epochs(path)
if len(existed_epochs) == 0:
raise FileNotFoundError(f"{path} does not contain checkpoint files")
path = path / get_checkpoint_filename(existed_epochs[-1])
return torch.load(path), path
def save_checkpoint(states_dict: dict, directory: Union[str, Path], epoch: int):
torch.save(states_dict, Path(directory) / get_checkpoint_filename(epoch))
def log(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("%d" % params)
def load(path, model, **extra_models):
print('Load net from %s ...' % path)
whole_dict = torch.load(path, map_location=device.default())
model.load_state_dict(whole_dict['model'])
for model, key in extra_models:
if key in whole_dict:
model.load_state_dict(whole_dict[key])
return whole_dict['iters'] if 'iters' in whole_dict else 0
def save(path, model, iters, print_log=True, **extra_models):
if print_log:
print('Saving net to %s ...' % path)
whole_dict = {
'iters': iters,
'model': model.state_dict(),
}
for model, key in extra_models:
whole_dict[key] = model.state_dict()
torch.save(whole_dict, path)
from numpy import average
import torch
import torch.cuda
from typing import Dict, List, OrderedDict
from numpy import average
from typing import Any, Callable, Dict, List, OrderedDict, Union
class Perf(object):
......@@ -94,24 +94,60 @@ def enable_perf():
default_perf_object = Perf()
def perf(fn_or_name):
class _PerfWrap(object):
def __init__(self, fn: Callable = None, name: str = None) -> None:
super().__init__()
self.fn = fn
self.name = name
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if self.fn == None and len(args) == 1 and isinstance(args[0], Callable):
self.fn = args[0]
return lambda *args, **kwargs: self(*args, **kwargs)
self.__enter__()
ret = self.fn(*args, **kwargs)
self.__exit__()
return ret
def __enter__(self):
#print(f"Start node \"{self.name or self.fn.__qualname__}\"")
start_node(self.name or self.fn.__qualname__)
return self
def __exit__(self, *args: Any, **kwargs: Any):
#print(f"End node \"{self.name or self.fn.__qualname__}\"")
end_node()
def perf(arg: Union[str, Callable]):
if isinstance(arg, str):
return _PerfWrap(name=arg)
else:
return lambda *args, **kwargs: _PerfWrap(fn=arg)(*args, **kwargs)
def debug_perf(fn_or_name):
if isinstance(fn_or_name, str):
name = fn_or_name
def perf_with_name(fn):
def wrap_perf(*args, **kwargs):
start_node(name)
node = Perf.Node(name)
ret = fn(*args, **kwargs)
end_node()
node.close()
torch.cuda.synchronize()
print(f"Debug Node {name}: {node.duration():.1f}ms")
return ret
return wrap_perf
return perf_with_name
fn = fn_or_name
def wrap_perf(*args, **kwargs):
start_node(fn.__qualname__)
node = Perf.Node(fn.__qualname__)
ret = fn(*args, **kwargs)
end_node()
node.close()
torch.cuda.synchronize()
print(f"Debug Node {fn.__qualname__}: {node.duration():.1f}ms")
return ret
return wrap_perf
......
......@@ -2,22 +2,27 @@ import shutil
import sys
import time
from .misc import format_time
from .constants import NAN
last_time = time.time()
begin_time = last_time
begin_time = 0
recent_times = []
def progress_bar(current, total, msg=None, premsg=None, barmsg=None):
global last_time, begin_time
global begin_time, recent_times
current_time = time.time()
if current == 0:
begin_time = time.time() # Reset for new bar.
current_time = time.time()
step_time = current_time - last_time
total_time = current_time - begin_time
last_time = current_time
estimated_time = 0 if current == 0 else total_time / current * (total - current)
recent_times = [begin_time]
total_time = 0
estimated_time = 0
step_time = 0
else:
recent_elapse = current_time - recent_times[0]
step_time = recent_elapse / len(recent_times)
total_time = current_time - begin_time
estimated_time = (total - current) * step_time
recent_times = (recent_times + [current_time])[-100:]
show_opt = int(current_time) % 6 >= 3 and current < total
show_barmsg = barmsg is not None and show_opt
......
import torch
from typing import Optional, Union, Any, List, Tuple
from . import math
class Samples(object):
indices: torch.Tensor
""" Tensor(N[, P], 2)` The unique indices of samples, e.g. (i-th ray, j-th sample)"""
pts: torch.Tensor
"""`Tensor(N[, P], 3)`"""
dirs: torch.Tensor
"""`Tensor(N[, P], 3)`"""
depths: torch.Tensor
"""`Tensor(N[, P])`"""
dists: torch.Tensor
"""`Tensor(N[, P])`"""
voxel_indices: torch.Tensor
"""`Tensor(N[, P])`"""
size: List[int]
"""Size of the samples"""
device: torch.device
"""Device where tensors of this object locate"""
def __init__(self, **_data: Union[torch.Tensor, float, int]) -> None:
super().__init__()
super().__setattr__("_data", _data)
super().__setattr__("size", self.pts.size()[:-1])
super().__setattr__("device", self.pts.device)
def __getitem__(self, index: Union[int, slice, list, tuple, torch.Tensor, None]):
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero(as_tuple=True)
return Samples(**{
key: value[index] if isinstance(value, torch.Tensor) else value
for key, value in self._data.items()
})
def __getattr__(self, __name: str) -> Union[torch.Tensor, Any]:
try:
return self._data[__name]
except KeyError:
return None
def __setattr__(self, __name: str, __value: Any) -> None:
self._data[__name] = __value
def reshape(self, *shape: int):
return Samples(**{
key: value.reshape(*shape, *value.shape[len(self.size):])
if isinstance(value, torch.Tensor) else value
for key, value in self._data.items()
})
def filter_rays(self) -> Optional[torch.Tensor]:
if isinstance(self.voxel_indices, torch.Tensor):
valid_rays_mask = self.voxel_indices.ne(-1).any(dim=-1) # (N)
rays_filter = valid_rays_mask.nonzero(as_tuple=True)[0] # (N) -> (N')
super().__setattr__("_data", {
key: value[rays_filter] if isinstance(value, torch.Tensor) else value
for key, value in self._data.items()
})
super().__setattr__("size", self.pts.size()[:-1])
return rays_filter
return None
def interpolate(self, fine_samples, *values: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
P1 = self.size[-1]
P2 = fine_samples.size[-1]
K = P2 // P1
if K > 1:
# Do interpolation
t1 = self.t # ([N, ]P1)
t2 = fine_samples.t # ([N, ]P2)
lo = torch.arange(P1, device=fine_samples.device).repeat_interleave(K)[:P2]
up = (lo + 1).clamp(max=P1 - 1)
t1_lo, t1_up = t1[..., lo], t1[..., up]
k = ((t2 - t1_lo) / (t1_up - t1_lo + math.tiny))[..., None] # ([N, ]P2, 1)
values = [
value[..., lo, :] * (1 - k) + value[..., up, :] * k # ([N, ]P2, X)
for value in values
]
return values[0] if len(values) == 1 else tuple(values)
import numpy as np
from .constants import *
from . import math
def helix(t_range, loops, n):
n_per_loop = n // loops
angles = np.linspace(0, 2 * PI, n_per_loop, endpoint=False)[None, :]. \
angles = np.linspace(0, 2 * math.math.pi, n_per_loop, endpoint=False)[None, :]. \
repeat(loops, axis=0).flatten()
centers = np.empty([n, 3])
centers[:, 0] = 0.5 * t_range[0] * np.cos(angles)
......@@ -18,9 +17,9 @@ def helix(t_range, loops, n):
def scan_around(t_range, circles, n):
angles = np.linspace(-PI, PI, n // circles, endpoint=False)
angles = np.linspace(-math.pi, math.pi, n // circles, endpoint=False)
x_rots = angles[None, :].repeat(circles, axis=0).flatten()
c_angles = angles + 0.8 * PI
c_angles = angles + 0.8 * math.pi
centers = np.empty([n, 3])
for i in range(circles):
r = (0.5 * t_range[0] / circles * (i + 1),
......@@ -28,16 +27,16 @@ def scan_around(t_range, circles, n):
0.5 * t_range[2])
s = slice(i * len(angles), (i + 1) * len(angles))
centers[s, 0] = r[0] * np.sin(c_angles)
centers[s, 1] = r[1] * np.sin(angles * 10 + i * 2 * PI / circles)
centers[s, 1] = r[1] * np.sin(angles * 10 + i * 2 * math.pi / circles)
centers[s, 2] = r[2] * np.cos(c_angles)
rots = np.stack([x_rots, np.zeros_like(x_rots)], axis=1)
return centers, rots
def look_around(t_range, n):
angles = np.linspace(-PI, PI, n, endpoint=False)
angles = np.linspace(-math.pi, math.pi, n, endpoint=False)
x_rots = angles
c_angles = angles + 0.8 * PI
c_angles = angles + 0.8 * math.pi
centers = np.empty([n, 3])
r = (0.5 * t_range[0], 0.5 * t_range[1], 0.5 * t_range[2])
centers[:, 0] = r[0] * np.sin(c_angles)
......
from typing import Union
import torch
import math
from . import math
from . import misc
......
from typing import Any, Dict, Union
import torch
InputData = Dict[str, Union[torch.Tensor, Any]]
ReturnData = Dict[str, Union[torch.Tensor, Any]]
NetOutput = Dict[str, torch.Tensor]
class NetInput:
def __init__(self, x: torch.Tensor = None, d: torch.Tensor = None, f: torch.Tensor = None) -> None:
self.x = x
self.d = d
self.f = f
if x is not None:
self.shape = x.shape[:-1]
elif d is not None:
self.shape = d.shape[:-1]
else:
self.shape = [0]
def __getitem__(self, index: Union[int, slice, list, tuple, torch.Tensor, None]) -> 'NetInput':
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero(as_tuple=True)
return NetInput(
self.x[index] if self.x is not None else None,
self.d[index] if self.d is not None else None,
self.f[index] if self.f is not None else None
)
from typing import List, Mapping, Tuple, Union
from numpy import ones_like
import torch
import math
import glm
from . import misc
from . import math
def fov2length(angle):
......
import torch
from typing import Tuple, Union
from . import math
def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor:
"""
......@@ -13,6 +15,18 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) ->
return ((bbox[1] - bbox[0]) / step_size).ceil().long()
def get_out_of_bound_mask(pts: torch.Tensor, bbox: torch.Tensor) -> torch.Tensor:
"""
Get a mask tensor indicating which elements in `pts` are out of the bound `bbox`
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:return `Tensor(N...)`: a mask tensor
"""
k = (pts - bbox[0]) / (bbox[1] - bbox[0])
return torch.logical_or(k < -math.tiny, k > 1 + math.tiny).any(-1)
def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
indices = grid_coords[..., 0]
for i in range(1, grid_coords.shape[-1]):
......@@ -20,9 +34,7 @@ def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Ten
return indices
def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor:
def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
"""
Get discretized (integer) grid coordinates of points.
......@@ -32,18 +44,13 @@ def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:param steps `Tensor(1|D)`: steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
"""
if step_size is not None:
return ((pts - bbox[0]) / step_size).floor().long()
return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long()
def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
"""
Get flattened grid indices of points.
......@@ -53,46 +60,34 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N...)`: grid indices
:return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not
"""
if step_size is not None:
steps = get_grid_steps(bbox, step_size) # (D)
grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D)
outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...)
:param steps `Tensor(1|D)`: steps alone every dim
:return `Tensor(N...)`: flattened grid indices
"""
grid_coords = to_grid_coords(pts, bbox, steps).minimum(steps - 1) # (N..., D)
outside_mask = get_out_of_bound_mask(pts, bbox) # (N...)
grid_indices = to_flat_indices(grid_coords, steps)
return grid_indices, outside_mask
grid_indices[outside_mask] = -1
return grid_indices
def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
"""
Initialize voxels.
"""
x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)], indexing="ij")
return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps)
x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)])
return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps)
def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor:
def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
"""
Get discretized (integer) grid coordinates of points.
Get center positions of grids.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
specified, then the grid coordinates will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param grid_coords `Tensor(N..., D)`: grid coordinates
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
:param steps `Tensor(1|D)`: steps alone every dim
:return `Tensor(N..., D)`: center positions of grids
"""
grid_coords = grid_coords.float() + .5
if step_size is not None:
return grid_coords * step_size + bbox[0]
return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]
......@@ -115,7 +110,7 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor
dtype = like.dtype
device = like.device
c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
offset = torch.stack(torch.meshgrid([c] * dims, indexing='ij'), -1).flatten(0, -2)\
offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2)\
* voxel_size * .5 / (n - 1 if align_border else n)
return offset
......@@ -141,14 +136,14 @@ def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Te
expand_bbox = bbox.clone()
expand_bbox[0] -= 0.5 * half_voxel_size
expand_bbox[1] += 0.5 * half_voxel_size
double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size)
double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, steps * 2 + 1)
# (M, 3) -> [1, 3, 5, ...]
corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3)
# (8M, 3) -> [0, 2, 4, ...]
corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True)
corners = to_voxel_centers(corner_coords, expand_bbox, step_size=half_voxel_size)
corners = to_voxel_centers(corner_coords, expand_bbox, steps * 2 + 1)
return corners, corner_indices.reshape(-1, 8)
......@@ -163,6 +158,7 @@ def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Te
"""
pts = pts[:, None] # (N, 1, 3)
corners = split_voxels_local(1, 2, like=pts) + 0.5 # (8, 3)
weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True) # (N, 8, 1)
corner_values = corner_values.reshape(corner_values.size(0), 8, -1) # (N, 8, X)
weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True) # (N, 8, 1)
return (weights * corner_values).sum(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment