Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
#include "utils.h"
void multires_hash_encode_kernel_wrapper_fullp(const uint n, const uint levels,
const uint coarse_levels, const uint dims,
const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices);
at::Tensor multires_hash_encode(const int levels, const int coarse_levels, at::Tensor x,
at::Tensor res_list, at::Tensor hash_table,
at::Tensor hash_table_offsets) {
CHECK_CUDA_CONT_TENSOR(FLOAT, x);
CHECK_CUDA_CONT_TENSOR(FLOAT, hash_table);
CHECK_CUDA_CONT_TENSOR(INT, res_list);
CHECK_CUDA_CONT_TENSOR(INT, hash_table_offsets);
const uint n = x.size(0);
const uint dims = x.size(1);
const uint feature_dims = hash_table.size(-1);
at::Tensor encoded =
torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype()));
multires_hash_encode_kernel_wrapper_fullp(
n, (uint)levels, (uint)coarse_levels, dims, feature_dims, x.data_ptr<float>(),
(uint *)res_list.data_ptr(), hash_table.data_ptr(), (uint *)hash_table_offsets.data_ptr(),
encoded.data_ptr(), false, nullptr, nullptr);
return encoded;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
multires_hash_encode_with_grad(const int levels, const int coarse_levels, at::Tensor x,
at::Tensor res_list, at::Tensor hash_table,
at::Tensor hash_table_offsets) {
CHECK_CUDA_CONT_TENSOR(FLOAT, x);
CHECK_CUDA_CONT_TENSOR(FLOAT, hash_table);
CHECK_CUDA_CONT_TENSOR(INT, res_list);
CHECK_CUDA_CONT_TENSOR(INT, hash_table_offsets);
const uint n = x.size(0);
const uint dims = x.size(1);
const uint feature_dims = hash_table.size(-1);
at::Tensor encoded =
torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype()));
at::Tensor weights =
torch::empty({levels, n, 1 << dims}, at::device(x.device()).dtype(at::kFloat));
at::Tensor indices =
torch::empty({levels, n, 1 << dims}, at::device(x.device()).dtype(at::kInt));
multires_hash_encode_kernel_wrapper_fullp(
n, (uint)levels, (uint)coarse_levels, dims, feature_dims, x.data_ptr<float>(),
(uint *)res_list.data_ptr(), hash_table.data_ptr(), (uint *)hash_table_offsets.data_ptr(),
encoded.data_ptr(), true, weights.data_ptr<float>(), (uint *)indices.data_ptr());
return std::make_tuple(encoded, weights, indices);
}
\ No newline at end of file
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "cutil_math.h" // required for float3 vector math
#include "utils.h"
namespace debug {
template <uint DIMS> __device__ uint fast_hash(const uvec<DIMS> gpos, const uint hashmap_size) {
static_assert(DIMS <= 7, "fast_hash can only hash up to 7 dimensions.");
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory
// coherence and is sufficient for our use case of obtaining a uniformly colliding index from
// high-dimensional coordinates.
constexpr uint primes[7] = {1, 2654435761, 805459861, 3674653429,
2097192037, 1434869437, 2165219737};
uint result = gpos[0];
#pragma unroll
for (uint dim = 1; dim < DIMS; ++dim)
result ^= gpos[dim] * primes[dim];
return result % hashmap_size;
}
template <uint DIMS> __device__ uint gidx(const uvec<DIMS> gpos, const uvec<DIMS> res) {
uint index = gpos[0] * res[1] + gpos[1];
#pragma unroll
for (uint dim = 2; dim < DIMS; ++dim)
index = index * res[dim] + gpos[dim];
return index;
}
__global__ void multires_hash_encode_kernel(const uint n, const uint coarse_levels,
const uvec<3> *res_list, const fvec<2> *hash_table,
const uint *hash_table_offsets, const fvec<3> *x,
fvec<2> *o_encoded, fvec<3> *o_local_pos,
uint *o_idx) {
const uint i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= n)
return;
const uint level = blockIdx.y;
const uint hash_table_offset = hash_table_offsets[level];
const uint hash_table_size = hash_table_offsets[level + 1] - hash_table_offset;
const uvec<3> res = res_list[level];
hash_table += hash_table_offset;
fvec<3> pos = x[i];
uvec<3> gpos;
#pragma unroll
for (uint dim = 0; dim < 3; ++dim) {
pos[dim] *= res[dim] - 1;
gpos[dim] = (uint)floor(pos[dim]);
pos[dim] -= gpos[dim];
}
// TODO: Debug codes
o_local_pos[n * level + i] = pos;
auto grid_idx = [&](const uvec<3> gpos) {
uint idx;
if (level >= coarse_levels)
idx = fast_hash(gpos, hash_table_size);
else
idx = gidx(gpos, res);
return idx;
};
// N-linear interpolation
fvec<2> result = {};
#pragma unroll
for (uint corner_idx = 0; corner_idx < (1 << 3); ++corner_idx) {
float weight = 1;
uvec<3> corner_gpos;
#pragma unroll
for (uint dim = 0; dim < 3; ++dim) {
if ((corner_idx & (1 << dim)) == 0) {
weight *= 1 - pos[dim];
corner_gpos[dim] = gpos[dim];
} else {
weight *= pos[dim];
corner_gpos[dim] = min(gpos[dim] + 1, res[dim] - 1);
}
}
auto idx = grid_idx(corner_gpos);
auto val = hash_table[idx];
o_idx[level * n * 8 + i * 8 + corner_idx] = idx;
#pragma unroll
for (uint feature = 0; feature < 2; ++feature) {
result[feature] += weight * val[feature];
}
}
o_encoded[level * n + i] = result;
}
} // namespace debug
std::tuple<at::Tensor, at::Tensor, at::Tensor>
multires_hash_encode_debug(const int levels, const int coarse_levels, at::Tensor x,
at::Tensor res_list, at::Tensor hash_table,
at::Tensor hash_table_offsets) {
const uint n = x.size(0);
const uint dims = x.size(1);
const uint feature_dims = hash_table.size(-1);
res_list = res_list.to(at::kInt);
hash_table_offsets = hash_table_offsets.to(at::kInt);
at::Tensor encoded =
torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype()));
at::Tensor local_pos =
torch::empty({levels, n, dims}, at::device(x.device()).dtype(at::kFloat));
at::Tensor idxs = torch::empty({levels, n, 8}, at::device(x.device()).dtype(at::kInt));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const uint threads = opt_n_threads(n);
const dim3 blocks = {(uint)ceil((float)n / threads), (uint)levels, 1};
debug::multires_hash_encode_kernel<<<blocks, threads, 0, stream>>>(
n, (uint)coarse_levels, (uvec<3> *)res_list.data_ptr(), (fvec<2> *)hash_table.data_ptr(),
(uint *)hash_table_offsets.data_ptr(), (fvec<3> *)x.data_ptr(),
(fvec<2> *)encoded.data_ptr(), (fvec<3> *)local_pos.data_ptr(),
(uint *)idxs.data_ptr());
return std::make_tuple(encoded.transpose(0, 1).reshape({n, -1}),
local_pos.transpose(0, 1).unsqueeze(-2),
idxs.to(at::kLong).transpose(0, 1));
}
\ No newline at end of file
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "cutil_math.h" // required for float3 vector math
template <uint DIMS> __device__ uint fast_hash(const uvec<DIMS> gpos, const uint hashmap_size) {
static_assert(DIMS <= 7, "fast_hash can only hash up to 7 dimensions.");
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory
// coherence and is sufficient for our use case of obtaining a uniformly colliding index from
// high-dimensional coordinates.
constexpr uint primes[7] = {1, 2654435761, 805459861, 3674653429,
2097192037, 1434869437, 2165219737};
uint result = gpos[0];
#pragma unroll
for (uint dim = 1; dim < DIMS; ++dim)
result ^= gpos[dim] * primes[dim];
return result % hashmap_size;
}
template <uint DIMS> __device__ uint gidx(const uvec<DIMS> gpos, const uvec<DIMS> res) {
uint index = gpos[0] * res[1] + gpos[1];
#pragma unroll
for (uint dim = 2; dim < DIMS; ++dim)
index = index * res[dim] + gpos[dim];
return index;
}
template <typename T, uint DIMS, uint FEATURE_DIMS>
__global__ void multires_hash_encode_kernel(const uint n, const uint coarse_levels,
const uvec<DIMS> *__restrict__ res_list,
const vec<T, FEATURE_DIMS> *__restrict__ hash_table,
const uint *__restrict__ hash_table_offsets,
const fvec<DIMS> *__restrict__ x,
vec<T, FEATURE_DIMS> *__restrict__ o_encoded,
const bool requires_grad, float *__restrict__ o_weights,
uint *__restrict__ o_indices) {
const uint i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= n)
return;
const uint level = blockIdx.y;
const uint hash_table_offset = hash_table_offsets[level];
const uint hash_table_size = hash_table_offsets[level + 1] - hash_table_offset;
const uvec<DIMS> res = res_list[level];
hash_table += hash_table_offset;
fvec<DIMS> pos = x[i];
uvec<DIMS> gpos;
#pragma unroll
for (uint dim = 0; dim < DIMS; ++dim) {
pos[dim] *= res[dim] - 1;
gpos[dim] = (uint)floor(pos[dim]);
pos[dim] -= gpos[dim];
}
auto hash_idx = [&](const uvec<DIMS> gpos) {
uint idx;
if (level >= coarse_levels)
idx = fast_hash(gpos, hash_table_size);
else
idx = gidx(gpos, res);
return idx;
};
// N-linear interpolation
vec<T, FEATURE_DIMS> result = {};
auto n_corners = (1 << DIMS);
#pragma unroll
for (uint corner_idx = 0; corner_idx < n_corners; ++corner_idx) {
float weight = 1;
uvec<DIMS> corner_gpos;
#pragma unroll
for (uint dim = 0; dim < DIMS; ++dim) {
if ((corner_idx & (1 << dim)) == 0) {
weight *= 1 - pos[dim];
corner_gpos[dim] = gpos[dim];
} else {
weight *= pos[dim];
corner_gpos[dim] = gpos[dim] + 1;
}
}
auto idx = hash_idx(corner_gpos);
auto val = hash_table[idx];
#pragma unroll
for (uint feature = 0; feature < FEATURE_DIMS; ++feature) {
result[feature] += (T)(weight * (float)val[feature]);
}
// For backward
if (requires_grad) {
auto j = (level * n + i) * n_corners + corner_idx;
o_indices[j] = idx + hash_table_offset;
o_weights[j] = weight;
}
}
o_encoded[level * n + i] = result;
}
template <typename T, uint FEATURE_DIMS>
void multires_hash_encode_kernel_wrapper(const uint n, const uint levels, const uint coarse_levels,
const uint dims, const float *x, const uint *res_list,
const vec<T, FEATURE_DIMS> *hash_table,
const uint *hash_table_offsets,
vec<T, FEATURE_DIMS> *o_encoded, const bool requires_grad,
float *o_weights, uint *o_indices) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const uint threads = opt_n_threads(n);
const dim3 blocks = {(uint)ceil((float)n / threads), levels, 1};
#define DISPATCH_KERNEL_CASE(__DIMS__) \
case __DIMS__: \
multires_hash_encode_kernel<<<blocks, threads, 0, stream>>>( \
n, coarse_levels, (const uvec<__DIMS__> *)res_list, hash_table, hash_table_offsets, \
(const fvec<__DIMS__> *)x, o_encoded, requires_grad, o_weights, o_indices); \
break;
switch (dims) {
DISPATCH_KERNEL_CASE(2)
DISPATCH_KERNEL_CASE(3)
default:
throw std::invalid_argument("'dims' should be 2 or 3");
}
CUDA_CHECK_ERRORS();
#undef DISPATCH_KERNEL_CASE
}
template <typename T>
void multires_hash_encode_kernel_wrapper(const uint n, const uint levels, const uint coarse_levels,
const uint dims, const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices) {
#define KERNEL_WRAPPER_CASE(__FEATURE_DIMS__) \
case __FEATURE_DIMS__: \
multires_hash_encode_kernel_wrapper( \
n, levels, coarse_levels, dims, x, res_list, \
(const vec<T, __FEATURE_DIMS__> *)hash_table, hash_table_offsets, \
(vec<T, __FEATURE_DIMS__> *)o_encoded, requires_grad, o_weights, o_indices); \
break;
switch (feature_dims) {
KERNEL_WRAPPER_CASE(1)
KERNEL_WRAPPER_CASE(2)
KERNEL_WRAPPER_CASE(4)
KERNEL_WRAPPER_CASE(8)
KERNEL_WRAPPER_CASE(16)
default:
throw std::invalid_argument("'feature_dims' should be 1, 2, 4, 8, 16");
}
#undef KERNEL_WRAPPER_CASE
}
#if !defined(__CUDA_NO_HALF_CONVERSIONS__)
void multires_hash_encode_kernel_wrapper_halfp(const uint n, const uint levels,
const uint coarse_levels, const uint dims,
const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices) {
multires_hash_encode_kernel_wrapper<__half>(n, levels, coarse_levels, dims, feature_dims, x,
res_list, hash_table, hash_table_offsets, o_encoded,
requires_grad, o_weights, o_indices);
}
#endif
void multires_hash_encode_kernel_wrapper_fullp(const uint n, const uint levels,
const uint coarse_levels, const uint dims,
const uint feature_dims, const float *x,
const uint *res_list, const void *hash_table,
const uint *hash_table_offsets, void *o_encoded,
const bool requires_grad, float *o_weights,
uint *o_indices) {
multires_hash_encode_kernel_wrapper<float>(n, levels, coarse_levels, dims, feature_dims, x,
res_list, hash_table, hash_table_offsets, o_encoded,
requires_grad, o_weights, o_indices);
}
\ No newline at end of file
import torch import torch
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from typing import Any, List, Mapping, Tuple from typing import Any
from torch import nn from torch import nn
from utils.view import * from utils.view import *
from utils import math from utils import math
...@@ -10,28 +10,29 @@ from .foveation import Foveation ...@@ -10,28 +10,29 @@ from .foveation import Foveation
class FoveatedNeuralRenderer(object): class FoveatedNeuralRenderer(object):
def __init__(self, layers_fov: List[float], def __init__(self, layers_fov: list[float],
layers_res: List[Tuple[int, int]], layers_res: list[tuple[int, int]],
layers_net: nn.ModuleList, layers_net: nn.ModuleList,
output_res: Tuple[int, int], *, output_res: tuple[int, int], *,
coord_sys: str = "gl",
device: torch.device = None): device: torch.device = None):
super().__init__() super().__init__()
self.layers_net = layers_net.to(device=device) self.layers_net = layers_net.to(device=device)
self.layers_cam = [ self.layers_cam = [
CameraParam({ Camera.create({
'fov': layers_fov[i], 'fov': layers_fov[i],
'cx': 0.5, 'cx': 0.5,
'cy': 0.5, 'cy': 0.5,
'normalized': True 'normalized': True
}, layers_res[i], device=device) }, layers_res[i], coord_sys=coord_sys, device=device)
for i in range(len(layers_fov)) for i in range(len(layers_fov))
] ]
self.cam = CameraParam({ self.cam = Camera.create({
'fov': layers_fov[-1], 'fov': layers_fov[-1],
'cx': 0.5, 'cx': 0.5,
'cy': 0.5, 'cy': 0.5,
'normalized': True 'normalized': True
}, output_res, device=device) }, output_res, coord_sys=coord_sys, device=device)
self.foveation = Foveation(layers_fov, layers_res, output_res, device=device) self.foveation = Foveation(layers_fov, layers_res, output_res, device=device)
self.device = device self.device = device
...@@ -44,14 +45,11 @@ class FoveatedNeuralRenderer(object): ...@@ -44,14 +45,11 @@ class FoveatedNeuralRenderer(object):
self.device = device self.device = device
return self return self
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, view: Trans, gaze, right_gaze=None, *,
return self.render(*args, **kwds) stereo_disparity: float = 0,
using_mask: bool = True,
def render(self, view: Trans, gaze, right_gaze=None, *, mono_periph_mode: int = 0,
stereo_disparity=0, ret_raw: bool = False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
using_mask=True,
mono_periph_mode=0,
ret_raw=False) -> Union[Mapping[str, torch.Tensor], Tuple[Mapping[str, torch.Tensor]]]:
if stereo_disparity > math.tiny: if stereo_disparity > math.tiny:
left_view = Trans( left_view = Trans(
view.trans_point(torch.tensor([-stereo_disparity / 2, 0, 0], device=self.device)), view.trans_point(torch.tensor([-stereo_disparity / 2, 0, 0], device=self.device)),
...@@ -71,7 +69,7 @@ class FoveatedNeuralRenderer(object): ...@@ -71,7 +69,7 @@ class FoveatedNeuralRenderer(object):
layer_mask=layers_mask[0])['color'] layer_mask=layers_mask[0])['color']
fovea_right = self._render(self.layers_net[0], self.layers_cam[0], right_view, right_gaze, fovea_right = self._render(self.layers_net[0], self.layers_cam[0], right_view, right_gaze,
layer_mask=layers_mask[0])['color'] layer_mask=layers_mask[0])['color']
if mono_periph_mode == 3: if mono_periph_mode == 3 or mono_periph_mode == 4:
mid = self._render(self.layers_net[1], self.layers_cam[1], view, mid = self._render(self.layers_net[1], self.layers_cam[1], view,
((left_gaze[0] + right_gaze[0]) // 2, left_gaze[1]), ((left_gaze[0] + right_gaze[0]) // 2, left_gaze[1]),
layer_mask=layers_mask[1])['color'] layer_mask=layers_mask[1])['color']
...@@ -79,8 +77,8 @@ class FoveatedNeuralRenderer(object): ...@@ -79,8 +77,8 @@ class FoveatedNeuralRenderer(object):
raw_left = [fovea_left, mid, periph] raw_left = [fovea_left, mid, periph]
raw_right = [fovea_right, mid, periph] raw_right = [fovea_right, mid, periph]
shift = int(left_gaze[0] - right_gaze[0]) // 2 shift = int(left_gaze[0] - right_gaze[0]) // 2
left_shifts = [0, 0, shift] left_shifts = [0, 0, shift if mono_periph_mode == 3 else 0]
right_shifts = [0, 0, -shift] right_shifts = [0, 0, -shift if mono_periph_mode == 3 else 0]
else: else:
mid_left = self._render_mid(self.layers_net[1], self.layers_cam[1], left_view, left_gaze, mid_left = self._render_mid(self.layers_net[1], self.layers_cam[1], left_view, left_gaze,
layer_mask=layers_mask[1], mono_view=view, layer_mask=layers_mask[1], mono_view=view,
...@@ -115,17 +113,22 @@ class FoveatedNeuralRenderer(object): ...@@ -115,17 +113,22 @@ class FoveatedNeuralRenderer(object):
] ]
return self._gen_output(res_raw, gaze, ret_raw=ret_raw) return self._gen_output(res_raw, gaze, ret_raw=ret_raw)
def _render(self, net, cam: CameraParam, view: Trans, gaze=None, *, def _render(self, net, cam: Camera, view: Trans, gaze=None, *,
ret_depth=False, ret_depth=False, layer_mask=None) -> dict[str, torch.Tensor]:
layer_mask=None) -> Mapping[str, torch.Tensor]: output_types = ["color"]
if ret_depth:
output_types.append("depth")
if gaze is not None: if gaze is not None:
cam = self._adjust_cam(cam, gaze) cam = self._adjust_cam(cam, gaze)
rays_o, rays_d = cam.get_global_rays(view, False) # (1, H, W, 3) rays_d = view.trans_vector(cam.local_rays.reshape(*cam.res, -1)) # (1, H, W, 3)
rays_o = view.t.broadcast_to(rays_d.shape)
if layer_mask is not None: if layer_mask is not None:
infer_mask = layer_mask >= 0 infer_mask = layer_mask >= 0
rays_o = rays_o[:, infer_mask] net_input = Rays({
rays_d = rays_d[:, infer_mask] "rays_o": rays_o[:, infer_mask].reshape(-1, 3),
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) "rays_d": rays_d[:, infer_mask].reshape(-1, 3)
})
net_output = net(net_input, *output_types)
ret = { ret = {
'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device)
} }
...@@ -136,17 +139,21 @@ class FoveatedNeuralRenderer(object): ...@@ -136,17 +139,21 @@ class FoveatedNeuralRenderer(object):
ret['depth'][:, infer_mask] = net_output['depth'] ret['depth'][:, infer_mask] = net_output['depth']
return ret return ret
else: else:
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) net_input = Rays({
"rays_o": rays_o.reshape(-1, 3),
"rays_d": rays_d.reshape(-1, 3)
})
net_output = net(net_input, *output_types)
return { return {
'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2),
'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None
} }
def _render_mid(self, net, cam: CameraParam, view: Trans, gaze=None, *, def _render_mid(self, net, cam: Camera, view: Trans, gaze=None, *,
layer_mask: torch.Tensor, layer_mask: torch.Tensor,
mono_view: Trans, mono_view: Trans,
blend_view: bool, blend_view: bool,
ret_depth=False) -> Mapping[str, torch.Tensor]: ret_depth=False) -> dict[str, torch.Tensor]:
""" """
[summary] [summary]
...@@ -159,18 +166,23 @@ class FoveatedNeuralRenderer(object): ...@@ -159,18 +166,23 @@ class FoveatedNeuralRenderer(object):
:param ret_depth: [description], defaults to False :param ret_depth: [description], defaults to False
:return: [description] :return: [description]
""" """
output_types = ["color"]
if ret_depth:
output_types.append("depth")
if gaze is not None: if gaze is not None:
cam = self._adjust_cam(cam, gaze) cam = self._adjust_cam(cam, gaze)
k = layer_mask[None, ..., None].clamp(1 if blend_view else 2, 2) - 1 # (1, H, W, 1) k = layer_mask[None, ..., None].clamp(1 if blend_view else 2, 2) - 1 # (1, H, W, 1)
rays_o = (1 - k) * view.t + k * mono_view.t # (1, H, W, 3) rays_o = (1 - k) * view.t + k * mono_view.t # (1, H, W, 3)
rays_d = view.trans_vector(cam.get_local_rays()) # (1, H, W, 3) rays_d = view.trans_vector(cam.local_rays.reshape(*cam.res, -1)) # (1, H, W, 3)
if layer_mask is not None: if layer_mask is not None:
infer_mask = layer_mask >= 0 infer_mask = layer_mask >= 0
rays_o = rays_o[:, infer_mask] net_input = Rays({
rays_d = rays_d[:, infer_mask] "rays_o": rays_o[:, infer_mask].reshape(-1, 3),
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) "rays_d": rays_d[:, infer_mask].reshape(-1, 3)
})
net_output = net(net_input, *output_types)
ret = { ret = {
'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device)
} }
...@@ -181,13 +193,18 @@ class FoveatedNeuralRenderer(object): ...@@ -181,13 +193,18 @@ class FoveatedNeuralRenderer(object):
ret['depth'][:, infer_mask] = net_output['depth'] ret['depth'][:, infer_mask] = net_output['depth']
return ret return ret
else: else:
net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) net_input = {
"rays_o": rays_o.reshape(-1, 3),
"rays_d": rays_d.reshape(-1, 3)
}
net_output = net(net_input, *output_types)
return { return {
'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2),
'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None
} }
def _gen_output(self, layers_img: List[torch.Tensor], gaze: Tuple[float, float], shifts=None, ret_raw=False) -> Mapping[str, torch.Tensor]: def _gen_output(self, layers_img: list[torch.Tensor], gaze: tuple[float, float], shifts=None,
ret_raw=False) -> dict[str, torch.Tensor]:
refined = self._post_process(layers_img) refined = self._post_process(layers_img)
blended = self.foveation.synthesis(refined, gaze, shifts) blended = self.foveation.synthesis(refined, gaze, shifts)
ret = { ret = {
...@@ -196,10 +213,10 @@ class FoveatedNeuralRenderer(object): ...@@ -196,10 +213,10 @@ class FoveatedNeuralRenderer(object):
} }
if ret_raw: if ret_raw:
ret['layers_raw'] = layers_img ret['layers_raw'] = layers_img
ret['blended_raw'] = self.foveation.synthesis(layers_img, gaze) ret['blended_raw'] = self.foveation.synthesis(layers_img, gaze, shifts)
return ret return ret
def _post_process(self, layers_img: List[torch.Tensor]) -> List[torch.Tensor]: def _post_process(self, layers_img: list[torch.Tensor]) -> list[torch.Tensor]:
return [ return [
#grad_aware_median(constrast_enhance(layers_img[0], 3, 0.2), 3, 3, True), #grad_aware_median(constrast_enhance(layers_img[0], 3, 0.2), 3, 3, True),
constrast_enhance(layers_img[0], 3, 0.2), constrast_enhance(layers_img[0], 3, 0.2),
...@@ -207,20 +224,18 @@ class FoveatedNeuralRenderer(object): ...@@ -207,20 +224,18 @@ class FoveatedNeuralRenderer(object):
constrast_enhance(layers_img[2], 5, 0.2) constrast_enhance(layers_img[2], 5, 0.2)
] ]
def _adjust_cam(self, layer_cam: CameraParam, gaze: Tuple[float, float]) -> CameraParam: def _adjust_cam(self, layer_cam: Camera, gaze: tuple[float, float]) -> Camera:
fovea_offset = ( fovea_offset = (
(gaze[0]) / self.cam.f[0].item() * layer_cam.f[0].item(), (gaze[0]) / self.cam.f[0].item() * layer_cam.f[0].item(),
(gaze[1]) / self.cam.f[1].item() * layer_cam.f[1].item() (gaze[1]) / self.cam.f[1].item() * layer_cam.f[1].item()
) )
return CameraParam({ return Camera.create({
'fx': layer_cam.f[0].item(), 'f': [layer_cam.f[0].item(), layer_cam.f[1].item()],
'fy': layer_cam.f[1].item(), 'c': [layer_cam.c[0].item() - fovea_offset[0], layer_cam.c[1].item() - fovea_offset[1]]
'cx': layer_cam.c[0].item() - fovea_offset[0], }, layer_cam.res, coord_sys=layer_cam.coord_sys, device=self.device)
'cy': layer_cam.c[1].item() - fovea_offset[1]
}, layer_cam.res, device=self.device)
def _warp(self, trans: Trans, trans0: Trans, def _warp(self, trans: Trans, trans0: Trans,
cam: CameraParam, z_list: torch.Tensor, cam: Camera, z_list: torch.Tensor,
image: torch.Tensor, depthmap: torch.Tensor) -> torch.Tensor: image: torch.Tensor, depthmap: torch.Tensor) -> torch.Tensor:
""" """
[summary] [summary]
......
...@@ -8,8 +8,8 @@ from utils import math ...@@ -8,8 +8,8 @@ from utils import math
class Foveation(object): class Foveation(object):
def __init__(self, layers_fov: List[float], layers_res: List[Tuple[float, float]], def __init__(self, layers_fov: list[float], layers_res: list[tuple[float, float]],
out_res: Tuple[int, int], *, blend: float = 0.6, device: torch.device = None): out_res: tuple[int, int], *, blend: float = 0.6, device: torch.device = None):
self.layers_fov = layers_fov self.layers_fov = layers_fov
self.layers_res = layers_res self.layers_res = layers_res
self.out_res = out_res self.out_res = out_res
...@@ -20,15 +20,15 @@ class Foveation(object): ...@@ -20,15 +20,15 @@ class Foveation(object):
self._gen_layer_blendmap(i) self._gen_layer_blendmap(i)
for i in range(self.n_layers - 1) for i in range(self.n_layers - 1)
] # blend maps of fovea layers ] # blend maps of fovea layers
self.coords = misc.meshgrid(*out_res).to(device=device) self.coords = misc.grid2d(*out_res, device=device)
def to(self, device: torch.device): def to(self, device: torch.device):
self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend] self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend]
self.coords = self.coords.to(device=device) self.coords = self.coords.to(device=device)
return self return self
def synthesis(self, layers: List[torch.Tensor], fovea_center: Tuple[float, float], def synthesis(self, layers: list[torch.Tensor], fovea_center: tuple[float, float],
shifts: List[int] = None, shifts: list[int] = None,
do_blend: bool = True, do_blend: bool = True,
crop_mode: bool = False) -> torch.Tensor: crop_mode: bool = False) -> torch.Tensor:
""" """
...@@ -40,6 +40,7 @@ class Foveation(object): ...@@ -40,6 +40,7 @@ class Foveation(object):
""" """
output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res, output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
mode='bilinear', align_corners=False) mode='bilinear', align_corners=False)
#output.fill_(0) # TODO: debug
if shifts is not None: if shifts is not None:
output = img.horizontal_shift(output, shifts[-1]) output = img.horizontal_shift(output, shifts[-1])
c = torch.tensor([ c = torch.tensor([
...@@ -99,11 +100,11 @@ class Foveation(object): ...@@ -99,11 +100,11 @@ class Foveation(object):
""" """
size = self.get_layer_size_in_final_image(i) size = self.get_layer_size_in_final_image(i)
R = size / 2 R = size / 2
p = misc.meshgrid(size, size).to(device=self.device) # (size, size, 2) p = misc.grid2d(size, device=self.device) # (size, size, 2)
r = torch.norm(p - R, dim=2) # (size, size, 2) r = torch.norm(p - R, dim=2) # (size, size, 2)
return misc.smooth_step(R, R * self.blend, r) return misc.smooth_step(R, R * self.blend, r)
def get_layers_mask(self, gaze=None) -> List[torch.Tensor]: def get_layers_mask(self, gaze=None) -> list[torch.Tensor]:
""" """
Generate mask images for layers[:-1] Generate mask images for layers[:-1]
the meaning of values in mask images: the meaning of values in mask images:
...@@ -127,8 +128,7 @@ class Foveation(object): ...@@ -127,8 +128,7 @@ class Foveation(object):
else: else:
c = torch.tensor([0.5, 0.5], device=self.device) c = torch.tensor([0.5, 0.5], device=self.device)
layers_mask.append(torch.ones(*self.layers_res[i], device=self.device) * -1) layers_mask.append(torch.ones(*self.layers_res[i], device=self.device) * -1)
coord = misc.meshgrid( coord = misc.grid2d(*self.layers_res[i], device=self.device) / self.layers_res[i][0]
*self.layers_res[i]).to(device=self.device) / self.layers_res[i][0]
r = 2 * torch.norm(coord - c, dim=-1) r = 2 * torch.norm(coord - c, dim=-1)
inner_radius = self.get_source_layer_cover_size_in_target_layer( inner_radius = self.get_source_layer_cover_size_in_target_layer(
self.layers_fov[i - 1], self.layers_fov[i], self.layers_res[i][0]) / self.layers_res[i][0] \ self.layers_fov[i - 1], self.layers_fov[i], self.layers_res[i][0]) / self.layers_res[i][0] \
......
...@@ -7,7 +7,7 @@ from utils import math ...@@ -7,7 +7,7 @@ from utils import math
class GuideRefinement(object): class GuideRefinement(object):
def __init__(self, guides_image, guides_view: view.Trans, def __init__(self, guides_image, guides_view: view.Trans,
guides_cam: view.CameraParam, net) -> None: guides_cam: view.Camera, net) -> None:
rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True) rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True)
guides_inferred = torch.stack([ guides_inferred = torch.stack([
net(rays_o[i], rays_d[i]).view( net(rays_o[i], rays_d[i]).view(
......
from typing import SupportsFloat
from model import Model
from utils.view import *
from utils.types import *
def render(model: Model, cam: Camera, view: Trans, *output_types: str,
gaze: tuple[float, float] = (0, 0), extra_input: dict = None,
layer_mask: torch.Tensor = None, batch_size: int = None) -> ReturnData:
if len(output_types) == 0:
raise ValueError("'output_types' is empty")
local_rays = cam.local_rays + cam.local_rays.new_tensor([*gaze, 0]) # (H*W, 3)
rays_d = view.trans_vector(local_rays) # (B..., H*W, 3)
rays_o = view.t[..., None, :].expand_as(rays_d)
print(cam.local_rays)
exit()
input = Rays(rays_o=rays_o, rays_d=rays_d, **extra_input or {}) # (B..., H*W)
if layer_mask is not None:
selector = layer_mask.flatten().ge(0).nonzero()
input = input.transform(lambda value: value.index_select(len(input.shape), selector))
input = input.flatten() # (B..., X) -> (N)
output = ReturnData() # will be (N)
n = input.shape[0]
batch_size = batch_size or n
for offset in range(0, n, batch_size):
batch_slice = slice(offset, min(offset + batch_size, n))
batch_output = model(input.select(batch_slice), *output_types)
for key, value in batch_output.items():
if key == "rays_filter":
continue
match value:
case torch.Tensor():
if key not in output:
output[key] = value.new_full([n, *value.shape[1:]],
math.huge * (key == "depth"))
if 'rays_filter' in batch_output:
output[key][batch_slice][batch_output['rays_filter']] = batch_output[key]
else:
output[key][batch_slice] = batch_output[key]
case SupportsFloat():
output[key] = output.get(key, 0) + value
case _:
output[key] = output.get(key, []) + [value]
output = output.reshape(*view.shape, -1) # (N) -> (B..., X)
if layer_mask is not None:
output = output.transform(lambda value:
value.new_zeros(*view.shape, local_rays.shape[0],
*value.shape[len(view.shape) + 1:])
.index_copy(len(view.shape), selector, value))
return output.reshape(*view.shape, *cam.res) # (B..., H*W) -> (B..., H, W)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
}, },
"sample_range": [1, 7], "sample_range": [1, 7],
"n_samples": 64, "n_samples": 64,
"multi_nets": 4 "multi_nets": 4,
"density_regularization_weight": 1e-4, "density_regularization_weight": 1e-4,
"density_regularization_scale": 1e4 "density_regularization_scale": 1e4
}, },
......
model=FsNeRF
; n_samples=64
; n_fields=1
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=50
checkpoint_interval=10
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
sample_mode=spherical_radius
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=20
checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
n_samples=64
; sample_mode=xyz
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
; n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=10
checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; n_samples=64
n_fields=4
; depth=8
; width=256
skips=[]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
; checkpoint_interval=10000(iters) or 10(epochs)
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; color=rgb
; n_samples=64
n_fields=4
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
checkpoint_interval=10
; batch_size=4096
; loss=[Color_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; color=rgb
n_samples=256
n_fields=2
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
checkpoint_interval=10
; batch_size=4096
loss=[Color_L2]
; lr=5e-4
; lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=FsNeRF
; color=rgb
; n_samples=64
n_fields=4
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; xfreqs=6
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=30
checkpoint_interval=50
; batch_size=4096
loss=[Color_L2]
; lr=5e-4
; lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
; sample_mode=xyz
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=20
checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
; sample_mode=xyz
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=0.
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
max_iters=200000
; max_epochs=20
; checkpoint_interval=5
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999954
; profile_iters=0
\ No newline at end of file
model=NeRF
; color=rgb
; n_samples=64
sample_mode=xyz_disp
perturb_sampling=true
; depth=8
; width=256
; skips=[4]
; act=relu
; ln=false
; color_decoder=NeRF
n_importance=128
; fine_depth=8
; fine_width=256
; fine_skips=[4]
; xfreqs=10
; dfreqs=4
; raw_noise_std=1e0
; near: float # from dataset
; far: float # from dataset
; white_bg: bool # from dataset
; coord: str # from dataset
; trainer=Trainer
; max_iters=200000
max_epochs=50
checkpoint_interval=10
; batch_size=4096
; loss=[Color_L2, CoarseColor_L2]
; lr=5e-4
lr_decay=0.9999908
; profile_iters=0
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment