Commit 5699ccbf authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 338ae906
from modules.sampler import Samples
from modules.space import Octree, Voxels
from utils.mem_profiler import MemProfiler
from utils.misc import print_and_log
from .base import *
class TrainWithSpace(Train):
def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000,
**kwargs) -> None:
super().__init__(model, **kwargs)
self.pruning_loop = pruning_loop
self.splitting_loop = splitting_loop
#MemProfiler.enable = True
def _train_epoch(self):
if not self.perf_mode:
if self.epoch != 1:
if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1:
try:
with torch.no_grad():
before, after = self.model.splitting()
print_and_log(
f"Splitting done. # of voxels before: {before}, after: {after}")
except NotImplementedError:
print_and_log(
"Note: The space does not support splitting operation. Just skip it.")
if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
try:
with torch.no_grad():
#before, after = self.model.pruning()
# print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
# self._prune_inner_voxels()
self._prune_voxels_by_weights()
except NotImplementedError:
print_and_log(
"Note: The space does not support pruning operation. Just skip it.")
super()._train_epoch()
def _prune_inner_voxels(self):
space: Voxels = self.model.space
voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
device=space.voxels.device)
iters_in_epoch = 0
batch_size = self.data_loader.batch_size
self.data_loader.batch_size = 2 ** 14
for _, rays_o, rays_d, _ in self.data_loader:
self.model(rays_o, rays_d,
raymarching_early_stop_tolerance=0.01,
raymarching_chunk_size_or_sections=[1],
perturb_sample=False,
voxel_access_counts=voxel_access_counts,
voxel_access_tolerance=0)
iters_in_epoch += 1
percent = iters_in_epoch / len(self.data_loader) * 100
sys.stdout.write(f'Pruning inner voxels...{percent:.1f}% \r')
self.data_loader.batch_size = batch_size
before, after = space.prune(voxel_access_counts > 0)
print(f"Prune inner voxels: {before} -> {after}")
def _prune_voxels_by_weights(self):
space: Voxels = self.model.space
voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
device=space.voxels.device)
iters_in_epoch = 0
batch_size = self.data_loader.batch_size
self.data_loader.batch_size = 2 ** 14
for _, rays_o, rays_d, _ in self.data_loader:
ret = self.model(rays_o, rays_d,
raymarching_early_stop_tolerance=0,
raymarching_chunk_size_or_sections=None,
perturb_sample=False,
extra_outputs=['weights'])
valid_mask = ret['weights'][..., 0] > 0.01
accessed_voxels = ret['samples'].voxel_indices[valid_mask]
voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels))
iters_in_epoch += 1
percent = iters_in_epoch / len(self.data_loader) * 100
sys.stdout.write(f'Pruning by weights...{percent:.1f}% \r')
self.data_loader.batch_size = batch_size
before, after = space.prune(voxel_access_counts > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
def _prune_voxels_by_voxel_weights(self):
space: Voxels = self.model.space
voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
device=space.voxels.device)
with torch.no_grad():
batch_size = self.data_loader.batch_size
self.data_loader.batch_size = 2 ** 14
iters_in_epoch = 0
for _, rays_o, rays_d, _ in self.data_loader:
ret = self.model(rays_o, rays_d,
raymarching_early_stop_tolerance=0,
raymarching_chunk_size_or_sections=None,
perturb_sample=False,
extra_outputs=['weights'])
self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0],
voxel_access_counts)
iters_in_epoch += 1
percent = iters_in_epoch / len(self.data_loader) * 100
sys.stdout.write(f'Pruning by voxel weights...{percent:.1f}% \r')
self.data_loader.batch_size = batch_size
before, after = space.prune(voxel_access_counts > 0)
print_and_log(f"Prune by voxel weights: {before} -> {after}")
def _accumulate_access_count_by_weight(self, samples: Samples, weights: torch.Tensor,
voxel_access_counts: torch.Tensor):
uni_vidxs = -torch.ones_like(samples.voxel_indices)
vidx_accu = torch.zeros_like(samples.voxel_indices, dtype=torch.float)
uni_vidxs_row = torch.arange(samples.size[0], dtype=torch.long, device=samples.device)
uni_vidxs_head = torch.zeros_like(samples.voxel_indices[:, 0])
uni_vidxs[:, 0] = samples.voxel_indices[:, 0]
vidx_accu[:, 0].add_(weights[:, 0])
for i in range(samples.size[1]):
# For those rows that voxels are changed, move the head one step forward
next_voxel = uni_vidxs[uni_vidxs_row, uni_vidxs_head].ne(samples.voxel_indices[:, i])
uni_vidxs_head[next_voxel].add_(1)
# Set voxel indices and accumulate weights
uni_vidxs[uni_vidxs_row, uni_vidxs_head] = samples.voxel_indices[:, i]
vidx_accu[uni_vidxs_row, uni_vidxs_head].add_(weights[:, i])
max_accu = vidx_accu.max(dim=1, keepdim=True)[0]
uni_vidxs[vidx_accu < max_accu * 0.1] = -1
access_voxels, access_count = uni_vidxs.unique(return_counts=True)
voxel_access_counts[access_voxels[1:]].add_(access_count[1:])
...@@ -260,8 +260,8 @@ def train(): ...@@ -260,8 +260,8 @@ def train():
if epochRange.start > 1: if epochRange.start > 1:
iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model) iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model)
else: else:
misc.create_dir(run_dir) os.makedirs(run_dir, exist_ok=True)
misc.create_dir(log_dir) os.makedirs(log_dir, exist_ok=True)
iters = 0 iters = 0
# 3. Train # 3. Train
...@@ -333,7 +333,7 @@ def test(): ...@@ -333,7 +333,7 @@ def test():
# 4. Save results # 4. Save results
print('Saving results...') print('Saving results...')
misc.create_dir(output_dir) os.makedirs(output_dir, exist_ok=True)
for key in out: for key in out:
shape = [n] + list(dataset.view_res) + list(out[key].size()[1:]) shape = [n] + list(dataset.view_res) + list(out[key].size()[1:])
...@@ -367,7 +367,7 @@ def test(): ...@@ -367,7 +367,7 @@ def test():
for i in range(n) for i in range(n)
]) ])
output_subdir = f"{output_dir}/{output_dataset_id}_bins" output_subdir = f"{output_dir}/{output_dataset_id}_bins"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.view_idxs]) img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.view_idxs])
......
...@@ -60,7 +60,7 @@ args.color = color.from_str(args.color) ...@@ -60,7 +60,7 @@ args.color = color.from_str(args.color)
def train(): def train():
misc.create_dir(run_dir) os.makedirs(run_dir, exist_ok=True)
train_set = UpsamplingDataset('.', 'input/out_view_%04d.png', train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt/view_%04d.png', color=args.color) 'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set, training_data_loader = FastDataLoader(dataset=train_set,
...@@ -80,7 +80,7 @@ def train(): ...@@ -80,7 +80,7 @@ def train():
def test(): def test():
misc.create_dir(os.path.dirname(args.testOutPatt)) os.makedirs(os.path.dirname(args.testOutPatt), exist_ok=True)
train_set = UpsamplingDataset( train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color) '.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set, training_data_loader = FastDataLoader(dataset=train_set,
......
...@@ -2,4 +2,6 @@ import math ...@@ -2,4 +2,6 @@ import math
HUGE_FLOAT = 1e10 HUGE_FLOAT = 1e10
TINY_FLOAT = 1e-6 TINY_FLOAT = 1e-6
PI = math.pi PI = math.pi
\ No newline at end of file NAN = math.nan
E = math.e
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
INF = 1000.0
def ones_like(x):
T = torch if isinstance(x, torch.Tensor) else np
return T.ones_like(x)
def stack(x):
T = torch if isinstance(x[0], torch.Tensor) else np
return T.stack(x)
def matmul(x, y):
T = torch if isinstance(x, torch.Tensor) else np
return T.matmul(x, y)
def cross(x, y, axis=0):
T = torch if isinstance(x, torch.Tensor) else np
return T.cross(x, y, axis)
def cat(x, axis=1):
if isinstance(x[0], torch.Tensor):
return torch.cat(x, dim=axis)
return np.concatenate(x, axis=axis)
def normalize(x, axis=-1, order=2):
if isinstance(x, torch.Tensor):
l2 = x.norm(p=order, dim=axis, keepdim=True)
return x / (l2 + 1e-8), l2
else:
l2 = np.linalg.norm(x, order, axis)
l2 = np.expand_dims(l2, axis)
l2[l2 == 0] = 1
return x / l2, l2
def parse_extrinsics(extrinsics, world2camera=True):
""" this function is only for numpy for now"""
if extrinsics.shape[0] == 3 and extrinsics.shape[1] == 4:
extrinsics = np.vstack([extrinsics, np.array([[0, 0, 0, 1.0]])])
if extrinsics.shape[0] == 1 and extrinsics.shape[1] == 16:
extrinsics = extrinsics.reshape(4, 4)
if world2camera:
extrinsics = np.linalg.inv(extrinsics).astype(np.float32)
return extrinsics
def parse_intrinsics(intrinsics):
fx = intrinsics[0, 0]
fy = intrinsics[1, 1]
cx = intrinsics[0, 2]
cy = intrinsics[1, 2]
return fx, fy, cx, cy
def uv2cam(uv, z, intrinsics, homogeneous=False):
fx, fy, cx, cy = parse_intrinsics(intrinsics)
x_lift = (uv[0] - cx) / fx * z
y_lift = (uv[1] - cy) / fy * z
z_lift = ones_like(x_lift) * z
if homogeneous:
return stack([x_lift, y_lift, z_lift, ones_like(z_lift)])
else:
return stack([x_lift, y_lift, z_lift])
def cam2world(xyz_cam, inv_RT):
return matmul(inv_RT, xyz_cam)[:3]
def r6d2mat(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalisation per Section B of [1].
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
def get_ray_direction(ray_start, uv, intrinsics, inv_RT, depths=None):
if depths is None:
depths = 1
rt_cam = uv2cam(uv, depths, intrinsics, True)
rt = cam2world(rt_cam, inv_RT)
ray_dir, _ = normalize(rt - ray_start[:, None], axis=0)
return ray_dir
def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
"""
This function takes a vector 'camera_position' which specifies the location
of the camera in world coordinates and two vectors `at` and `up` which
indicate the position of the object and the up directions of the world
coordinate system respectively. The object is assumed to be centered at
the origin.
The output is a rotation matrix representing the transformation
from world coordinates -> view coordinates.
Input:
camera_position: 3
at: 1 x 3 or N x 3 (0, 0, 0) in default
up: 1 x 3 or N x 3 (0, 1, 0) in default
"""
if at is None:
at = torch.zeros_like(camera_position)
else:
at = torch.tensor(at).type_as(camera_position)
if up is None:
up = torch.zeros_like(camera_position)
up[2] = -1
else:
up = torch.tensor(up).type_as(camera_position)
z_axis = normalize(at - camera_position)[0]
x_axis = normalize(cross(up, z_axis))[0]
y_axis = normalize(cross(z_axis, x_axis))[0]
R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
return R
def ray(ray_start, ray_dir, depths):
return ray_start + ray_dir * depths
def compute_normal_map(ray_start, ray_dir, depths, RT, width=512, proj=False):
raise NotImplementedError("This function needs fairnr.data.data_utils to work. "
"Will remove this dependency later.")
# TODO:
# this function is pytorch-only (for not)
wld_coords = ray(ray_start, ray_dir, depths.unsqueeze(-1)).transpose(0, 1)
cam_coords = matmul(RT[:3, :3], wld_coords) + RT[:3, 3].unsqueeze(-1)
cam_coords = D.unflatten_img(cam_coords, width)
# estimate local normal
shift_l = cam_coords[:, 2:, :]
shift_r = cam_coords[:, :-2, :]
shift_u = cam_coords[:, :, 2:]
shift_d = cam_coords[:, :, :-2]
diff_hor = normalize(shift_r - shift_l, axis=0)[0][:, :, 1:-1]
diff_ver = normalize(shift_u - shift_d, axis=0)[0][:, 1:-1, :]
normal = cross(diff_hor, diff_ver)
_normal = normal.new_zeros(*cam_coords.size())
_normal[:, 1:-1, 1:-1] = normal
_normal = _normal.reshape(3, -1).transpose(0, 1)
# compute the projected color
if proj:
_normal = normalize(_normal, axis=1)[0]
wld_coords0 = ray(ray_start, ray_dir, 0).transpose(0, 1)
cam_coords0 = matmul(RT[:3, :3], wld_coords0) + RT[:3, 3].unsqueeze(-1)
cam_coords0 = D.unflatten_img(cam_coords0, width)
cam_raydir = normalize(cam_coords - cam_coords0, 0)[0].reshape(3, -1).transpose(0, 1)
proj_factor = (_normal * cam_raydir).sum(-1).abs() * 0.8 + 0.2
return proj_factor
return _normal
# helper functions for encoder
def padding_points(xs, pad):
if len(xs) == 1:
return xs[0].unsqueeze(0)
maxlen = max([x.size(0) for x in xs])
xt = xs[0].new_ones(len(xs), maxlen, xs[0].size(1)).fill_(pad)
for i in range(len(xs)):
xt[i, :xs[i].size(0)] = xs[i]
return xt
def pruning_points(feats, points, scores, depth=0, th=0.5):
if depth > 0:
g = int(8 ** depth)
scores = scores.reshape(scores.size(0), -1, g).sum(-1, keepdim=True)
scores = scores.expand(*scores.size()[:2], g).reshape(scores.size(0), -1)
alpha = (1 - torch.exp(-scores)) > th
feats = [feats[i][alpha[i]] for i in range(alpha.size(0))]
points = [points[i][alpha[i]] for i in range(alpha.size(0))]
points = padding_points(points, INF)
feats = padding_points(feats, 0)
return feats, points
def offset_points(point_xyz: torch.Tensor, half_voxel: Union[torch.Tensor, int, float] = 1,
offset_only: bool = False, bits: int = 2) -> torch.Tensor:
"""
[summary]
:param point_xyz `Tensor(N, 3)`: [description]
:param half_voxel `Tensor(1) | int | float`: [description], defaults to 1
:param offset_only `bool`: [description], defaults to False
:param bits `int`: [description], defaults to 2
:return `Tensor(N, X, 3)|Tensor(X, 3)`: [description]
"""
c = torch.arange(1 - bits, bits, 2, dtype=point_xyz.dtype, device=point_xyz.device)
offset = (torch.stack(torch.meshgrid(c, c, c), dim=-1).reshape(-1, 3)) / (bits - 1) * half_voxel
return offset if offset_only else point_xyz[:, None] + offset
def discretize_points(voxel_points, voxel_size):
# this function turns voxel centers/corners into integer indeices
# we assume all points are alreay put as voxels (real numbers)
minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0]
voxel_indices = ((voxel_points - minimal_voxel_point) / voxel_size).round_().long() # float
residual = (voxel_points - voxel_indices.type_as(voxel_points)
* voxel_size).mean(0, keepdim=True)
return voxel_indices, residual
def expand_points(voxel_points, voxel_size):
_voxel_size = min([
torch.sqrt(((voxel_points[j:j + 1] - voxel_points[j + 1:]) ** 2).sum(-1).min())
for j in range(100)])
depth = int(np.round(torch.log2(_voxel_size / voxel_size)))
if depth > 0:
half_voxel = _voxel_size / 2.0
for _ in range(depth):
voxel_points = offset_points(voxel_points, half_voxel / 2.0).reshape(-1, 3)
half_voxel = half_voxel / 2.0
return voxel_points, depth
def get_edge(depth_pts, voxel_pts, voxel_size, th=0.05):
voxel_pts = offset_points(voxel_pts, voxel_size / 2.0)
diff_pts = (voxel_pts - depth_pts[:, None, :]).norm(dim=2)
ab = diff_pts.sort(dim=1)[0][:, :2]
a, b = ab[:, 0], ab[:, 1]
c = voxel_size
p = (ab.sum(-1) + c) / 2.0
h = (p * (p - a) * (p - b) * (p - c)) ** 0.5 / c
return h < (th * voxel_size)
# fill-in image
def fill_in(shape, hits, input, initial=1.0):
input_sizes = [k for k in input.size()]
if (len(input_sizes) == len(shape)) and \
all([shape[i] == input_sizes[i] for i in range(len(shape))]):
return input # shape is the same no need to fill
if isinstance(initial, torch.Tensor):
output = initial.expand(*shape)
else:
output = input.new_ones(*shape) * initial
if input is not None:
if len(shape) == 1:
return output.masked_scatter(hits, input)
return output.masked_scatter(hits.unsqueeze(-1).expand(*shape), input)
return output
import os import os
from pathlib import Path
import shutil import shutil
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from typing import Tuple from typing import List, Tuple, Union
from . import misc from . import misc
from .constants import * from .constants import *
...@@ -65,7 +66,7 @@ def load(*paths: str, permute=True, with_alpha=False) -> torch.Tensor: ...@@ -65,7 +66,7 @@ def load(*paths: str, permute=True, with_alpha=False) -> torch.Tensor:
chns = 4 if with_alpha else 3 chns = 4 if with_alpha else 3
new_paths = [] new_paths = []
for path in paths: for path in paths:
new_paths += [path] if isinstance(path, str) else list(path) new_paths += [path] if isinstance(path, (str, Path)) else list(path)
imgs = np.stack([plt.imread(path)[..., :chns] for path in new_paths]) imgs = np.stack([plt.imread(path)[..., :chns] for path in new_paths])
if imgs.dtype == 'uint8': if imgs.dtype == 'uint8':
imgs = imgs.astype(np.float32) / 255 imgs = imgs.astype(np.float32) / 255
...@@ -76,7 +77,7 @@ def load_seq(path: str, n: int, permute=True, with_alpha=False) -> torch.Tensor: ...@@ -76,7 +77,7 @@ def load_seq(path: str, n: int, permute=True, with_alpha=False) -> torch.Tensor:
return load([path % i for i in range(n)], permute=permute, with_alpha=with_alpha) return load([path % i for i in range(n)], permute=permute, with_alpha=with_alpha)
def save(input: torch.Tensor, *paths: str): def save(input: torch.Tensor, *paths: Union[str, Path, List[Union[str, Path]]]):
""" """
Save one or multiple torch-image(s) to `paths` Save one or multiple torch-image(s) to `paths`
...@@ -86,7 +87,7 @@ def save(input: torch.Tensor, *paths: str): ...@@ -86,7 +87,7 @@ def save(input: torch.Tensor, *paths: str):
""" """
new_paths = [] new_paths = []
for path in paths: for path in paths:
new_paths += [path] if isinstance(path, str) else list(path) new_paths += [path] if isinstance(path, (str, Path)) else list(path)
if len(input.size()) < 4: if len(input.size()) < 4:
input = input[None] input = input[None]
if input.size(0) != len(new_paths): if input.size(0) != len(new_paths):
...@@ -100,9 +101,9 @@ def save(input: torch.Tensor, *paths: str): ...@@ -100,9 +101,9 @@ def save(input: torch.Tensor, *paths: str):
plt.imsave(path, np_img[i]) plt.imsave(path, np_img[i])
def save_seq(input: torch.Tensor, path: str): def save_seq(input: torch.Tensor, path: Union[str, Path]):
n = 1 if len(input.size()) <= 3 else input.size(0) n = 1 if len(input.size()) <= 3 else input.size(0)
return save(input, [path % i for i in range(n)]) return save(input, [str(path) % i for i in range(n)])
def plot(input: torch.Tensor, *, ax: plt.Axes = None): def plot(input: torch.Tensor, *, ax: plt.Axes = None):
...@@ -118,7 +119,7 @@ def plot(input: torch.Tensor, *, ax: plt.Axes = None): ...@@ -118,7 +119,7 @@ def plot(input: torch.Tensor, *, ax: plt.Axes = None):
return plt.imshow(im) if ax is None else ax.imshow(im) return plt.imshow(im) if ax is None else ax.imshow(im)
def save_video(frames: torch.Tensor, path: str, fps: int, def save_video(frames: torch.Tensor, path: Union[str, Path], fps: int,
repeat: int = 1, pingpong: bool = False): repeat: int = 1, pingpong: bool = False):
""" """
Encode and save a sequence of frames as video file Encode and save a sequence of frames as video file
...@@ -134,19 +135,16 @@ def save_video(frames: torch.Tensor, path: str, fps: int, ...@@ -134,19 +135,16 @@ def save_video(frames: torch.Tensor, path: str, fps: int,
frames = torch.cat([frames, frames.flip(0)], 0) frames = torch.cat([frames, frames.flip(0)], 0)
if repeat > 1: if repeat > 1:
frames = frames.expand(repeat, -1, -1, -1, -1).flatten(0, 1) frames = frames.expand(repeat, -1, -1, -1, -1).flatten(0, 1)
dir, file_name = os.path.split(path)
if not dir: path = Path(path)
dir = './' tempdir = Path('/dev/shm/dvs_tmp/video')
misc.create_dir(dir) inferout = tempdir / path.stem / f"%04d.bmp"
cwd = os.getcwd() os.makedirs(inferout.parent, exist_ok=True)
os.chdir(dir) os.makedirs(path.parent, exist_ok=True)
temp_out_dir = os.path.splitext(file_name)[0] + '_tempout'
misc.create_dir(temp_out_dir) save_seq(frames, inferout)
os.chdir(temp_out_dir) os.system(f'ffmpeg -y -r {fps:d} -i {inferout} -c:v libx264 {path}')
save_seq(frames, 'out_%04d.png') shutil.rmtree(inferout.parent)
os.system(f'ffmpeg -y -r {fps:d} -i out_%04d.png -c:v libx264 ../{file_name}')
os.chdir(cwd)
shutil.rmtree(os.path.join(dir, temp_out_dir))
def horizontal_shift(input: torch.Tensor, offset: int, dim=-1) -> torch.Tensor: def horizontal_shift(input: torch.Tensor, offset: int, dim=-1) -> torch.Tensor:
......
...@@ -2,13 +2,14 @@ from cgitb import enable ...@@ -2,13 +2,14 @@ from cgitb import enable
import torch import torch
from .device import * from .device import *
class MemProfiler: class MemProfiler:
enable = False enable = False
@staticmethod @staticmethod
def print_memory_stats(prefix, last_allocated=None, device=None): def print_memory_stats(prefix, last_allocated=None, device=None, enable_once=False):
if not MemProfiler.enable: if not enable_once and not MemProfiler.enable:
return return
if device is None: if device is None:
device = default() device = default()
......
import os from itertools import repeat
import logging
from pathlib import Path
import re
import shutil
import torch import torch
import glm import glm
import csv import csv
import numpy as np import numpy as np
from typing import List, Tuple, Union from typing import List, Union
from torch.types import Number from torch.types import Number
from .constants import * from .constants import *
from .device import * from .device import *
...@@ -59,31 +63,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor ...@@ -59,31 +63,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) if normalize else torch.stack([x, y], 2) return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) if normalize else torch.stack([x, y], 2)
def create_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
angle = -torch.atan(x / y) + (y < 0) * PI + 0.5 * PI angle = -torch.atan(x / y) - (y < 0) * PI + 0.5 * PI
return angle return angle
def depth_sample(depth_range: Tuple[float, float], n: int, lindisp: bool) -> torch.Tensor:
"""
Get [n_layers] foreground layers whose diopters are distributed uniformly
in [depth_range] plus a background layer
:param depth_range: depth range of foreground layers
:param n_layers: number of foreground layers
:return: list of [n_layers+1] depths
"""
if lindisp:
depth_range = (1 / depth_range[0], 1 / depth_range[1])
samples = torch.linspace(depth_range[0], depth_range[1], n)
return samples
def broadcast_cat(input: torch.Tensor, def broadcast_cat(input: torch.Tensor,
s: Union[Number, List[Number], torch.Tensor], s: Union[Number, List[Number], torch.Tensor],
dim=-1, dim=-1,
...@@ -130,4 +114,73 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: ...@@ -130,4 +114,73 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return input.view(out_shape) return input.view(out_shape)
def values(map, *keys): return list(map[key] for key in keys) def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
seconds_final = int(seconds)
seconds = seconds - seconds_final
millis = int(seconds * 1000)
if days > 0:
output = f"{days}D{hours:0>2d}h{minutes:0>2d}m"
elif hours > 0:
output = f"{hours:0>2d}h{minutes:0>2d}m{seconds_final:0>2d}s"
elif minutes > 0:
output = f"{minutes:0>2d}m{seconds_final:0>2d}s"
elif seconds_final > 0:
output = f"{seconds_final:0>2d}s{millis:0>3d}ms"
elif millis > 0:
output = f"{millis:0>3d}ms"
else:
output = '0ms'
return output
def print_and_log(s):
print(s)
logging.info(s)
def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: Union[torch.Tensor, Number] = 0):
"""
Extend PyTorch's built-in `masked_scatter` function
:param mask `Tensor(M...)`: the boolean mask
:param value `Tensor(N, D...)`: the value to fill in with, should have at least as many elements
as the number of ones in `mask`
:param destination `Tensor(M..., D...)`: (optional) the destination tensor to fill,
if not specified, a new tensor filled with
`empty_value` will be created and used as destination
:param empty_value `Number`: the initial elements in the newly created destination tensor,
defaults to 0
:return `Tensor(M..., D...)`: the destination tensor after filled
"""
M_ = mask.size()
D_ = value.size()[1:]
if not isinstance(initial, torch.Tensor):
initial = value.new_full([*M_, *D_], initial)
return initial.masked_scatter(mask.reshape(*M_, *repeat(1, len(D_))), value)
def list_epochs(dir: Path, pattern: str) -> List[int]:
prefix = pattern.split("*")[0]
epoch_list = [int(str(path.stem)[len(prefix):]) for path in dir.glob(pattern)]
epoch_list.sort()
return epoch_list
def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int):
start, end = re.search(r'%0\dd', file_pattern).span()
prefix, suffix = start, len(file_pattern) - end
seqs = [
int(path.name[prefix:-suffix])
for path in dir.glob(re.sub(r'%0\dd', "*", file_pattern))
]
seqs.sort(reverse=offset > 0)
for i in seqs:
(dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset)))
from numpy import average
import torch
import torch.cuda import torch.cuda
from typing import Dict, List, OrderedDict
class Perf(object): class Perf(object):
frames: List[Dict[str, float]]
def __init__(self, enable, start=False) -> None: class Node:
def __init__(self, name, parent=None) -> None:
self.name = name
self.parent = parent
self.events = []
self.event_names = []
self.child_nodes = []
self.child_nodes_event_idx = []
self.add_checkpoint("Start")
def add_checkpoint(self, name):
event = torch.cuda.Event(enable_timing=True)
event.record()
self.events.append(event)
self.event_names.append(name)
def add_child(self, name):
child = Perf.Node(name, self)
self.child_nodes.append(child)
self.child_nodes_event_idx.append(len(self.events))
return child
def close(self):
self.add_checkpoint("End")
return self.parent
def duration(self, i0=0, i1=-1) -> float:
return self.events[i0].elapsed_time(self.events[i1])
def result(self, prefix: str = '') -> OrderedDict[str, float]:
path = f"{prefix}{self.name}"
res = {path: self.duration()}
j = 0
for i in range(1, len(self.events) - 1):
event_path = f"{path}/{self.event_names[i]}"
res[event_path] = self.duration(i - 1, i)
while j < len(self.child_nodes):
if self.child_nodes_event_idx[j] > i:
break
res.update(self.child_nodes[j].result(f"{event_path}/"))
j += 1
while j < len(self.child_nodes):
res.update(self.child_nodes[j].result(f"{path}/"))
j += 1
return res
def __init__(self) -> None:
super().__init__() super().__init__()
self.enable = enable self.root_node = None
self.start_event = None self.current_node = None
if start: self.frames = []
self.start()
def start_node(self, name):
def start(self): if self.current_node is None:
if not self.enable: self.root_node = self.current_node = Perf.Node(name)
return else:
if self.start_event == None: self.current_node = self.current_node.add_child(name)
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True) def checkpoint(self, name):
torch.cuda.synchronize() self.current_node.add_checkpoint(name)
self.start_event.record()
def end_node(self):
def checkpoint(self, name: str = None, end: bool = False): self.current_node = self.current_node.close()
if not self.enable: if self.current_node is None:
return 0 torch.cuda.synchronize()
self.end_event.record() self.frames.append(self.root_node.result())
torch.cuda.synchronize()
duration = self.start_event.elapsed_time(self.end_event) def get_result(self, i=None):
if name: if i is not None:
print('%s: %.1fms' % (name, duration)) return self.frames[i]
if not end: if len(self.frames) == 0:
self.start_event.record() return {}
return duration res = {key: [val] for key, val in self.frames[0].items()}
for i in range(1, len(self.frames)):
for key, val in self.frames[i].items():
res[key].append(val)
return {key: average(val) for key, val in res.items()}
default_perf_object = None
def enable_perf():
global default_perf_object
default_perf_object = Perf()
def perf(fn_or_name):
if isinstance(fn_or_name, str):
name = fn_or_name
def perf_with_name(fn):
def wrap_perf(*args, **kwargs):
start_node(name)
ret = fn(*args, **kwargs)
end_node()
return ret
return wrap_perf
return perf_with_name
fn = fn_or_name
def wrap_perf(*args, **kwargs):
start_node(fn.__qualname__)
ret = fn(*args, **kwargs)
end_node()
return ret
return wrap_perf
def start_node(name):
if default_perf_object is not None:
default_perf_object.start_node(name)
def end_node():
if default_perf_object is not None:
default_perf_object.end_node()
def checkpoint(name):
if default_perf_object is not None:
default_perf_object.checkpoint(name)
def get_perf_result(i=None):
if default_perf_object is not None:
return default_perf_object.get_result(i)
return None
import shutil
import sys import sys
import time import time
import os from .misc import format_time
from .constants import NAN
bar_length = 50
LAST_T = time.time()
BEGIN_T = LAST_T
last_time = time.time()
begin_time = last_time
def get_terminal_columns():
return os.get_terminal_size().columns
def progress_bar(current, total, msg=None, premsg=None, barmsg=None):
def progress_bar(current, total, msg=None, premsg=None): global last_time, begin_time
global LAST_T, BEGIN_T
if current == 0: if current == 0:
BEGIN_T = time.time() # Reset for new bar. begin_time = time.time() # Reset for new bar.
current_time = time.time() current_time = time.time()
step_time = current_time - LAST_T step_time = current_time - last_time
LAST_T = current_time total_time = current_time - begin_time
total_time = current_time - BEGIN_T last_time = current_time
estimated_time = 0 if current == 0 else total_time / current * (total - current)
show_opt = int(current_time) % 6 >= 3 and current < total
show_barmsg = barmsg is not None and show_opt
str0 = f"{premsg} [" if premsg else '[' str0 = f"{premsg} [" if premsg else '['
str1 = f"] {current + 1:d}/{total:d} | Step: {format_time(step_time)} | Tot: {format_time(total_time)}" str1 = f"] {current:d}/{total:d} | Step: {format_time(step_time)} | " + (
f"Eta: {format_time(estimated_time)}" if show_opt else f"Tot: {format_time(total_time)}"
)
if msg: if msg:
str1 += f" | {msg}" str1 += f" | {msg}"
tot_cols = get_terminal_columns() tot_cols = shutil.get_terminal_size().columns - 10
bar_length = tot_cols - len(str0) - len(str1) bar_length = tot_cols - len(str0) - len(str1)
current_len = int(bar_length * (current + 1) / total) if show_barmsg and bar_length < len(barmsg):
rest_len = int(bar_length - current_len) sys.stdout.write(str0[:-1] + barmsg)
elif bar_length <= 0:
if current_len == 0: sys.stdout.write(str0[:-1] + str1[2:])
str_bar = '.' * rest_len
else: else:
str_bar = '=' * (current_len - 1) + '>' + '.' * rest_len current_len = int(bar_length * current / total)
rest_len = int(bar_length - current_len)
sys.stdout.write(str0 + str_bar + str1) str_bar = ''
if current_len > 0:
if current < total - 1: str_bar += '=' * (current_len - 1) + '>'
sys.stdout.write('\r') str_bar += '.' * rest_len
else: if show_barmsg:
sys.stdout.write('\n') str_bar = barmsg + str_bar[len(barmsg):]
sys.stdout.write(str0 + str_bar + str1)
sys.stdout.write('\r' if current < total else '\n')
sys.stdout.flush() sys.stdout.flush()
# return the formatted time
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
seconds_final = int(seconds)
seconds = seconds - seconds_final
millis = int(seconds * 1000)
output = ''
time_index = 1
if days > 0:
output += str(days) + 'D'
time_index += 1
if hours > 0 and time_index <= 2:
output += str(hours) + 'h'
time_index += 1
if minutes > 0 and time_index <= 2:
output += str(minutes) + 'm'
time_index += 1
if seconds_final > 0 and time_index <= 2:
output += '%02ds' % seconds_final
time_index += 1
if millis > 0 and time_index <= 2:
output += '%03dms' % millis
time_index += 1
if output == '':
output = '0ms'
return output
from typing import List, Union from typing import Union
import torch import torch
import math import math
from . import misc from . import misc
...@@ -13,12 +13,12 @@ def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Te ...@@ -13,12 +13,12 @@ def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Te
:return `Tensor(..., 3)`: coordinates in Spherical (r, theta, phi) :return `Tensor(..., 3)`: coordinates in Spherical (r, theta, phi)
""" """
rho = torch.sqrt(torch.sum(cart * cart, dim=-1)) rho = torch.sqrt(torch.sum(cart * cart, dim=-1))
theta = misc.get_angle(cart[..., 0], cart[..., 2]) theta = misc.get_angle(cart[..., 2], cart[..., 0])
if inverse_r: if inverse_r:
rho = rho.reciprocal() rho = rho.reciprocal()
phi = torch.acos(cart[..., 1] * rho) phi = torch.asin(cart[..., 1] * rho)
else: else:
phi = torch.acos(cart[..., 1] / rho) phi = torch.asin(cart[..., 1] / rho)
return torch.stack([rho, theta, phi], dim=-1) return torch.stack([rho, theta, phi], dim=-1)
...@@ -34,9 +34,9 @@ def spherical2cartesian(spher: torch.Tensor, inverse_r: bool = False) -> torch.T ...@@ -34,9 +34,9 @@ def spherical2cartesian(spher: torch.Tensor, inverse_r: bool = False) -> torch.T
rho = rho.reciprocal() rho = rho.reciprocal()
sin_theta_phi = torch.sin(spher[..., 1:3]) sin_theta_phi = torch.sin(spher[..., 1:3])
cos_theta_phi = torch.cos(spher[..., 1:3]) cos_theta_phi = torch.cos(spher[..., 1:3])
x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1] x = rho * sin_theta_phi[..., 0] * cos_theta_phi[..., 1]
y = rho * cos_theta_phi[..., 1] y = rho * sin_theta_phi[..., 1]
z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1] z = rho * cos_theta_phi[..., 0] * cos_theta_phi[..., 1]
return torch.stack([x, y, z], dim=-1) return torch.stack([x, y, z], dim=-1)
......
import torch
from typing import Tuple, Union
def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor:
"""
Get grid steps alone every dim.
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: step size
:return `Tensor(D)`: grid steps alone every dim
"""
return ((bbox[1] - bbox[0]) / step_size).ceil().long()
def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor:
"""
Get discretized (integer) grid coordinates of points.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
specified, then the grid coordinates will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
"""
if step_size is not None:
return ((pts - bbox[0]) / step_size).floor().long()
return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long()
def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get flattened grid indices of points.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
specified, then the grid indices will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N...)`: grid indices
:return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not
"""
if step_size is not None:
steps = get_grid_steps(bbox, step_size) # (D)
grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D)
outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...)
if pts.size(-1) == 1:
grid_indices = grid_coords[..., 0]
elif pts.size(-1) == 2:
grid_indices = grid_coords[..., 0] * steps[1] + grid_coords[..., 1]
elif pts.size(-1) == 3:
grid_indices = grid_coords[..., 0] * steps[1] * steps[2] \
+ grid_coords[..., 1] * steps[2] + grid_coords[..., 2]
elif pts.size(-1) == 4:
grid_indices = grid_coords[..., 0] * steps[1] * steps[2] * steps[3] \
+ grid_coords[..., 1] * steps[2] * steps[3] \
+ grid_coords[..., 2] * steps[3] \
+ grid_coords[..., 3]
else:
raise NotImplementedError("The function does not support D>4")
return grid_indices, outside_mask
def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
"""
Initialize voxels.
"""
x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)])
return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps)
def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor:
"""
Get discretized (integer) grid coordinates of points.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
specified, then the grid coordinates will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
"""
grid_coords = grid_coords.float() + 0.5
if step_size is not None:
return grid_coords * step_size + bbox[0]
return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]
def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_border: bool = True,
dims=3, *, dtype: torch.dtype = None, device: torch.device = None,
like: torch.Tensor = None):
"""
[summary]
:param voxel_size `Tensor(D)|float`: [description]
:param n `int`: [description]
:param align_border `bool`: [description], defaults to False
:param dims `int`: [description], defaults to 3
:param dtype `dtype`: [description], defaults to None
:param device `device`: [description], defaults to None
:param like `Tensor(*)`:
:return `Tensor(X, D)`: [description]
"""
if like is not None:
dtype = like.dtype
device = like.device
c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2) * voxel_size / 2 /\
(n - 1 if align_border else n)
return offset
def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, float], n: int,
align_border: bool = True):
"""
[summary]
:param voxel_centers `Tensor(N, D)`: [description]
:param voxel_size `Tensor(D)|float`: [description]
:param n `int`: [description]
:param align_border `bool`: [description], defaults to False
:param return_local `bool`: [description], defaults to False
:return `Tensor(N, X, D)`: [description]
"""
return voxel_centers[:, None] + split_voxels_local(
voxel_size, n, align_border, voxel_centers.shape[-1], like=voxel_centers)
def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5
expand_bbox = bbox
expand_bbox[0] -= 0.5 * half_voxel_size
expand_bbox[1] += 0.5 * half_voxel_size
double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size)
# (M, 3) -> [1, 3, 5, ...]
corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3)
# (8M, 3) -> [0, 2, 4, ...]
corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True)
corners = to_voxel_centers(corner_coords, expand_bbox, step_size=half_voxel_size)
return corners, corner_indices.reshape(-1, 8)
def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
"""
Perform trilinear interpolation in unit voxel ([0,0,0] ~ [1,1,1]).
:param pts `Tensor(N, 3)`: uniform coordinates in voxels
:param corner_values `Tensor(N, 8X)|Tensor(N, 8, X)`: values at corners of voxels
:return `Tensor(N, X)`: interpolated values
"""
pts = pts[:, None] # (N, 1, 3)
corners = split_voxels_local(1, 2, like=pts) + 0.5 # (8, 3)
weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True) # (N, 8, 1)
corner_values = corner_values.reshape(corner_values.size(0), 8, -1) # (N, 8, X)
return (weights * corner_values).sum(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment