Commit 5d1d329d authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f6604bd2
part0.py
\ No newline at end of file
part0.py
\ No newline at end of file
part0.py
\ No newline at end of file
part0.py
\ No newline at end of file
part0.py
\ No newline at end of file
part0.py
\ No newline at end of file
part0.py
\ No newline at end of file
from my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nmsl'
config.N_ENCODE_DIM = 10
config.NET_TYPE = 'snerffast4'
config.N_ENCODE_DIM = 6
#config.N_DIR_ENCODE = 4
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 4
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
'n_samples': 64,
'perturb_sample': False
})
from my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nnmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 4
......
def update_config(config):
# Net parameters
config.NET_TYPE = 'snerffast4'
config.N_ENCODE_DIM = 6
#config.N_DIR_ENCODE = 4
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 7),
'n_samples': 64,
'perturb_sample': False
})
def update_config(config):
# Net parameters
config.NET_TYPE = 'snerffastx4'
config.N_ENCODE_DIM = 6
#config.N_DIR_ENCODE = 4
config.FC_PARAMS.update({
'nf': 512,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (0.3, 7),
'n_samples': 128,
'perturb_sample': False
})
import os
import importlib
from my import color_mode
from utils.constants import *
from utils import color
from nets.msl_net import MslNet
from nets.msl_net_new import NewMslNet
from nets.msl_ray import MslRay
from nets.msl_fast import MslFast
from nets.snerf_fast import SnerfFast
from nets.cnerf_v3 import CNerf
from nets.nerf import CascadeNerf
from nets.nerf import CascadeNerf2
from nets.nnerf import NNerf
from nets.nerf_depth import NerfDepth
from nets.bg_net import BgNet
from nets.oracle import Oracle
class SphericalViewSynConfig(object):
......@@ -10,18 +21,19 @@ class SphericalViewSynConfig(object):
def __init__(self):
self.name = 'default'
self.COLOR = color_mode.RGB
self.COLOR = color.RGB
# Net parameters
self.NET_TYPE = 'msl'
self.N_ENCODE_DIM = 10
self.N_DIR_ENCODE = None
self.NORMALIZE = False
self.DIR_AS_INPUT = False
self.OPT_DECAY = 0
self.DEPTH_REF = False
self.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': []
'skips': [],
'activation': 'relu'
}
self.SAMPLE_PARAMS = {
'spherical': True,
......@@ -31,6 +43,12 @@ class SphericalViewSynConfig(object):
'lindisp': True,
'inverse_r': True,
}
self.NERF_FINE_NET_PARAMS = {
'enable': False,
'nf': 256,
'n_layers': 8,
'additional_samples': 64
}
def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.')
......@@ -45,17 +63,19 @@ class SphericalViewSynConfig(object):
self.name = name
def to_id(self):
net_type_id = "%s-%s" % (self.NET_TYPE, color_mode.to_str(self.COLOR))
encode_id = "_e%d" % self.N_ENCODE_DIM
fc_id = "_fc%dx%d" % (self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'])
skip_id = "_skip%s" % ','.join([
'%d' % val
for val in self.FC_PARAMS['skips']
]) if len(self.FC_PARAMS['skips']) > 0 else ""
net_type_id = f"{self.NET_TYPE}-{color.to_str(self.COLOR)}"
encode_id = f"_e{self.N_ENCODE_DIM}"
dir_encode_id = f"_ed{self.N_DIR_ENCODE}" if self.N_DIR_ENCODE else ''
fc_id = f"_fc{self.FC_PARAMS['nf']}x{self.FC_PARAMS['n_layers']}"
skip_id = "_skip%s" % ','.join(['%d' % val for val in self.FC_PARAMS['skips']]) \
if len(self.FC_PARAMS['skips']) > 0 else ""
act_id = f"_*{self.FC_PARAMS['activation']}" if self.FC_PARAMS['activation'] != 'relu' else ''
depth_id = "_d%.2f-%.2f" % (self.SAMPLE_PARAMS['depth_range'][0],
self.SAMPLE_PARAMS['depth_range'][1])
samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples']
opt_decay_id = '_decay%.1e' % self.OPT_DECAY if self.OPT_DECAY > 1e-5 else ''
samples_id = f"_s{self.SAMPLE_PARAMS['n_samples']}"
ffc_id = f"_ffc{self.NERF_FINE_NET_PARAMS['nf']}x{self.NERF_FINE_NET_PARAMS['n_layers']}"
fsamples_id = f"_fs{self.NERF_FINE_NET_PARAMS['additional_samples']}"
fine_id = f"{ffc_id}{fsamples_id}" if self.NERF_FINE_NET_PARAMS['enable'] else ''
neg_flags = '%s%s%s' % (
'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
'l' if not self.SAMPLE_PARAMS['lindisp'] else '',
......@@ -64,10 +84,14 @@ class SphericalViewSynConfig(object):
neg_flags = '_~' + neg_flags if neg_flags != '' else ''
pos_flags = '%s%s' % (
'n' if self.NORMALIZE else '',
'd' if self.DIR_AS_INPUT else '',
'd' if self.DEPTH_REF else ''
)
pos_flags = '_+' + pos_flags if pos_flags != '' else ''
return "%s@%s%s%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, opt_decay_id, neg_flags, pos_flags)
return "%s@%s%s%s%s%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, dir_encode_id,
fc_id, skip_id, act_id,
depth_id, samples_id,
fine_id,
neg_flags, pos_flags)
def from_id(self, id: str):
id_splited = id.split('@')
......@@ -75,6 +99,18 @@ class SphericalViewSynConfig(object):
self.name = id_splited[0]
segs = id_splited[-1].split('_')
for i, seg in enumerate(segs):
if seg.startswith('ffc'): # Full-connected network parameters
self.NERF_FINE_NET_PARAMS['nf'], self.NERF_FINE_NET_PARAMS['n_layers'] = (
int(str) for str in seg[3:].split('x'))
self.NERF_FINE_NET_PARAMS['enable'] = True
continue
if seg.startswith('fs'): # Number of samples
try:
self.NERF_FINE_NET_PARAMS['additional_samples'] = int(seg[2:])
self.NERF_FINE_NET_PARAMS['enable'] = True
continue
except ValueError:
pass
if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
int(str) for str in seg[2:].split('x'))
......@@ -83,19 +119,30 @@ class SphericalViewSynConfig(object):
self.FC_PARAMS['skips'] = [int(str)
for str in seg[4:].split(',')]
continue
if seg.startswith('decay'):
self.OPT_DECAY = float(seg[5:])
if seg.startswith('*'): # Activation
self.FC_PARAMS['activation'] = seg[1:]
continue
if seg.startswith('ed'): # Encode direction
self.N_DIR_ENCODE = int(seg[2:])
if self.N_DIR_ENCODE == 0:
self.N_DIR_ENCODE = None
continue
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:])
continue
if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple(
float(str) for str in seg[1:].split('-'))
continue
try:
self.SAMPLE_PARAMS['depth_range'] = tuple(
float(str) for str in seg[1:].split('-'))
continue
except ValueError:
pass
if seg.startswith('s'): # Number of samples
self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
continue
try:
self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
continue
except ValueError:
pass
if seg.startswith('~'): # Negative flags
if seg.find('p') >= 0:
self.SAMPLE_PARAMS['perturb_sample'] = False
......@@ -103,27 +150,39 @@ class SphericalViewSynConfig(object):
self.SAMPLE_PARAMS['lindisp'] = False
if seg.find('i') >= 0:
self.SAMPLE_PARAMS['inverse_r'] = False
if seg.find('n') >= 0:
self.NORMALIZE = False
if seg.find('d') >= 0:
self.DEPTH_REF = False
continue
if seg.startswith('+'): # Positive flags
if seg.find('p') >= 0:
self.SAMPLE_PARAMS['perturb_sample'] = True
if seg.find('l') >= 0:
self.SAMPLE_PARAMS['lindisp'] = True
if seg.find('i') >= 0:
self.SAMPLE_PARAMS['inverse_r'] = True
if seg.find('n') >= 0:
self.NORMALIZE = True
if seg.find('d') >= 0:
self.DIR_AS_INPUT = True
self.DEPTH_REF = True
continue
if i == 0: # NetType
self.NET_TYPE, color_str = seg.split('-')
self.COLOR = color_mode.from_str(color_str)
continue
self.COLOR = color.from_str(color_str)
def print(self):
print('==== Config %s ====' % self.name)
print('Net type: ', self.NET_TYPE)
print('Encode dim: ', self.N_ENCODE_DIM)
print('Optimizer decay: ', self.OPT_DECAY)
print('Normalize: ', self.NORMALIZE)
print('Direction as input: ', self.DIR_AS_INPUT)
print('Train with depth: ', self.DEPTH_REF)
print('Support direction: ', False if self.N_DIR_ENCODE is None
else f'encode to {self.N_DIR_ENCODE}')
print('Full-connected network parameters:', self.FC_PARAMS)
print('Sample parameters', self.SAMPLE_PARAMS)
if self.NERF_FINE_NET_PARAMS['enable']:
print('NeRF fine network parameters', self.NERF_FINE_NET_PARAMS)
print('==========================')
def create_net(self):
......@@ -131,9 +190,105 @@ class SphericalViewSynConfig(object):
return MslNet(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
dir_as_input=self.DIR_AS_INPUT,
color=self.COLOR,
c=self.COLOR,
encode_to_dim=self.N_ENCODE_DIM)
if self.NET_TYPE == 'mslray':
return MslRay(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
c=self.COLOR,
encode_to_dim=self.N_ENCODE_DIM)
if self.NET_TYPE == 'mslfast':
return MslFast(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
c=self.COLOR,
encode_to_dim=self.N_ENCODE_DIM)
if self.NET_TYPE == 'msl2fast':
return MslFast(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
c=self.COLOR,
encode_to_dim=self.N_ENCODE_DIM,
include_r=True)
if self.NET_TYPE == 'nerf':
return CascadeNerf(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
fine_params=self.NERF_FINE_NET_PARAMS,
normalize_coord=self.NORMALIZE,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
dir_encode=self.N_DIR_ENCODE)
if self.NET_TYPE == 'nerf2':
return CascadeNerf2(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
dir_encode=self.N_DIR_ENCODE)
if self.NET_TYPE == 'nerfbg':
return CascadeNerf(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
fine_params=self.NERF_FINE_NET_PARAMS,
normalize_coord=self.NORMALIZE,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
dir_encode=self.N_DIR_ENCODE,
bg_layer=True)
if self.NET_TYPE == 'bgnet':
return BgNet(fc_params=self.FC_PARAMS,
encode=self.N_ENCODE_DIM,
c=self.COLOR)
if self.NET_TYPE.startswith('oracle'):
return Oracle(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
coord_encode=self.N_ENCODE_DIM,
out_activation=self.NET_TYPE[6:] if len(self.NET_TYPE) > 6 else 'sigmoid')
if self.NET_TYPE.startswith('cnerf'):
return CNerf(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
n_bins=int(self.NET_TYPE[5:] if len(self.NET_TYPE) > 5 else 128))
if self.NET_TYPE.startswith('dnerfa'):
return NerfDepth(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
n_bins=int(self.NET_TYPE[7:] if len(self.NET_TYPE) > 7 else 128),
include_neighbor_bins=False)
if self.NET_TYPE.startswith('dnerf'):
return NerfDepth(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
n_bins=int(self.NET_TYPE[6:] if len(self.NET_TYPE) > 6 else 128))
if self.NET_TYPE.startswith('nnerf'):
return NNerf(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
n_nets=int(self.NET_TYPE[5:] if len(self.NET_TYPE) > 5 else 1),
normalize_coord=self.NORMALIZE,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
dir_encode=self.N_DIR_ENCODE)
if self.NET_TYPE.startswith('snerffastx'):
return SnerfFast(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
n_parts=int(self.NET_TYPE[10:] if len(self.NET_TYPE) > 10 else 1),
normalize_coord=self.NORMALIZE,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
dir_encode=self.N_DIR_ENCODE,
multiple_net=False)
if self.NET_TYPE.startswith('snerffast'):
return SnerfFast(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
n_parts=int(self.NET_TYPE[9:] if len(self.NET_TYPE) > 9 else 1),
normalize_coord=self.NORMALIZE,
c=self.COLOR,
coord_encode=self.N_ENCODE_DIM,
dir_encode=self.N_DIR_ENCODE)
if self.NET_TYPE.startswith('nmsl'):
n_nets = int(self.NET_TYPE[4:]) if len(self.NET_TYPE) > 4 else 2
if self.SAMPLE_PARAMS['n_samples'] % n_nets != 0:
......@@ -142,15 +297,6 @@ class SphericalViewSynConfig(object):
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
n_nets=n_nets,
dir_as_input=self.DIR_AS_INPUT,
color=self.COLOR,
encode_to_dim=self.N_ENCODE_DIM)
if self.NET_TYPE == 'nnmsl':
return NewMslNet(fc_params=self.FC_PARAMS,
sampler_params=self.SAMPLE_PARAMS,
normalize_coord=self.NORMALIZE,
dir_as_input=self.DIR_AS_INPUT,
not_same_net=True,
color=self.COLOR,
c=self.COLOR,
encode_to_dim=self.N_ENCODE_DIM)
raise ValueError('Invalid net type')
\ No newline at end of file
raise ValueError('Invalid net type')
......@@ -51,7 +51,7 @@ else ifeq ($(TARGET), qnx)
CUCC = $(CUDA_INSTALL_DIR)/bin/nvcc -m64 -ccbin $(CC)
else ifeq ($(TARGET), android64)
ifeq ($(ANDROID_CC),)
$(error ANDROID_CC must be set to the clang compiler to build for android 64bit, for example /path/to/my-toolchain/bin/aarch64-linux-android-clang++)
$(error ANDROID_CC must be set to the clang compiler to build for android 64bit, for example /path/to/utils-toolchain/bin/aarch64-linux-android-clang++)
endif
CUDA_LIBDIR = lib
ANDROID_FLAGS = -DANDROID -D_GLIBCXX_USE_C99=1 -Wno-sign-compare -D__aarch64__ -Wno-strict-aliasing -Werror -pie -fPIE -Wno-unused-command-line-argument
......
#include "Encoder.h"
#include "thread_index.h"
/// idx3.y = 0: x, y, z, sin(x), sin(y), sin(z), cos(x), cos(y), cos(z)
/// idx3.y = 1: sin(2x), sin(2y), sin(2z), cos(2x), cos(2y), cos(2z)
/// idx3.z = 0: x, y, z, sin(x), sin(y), sin(z), cos(x), cos(y), cos(z)
/// idx3.z = 1: sin(2x), sin(2y), sin(2z), cos(2x), cos(2y), cos(2z)
/// ...
/// idx3.y = n_freq-1: sin(2^(n_freq-1)x), sin(2^(n_freq-1)y), sin(2^(n_freq-1)z),
/// idx3.z = n_freq-1: sin(2^(n_freq-1)x), sin(2^(n_freq-1)y), sin(2^(n_freq-1)z),
/// cos(2^(n_freq-1)x), cos(2^(n_freq-1)y), cos(2^(n_freq-1)z)
/// Dispatch (n_freq, n_batch, 1)
__global__ void cu_encode(glm::vec3 *o_encoded, glm::vec3 *input, float *freqs, uint n)
/// Dispatch (n_batch, n_chns, n_freqs)
__global__ void cu_encode(float *output, float *input, float *freqs, uint n)
{
glm::uvec3 idx3 = IDX3;
if (idx3.y >= n)
if (idx3.x >= n)
return;
uint encode_dim = blockDim.x * 2 + 1;
uint offset = idx3.y * encode_dim;
uint n = blockDim.x, inChns = blockDim.y, nFreqs = blockDim.z;
uint i = idx3.x, chn = idx3.y, freq = idx3.z;
uint elem = i * inChns + chn;
uint outChns = inChns * (nFreqs * 2 + 1);
uint base = i * outChns + chn;
if (idx3.x == 0)
o_encoded[offset] = input[idx3.y];
glm::vec3 x = freqs[idx3.x] * input[idx3.y];
glm::vec3 s, c;
/*__sincosf(x.x, &s.x, &c.x);
__sincosf(x.y, &s.y, &c.y);
__sincosf(x.z, &s.z, &c.z);
o_encoded[offset + idx3.x * 2 + 1] = s;
o_encoded[offset + idx3.x * 2 + 2] = c;*/
o_encoded[offset + idx3.x * 2 + 1] = glm::sin(x);
o_encoded[offset + idx3.x * 2 + 2] = glm::cos(x);
output[base] = input[elem];
float x = freqs[freq] * input[elem];
float s, c;
__sincosf(x, &s, &c);
output[base + inChns * (freq * 2 + 1)] = s;
output[base + inChns * (freq * 2 + 2)] = c;
}
void Encoder::encode(sptr<CudaArray<float>> o_encoded, sptr<CudaArray<glm::vec3>> input)
void Encoder::encode(sptr<CudaArray<float>> output, sptr<CudaArray<float>> input)
{
dim3 blockSize(_multires, 1024 / _multires);
dim3 gridSize(1, (uint)ceil(input->n() / (float)blockSize.y));
cu_encode<<<gridSize, blockSize>>>((glm::vec3 *)o_encoded->getBuffer(),
*input, *_freqs, input->n());
dim3 blkSize(1024 / _chns / _multires, _chns, _multires);
dim3 grdSize((uint)ceil(input->n() / (float)blkSize.x), 1, 1);
cu_encode<<<grdSize, blkSize>>>(output->getBuffer(), *input, *_freqs, input->n());
CHECK_EX(cudaGetLastError());
}
......
......@@ -3,13 +3,14 @@
class Encoder {
public:
Encoder(uint multires) : _multires(multires) { _genFreqArray(); }
Encoder(uint multires, uint chns) : _multires(multires), _chns(chns) { _genFreqArray(); }
uint outDim() const { return _multires * 6 + 3; }
void encode(sptr<CudaArray<float>> o_encoded, sptr<CudaArray<glm::vec3>> input);
uint outDim() const { return _chns * (1 + _multires * 2); }
void encode(sptr<CudaArray<float>> output, sptr<CudaArray<float>> input);
private:
uint _multires;
uint _chns;
sptr<CudaArray<float>> _freqs;
void _genFreqArray();
......
......@@ -6,7 +6,7 @@ InferPipeline::InferPipeline(
uint samples) : _batchSize(batchSize),
_samples(samples),
_sampler(new Sampler({1.0f, 50.0f}, samples)),
_encoder(new Encoder(10)),
_encoder(new Encoder(10, 3)),
_renderer(new Renderer()),
_net(isNmsl ? new Nmsl2(batchSize, samples) : new Msl(batchSize, samples))
{
......@@ -31,8 +31,10 @@ void InferPipeline::run(sptr<CudaArray<glm::vec4>> o_colors,
_sampler->sampleOnRays(_sphericalCoords, _depths, rays, rayOrigin);
cudaEventRecord(eSampled);
_encoder->encode(_encoded, _sphericalCoords);
sptr<CudaArray<float>> coords(new CudaArray<float>((float *)_sphericalCoords->getBuffer(),
_sphericalCoords->n() * 3));
_encoder->encode(_encoded, coords);
cudaEventRecord(eEncoded);
......@@ -44,7 +46,8 @@ void InferPipeline::run(sptr<CudaArray<glm::vec4>> o_colors,
cudaEventRecord(eRendered);
if (showPerf) {
if (showPerf)
{
CHECK_EX(cudaDeviceSynchronize());
float timeTotal, timeSample, timeEncode, timeInfer, timeRender;
......@@ -60,7 +63,7 @@ void InferPipeline::run(sptr<CudaArray<glm::vec4>> o_colors,
<< timeInfer << "ms, Render: " << timeRender << "ms)";
Logger::instance.info(sout.str());
}
/*
/*
{
std::ostringstream sout;
sout << "Rays:" << std::endl;
......
import sys
import os
import argparse
import torch
......@@ -10,8 +9,6 @@ import plotly.express as px
import pandas as pd
from dash.dependencies import Input, Output
#sys.path.append(os.path.abspath(sys.path[0] + '/../'))
#__package__ = "deep_view_syn"
if __name__ == '__main__':
parser = argparse.ArgumentParser()
......@@ -24,13 +21,13 @@ if __name__ == '__main__':
print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False)
from data.spherical_view_syn import *
from configs.spherical_view_syn import SphericalViewSynConfig
from my import netio
from my import util
from my import device
from my import view
from my.gen_final import GenFinal
from utils import netio
from utils import device
from utils import view
from utils import img
from nets.modules import Sampler
......@@ -49,8 +46,8 @@ def load_net(path):
config = SphericalViewSynConfig()
config.from_id(net_config)
config.SAMPLE_PARAMS['perturb_sample'] = False
net = config.create_net().to(device.GetDevice())
netio.LoadNet(path, net)
net = config.create_net().to(device.default())
netio.load(path, net)
return net
......@@ -65,9 +62,9 @@ def load_views(data_desc_file) -> view.Trans:
with open(datadir + data_desc_file, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
view_centers = torch.tensor(
data_desc['view_centers'], device=device.GetDevice()).view(-1, 3)
data_desc['view_centers'], device=device.default()).view(-1, 3)
view_rots = torch.tensor(
data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)
data_desc['view_rots'], device=device.default()).view(-1, 3, 3)
return view.Trans(view_centers, view_rots)
......@@ -76,7 +73,7 @@ cam = view.CameraParam({
'cx': 0.5,
'cy': 0.5,
'normalized': True
}, res, device=device.GetDevice())
}, res, device=device.default())
net = load_net(net_path)
sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False,
spherical=True, lindisp=True, inverse_r=True)
......@@ -98,7 +95,7 @@ styles = {
'overflowX': 'scroll'
}
}
fig = px.imshow(util.Tensor2MatImg(image))
fig = px.imshow(img.torch2np(image))
fig1 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
fig2 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
app = dash.Dash(__name__, external_stylesheets=[
......@@ -139,10 +136,10 @@ def raw2color(raw: torch.Tensor, z_vals: torch.Tensor):
"""
Raw value inferred from model to color and alpha
:param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
:param z_vals ```Tensor(N.rays, N.samples)```: integration time
:return ```Tensor(N.rays, N.samples, 1|3)```: color
:return ```Tensor(N.rays, N.samples)```: alpha
:param raw `Tensor(N.rays, N.samples, 2|4)`: model's output
:param z_vals `Tensor(N.rays, N.samples)`: integration time
:return `Tensor(N.rays, N.samples, 1|3)`: color
:return `Tensor(N.rays, N.samples)`: alpha
"""
# Compute 'distance' (in time) between each integration time along a ray.
......@@ -163,7 +160,7 @@ def raw2color(raw: torch.Tensor, z_vals: torch.Tensor):
def draw_scatter():
global fig1, fig2
p = torch.tensor([x, y], device=device.GetDevice())
p = torch.tensor([x, y], device=device.default())
ray_d = test_view.trans_vector(cam.unproj(p))
ray_o = test_view.t
raw, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler)
......
import torch
import math
from my import device
from utils import device
class FastDataLoader(object):
......@@ -9,8 +9,8 @@ class FastDataLoader(object):
def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
super().__init__()
self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \
if shuffle else torch.arange(len(dataset), device=device.GetDevice())
self.indices = torch.randperm(len(dataset), device=device.default()) \
if shuffle else torch.arange(len(dataset), device=device.default())
self.offset = 0
self.batch_size = batch_size
self.dataset = dataset
......@@ -23,7 +23,7 @@ class FastDataLoader(object):
self.offset += self.batch_size
return self.dataset[indices]
def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None:
def __init__(self, dataset, batch_size, shuffle, drop_last=False, **kwargs) -> None:
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
......
import os
import json
import torch
import torchvision.transforms.functional as trans_f
import glm
import torch.nn.functional as nn_f
from typing import Tuple, Union
from my import util
from my import device
from my import view
from my import color_mode
from utils import img
from utils import device
from utils import view
from utils import color
class SphericalViewSynDataset(object):
......@@ -22,10 +22,11 @@ class SphericalViewSynDataset(object):
view_centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
view_images ```Tensor(N, 3, H, W)```: images of views\n
view_depths ```Tensor(N, H, W)```: depths of views\n
"""
def __init__(self, dataset_desc_path: str, load_images: bool = True,
load_depths: bool = False, color: int = color_mode.RGB,
load_depths: bool = False, load_bins: bool = False, c: int = color.RGB,
calculate_rays: bool = True, res: Tuple[int, int] = None):
"""
Initialize data loader for spherical view synthesis task
......@@ -41,46 +42,49 @@ class SphericalViewSynDataset(object):
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param load_depths ```bool```: whether load depth images and return in __getitem__()
:param color ```int```: color space to convert view images to
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
super().__init__()
self.data_dir = os.path.dirname(dataset_desc_path)
self.load_images = load_images
self.load_depths = load_depths
self.load_bins = load_bins
# Load dataset description file
self._load_desc(dataset_desc_path, res)
# Load view images
if self.load_images:
self.view_images = util.ReadImageTensor(
[self.view_file_pattern % i
for i in range(self.view_centers.size(0))]
).to(device.GetDevice())
if color == color_mode.GRAY:
self.view_images = trans_f.rgb_to_grayscale(self.view_images)
elif color == color_mode.YCbCr:
self.view_images = util.rgb2ycbcr(self.view_images)
self.view_images = color.cvt(
img.load(self.view_file % i for i in self.view_idxs).to(device.default()),
color.RGB, c)
if res:
self.view_images = nn_f.interpolate(self.view_images, res)
else:
self.view_images = None
# Load depthmaps
if self.load_depths:
self.view_depths = self._decode_depth_images(
util.ReadImageTensor(
[self.depth_file_pattern % i
for i in range(self.view_centers.size(0))]
).to(device.GetDevice()),
self.cam_params.get_local_rays())
img.load(self.depth_file % i for i in self.view_idxs).to(device.default()))
if res:
self.view_depths = nn_f.interpolate(self.view_depths, res)
else:
self.view_depths = None
self.patched_images = self.view_images # (N, 1|3, H, W)
# Load depthmaps
if self.load_bins:
self.view_bins = img.load([self.bins_file % i for i in self.view_idxs], permute=False) \
.to(device.default())
if res:
self.view_bins = nn_f.interpolate(self.view_bins, res)
else:
self.view_bins = None
self.patched_images = self.view_images
self.patched_depths = self.view_depths
self.patched_bins = self.view_bins
if calculate_rays:
# rays_o & rays_d are both (N, H, W, 3)
......@@ -89,49 +93,54 @@ class SphericalViewSynDataset(object):
self.patched_rays_o = self.rays_o
self.patched_rays_d = self.rays_d
def _decode_depth_images(self, input, local_rays):
output = self.depth_range[0] / input[..., 0, :, :]
output /= local_rays[..., 2]
return output
def _decode_depth_images(self, input):
disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
return torch.reciprocal(disp_val)
def _load_desc(self, path, res = None):
def _euler_to_matrix(self, euler):
q = glm.quat(glm.radians(glm.vec3(euler[0], euler[1], euler[2])))
return glm.transpose(glm.mat3_cast(q)).to_list()
def _load_desc(self, path, res=None):
with open(path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
if data_desc['view_file_pattern'] == '':
if not data_desc.get('view_file_pattern'):
self.load_images = False
else:
self.view_file_pattern: str = os.path.join(
self.data_dir, data_desc['view_file_pattern'])
if data_desc['depth_file_pattern'] == '':
self.view_file = os.path.join(self.data_dir, data_desc['view_file_pattern'])
if not data_desc.get('depth_file_pattern'):
self.load_depths = False
else:
self.depth_file_pattern: str = os.path.join(
self.data_dir, data_desc['depth_file_pattern'])
self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x'])
self.cam_params = view.CameraParam(data_desc['cam_params'],
self.view_res,
device=device.GetDevice())
if res:
self.view_res = res
self.cam_params.resize(res)
self.depth_range = [
data_desc['depth_range']['min'],
data_desc['depth_range']['max']
] if 'range' in data_desc else None
self.depth_file = os.path.join(self.data_dir, data_desc['depth_file_pattern'])
if not data_desc.get('bins_file_pattern'):
self.load_bins = False
else:
self.bins_file = os.path.join(self.data_dir, data_desc['bins_file_pattern'])
self.view_res = res if res else (data_desc['view_res']['y'], data_desc['view_res']['x'])
self.cam_params = view.CameraParam(data_desc['cam_params'], self.view_res,
device=device.default())
self.depth_range = [data_desc['depth_range']['min'], data_desc['depth_range']['max']] \
if 'depth_range' in data_desc else None
self.range = [data_desc['range']['min'], data_desc['range']['max']] \
if 'range' in data_desc else None
self.samples = data_desc['samples'] if 'samples' in data_desc else None
self.view_centers = torch.tensor(data_desc['view_centers'],
device=device.GetDevice()) # (N, 3)
self.view_centers = torch.tensor(
data_desc['view_centers'], device=device.default()) # (N, 3)
self.view_rots = torch.tensor(
data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3) # (N, 3, 3)
[self._euler_to_matrix([rot[1], rot[0], 0]) for rot in data_desc['view_rots']]
if len(data_desc['view_rots'][0]) == 2 else data_desc['view_rots'],
device=device.default()).view(-1, 3, 3) # (N, 3, 3)
#self.view_centers = self.view_centers[:6]
#self.view_rots = self.view_rots[:6]
self.n_views = self.view_centers.size(0)
self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
self.view_idxs = data_desc['views'][:self.n_views] if 'views' in data_desc else range(self.n_views)
if 'gl_coord' in data_desc and data_desc['gl_coord'] == True:
print('Convert from OGL coordinate to DX coordinate (i. e. right-hand to left-hand)')
self.cam_params.f[1] *= -1
print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
if not data_desc['cam_params'].get('normalized'):
self.cam_params.f[1] *= -1
self.view_centers[:, 2] *= -1
self.view_rots[:, 2] *= -1
self.view_rots[..., 2] *= -1
......@@ -158,16 +167,26 @@ class SphericalViewSynDataset(object):
if patch_size[0] == 1 and patch_size[1] == 1:
self.patched_images = self.view_images[slices] \
.permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None
self.patched_depths = self.view_depths[slices].flatten() if self.load_depths else None
self.patched_bins = self.view_bins[slices].flatten(0, 2) if self.load_bins else None
self.patched_rays_o = self.rays_o[ray_slices].flatten(0, 2)
self.patched_rays_d = self.rays_d[ray_slices].flatten(0, 2)
elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]:
self.patched_images = self.view_images
self.patched_depths = self.view_depths
self.patched_bins = self.view_bins
self.patched_rays_o = self.rays_o
self.patched_rays_d = self.rays_d
else:
self.patched_images = self.view_images[slices] \
.view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \
.permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None
self.patched_depths = self.view_depths[slices] \
.view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1]) \
.permute(0, 1, 3, 2, 4).flatten(0, 2) if self.load_depths else None
self.patched_bins = self.view_bins[slices] \
.view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
.permute(0, 1, 3, 2, 4, 5).flatten(0, 2) if self.load_bins else None
self.patched_rays_o = self.rays_o[ray_slices] \
.view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
.permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
......@@ -179,7 +198,5 @@ class SphericalViewSynDataset(object):
return self.patched_rays_o.size(0)
def __getitem__(self, idx):
if self.load_images:
return idx, self.patched_images[idx], self.patched_rays_o[idx], \
self.patched_rays_d[idx]
return idx, False, self.patched_rays_o[idx], self.patched_rays_d[idx]
return idx, self.patched_images[idx] if self.load_images else None, \
self.patched_rays_o[idx], self.patched_rays_d[idx]
import sys
import tty
import termios
import select
import time
def readchar():
r, w, e = select.select([sys.stdin], [], [])
if sys.stdin in r:
ch = sys.stdin.read(1)
return ch
fd = sys.stdin.fileno()
oldtty = termios.tcgetattr(fd)
newtty = termios.tcgetattr(fd)
try:
termios.tcsetattr(fd, termios.TCSANOW, newtty)
tty.setraw(fd)
tty.setcbreak(fd)
while True:
print('Wait')
time.sleep(0.1)
key = readchar()
print('%d' % ord(key))
if key == 'w':
print('w')
if key == 'q':
break
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, oldtty)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment