Commit 408738c8 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 3554ba52
{ {
"python.pythonPath": "/home/dengnc/miniconda3/envs/pytorch/bin/python",
"files.watcherExclude": {
"**/data/**": true
}
} }
\ No newline at end of file
[Net Parameters]
aaa = 1
bbb = 'abc'
ccc = (2,3)
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
})
...@@ -2,6 +2,10 @@ import os ...@@ -2,6 +2,10 @@ import os
import importlib import importlib
from os.path import join from os.path import join
from ..my import color_mode from ..my import color_mode
from ..nets.msl_net import MslNet
from ..nets.msl_net_new import NewMslNet
from ..nets.spher_net import SpherNet
class SphericalViewSynConfig(object): class SphericalViewSynConfig(object):
...@@ -14,6 +18,9 @@ class SphericalViewSynConfig(object): ...@@ -14,6 +18,9 @@ class SphericalViewSynConfig(object):
# Net parameters # Net parameters
self.NET_TYPE = 'msl' self.NET_TYPE = 'msl'
self.N_ENCODE_DIM = 10 self.N_ENCODE_DIM = 10
self.NORMALIZE = False
self.DIR_AS_INPUT = False
self.OPT_DECAY = 0
self.FC_PARAMS = { self.FC_PARAMS = {
'nf': 256, 'nf': 256,
'n_layers': 8, 'n_layers': 8,
...@@ -49,16 +56,22 @@ class SphericalViewSynConfig(object): ...@@ -49,16 +56,22 @@ class SphericalViewSynConfig(object):
'%d' % val '%d' % val
for val in self.FC_PARAMS['skips'] for val in self.FC_PARAMS['skips']
]) if len(self.FC_PARAMS['skips']) > 0 else "" ]) if len(self.FC_PARAMS['skips']) > 0 else ""
depth_id = "_d%d-%d" % (self.SAMPLE_PARAMS['depth_range'][0], depth_id = "_d%.2f-%.2f" % (self.SAMPLE_PARAMS['depth_range'][0],
self.SAMPLE_PARAMS['depth_range'][1]) self.SAMPLE_PARAMS['depth_range'][1])
samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples'] samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples']
opt_decay_id = '_decay%.1e' % self.OPT_DECAY if self.OPT_DECAY > 1e-5 else ''
neg_flags = '%s%s%s' % ( neg_flags = '%s%s%s' % (
'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '', 'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
'l' if not self.SAMPLE_PARAMS['lindisp'] else '', 'l' if not self.SAMPLE_PARAMS['lindisp'] else '',
'i' if not self.SAMPLE_PARAMS['inverse_r'] else '' 'i' if not self.SAMPLE_PARAMS['inverse_r'] else ''
) )
neg_flags = '_~' + neg_flags if neg_flags != '' else '' neg_flags = '_~' + neg_flags if neg_flags != '' else ''
return "%s@%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, neg_flags) pos_flags = '%s%s' % (
'n' if self.NORMALIZE else '',
'd' if self.DIR_AS_INPUT else '',
)
pos_flags = '_+' + pos_flags if pos_flags != '' else ''
return "%s@%s%s%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, opt_decay_id, neg_flags, pos_flags)
def from_id(self, id: str): def from_id(self, id: str):
id_splited = id.split('@') id_splited = id.split('@')
...@@ -66,9 +79,6 @@ class SphericalViewSynConfig(object): ...@@ -66,9 +79,6 @@ class SphericalViewSynConfig(object):
self.name = id_splited[0] self.name = id_splited[0]
segs = id_splited[-1].split('_') segs = id_splited[-1].split('_')
for i, seg in enumerate(segs): for i, seg in enumerate(segs):
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:])
continue
if seg.startswith('fc'): # Full-connected network parameters if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = ( self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
int(str) for str in seg[2:].split('x')) int(str) for str in seg[2:].split('x'))
...@@ -77,6 +87,12 @@ class SphericalViewSynConfig(object): ...@@ -77,6 +87,12 @@ class SphericalViewSynConfig(object):
self.FC_PARAMS['skips'] = [int(str) self.FC_PARAMS['skips'] = [int(str)
for str in seg[4:].split(',')] for str in seg[4:].split(',')]
continue continue
if seg.startswith('decay'):
self.OPT_DECAY = float(seg[5:])
continue
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:])
continue
if seg.startswith('d'): # Depth range if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple( self.SAMPLE_PARAMS['depth_range'] = tuple(
float(str) for str in seg[1:].split('-')) float(str) for str in seg[1:].split('-'))
...@@ -85,9 +101,18 @@ class SphericalViewSynConfig(object): ...@@ -85,9 +101,18 @@ class SphericalViewSynConfig(object):
self.SAMPLE_PARAMS['n_samples'] = int(seg[1:]) self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
continue continue
if seg.startswith('~'): # Negative flags if seg.startswith('~'): # Negative flags
self.SAMPLE_PARAMS['perturb_sample'] = (seg.find('p') < 0) if seg.find('p') >= 0:
self.SAMPLE_PARAMS['lindisp'] = (seg.find('l') < 0) self.SAMPLE_PARAMS['perturb_sample'] = False
self.SAMPLE_PARAMS['inverse_r'] = (seg.find('i') < 0) if seg.find('l') >= 0:
self.SAMPLE_PARAMS['lindisp'] = False
if seg.find('i') >= 0:
self.SAMPLE_PARAMS['inverse_r'] = False
continue
if seg.startswith('+'): # Positive flags
if seg.find('n') >= 0:
self.NORMALIZE = True
if seg.find('d') >= 0:
self.DIR_AS_INPUT = True
continue continue
if i == 0: # NetType if i == 0: # NetType
self.NET_TYPE, color_str = seg.split('-') self.NET_TYPE, color_str = seg.split('-')
...@@ -98,6 +123,41 @@ class SphericalViewSynConfig(object): ...@@ -98,6 +123,41 @@ class SphericalViewSynConfig(object):
print('==== Config %s ====' % self.name) print('==== Config %s ====' % self.name)
print('Net type: ', self.NET_TYPE) print('Net type: ', self.NET_TYPE)
print('Encode dim: ', self.N_ENCODE_DIM) print('Encode dim: ', self.N_ENCODE_DIM)
print('Optimizer decay: ', self.OPT_DECAY)
print('Normalize: ', self.NORMALIZE)
print('Direction as input: ', self.DIR_AS_INPUT)
print('Full-connected network parameters:', self.FC_PARAMS) print('Full-connected network parameters:', self.FC_PARAMS)
print('Sample parameters', self.SAMPLE_PARAMS) print('Sample parameters', self.SAMPLE_PARAMS)
print('==========================') print('==========================')
def create_net(self):
return net_builder[self.NET_TYPE](self)
net_builder = {
'msl': lambda config: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
normalize_coord=config.NORMALIZE,
dir_as_input=config.DIR_AS_INPUT,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'nmsl': lambda config: NewMslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
normalize_coord=config.NORMALIZE,
dir_as_input=config.DIR_AS_INPUT,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'nnmsl': lambda config: NewMslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
normalize_coord=config.NORMALIZE,
dir_as_input=config.DIR_AS_INPUT,
not_same_net=True,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM)
}
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
})
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nnmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
...@@ -85,7 +85,7 @@ class SphericalViewSynDataset(object): ...@@ -85,7 +85,7 @@ class SphericalViewSynDataset(object):
if calculate_rays: if calculate_rays:
# rays_o & rays_d are both (N, H, W, 3) # rays_o & rays_d are both (N, H, W, 3)
self.rays_o, self.rays_d = self.cam_params.get_global_rays( self.rays_o, self.rays_d = self.cam_params.get_global_rays(
self.view_centers, self.view_rots) view.Trans(self.view_centers, self.view_rots))
self.patched_rays_o = self.rays_o self.patched_rays_o = self.rays_o
self.patched_rays_d = self.rays_d self.patched_rays_d = self.rays_d
......
import torch
import torch.nn.functional as nn_f
from . import view
def get_warp(rays_o, rays_d, depthmap, tgt_o, tgt_r, tgt_cam):
print(rays_o.size(), rays_d.size(), depthmap.size())
pcloud = rays_o + rays_d * depthmap[..., None]
print(rays_o.size(), rays_d.size(), depthmap.size(), pcloud.size())
pcloud_in_tgt = view.trans_point(
pcloud, tgt_o, tgt_r, inverse=True)
print(pcloud_in_tgt.size())
pixel_positions = tgt_cam.proj(pcloud_in_tgt)
pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5
pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5
pixel_positions -= 1
return pixel_positions
def refine(image, depthmap, rays_o, rays_d, bounds_img, bounds_o,
bounds_r, ref_cam: view.CameraParam, net, is_lr):
if is_lr:
image = nn_f.upsample(
image[None, ...], scale_factor=2, mode='bicubic')[0]
depthmap = nn_f.upsample(
depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]
bounds_rays_o, bounds_rays_d = ref_cam.get_global_rays(
bounds_o, bounds_r, flatten=True)
bounds_inferred = torch.stack([
net(bounds_rays_o[i], bounds_rays_d[i]).view(
ref_cam.res[0], ref_cam.res[1], -1).permute(2, 0, 1)
for i in range(bounds_img.size(0))
], 0)
bounds_diff = (bounds_img - bounds_inferred) / (bounds_inferred + 1e-5)
bounds_warp = get_warp(rays_o, rays_d, depthmap,
bounds_o, bounds_r, ref_cam)
warped_diff = nn_f.grid_sample(bounds_diff, bounds_warp)
print(bounds_warp.size(), warped_diff.size())
avg_diff = torch.mean(warped_diff, 0)
return image * (1 + avg_diff)
...@@ -26,7 +26,7 @@ class Foveation(object): ...@@ -26,7 +26,7 @@ class Foveation(object):
return self return self
def synthesis(self, layers: List[torch.Tensor], def synthesis(self, layers: List[torch.Tensor],
normalized_fovea_center: Tuple[float, float]) -> torch.Tensor: fovea_center: Tuple[float, float]) -> torch.Tensor:
""" """
Generate foveated retinal image by blending fovea layers Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers** **Note: current implementation only support two fovea layers**
...@@ -37,8 +37,8 @@ class Foveation(object): ...@@ -37,8 +37,8 @@ class Foveation(object):
output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res, output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
mode='bilinear', align_corners=False) mode='bilinear', align_corners=False)
c = torch.tensor([ c = torch.tensor([
normalized_fovea_center[0] * self.out_res[1], fovea_center[0] + self.out_res[1] / 2,
normalized_fovea_center[1] * self.out_res[0] fovea_center[1] + self.out_res[0] / 2
], device=self.coords.device) ], device=self.coords.device)
for i in range(self.n_layers - 2, -1, -1): for i in range(self.n_layers - 2, -1, -1):
if layers[i] == None: if layers[i] == None:
...@@ -61,24 +61,6 @@ class Foveation(object): ...@@ -61,24 +61,6 @@ class Foveation(object):
k = length_i / length k = length_i / length
return int(math.ceil(self.out_res[0] * k)) return int(math.ceil(self.out_res[0] * k))
def get_layer_region_in_final_image(self, i: int,
normalized_fovea_center: Tuple[float, float]) -> Tuple[slice, slice]:
"""
Get region of fovea layer i in final image
:param i: index of layer
:return: tuple of slice objects stores the start and end of region in horizontal and vertical
"""
roi_size = self.get_layer_size_in_final_image(i)
roi_center = (int(self.out_res[1] * normalized_fovea_center[0]),
int(self.out_res[0] * normalized_fovea_center[1]))
roi_offset_y = roi_center[1] - roi_size // 2
roi_offset_x = roi_center[0] - roi_size // 2
return (...,
slice(roi_offset_y, roi_offset_y + roi_size),
slice(roi_offset_x, roi_offset_x + roi_size)
)
def _gen_layer_blendmap(self, i: int) -> torch.Tensor: def _gen_layer_blendmap(self, i: int) -> torch.Tensor:
""" """
Generate blend map for fovea layer i Generate blend map for fovea layer i
......
import torch
from torch import nn
from typing import List, Mapping, Tuple
from . import view
from . import refine
from .foveation import Foveation
from .simple_perf import SimplePerf
class GenFinal(object):
def __init__(self, layers_fov: List[float],
layers_res: List[Tuple[int, int]],
full_res: Tuple[int, int],
fovea_net: nn.Module,
periph_net: nn.Module,
device: torch.device = None) -> None:
super().__init__()
self.layer_cams = [
view.CameraParam({
'fov': layers_fov[i],
'cx': 0.5,
'cy': 0.5,
'normalized': True
}, layers_res[i], device=device)
for i in range(len(layers_fov))
]
self.full_cam = view.CameraParam({
'fov': layers_fov[-1],
'cx': 0.5,
'cy': 0.5,
'normalized': True
}, full_res, device=device)
self.fovea_net = fovea_net.to(device)
self.periph_net = periph_net.to(device)
self.foveation = Foveation(
layers_fov, full_res, device=device)
self.device = device
def to(self, device: torch.device):
self.fovea_net.to(device)
self.periph_net.to(device)
self.foveation.to(device)
self.full_cam.to(device)
for cam in self.layer_cams:
cam.to(device)
self.device = device
return self
def gen(self, gaze, trans, ret_raw=False, perf_time=False) -> Mapping[str, torch.Tensor]:
fovea_cam = self._adjust_cam(self.layer_cams[0], self.full_cam, gaze)
mid_cam = self._adjust_cam(self.layer_cams[1], self.full_cam, gaze)
periph_cam = self.layer_cams[2]
perf = SimplePerf(True, True) if perf_time else None
# x_rays_o, x_rays_d: (Hx, Wx, 3)
fovea_rays_o, fovea_rays_d = fovea_cam.get_global_rays(trans, True)
mid_rays_o, mid_rays_d = mid_cam.get_global_rays(trans, True)
periph_rays_o, periph_rays_d = periph_cam.get_global_rays(trans, True)
mid_periph_rays_o = torch.cat([mid_rays_o, periph_rays_o], 1)
mid_periph_rays_d = torch.cat([mid_rays_d, periph_rays_d], 1)
if perf_time:
perf.Checkpoint('Get rays')
perf1 = SimplePerf(True, True) if perf_time else None
fovea_inferred = self.fovea_net(fovea_rays_o[0], fovea_rays_d[0]).view(
1, fovea_cam.res[0], fovea_cam.res[1], -1).permute(0, 3, 1, 2) # (1, C, H_fovea, W_fovea)
if perf_time:
perf1.Checkpoint('Infer fovea')
periph_mid_inferred = self.periph_net(mid_periph_rays_o[0], mid_periph_rays_d[0])
mid_inferred = periph_mid_inferred[:mid_cam.res[0] * mid_cam.res[1], :].view(
1, mid_cam.res[0], mid_cam.res[1], -1).permute(0, 3, 1, 2)
periph_inferred = periph_mid_inferred[mid_cam.res[0] * mid_cam.res[1]:, :].view(
1, periph_cam.res[0], periph_cam.res[1], -1).permute(0, 3, 1, 2)
if perf_time:
perf1.Checkpoint('Infer mid & periph')
perf.Checkpoint('Infer')
fovea_refined = refine.constrast_enhance(fovea_inferred, 3, 0.2)
mid_refined = refine.constrast_enhance(mid_inferred, 5, 0.2)
periph_refined = refine.constrast_enhance(periph_inferred, 5, 0.2)
if perf_time:
perf.Checkpoint('Refine')
blended = self.foveation.synthesis([
fovea_refined,
mid_refined,
periph_refined
], (gaze[0], gaze[1]))
if perf_time:
perf.Checkpoint('Blend')
if ret_raw:
return {
'fovea': fovea_refined,
'mid': mid_refined,
'periph': periph_refined,
'blended': blended,
'fovea_raw': fovea_inferred,
'mid_raw': mid_inferred,
'periph_raw': periph_inferred,
'blended_raw': self.foveation.synthesis([
fovea_inferred,
mid_inferred,
periph_inferred
], (gaze[0], gaze[1]))
}
return {
'fovea': fovea_refined,
'mid': mid_refined,
'periph': periph_refined,
'blended': blended
}
def _adjust_cam(self, cam: view.CameraParam, full_cam: view.CameraParam,
gaze: Tuple[float, float]) -> view.CameraParam:
fovea_offset = (
(gaze[0]) / full_cam.f[0].item() * cam.f[0].item(),
(gaze[1]) / full_cam.f[1].item() * cam.f[1].item()
)
return view.CameraParam({
'fx': cam.f[0].item(),
'fy': cam.f[1].item(),
'cx': cam.c[0].item() - fovea_offset[0],
'cy': cam.c[1].item() - fovea_offset[1]
}, cam.res, device=self.device)
import sys import sys
import time import time
TOTAL_BAR_LENGTH = 80 TOTAL_BAR_LENGTH = 50
LAST_T = time.time() LAST_T = time.time()
BEGIN_T = LAST_T BEGIN_T = LAST_T
def progress_bar(current, total, msg=None): def progress_bar(current, total, msg=None, premsg=None):
global LAST_T, BEGIN_T global LAST_T, BEGIN_T
if current == 0: if current == 0:
BEGIN_T = time.time() # Reset for new bar. BEGIN_T = time.time() # Reset for new bar.
...@@ -14,21 +14,25 @@ def progress_bar(current, total, msg=None): ...@@ -14,21 +14,25 @@ def progress_bar(current, total, msg=None):
current_len = int(TOTAL_BAR_LENGTH * (current + 1) / total) current_len = int(TOTAL_BAR_LENGTH * (current + 1) / total)
rest_len = int(TOTAL_BAR_LENGTH - current_len) - 1 rest_len = int(TOTAL_BAR_LENGTH - current_len) - 1
sys.stdout.write(' %d/%d' % (current + 1, total)) if premsg:
sys.stdout.write(' [') sys.stdout.write(premsg)
sys.stdout.write(' ')
sys.stdout.write('[')
for i in range(current_len): for i in range(current_len):
sys.stdout.write('=') sys.stdout.write('=')
if current_len < TOTAL_BAR_LENGTH:
sys.stdout.write('>') sys.stdout.write('>')
for i in range(rest_len): for i in range(rest_len):
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.write(']') sys.stdout.write(']')
sys.stdout.write(' %d/%d' % (current + 1, total))
current_time = time.time() current_time = time.time()
step_time = current_time - LAST_T step_time = current_time - LAST_T
LAST_T = current_time LAST_T = current_time
total_time = current_time - BEGIN_T total_time = current_time - BEGIN_T
time_used = ' Step: %s' % format_time(step_time) time_used = ' | Step: %s' % format_time(step_time)
time_used += ' | Tot: %s' % format_time(total_time) time_used += ' | Tot: %s' % format_time(total_time)
if msg: if msg:
time_used += ' | ' + msg time_used += ' | ' + msg
...@@ -37,9 +41,9 @@ def progress_bar(current, total, msg=None): ...@@ -37,9 +41,9 @@ def progress_bar(current, total, msg=None):
sys.stdout.write(msg) sys.stdout.write(msg)
if current < total - 1: if current < total - 1:
sys.stdout.write('\r') sys.stdout.write(' \r')
else: else:
sys.stdout.write('\n') sys.stdout.write(' \n')
sys.stdout.flush() sys.stdout.flush()
...@@ -67,10 +71,10 @@ def format_time(seconds): ...@@ -67,10 +71,10 @@ def format_time(seconds):
output += str(minutes) + 'm' output += str(minutes) + 'm'
time_index += 1 time_index += 1
if seconds_final > 0 and time_index <= 2: if seconds_final > 0 and time_index <= 2:
output += str(seconds_final) + 's' output += '%02ds' % seconds_final
time_index += 1 time_index += 1
if millis > 0 and time_index <= 2: if millis > 0 and time_index <= 2:
output += str(millis) + 'ms' output += '%03dms' % millis
time_index += 1 time_index += 1
if output == '': if output == '':
output = '0ms' output = '0ms'
......
import torch
import numpy as np
import torch.nn.functional as nn_f
from . import view
class GuideRefinement(object):
def __init__(self, guides_image, guides_view: view.Trans,
guides_cam: view.CameraParam, net) -> None:
rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True)
guides_inferred = torch.stack([
net(rays_o[i], rays_d[i]).view(
guides_cam.res[0], guides_cam.res[1], -1).permute(2, 0, 1)
for i in range(guides_image.size(0))
], 0)
self.guides_diff = (guides_image - guides_inferred) / \
(guides_inferred + 1e-5)
self.guides_view = guides_view
self.guides_cam = guides_cam
def get_warp(self, rays_o, rays_d, depthmap, tgt_trans: view.Trans, tgt_cam):
rays_size = list(depthmap.size()) + [3]
rays_o = rays_o.view(rays_size)
rays_d = rays_d.view(rays_size)
#print(rays_o.size(), rays_d.size(), depthmap.size())
pcloud = rays_o + rays_d * depthmap[..., None]
#print('pcloud', pcloud.size())
pcloud_in_tgt = tgt_trans.trans_point(pcloud, inverse=True)
#print(pcloud_in_tgt.size())
pixel_positions = tgt_cam.proj(pcloud_in_tgt)
pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5
pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5
pixel_positions -= 1
return pixel_positions
def refine_by_guide(self, image, depthmap, rays_o, rays_d, is_lr):
if is_lr:
image = nn_f.upsample(
image[None, ...], scale_factor=2, mode='bicubic')[0]
depthmap = nn_f.upsample(
depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]
warp = self.get_warp(rays_o, rays_d, depthmap,
self.guides_view, self.guides_cam)
warped_diff = nn_f.grid_sample(self.guides_diff, warp)
print(warp.size(), warped_diff.size())
avg_diff = torch.mean(warped_diff, 0)
return image * (1 + avg_diff)
def constrast_enhance(image, sigma, fe):
kernel = torch.ones(1, 1, 3, 3, device=image.device) / 9
mean = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
cScale = 1.0 + sigma * fe
return torch.clamp(mean + (image - mean) * cScale, 0, 1)
def get_grad(image):
kernel = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=image.device, dtype=torch.float32).view(1, 1, 3, 3)
x_grad = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
kernel = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=image.device, dtype=torch.float32).view(1, 1, 3, 3)
y_grad = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
return (x_grad ** 2 + y_grad ** 2).sqrt().clamp(0, 1)
def getGaussianKernel(ksize, sigma=0):
if sigma <= 0:
# 根据 kernelsize 计算默认的 sigma,和 opencv 保持一致
sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8
center = ksize // 2
xs = (np.arange(ksize, dtype=np.float32) - center) # 元素与矩阵中心的横向距离
kernel1d = np.exp(-(xs ** 2) / (2 * sigma ** 2)) # 计算一维卷积核
# 根据指数函数性质,利用矩阵乘法快速计算二维卷积核
kernel = kernel1d[..., None] @ kernel1d[None, ...]
kernel = torch.from_numpy(kernel)
kernel = kernel / kernel.sum() # 归一化
return kernel.view(1, 1, 3, 3)
def grad_aware_gaussian(image, ksize, sigma=0):
kernel = getGaussianKernel(ksize, sigma).to(image.device)
print(kernel.size())
blur = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
grad = get_grad(image)
return image * grad + blur * (1 - grad)
def bilateral_filter(batch_img, ksize, sigmaColor=None, sigmaSpace=None):
device = batch_img.device
if sigmaSpace is None:
sigmaSpace = 0.15 * ksize + 0.35 # 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8
if sigmaColor is None:
sigmaColor = sigmaSpace
pad = (ksize - 1) // 2
batch_img_pad = nn_f.pad(batch_img, pad=[pad, pad, pad, pad], mode='reflect')
# batch_img 的维度为 BxcxHxW, 因此要沿着第 二、三维度 unfold
# patches.shape: B x C x H x W x ksize x ksize
patches = batch_img_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
patch_dim = patches.dim() # 6
# 求出像素亮度差
diff_color = patches - batch_img.unsqueeze(-1).unsqueeze(-1)
# 根据像素亮度差,计算权重矩阵
weights_color = torch.exp(-(diff_color ** 2) / (2 * sigmaColor ** 2))
# 归一化权重矩阵
weights_color = weights_color / weights_color.sum(dim=(-1, -2), keepdim=True)
# 获取 gaussian kernel 并将其复制成和 weight_color 形状相同的 tensor
weights_space = getGaussianKernel(ksize, sigmaSpace).to(device)
weights_space_dim = (patch_dim - 2) * (1,) + (ksize, ksize)
weights_space = weights_space.view(*weights_space_dim).expand_as(weights_color)
# 两个权重矩阵相乘得到总的权重矩阵
weights = weights_space * weights_color
# 总权重矩阵的归一化参数
weights_sum = weights.sum(dim=(-1, -2))
# 加权平均
weighted_pix = (weights * patches).sum(dim=(-1, -2)) / weights_sum
return weighted_pix
\ No newline at end of file
from typing import Mapping, Tuple, Union from typing import List, Mapping, Tuple, Union
import torch import torch
from . import util from . import util
...@@ -65,8 +65,7 @@ class CameraParam(object): ...@@ -65,8 +65,7 @@ class CameraParam(object):
rays = rays.flatten(0, 1) rays = rays.flatten(0, 1)
return rays return rays
def get_global_rays(self, t: torch.Tensor, r: torch.Tensor, def get_global_rays(self, trans, flatten=False, norm=True) -> torch.Tensor:
flatten=False, norm=True) -> torch.Tensor:
""" """
[summary] [summary]
...@@ -77,9 +76,9 @@ class CameraParam(object): ...@@ -77,9 +76,9 @@ class CameraParam(object):
:return: [description] :return: [description]
""" """
rays = self.get_local_rays(flatten, norm) # (M.., 3) rays = self.get_local_rays(flatten, norm) # (M.., 3)
rays_o, _ = torch.broadcast_tensors(t[..., None, :], rays) if flatten \ rays_o, _ = torch.broadcast_tensors(trans.t[..., None, :], rays) if flatten \
else torch.broadcast_tensors(t[..., None, None, :], rays) # (N.., M.., 3) else torch.broadcast_tensors(trans.t[..., None, None, :], rays) # (N.., M.., 3)
rays_d = trans_vector(rays, r) rays_d = trans.trans_vector(rays)
return rays_o, rays_d return rays_o, rays_d
def _convert_camera_params(self, input_camera_params: Mapping[str, Union[float, bool]], def _convert_camera_params(self, input_camera_params: Mapping[str, Union[float, bool]],
...@@ -114,6 +113,62 @@ class CameraParam(object): ...@@ -114,6 +113,62 @@ class CameraParam(object):
return camera_params return camera_params
class Trans(object):
def __init__(self, t: torch.Tensor, r: torch.Tensor) -> None:
self.t = t
self.r = r
if len(self.t.size()) == 1:
self.t = self.t[None, :]
self.r = self.r[None, :, :]
def trans_point(self, p: torch.Tensor, inverse=False) -> torch.Tensor:
"""
Transform points by given translation vectors and rotation matrices
:param p ```Tensor(N.., 3)```: points to transform
:param t ```Tensor(M.., 3)```: translation vectors
:param r ```Tensor(M.., 3, 3)```: rotation matrices
:param inverse: whether perform inverse transform
:return ```Tensor(M.., N.., 3)```: transformed points
"""
size_N = list(p.size())[:-1]
size_M = list(self.r.size())[:-2]
out_size = size_M + size_N + [3]
t_size = size_M + [1 for _ in range(len(size_N))] + [3]
t = self.t.view(t_size) # (M.., 1.., 3)
if inverse:
p = (p - t).view(size_M + [-1, 3])
r = self.r
else:
p = p.view(-1, 3)
r = self.r.movedim(-1, -2) # Transpose rotation matrices
out = torch.matmul(p, r).view(out_size)
if not inverse:
out = out + t
return out
def trans_vector(self, v: torch.Tensor, inverse=False) -> torch.Tensor:
"""
Transform vectors by given translation vectors and rotation matrices
:param v ```Tensor(N.., 3)```: vectors to transform
:param r ```Tensor(M.., 3, 3)```: rotation matrices
:param inverse: whether perform inverse transform
:return ```Tensor(M.., N.., 3)```: transformed vectors
"""
out_size = list(self.r.size())[:-2] + list(v.size())[:-1] + [3]
r = self.r if inverse else self.r.movedim(-1, -2) # Transpose rotation matrices
out = torch.matmul(v.view(-1, 3), r).view(out_size)
return out
def size(self) -> List[int]:
return list(self.t.size()[:-1])
def get(self, *index):
return Trans(self.t[index], self.r[index])
def trans_point(p: torch.Tensor, t: torch.Tensor, r: torch.Tensor, inverse=False) -> torch.Tensor: def trans_point(p: torch.Tensor, t: torch.Tensor, r: torch.Tensor, inverse=False) -> torch.Tensor:
""" """
Transform points by given translation vectors and rotation matrices Transform points by given translation vectors and rotation matrices
......
from typing import Tuple from typing import Tuple
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from .my import net_modules from ..my import net_modules
from .my import util from ..my import util
from .my import device from ..my import device
from .my import color_mode from ..my import color_mode
rand_gen = torch.Generator(device=device.GetDevice()) rand_gen = torch.Generator(device=device.GetDevice())
...@@ -145,19 +146,17 @@ class Sampler(nn.Module): ...@@ -145,19 +146,17 @@ class Sampler(nn.Module):
:param lindisp: If True, sample linearly in inverse depth rather than in depth :param lindisp: If True, sample linearly in inverse depth rather than in depth
""" """
super().__init__() super().__init__()
if lindisp: self.lindisp = lindisp
self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1], if self.lindisp:
n_samples, device=device.GetDevice()) depth_range = (1 / depth_range[0], 1 / depth_range[1])
else:
self.r = torch.linspace(depth_range[0], depth_range[1], self.r = torch.linspace(depth_range[0], depth_range[1],
n_samples, device=device.GetDevice()) n_samples, device=device.GetDevice())
step = (depth_range[1] - depth_range[0]) / (n_samples - 1)
self.perturb_sample = perturb_sample self.perturb_sample = perturb_sample
self.spherical = spherical self.spherical = spherical
self.inverse_r = inverse_r self.inverse_r = inverse_r
if perturb_sample: self.upper = torch.clamp_min(self.r + step / 2, 0)
mids = .5 * (self.r[1:] + self.r[:-1]) self.lower = torch.clamp_min(self.r - step / 2, 0)
self.upper = torch.cat([mids, self.r[-1:]], -1)
self.lower = torch.cat([self.r[:1], mids], -1)
def forward(self, rays_o, rays_d): def forward(self, rays_o, rays_d):
""" """
...@@ -177,18 +176,26 @@ class Sampler(nn.Module): ...@@ -177,18 +176,26 @@ class Sampler(nn.Module):
r = self.lower + (self.upper - self.lower) * t_rand r = self.lower + (self.upper - self.lower) * t_rand
else: else:
r = self.r r = self.r
if self.lindisp:
r = torch.reciprocal(r)
if self.spherical: if self.spherical:
pts, depths = RaySphereIntersect(rays_o, rays_d, r) pts, depths = RaySphereIntersect(rays_o, rays_d, r)
sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r) sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
return sphers, depths return sphers, pts, depths
else: else:
return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
# x>0, y>0 -> (y, -x)
# x<0, y>0 -> (-y, x)
# x<0, y<0 -> (y, -x)
# x>0, y<0 -> (-y, x)
class MslNet(nn.Module): class MslNet(nn.Module):
def __init__(self, fc_params, sampler_params, def __init__(self, fc_params, sampler_params,
normalize_coord: bool,
dir_as_input: bool,
color: int = color_mode.RGB, color: int = color_mode.RGB,
encode_to_dim: int = 0, encode_to_dim: int = 0,
export_mode: bool = False): export_mode: bool = False):
...@@ -197,7 +204,8 @@ class MslNet(nn.Module): ...@@ -197,7 +204,8 @@ class MslNet(nn.Module):
:param fc_params: parameters for full-connection network :param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler :param sampler_params: parameters for sampler
:param gray: is grayscale mode :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param color: color mode
:param encode_to_dim: encode input to number of dimensions :param encode_to_dim: encode input to number of dimensions
""" """
super().__init__() super().__init__()
...@@ -209,7 +217,10 @@ class MslNet(nn.Module): ...@@ -209,7 +217,10 @@ class MslNet(nn.Module):
self.sampler = Sampler(**sampler_params) self.sampler = Sampler(**sampler_params)
self.rendering = Rendering() self.rendering = Rendering()
self.export_mode = export_mode self.export_mode = export_mode
if color == color_mode.YCbCr: self.normalize_coord = normalize_coord
self.dir_as_input = dir_as_input
self.color = color
if self.color == color_mode.YCbCr:
self.net1 = net_modules.FcNet( self.net1 = net_modules.FcNet(
in_chns=fc_params['in_chns'], in_chns=fc_params['in_chns'],
out_chns=fc_params['nf'] + 2, out_chns=fc_params['nf'] + 2,
...@@ -221,8 +232,51 @@ class MslNet(nn.Module): ...@@ -221,8 +232,51 @@ class MslNet(nn.Module):
nf=fc_params['nf'], nf=fc_params['nf'],
n_layers=1) n_layers=1)
self.net = None self.net = None
elif self.dir_as_input:
self.input_encoder2 = net_modules.InputEncoder.Get(4, 2)
self.net1 = net_modules.FcNet(
in_chns=fc_params['in_chns'],
out_chns=fc_params['nf'],
nf=fc_params['nf'],
n_layers=fc_params['n_layers'])
self.net2 = net_modules.FcNet(
in_chns=fc_params['nf'] + self.input_encoder2.out_dim,
out_chns=fc_params['out_chns'],
nf=fc_params['nf'],
n_layers=1)
self.net = None
else: else:
self.net = net_modules.FcNet(**fc_params) self.net = net_modules.FcNet(**fc_params)
if self.normalize_coord:
self.register_buffer('angle_range', torch.tensor(
[[1e5, 1e5], [-1e5, -1e5]]))
self.register_buffer('depth_range', torch.tensor([
self.sampler.lower[0], self.sampler.upper[-1]
]))
def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
coords, _, _ = self.sampler(rays_o, rays_d)
coords = coords[..., 1:].view(-1, 2)
self.angle_range = torch.stack([
torch.cat([coords, self.angle_range[0:1]]).amin(0),
torch.cat([coords, self.angle_range[1:2]]).amax(0)
])
def calc_local_dir(self, rays_d, coords, pts: torch.Tensor):
"""
[summary]
:param rays_d ```Tensor(B, 3)```:
:param coords ```Tensor(B, N, 3)```:
:param pts ```Tensor(B, N, 3)```:
:return ```Tensor(B, N, 2)```
"""
local_z = pts / pts.norm(dim=-1, keepdim=True)
local_x = util.SphericalToCartesian(coords + torch.tensor([0, 0.1 / 180 * math.pi, 0], device=coords.device)) - pts
local_x = local_x / local_x.norm(dim=-1, keepdim=True)
local_y = torch.cross(local_x, local_z, -1)
local_rot = torch.stack([local_x, local_y, local_z], dim=-2) # (B, N, 3, 3)
return util.CartesianToSpherical(torch.matmul(rays_d[:, None, None, :], local_rot)).squeeze(-2)[..., 1:3]
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor: ret_depth: bool = False) -> torch.Tensor:
...@@ -233,16 +287,27 @@ class MslNet(nn.Module): ...@@ -233,16 +287,27 @@ class MslNet(nn.Module):
:param rays_d ```Tensor(B, 3)```: rays' direction :param rays_d ```Tensor(B, 3)```: rays' direction
:return: ```Tensor(B, C)``, inferred images/pixels :return: ```Tensor(B, C)``, inferred images/pixels
""" """
coords, depths = self.sampler(rays_o, rays_d) coords, pts, depths = self.sampler(rays_o, rays_d)
if self.dir_as_input:
dirs = self.calc_local_dir(rays_d, coords, pts)
if self.normalize_coord: # Normalize coords to [0, 2pi]
range = torch.cat([self.depth_range.view(2, 1), self.angle_range], 1)
coords = (coords - range[0]) / (range[1] - range[0]) * 2 * math.pi
encoded = self.input_encoder(coords) encoded = self.input_encoder(coords)
if not self.net: if self.color == color_mode.YCbCr:
mid_output = self.net1(encoded) mid_output = self.net1(encoded)
net2_output = self.net2(mid_output[..., :-2]) net2_output = self.net2(mid_output[..., :-2])
raw = torch.cat([ raw = torch.cat([
mid_output[..., -2:], mid_output[..., -2:],
net2_output net2_output
], -1) ], -1)
elif self.dir_as_input:
encoded_dirs = self.input_encoder2(dirs)
#print(encoded.size(), self.net1(encoded).size(), encoded_dirs.size())
raw = self.net2(torch.cat([self.net1(encoded), encoded_dirs], -1))
else: else:
raw = self.net(encoded) raw = self.net(encoded)
......
from typing import Tuple from typing import Tuple
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from .my import net_modules from ..my import net_modules
from .my import util from ..my import util
from .my import device from ..my import device
from ..my import color_mode
rand_gen = torch.Generator(device=device.GetDevice()) rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed()) rand_gen.manual_seed(torch.seed())
...@@ -143,19 +146,17 @@ class Sampler(nn.Module): ...@@ -143,19 +146,17 @@ class Sampler(nn.Module):
:param lindisp: If True, sample linearly in inverse depth rather than in depth :param lindisp: If True, sample linearly in inverse depth rather than in depth
""" """
super().__init__() super().__init__()
if lindisp: self.lindisp = lindisp
self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1], if self.lindisp:
n_samples, device=device.GetDevice()) depth_range = (1 / depth_range[0], 1 / depth_range[1])
else:
self.r = torch.linspace(depth_range[0], depth_range[1], self.r = torch.linspace(depth_range[0], depth_range[1],
n_samples, device=device.GetDevice()) n_samples, device=device.GetDevice())
step = (depth_range[1] - depth_range[0]) / (n_samples - 1)
self.perturb_sample = perturb_sample self.perturb_sample = perturb_sample
self.spherical = spherical self.spherical = spherical
self.inverse_r = inverse_r self.inverse_r = inverse_r
if perturb_sample: self.upper = torch.clamp_min(self.r + step / 2, 0)
mids = .5 * (self.r[1:] + self.r[:-1]) self.lower = torch.clamp_min(self.r - step / 2, 0)
self.upper = torch.cat([mids, self.r[-1:]], -1)
self.lower = torch.cat([self.r[:1], mids], -1)
def forward(self, rays_o, rays_d): def forward(self, rays_o, rays_d):
""" """
...@@ -175,19 +176,24 @@ class Sampler(nn.Module): ...@@ -175,19 +176,24 @@ class Sampler(nn.Module):
r = self.lower + (self.upper - self.lower) * t_rand r = self.lower + (self.upper - self.lower) * t_rand
else: else:
r = self.r r = self.r
if self.lindisp:
r = torch.reciprocal(r)
if self.spherical: if self.spherical:
pts, depths = RaySphereIntersect(rays_o, rays_d, r) pts, depths = RaySphereIntersect(rays_o, rays_d, r)
sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r) sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
return sphers, depths return sphers, pts, depths
else: else:
return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
class MslNet(nn.Module): class NewMslNet(nn.Module):
def __init__(self, fc_params, sampler_params, def __init__(self, fc_params, sampler_params,
gray=False, normalize_coord: bool,
dir_as_input: bool,
not_same_net: bool = False,
color: int = color_mode.RGB,
encode_to_dim: int = 0, encode_to_dim: int = 0,
export_mode: bool = False): export_mode: bool = False):
""" """
...@@ -195,7 +201,8 @@ class MslNet(nn.Module): ...@@ -195,7 +201,8 @@ class MslNet(nn.Module):
:param fc_params: parameters for full-connection network :param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler :param sampler_params: parameters for sampler
:param gray: is grayscale mode :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param color: color mode
:param encode_to_dim: encode input to number of dimensions :param encode_to_dim: encode input to number of dimensions
""" """
super().__init__() super().__init__()
...@@ -203,14 +210,36 @@ class MslNet(nn.Module): ...@@ -203,14 +210,36 @@ class MslNet(nn.Module):
self.input_encoder = net_modules.InputEncoder.Get( self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns) encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4 fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
self.sampler = Sampler(**sampler_params) self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering() self.rendering = Rendering()
self.export_mode = export_mode self.export_mode = export_mode
self.normalize_coord = normalize_coord
self.nets = nn.ModuleList([
def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor, local_rays: torch.Tensor, net_modules.FcNet(**fc_params),
net_modules.FcNet(in_chns=fc_params['in_chns'],
out_chns=fc_params['out_chns'],
nf=128, n_layers=4) if not_same_net
else net_modules.FcNet(**fc_params)
])
self.n_samples = sampler_params['n_samples']
if self.normalize_coord:
self.register_buffer('angle_range', torch.tensor(
[[1e5, 1e5], [-1e5, -1e5]]))
self.register_buffer('depth_range', torch.tensor([
[self.sampler.lower[0], self.sampler.lower[self.n_samples // 2]],
[self.sampler.upper[self.n_samples // 2 - 1], self.sampler.upper[-1]]
]))
def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
coords, _, _ = self.sampler(rays_o, rays_d)
coords = coords[..., 1:].view(-1, 2)
self.angle_range = torch.stack([
torch.cat([coords, self.angle_range[0:1]]).amin(0),
torch.cat([coords, self.angle_range[1:2]]).amax(0)
])
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor: ret_depth: bool = False) -> torch.Tensor:
""" """
rays -> colors rays -> colors
...@@ -219,18 +248,28 @@ class MslNet(nn.Module): ...@@ -219,18 +248,28 @@ class MslNet(nn.Module):
:param rays_d ```Tensor(B, 3)```: rays' direction :param rays_d ```Tensor(B, 3)```: rays' direction
:return: ```Tensor(B, C)``, inferred images/pixels :return: ```Tensor(B, C)``, inferred images/pixels
""" """
rays_o = local_rays * 0 + view_centers coords, pts, depths = self.sampler(rays_o, rays_d)
rays_d = torch.matmul(local_rays.flatten(0, -2), r).view(out_size) if self.normalize_coord: # Normalize coords to [0, 2pi]
coords, depths = self.sampler(rays_o, rays_d) range = torch.cat([self.depth_range[:, 0:1], self.angle_range], 1)
coords[:, :self.n_samples // 2] = (coords[:, :self.n_samples // 2] - range[0]) / (
range[1] - range[0]) * 2 * math.pi
range = torch.cat([self.depth_range[:, 1:2], self.angle_range], 1)
coords[:, self.n_samples // 2:] = (coords[:, self.n_samples // 2:] - range[0]) / (
range[1] - range[0]) * 2 * math.pi
encoded = self.input_encoder(coords) encoded = self.input_encoder(coords)
raw = torch.cat([
self.nets[0](encoded[:, :self.n_samples // 2]),
self.nets[1](encoded[:, self.n_samples // 2:]),
], 1)
if self.export_mode: if self.export_mode:
colors, alphas = self.rendering.raw2color(self.net(encoded), depths) colors, alphas = self.rendering.raw2color(raw, depths)
return torch.cat([colors, alphas[..., None]], -1) return torch.cat([colors, alphas[..., None]], -1)
if ret_depth: if ret_depth:
color_map, _, _, _, depth_map = self.rendering( color_map, _, _, _, depth_map = self.rendering(
self.net(encoded), depths, ret_extra=True) raw, depths, ret_extra=True)
return color_map, depth_map return color_map, depth_map
return self.rendering(self.net(encoded), depths) return self.rendering(raw, depths)
import torch import torch
import torch.nn as nn import torch.nn as nn
from .my import net_modules from ..my import net_modules
from .my import util from ..my import util
class SpherNet(nn.Module): class SpherNet(nn.Module):
......
from typing import List from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import * from ..pytorch_prototyping.pytorch_prototyping import *
from .my import util from ..my import util
from .my import device from ..my import device
class Encoder(nn.Module): class Encoder(nn.Module):
......
...@@ -18,28 +18,29 @@ ...@@ -18,28 +18,29 @@
"from ..my import util\n", "from ..my import util\n",
"\n", "\n",
"path_patts = [\n", "path_patts = [\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/gt/view_%04d.png',\n", " '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/test/view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/fovea_rgb@msl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_500/train/out_view_%04d.png',\n", " '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/fovea_rgb@msl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_500/test/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/input/out_view_%04d.png',\n", " '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/new_fovea_rgb@nmsl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_300/test/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/output/view_%04d.png'\n", " # '/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/new_fovea_rgb@nmsl-rgb_e10_fc128x4_d1-50_s32/output/model-epoch_200/test/out_view_%04d.png'\n",
"]\n", "]\n",
"titles = [ 'Ground truth', 'Normal', 'Low Res', 'Upsampling']\n", "titles = ['Ground truth', 'Baseline', 'NMSL', 'With Dir', 'NMSL_32']\n",
"show_range = range(5)\n", "#show_range = range(20)\n",
"show_range = [ 2, 4, 5, 10, 12, 13, 14 ]\n",
"\n", "\n",
"#os.chdir('/home/dengnc/deep_view_syn/data/')\n", "# os.chdir('/home/dengnc/deep_view_syn/data/')\n",
"image_seqs = [\n", "image_seqs = [\n",
" util.ReadImageTensor([path_patt % i for i in show_range])\n", " util.ReadImageTensor([path_patt % i for i in show_range])\n",
" for path_patt in path_patts\n", " for path_patt in path_patts\n",
"]\n", "]\n",
"\n", "\n",
"for i in show_range:\n", "for i in show_range:\n",
" plt.figure(facecolor='white', figsize=(12, 4))\n", " plt.figure(facecolor='white', figsize=(4 * len(image_seqs), 4))\n",
" plt.suptitle('View %d' % i)\n", " plt.suptitle('View %d' % i)\n",
" for j in range(len(image_seqs)):\n", " for j in range(len(image_seqs)):\n",
" plt.subplot(1, len(image_seqs), j + 1)\n", " plt.subplot(1, len(image_seqs), j + 1)\n",
" plt.title(titles[j])\n", " plt.title(titles[j])\n",
" util.PlotImageTensor(image_seqs[j][i])\n", " util.PlotImageTensor(image_seqs[j][i])\n"
" \n"
] ]
}, },
{ {
...@@ -53,11 +54,19 @@ ...@@ -53,11 +54,19 @@
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",
"language": "python",
"name": "python3" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9-final" "version": "3.7.9-final"
}, },
"orig_nbformat": 2 "orig_nbformat": 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment