Skip to content
Snippets Groups Projects
Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
No related merge requests found
Showing
with 773 additions and 56 deletions
#include "utils.h"
void multires_hash_encode_kernel_wrapper_fullp(const uint n, const uint levels,
const uint coarse_levels, const uint dims,
const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices);
at::Tensor multires_hash_encode(const int levels, const int coarse_levels, at::Tensor x,
at::Tensor res_list, at::Tensor hash_table,
at::Tensor hash_table_offsets) {
CHECK_CUDA_CONT_TENSOR(FLOAT, x);
CHECK_CUDA_CONT_TENSOR(FLOAT, hash_table);
CHECK_CUDA_CONT_TENSOR(INT, res_list);
CHECK_CUDA_CONT_TENSOR(INT, hash_table_offsets);
const uint n = x.size(0);
const uint dims = x.size(1);
const uint feature_dims = hash_table.size(-1);
at::Tensor encoded =
torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype()));
multires_hash_encode_kernel_wrapper_fullp(
n, (uint)levels, (uint)coarse_levels, dims, feature_dims, x.data_ptr<float>(),
(uint *)res_list.data_ptr(), hash_table.data_ptr(), (uint *)hash_table_offsets.data_ptr(),
encoded.data_ptr(), false, nullptr, nullptr);
return encoded;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
multires_hash_encode_with_grad(const int levels, const int coarse_levels, at::Tensor x,
at::Tensor res_list, at::Tensor hash_table,
at::Tensor hash_table_offsets) {
CHECK_CUDA_CONT_TENSOR(FLOAT, x);
CHECK_CUDA_CONT_TENSOR(FLOAT, hash_table);
CHECK_CUDA_CONT_TENSOR(INT, res_list);
CHECK_CUDA_CONT_TENSOR(INT, hash_table_offsets);
const uint n = x.size(0);
const uint dims = x.size(1);
const uint feature_dims = hash_table.size(-1);
at::Tensor encoded =
torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype()));
at::Tensor weights =
torch::empty({levels, n, 1 << dims}, at::device(x.device()).dtype(at::kFloat));
at::Tensor indices =
torch::empty({levels, n, 1 << dims}, at::device(x.device()).dtype(at::kInt));
multires_hash_encode_kernel_wrapper_fullp(
n, (uint)levels, (uint)coarse_levels, dims, feature_dims, x.data_ptr<float>(),
(uint *)res_list.data_ptr(), hash_table.data_ptr(), (uint *)hash_table_offsets.data_ptr(),
encoded.data_ptr(), true, weights.data_ptr<float>(), (uint *)indices.data_ptr());
return std::make_tuple(encoded, weights, indices);
}
\ No newline at end of file
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "cutil_math.h" // required for float3 vector math
#include "utils.h"
namespace debug {
template <uint DIMS> __device__ uint fast_hash(const uvec<DIMS> gpos, const uint hashmap_size) {
static_assert(DIMS <= 7, "fast_hash can only hash up to 7 dimensions.");
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory
// coherence and is sufficient for our use case of obtaining a uniformly colliding index from
// high-dimensional coordinates.
constexpr uint primes[7] = {1, 2654435761, 805459861, 3674653429,
2097192037, 1434869437, 2165219737};
uint result = gpos[0];
#pragma unroll
for (uint dim = 1; dim < DIMS; ++dim)
result ^= gpos[dim] * primes[dim];
return result % hashmap_size;
}
template <uint DIMS> __device__ uint gidx(const uvec<DIMS> gpos, const uvec<DIMS> res) {
uint index = gpos[0] * res[1] + gpos[1];
#pragma unroll
for (uint dim = 2; dim < DIMS; ++dim)
index = index * res[dim] + gpos[dim];
return index;
}
__global__ void multires_hash_encode_kernel(const uint n, const uint coarse_levels,
const uvec<3> *res_list, const fvec<2> *hash_table,
const uint *hash_table_offsets, const fvec<3> *x,
fvec<2> *o_encoded, fvec<3> *o_local_pos,
uint *o_idx) {
const uint i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= n)
return;
const uint level = blockIdx.y;
const uint hash_table_offset = hash_table_offsets[level];
const uint hash_table_size = hash_table_offsets[level + 1] - hash_table_offset;
const uvec<3> res = res_list[level];
hash_table += hash_table_offset;
fvec<3> pos = x[i];
uvec<3> gpos;
#pragma unroll
for (uint dim = 0; dim < 3; ++dim) {
pos[dim] *= res[dim] - 1;
gpos[dim] = (uint)floor(pos[dim]);
pos[dim] -= gpos[dim];
}
// TODO: Debug codes
o_local_pos[n * level + i] = pos;
auto grid_idx = [&](const uvec<3> gpos) {
uint idx;
if (level >= coarse_levels)
idx = fast_hash(gpos, hash_table_size);
else
idx = gidx(gpos, res);
return idx;
};
// N-linear interpolation
fvec<2> result = {};
#pragma unroll
for (uint corner_idx = 0; corner_idx < (1 << 3); ++corner_idx) {
float weight = 1;
uvec<3> corner_gpos;
#pragma unroll
for (uint dim = 0; dim < 3; ++dim) {
if ((corner_idx & (1 << dim)) == 0) {
weight *= 1 - pos[dim];
corner_gpos[dim] = gpos[dim];
} else {
weight *= pos[dim];
corner_gpos[dim] = min(gpos[dim] + 1, res[dim] - 1);
}
}
auto idx = grid_idx(corner_gpos);
auto val = hash_table[idx];
o_idx[level * n * 8 + i * 8 + corner_idx] = idx;
#pragma unroll
for (uint feature = 0; feature < 2; ++feature) {
result[feature] += weight * val[feature];
}
}
o_encoded[level * n + i] = result;
}
} // namespace debug
std::tuple<at::Tensor, at::Tensor, at::Tensor>
multires_hash_encode_debug(const int levels, const int coarse_levels, at::Tensor x,
at::Tensor res_list, at::Tensor hash_table,
at::Tensor hash_table_offsets) {
const uint n = x.size(0);
const uint dims = x.size(1);
const uint feature_dims = hash_table.size(-1);
res_list = res_list.to(at::kInt);
hash_table_offsets = hash_table_offsets.to(at::kInt);
at::Tensor encoded =
torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype()));
at::Tensor local_pos =
torch::empty({levels, n, dims}, at::device(x.device()).dtype(at::kFloat));
at::Tensor idxs = torch::empty({levels, n, 8}, at::device(x.device()).dtype(at::kInt));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const uint threads = opt_n_threads(n);
const dim3 blocks = {(uint)ceil((float)n / threads), (uint)levels, 1};
debug::multires_hash_encode_kernel<<<blocks, threads, 0, stream>>>(
n, (uint)coarse_levels, (uvec<3> *)res_list.data_ptr(), (fvec<2> *)hash_table.data_ptr(),
(uint *)hash_table_offsets.data_ptr(), (fvec<3> *)x.data_ptr(),
(fvec<2> *)encoded.data_ptr(), (fvec<3> *)local_pos.data_ptr(),
(uint *)idxs.data_ptr());
return std::make_tuple(encoded.transpose(0, 1).reshape({n, -1}),
local_pos.transpose(0, 1).unsqueeze(-2),
idxs.to(at::kLong).transpose(0, 1));
}
\ No newline at end of file
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "cutil_math.h" // required for float3 vector math
template <uint DIMS> __device__ uint fast_hash(const uvec<DIMS> gpos, const uint hashmap_size) {
static_assert(DIMS <= 7, "fast_hash can only hash up to 7 dimensions.");
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory
// coherence and is sufficient for our use case of obtaining a uniformly colliding index from
// high-dimensional coordinates.
constexpr uint primes[7] = {1, 2654435761, 805459861, 3674653429,
2097192037, 1434869437, 2165219737};
uint result = gpos[0];
#pragma unroll
for (uint dim = 1; dim < DIMS; ++dim)
result ^= gpos[dim] * primes[dim];
return result % hashmap_size;
}
template <uint DIMS> __device__ uint gidx(const uvec<DIMS> gpos, const uvec<DIMS> res) {
uint index = gpos[0] * res[1] + gpos[1];
#pragma unroll
for (uint dim = 2; dim < DIMS; ++dim)
index = index * res[dim] + gpos[dim];
return index;
}
template <typename T, uint DIMS, uint FEATURE_DIMS>
__global__ void multires_hash_encode_kernel(const uint n, const uint coarse_levels,
const uvec<DIMS> *__restrict__ res_list,
const vec<T, FEATURE_DIMS> *__restrict__ hash_table,
const uint *__restrict__ hash_table_offsets,
const fvec<DIMS> *__restrict__ x,
vec<T, FEATURE_DIMS> *__restrict__ o_encoded,
const bool requires_grad, float *__restrict__ o_weights,
uint *__restrict__ o_indices) {
const uint i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= n)
return;
const uint level = blockIdx.y;
const uint hash_table_offset = hash_table_offsets[level];
const uint hash_table_size = hash_table_offsets[level + 1] - hash_table_offset;
const uvec<DIMS> res = res_list[level];
hash_table += hash_table_offset;
fvec<DIMS> pos = x[i];
uvec<DIMS> gpos;
#pragma unroll
for (uint dim = 0; dim < DIMS; ++dim) {
pos[dim] *= res[dim] - 1;
gpos[dim] = (uint)floor(pos[dim]);
pos[dim] -= gpos[dim];
}
auto hash_idx = [&](const uvec<DIMS> gpos) {
uint idx;
if (level >= coarse_levels)
idx = fast_hash(gpos, hash_table_size);
else
idx = gidx(gpos, res);
return idx;
};
// N-linear interpolation
vec<T, FEATURE_DIMS> result = {};
auto n_corners = (1 << DIMS);
#pragma unroll
for (uint corner_idx = 0; corner_idx < n_corners; ++corner_idx) {
float weight = 1;
uvec<DIMS> corner_gpos;
#pragma unroll
for (uint dim = 0; dim < DIMS; ++dim) {
if ((corner_idx & (1 << dim)) == 0) {
weight *= 1 - pos[dim];
corner_gpos[dim] = gpos[dim];
} else {
weight *= pos[dim];
corner_gpos[dim] = gpos[dim] + 1;
}
}
auto idx = hash_idx(corner_gpos);
auto val = hash_table[idx];
#pragma unroll
for (uint feature = 0; feature < FEATURE_DIMS; ++feature) {
result[feature] += (T)(weight * (float)val[feature]);
}
// For backward
if (requires_grad) {
auto j = (level * n + i) * n_corners + corner_idx;
o_indices[j] = idx + hash_table_offset;
o_weights[j] = weight;
}
}
o_encoded[level * n + i] = result;
}
template <typename T, uint FEATURE_DIMS>
void multires_hash_encode_kernel_wrapper(const uint n, const uint levels, const uint coarse_levels,
const uint dims, const float *x, const uint *res_list,
const vec<T, FEATURE_DIMS> *hash_table,
const uint *hash_table_offsets,
vec<T, FEATURE_DIMS> *o_encoded, const bool requires_grad,
float *o_weights, uint *o_indices) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const uint threads = opt_n_threads(n);
const dim3 blocks = {(uint)ceil((float)n / threads), levels, 1};
#define DISPATCH_KERNEL_CASE(__DIMS__) \
case __DIMS__: \
multires_hash_encode_kernel<<<blocks, threads, 0, stream>>>( \
n, coarse_levels, (const uvec<__DIMS__> *)res_list, hash_table, hash_table_offsets, \
(const fvec<__DIMS__> *)x, o_encoded, requires_grad, o_weights, o_indices); \
break;
switch (dims) {
DISPATCH_KERNEL_CASE(2)
DISPATCH_KERNEL_CASE(3)
default:
throw std::invalid_argument("'dims' should be 2 or 3");
}
CUDA_CHECK_ERRORS();
#undef DISPATCH_KERNEL_CASE
}
template <typename T>
void multires_hash_encode_kernel_wrapper(const uint n, const uint levels, const uint coarse_levels,
const uint dims, const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices) {
#define KERNEL_WRAPPER_CASE(__FEATURE_DIMS__) \
case __FEATURE_DIMS__: \
multires_hash_encode_kernel_wrapper( \
n, levels, coarse_levels, dims, x, res_list, \
(const vec<T, __FEATURE_DIMS__> *)hash_table, hash_table_offsets, \
(vec<T, __FEATURE_DIMS__> *)o_encoded, requires_grad, o_weights, o_indices); \
break;
switch (feature_dims) {
KERNEL_WRAPPER_CASE(1)
KERNEL_WRAPPER_CASE(2)
KERNEL_WRAPPER_CASE(4)
KERNEL_WRAPPER_CASE(8)
KERNEL_WRAPPER_CASE(16)
default:
throw std::invalid_argument("'feature_dims' should be 1, 2, 4, 8, 16");
}
#undef KERNEL_WRAPPER_CASE
}
#if !defined(__CUDA_NO_HALF_CONVERSIONS__)
void multires_hash_encode_kernel_wrapper_halfp(const uint n, const uint levels,
const uint coarse_levels, const uint dims,
const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices) {
multires_hash_encode_kernel_wrapper<__half>(n, levels, coarse_levels, dims, feature_dims, x,
res_list, hash_table, hash_table_offsets, o_encoded,
requires_grad, o_weights, o_indices);
}
#endif
void multires_hash_encode_kernel_wrapper_fullp(const uint n, const uint levels,
const uint coarse_levels, const uint dims,
const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices) {
multires_hash_encode_kernel_wrapper<float>(n, levels, coarse_levels, dims, feature_dims, x,
res_list, hash_table, hash_table_offsets, o_encoded,
requires_grad, o_weights, o_indices);
}
\ No newline at end of file
import torch import torch
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from typing import Any, List, Mapping, Tuple from typing import Any
from torch import nn from torch import nn
from utils.view import * from utils.view import *
from utils import math from utils import math
...@@ -10,28 +10,29 @@ from .foveation import Foveation ...@@ -10,28 +10,29 @@ from .foveation import Foveation
class FoveatedNeuralRenderer(object): class FoveatedNeuralRenderer(object):
def __init__(self, layers_fov: List[float], def __init__(self, layers_fov: list[float],
layers_res: List[Tuple[int, int]], layers_res: list[tuple[int, int]],
layers_net: nn.ModuleList, layers_net: nn.ModuleList,
output_res: Tuple[int, int], *, output_res: tuple[int, int], *,
coord_sys: str = "gl",
device: torch.device = None): device: torch.device = None):
super().__init__() super().__init__()
self.layers_net = layers_net.to(device=device) self.layers_net = layers_net.to(device=device)
self.layers_cam = [ self.layers_cam = [
CameraParam({ Camera.create({
'fov': layers_fov[i], 'fov': layers_fov[i],
'cx': 0.5, 'cx': 0.5,
'cy': 0.5, 'cy': 0.5,
'normalized': True 'normalized': True
}, layers_res[i], device=device) }, layers_res[i], coord_sys=coord_sys, device=device)
for i in range(len(layers_fov)) for i in range(len(layers_fov))
] ]
self.cam = CameraParam({ self.cam = Camera.create({
'fov': layers_fov[-1], 'fov': layers_fov[-1],
'cx': 0.5, 'cx': 0.5,
'cy': 0.5, 'cy': 0.5,
'normalized': True 'normalized': True
}, output_res, device=device) }, output_res, coord_sys=coord_sys, device=device)
self.foveation = Foveation(layers_fov, layers_res, output_res, device=device) self.foveation = Foveation(layers_fov, layers_res, output_res, device=device)
self.device = device self.device = device
...@@ -44,14 +45,11 @@ class FoveatedNeuralRenderer(object): ...@@ -44,14 +45,11 @@ class FoveatedNeuralRenderer(object):
self.device = device self.device = device
return self return self
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, view: Trans, gaze, right_gaze=None, *,
return self.render(*args, **kwds) stereo_disparity: float = 0,
using_mask: bool = True,
def render(self, view: Trans, gaze, right_gaze=None, *, mono_periph_mode: int = 0,
stereo_disparity=0, ret_raw: bool = False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
using_mask=True,
mono_periph_mode=0,
ret_raw=False) -> Union[Mapping[str, torch.Tensor], Tuple[Mapping[str, torch.Tensor]]]:
if stereo_disparity > math.tiny: if stereo_disparity > math.tiny:
left_view = Trans( left_view = Trans(
view.trans_point(torch.tensor([-stereo_disparity / 2, 0, 0], device=self.device)), view.trans_point(torch.tensor([-stereo_disparity / 2, 0, 0], device=self.device)),
...@@ -71,7 +69,7 @@ class FoveatedNeuralRenderer(object): ...@@ -71,7 +69,7 @@ class FoveatedNeuralRenderer(object):
layer_mask=layers_mask[0])['color'] layer_mask=layers_mask[0])['color']
fovea_right = self._render(self.layers_net[0], self.layers_cam[0], right_view, right_gaze, fovea_right = self._render(self.layers_net[0], self.layers_cam[0], right_view, right_gaze,
layer_mask=layers_mask[0])['color'] layer_mask=layers_mask[0])['color']
if mono_periph_mode == 3: if mono_periph_mode == 3 or mono_periph_mode == 4:
mid = self._render(self.layers_net[1], self.layers_cam[1], view, mid = self._render(self.layers_net[1], self.layers_cam[1], view,
((left_gaze[0] + right_gaze[0]) // 2, left_gaze[1]), ((left_gaze[0] + right_gaze[0]) // 2, left_gaze[1]),
layer_mask=layers_mask[1])['color'] layer_mask=layers_mask[1])['color']
...@@ -79,8 +77,8 @@ class FoveatedNeuralRenderer(object): ...@@ -79,8 +77,8 @@ class FoveatedNeuralRenderer(object):
raw_left = [fovea_left, mid, periph] raw_left = [fovea_left, mid, periph]
raw_right = [fovea_right, mid, periph] raw_right = [fovea_right, mid, periph]
shift = int(left_gaze[0] - right_gaze[0]) // 2 shift = int(left_gaze[0] - right_gaze[0]) // 2
left_shifts = [0, 0, shift] left_shifts = [0, 0, shift if mono_periph_mode == 3 else 0]
right_shifts = [0, 0, -shift] right_shifts = [0, 0, -shift if mono_periph_mode == 3 else 0]
else: else:
mid_left = self._render_mid(self.layers_net[1], self.layers_cam[1], left_view, left_gaze, mid_left = self._render_mid(self.layers_net[1], self.layers_cam[1], left_view, left_gaze,
layer_mask=layers_mask[1], mono_view=view, layer_mask=layers_mask[1], mono_view=view,
...@@ -115,17 +113,22 @@ class FoveatedNeuralRenderer(object): ...@@ -115,17 +113,22 @@ class FoveatedNeuralRenderer(object):
] ]
return self._gen_output(res_raw, gaze, ret_raw=ret_raw) return self._gen_output(res_raw, gaze, ret_raw=ret_raw)
def _render(self, net, cam: CameraParam, view: Trans, gaze=None, *, def _render(self, net, cam: Camera, view: Trans, gaze=None, *,
ret_depth=False, ret_depth=False, layer_mask=None) -> dict[str, torch.Tensor]:
layer_mask=None) -> Mapping[str, torch.Tensor]: output_types = ["color"]
if ret_depth:
output_types.append("depth")
if gaze is not None: if gaze is not None:
cam = self._adjust_cam(cam, gaze) cam = self._adjust_cam(cam, gaze)
rays_o, rays_d = cam.get_global_rays(view, False) # (1, H, W, 3) rays_d = view.trans_vector(cam.local_rays.reshape(*cam.res, -1)) # (1, H, W, 3)
rays_o = view.t.broadcast_to(rays_d.shape)
if layer_mask is not None: if layer_mask is not None:
infer_mask = layer_mask >= 0 infer_mask = layer_mask >= 0
rays_o = rays_o[:, infer_mask] net_input = Rays({
rays_d = rays_d[:, infer_mask] "rays_o": rays_o[:, infer_mask].reshape(-1, 3),
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) "rays_d": rays_d[:, infer_mask].reshape(-1, 3)
})
net_output = net(net_input, *output_types)
ret = { ret = {
'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device)
} }
...@@ -136,17 +139,21 @@ class FoveatedNeuralRenderer(object): ...@@ -136,17 +139,21 @@ class FoveatedNeuralRenderer(object):
ret['depth'][:, infer_mask] = net_output['depth'] ret['depth'][:, infer_mask] = net_output['depth']
return ret return ret
else: else:
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) net_input = Rays({
"rays_o": rays_o.reshape(-1, 3),
"rays_d": rays_d.reshape(-1, 3)
})
net_output = net(net_input, *output_types)
return { return {
'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2),
'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None
} }
def _render_mid(self, net, cam: CameraParam, view: Trans, gaze=None, *, def _render_mid(self, net, cam: Camera, view: Trans, gaze=None, *,
layer_mask: torch.Tensor, layer_mask: torch.Tensor,
mono_view: Trans, mono_view: Trans,
blend_view: bool, blend_view: bool,
ret_depth=False) -> Mapping[str, torch.Tensor]: ret_depth=False) -> dict[str, torch.Tensor]:
""" """
[summary] [summary]
...@@ -159,18 +166,23 @@ class FoveatedNeuralRenderer(object): ...@@ -159,18 +166,23 @@ class FoveatedNeuralRenderer(object):
:param ret_depth: [description], defaults to False :param ret_depth: [description], defaults to False
:return: [description] :return: [description]
""" """
output_types = ["color"]
if ret_depth:
output_types.append("depth")
if gaze is not None: if gaze is not None:
cam = self._adjust_cam(cam, gaze) cam = self._adjust_cam(cam, gaze)
k = layer_mask[None, ..., None].clamp(1 if blend_view else 2, 2) - 1 # (1, H, W, 1) k = layer_mask[None, ..., None].clamp(1 if blend_view else 2, 2) - 1 # (1, H, W, 1)
rays_o = (1 - k) * view.t + k * mono_view.t # (1, H, W, 3) rays_o = (1 - k) * view.t + k * mono_view.t # (1, H, W, 3)
rays_d = view.trans_vector(cam.get_local_rays()) # (1, H, W, 3) rays_d = view.trans_vector(cam.local_rays.reshape(*cam.res, -1)) # (1, H, W, 3)
if layer_mask is not None: if layer_mask is not None:
infer_mask = layer_mask >= 0 infer_mask = layer_mask >= 0
rays_o = rays_o[:, infer_mask] net_input = Rays({
rays_d = rays_d[:, infer_mask] "rays_o": rays_o[:, infer_mask].reshape(-1, 3),
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) "rays_d": rays_d[:, infer_mask].reshape(-1, 3)
})
net_output = net(net_input, *output_types)
ret = { ret = {
'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device)
} }
...@@ -181,13 +193,18 @@ class FoveatedNeuralRenderer(object): ...@@ -181,13 +193,18 @@ class FoveatedNeuralRenderer(object):
ret['depth'][:, infer_mask] = net_output['depth'] ret['depth'][:, infer_mask] = net_output['depth']
return ret return ret
else: else:
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) net_input = {
"rays_o": rays_o.reshape(-1, 3),
"rays_d": rays_d.reshape(-1, 3)
}
net_output = net(net_input, *output_types)
return { return {
'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2),
'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None
} }
def _gen_output(self, layers_img: List[torch.Tensor], gaze: Tuple[float, float], shifts=None, ret_raw=False) -> Mapping[str, torch.Tensor]: def _gen_output(self, layers_img: list[torch.Tensor], gaze: tuple[float, float], shifts=None,
ret_raw=False) -> dict[str, torch.Tensor]:
refined = self._post_process(layers_img) refined = self._post_process(layers_img)
blended = self.foveation.synthesis(refined, gaze, shifts) blended = self.foveation.synthesis(refined, gaze, shifts)
ret = { ret = {
...@@ -196,10 +213,10 @@ class FoveatedNeuralRenderer(object): ...@@ -196,10 +213,10 @@ class FoveatedNeuralRenderer(object):
} }
if ret_raw: if ret_raw:
ret['layers_raw'] = layers_img ret['layers_raw'] = layers_img
ret['blended_raw'] = self.foveation.synthesis(layers_img, gaze) ret['blended_raw'] = self.foveation.synthesis(layers_img, gaze, shifts)
return ret return ret
def _post_process(self, layers_img: List[torch.Tensor]) -> List[torch.Tensor]: def _post_process(self, layers_img: list[torch.Tensor]) -> list[torch.Tensor]:
return [ return [
#grad_aware_median(constrast_enhance(layers_img[0], 3, 0.2), 3, 3, True), #grad_aware_median(constrast_enhance(layers_img[0], 3, 0.2), 3, 3, True),
constrast_enhance(layers_img[0], 3, 0.2), constrast_enhance(layers_img[0], 3, 0.2),
...@@ -207,20 +224,18 @@ class FoveatedNeuralRenderer(object): ...@@ -207,20 +224,18 @@ class FoveatedNeuralRenderer(object):
constrast_enhance(layers_img[2], 5, 0.2) constrast_enhance(layers_img[2], 5, 0.2)
] ]
def _adjust_cam(self, layer_cam: CameraParam, gaze: Tuple[float, float]) -> CameraParam: def _adjust_cam(self, layer_cam: Camera, gaze: tuple[float, float]) -> Camera:
fovea_offset = ( fovea_offset = (
(gaze[0]) / self.cam.f[0].item() * layer_cam.f[0].item(), (gaze[0]) / self.cam.f[0].item() * layer_cam.f[0].item(),
(gaze[1]) / self.cam.f[1].item() * layer_cam.f[1].item() (gaze[1]) / self.cam.f[1].item() * layer_cam.f[1].item()
) )
return CameraParam({ return Camera.create({
'fx': layer_cam.f[0].item(), 'f': [layer_cam.f[0].item(), layer_cam.f[1].item()],
'fy': layer_cam.f[1].item(), 'c': [layer_cam.c[0].item() - fovea_offset[0], layer_cam.c[1].item() - fovea_offset[1]]
'cx': layer_cam.c[0].item() - fovea_offset[0], }, layer_cam.res, coord_sys=layer_cam.coord_sys, device=self.device)
'cy': layer_cam.c[1].item() - fovea_offset[1]
}, layer_cam.res, device=self.device)
def _warp(self, trans: Trans, trans0: Trans, def _warp(self, trans: Trans, trans0: Trans,
cam: CameraParam, z_list: torch.Tensor, cam: Camera, z_list: torch.Tensor,
image: torch.Tensor, depthmap: torch.Tensor) -> torch.Tensor: image: torch.Tensor, depthmap: torch.Tensor) -> torch.Tensor:
""" """
[summary] [summary]
......
...@@ -8,8 +8,8 @@ from utils import math ...@@ -8,8 +8,8 @@ from utils import math
class Foveation(object): class Foveation(object):
def __init__(self, layers_fov: List[float], layers_res: List[Tuple[float, float]], def __init__(self, layers_fov: list[float], layers_res: list[tuple[float, float]],
out_res: Tuple[int, int], *, blend: float = 0.6, device: torch.device = None): out_res: tuple[int, int], *, blend: float = 0.6, device: torch.device = None):
self.layers_fov = layers_fov self.layers_fov = layers_fov
self.layers_res = layers_res self.layers_res = layers_res
self.out_res = out_res self.out_res = out_res
...@@ -20,15 +20,15 @@ class Foveation(object): ...@@ -20,15 +20,15 @@ class Foveation(object):
self._gen_layer_blendmap(i) self._gen_layer_blendmap(i)
for i in range(self.n_layers - 1) for i in range(self.n_layers - 1)
] # blend maps of fovea layers ] # blend maps of fovea layers
self.coords = misc.meshgrid(*out_res).to(device=device) self.coords = misc.grid2d(*out_res, device=device)
def to(self, device: torch.device): def to(self, device: torch.device):
self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend] self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend]
self.coords = self.coords.to(device=device) self.coords = self.coords.to(device=device)
return self return self
def synthesis(self, layers: List[torch.Tensor], fovea_center: Tuple[float, float], def synthesis(self, layers: list[torch.Tensor], fovea_center: tuple[float, float],
shifts: List[int] = None, shifts: list[int] = None,
do_blend: bool = True, do_blend: bool = True,
crop_mode: bool = False) -> torch.Tensor: crop_mode: bool = False) -> torch.Tensor:
""" """
...@@ -40,6 +40,7 @@ class Foveation(object): ...@@ -40,6 +40,7 @@ class Foveation(object):
""" """
output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res, output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
mode='bilinear', align_corners=False) mode='bilinear', align_corners=False)
#output.fill_(0) # TODO: debug
if shifts is not None: if shifts is not None:
output = img.horizontal_shift(output, shifts[-1]) output = img.horizontal_shift(output, shifts[-1])
c = torch.tensor([ c = torch.tensor([
...@@ -99,11 +100,11 @@ class Foveation(object): ...@@ -99,11 +100,11 @@ class Foveation(object):
""" """
size = self.get_layer_size_in_final_image(i) size = self.get_layer_size_in_final_image(i)
R = size / 2 R = size / 2
p = misc.meshgrid(size, size).to(device=self.device) # (size, size, 2) p = misc.grid2d(size, device=self.device) # (size, size, 2)
r = torch.norm(p - R, dim=2) # (size, size, 2) r = torch.norm(p - R, dim=2) # (size, size, 2)
return misc.smooth_step(R, R * self.blend, r) return misc.smooth_step(R, R * self.blend, r)
def get_layers_mask(self, gaze=None) -> List[torch.Tensor]: def get_layers_mask(self, gaze=None) -> list[torch.Tensor]:
""" """
Generate mask images for layers[:-1] Generate mask images for layers[:-1]
the meaning of values in mask images: the meaning of values in mask images:
...@@ -127,8 +128,7 @@ class Foveation(object): ...@@ -127,8 +128,7 @@ class Foveation(object):
else: else:
c = torch.tensor([0.5, 0.5], device=self.device) c = torch.tensor([0.5, 0.5], device=self.device)
layers_mask.append(torch.ones(*self.layers_res[i], device=self.device) * -1) layers_mask.append(torch.ones(*self.layers_res[i], device=self.device) * -1)
coord = misc.meshgrid( coord = misc.grid2d(*self.layers_res[i], device=self.device) / self.layers_res[i][0]
*self.layers_res[i]).to(device=self.device) / self.layers_res[i][0]
r = 2 * torch.norm(coord - c, dim=-1) r = 2 * torch.norm(coord - c, dim=-1)
inner_radius = self.get_source_layer_cover_size_in_target_layer( inner_radius = self.get_source_layer_cover_size_in_target_layer(
self.layers_fov[i - 1], self.layers_fov[i], self.layers_res[i][0]) / self.layers_res[i][0] \ self.layers_fov[i - 1], self.layers_fov[i], self.layers_res[i][0]) / self.layers_res[i][0] \
......
...@@ -7,7 +7,7 @@ from utils import math ...@@ -7,7 +7,7 @@ from utils import math
class GuideRefinement(object): class GuideRefinement(object):
def __init__(self, guides_image, guides_view: view.Trans, def __init__(self, guides_image, guides_view: view.Trans,
guides_cam: view.CameraParam, net) -> None: guides_cam: view.Camera, net) -> None:
rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True) rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True)
guides_inferred = torch.stack([ guides_inferred = torch.stack([
net(rays_o[i], rays_d[i]).view( net(rays_o[i], rays_d[i]).view(
......
from typing import SupportsFloat
from model import Model
from utils.view import *
from utils.types import *
def render(model: Model, cam: Camera, view: Trans, *output_types: str,
gaze: tuple[float, float] = (0, 0), extra_input: dict = None,
layer_mask: torch.Tensor = None, batch_size: int = None) -> ReturnData:
if len(output_types) == 0:
raise ValueError("'output_types' is empty")
local_rays = cam.local_rays + cam.local_rays.new_tensor([*gaze, 0]) # (H*W, 3)
rays_d = view.trans_vector(local_rays) # (B..., H*W, 3)
rays_o = view.t[..., None, :].expand_as(rays_d)
print(cam.local_rays)
exit()
input = Rays(rays_o=rays_o, rays_d=rays_d, **extra_input or {}) # (B..., H*W)
if layer_mask is not None:
selector = layer_mask.flatten().ge(0).nonzero()
input = input.transform(lambda value: value.index_select(len(input.shape), selector))
input = input.flatten() # (B..., X) -> (N)
output = ReturnData() # will be (N)
n = input.shape[0]
batch_size = batch_size or n
for offset in range(0, n, batch_size):
batch_slice = slice(offset, min(offset + batch_size, n))
batch_output = model(input.select(batch_slice), *output_types)
for key, value in batch_output.items():
if key == "rays_filter":
continue
match value:
case torch.Tensor():
if key not in output:
output[key] = value.new_full([n, *value.shape[1:]],
math.huge * (key == "depth"))
if 'rays_filter' in batch_output:
output[key][batch_slice][batch_output['rays_filter']] = batch_output[key]
else:
output[key][batch_slice] = batch_output[key]
case SupportsFloat():
output[key] = output.get(key, 0) + value
case _:
output[key] = output.get(key, []) + [value]
output = output.reshape(*view.shape, -1) # (N) -> (B..., X)
if layer_mask is not None:
output = output.transform(lambda value:
value.new_zeros(*view.shape, local_rays.shape[0],
*value.shape[len(view.shape) + 1:])
.index_copy(len(view.shape), selector, value))
return output.reshape(*view.shape, *cam.res) # (B..., H*W) -> (B..., H, W)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
}, },
"sample_range": [1, 7], "sample_range": [1, 7],
"n_samples": 64, "n_samples": 64,
"multi_nets": 4 "multi_nets": 4,
"density_regularization_weight": 1e-4, "density_regularization_weight": 1e-4,
"density_regularization_scale": 1e4 "density_regularization_scale": 1e4
}, },
......
model=FsNeRF
; n_samples=64
; n_fields=1
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=50
checkpoint_interval=10
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
sample_mode=spherical_radius
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=20
checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
n_samples=64
; sample_mode=xyz
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
; n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=10
checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; n_samples=64
n_fields=4
; depth=8
; width=256
skips=[]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
; checkpoint_interval=10000(iters) or 10(epochs)
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; color=rgb
; n_samples=64
n_fields=4
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
checkpoint_interval=10
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; color=rgb
n_samples=256
n_fields=2
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
checkpoint_interval=10
; batch_size=4096
loss=[Color_L2]
; lr=5e-4
; lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; color=rgb
; n_samples=64
n_fields=4
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
checkpoint_interval=50
; batch_size=4096
loss=[Color_L2]
; lr=5e-4
; lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
; sample_mode=xyz
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=20
checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
; sample_mode=xyz
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
max_iters=200000
; max_epochs=20
; checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
sample_mode=xyz_disp
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=1e0
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=50
checkpoint_interval=10
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999908
; profile_iters=0
\ No newline at end of file
File moved
File moved
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment