Commit 408738c8 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 3554ba52
{
"python.pythonPath": "/home/dengnc/miniconda3/envs/pytorch/bin/python",
"files.watcherExclude": {
"**/data/**": true
}
}
\ No newline at end of file
[Net Parameters]
aaa = 1
bbb = 'abc'
ccc = (2,3)
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
})
......@@ -2,6 +2,10 @@ import os
import importlib
from os.path import join
from ..my import color_mode
from ..nets.msl_net import MslNet
from ..nets.msl_net_new import NewMslNet
from ..nets.spher_net import SpherNet
class SphericalViewSynConfig(object):
......@@ -14,6 +18,9 @@ class SphericalViewSynConfig(object):
# Net parameters
self.NET_TYPE = 'msl'
self.N_ENCODE_DIM = 10
self.NORMALIZE = False
self.DIR_AS_INPUT = False
self.OPT_DECAY = 0
self.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
......@@ -49,16 +56,22 @@ class SphericalViewSynConfig(object):
'%d' % val
for val in self.FC_PARAMS['skips']
]) if len(self.FC_PARAMS['skips']) > 0 else ""
depth_id = "_d%d-%d" % (self.SAMPLE_PARAMS['depth_range'][0],
depth_id = "_d%.2f-%.2f" % (self.SAMPLE_PARAMS['depth_range'][0],
self.SAMPLE_PARAMS['depth_range'][1])
samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples']
opt_decay_id = '_decay%.1e' % self.OPT_DECAY if self.OPT_DECAY > 1e-5 else ''
neg_flags = '%s%s%s' % (
'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
'l' if not self.SAMPLE_PARAMS['lindisp'] else '',
'i' if not self.SAMPLE_PARAMS['inverse_r'] else ''
)
neg_flags = '_~' + neg_flags if neg_flags != '' else ''
return "%s@%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, neg_flags)
pos_flags = '%s%s' % (
'n' if self.NORMALIZE else '',
'd' if self.DIR_AS_INPUT else '',
)
pos_flags = '_+' + pos_flags if pos_flags != '' else ''
return "%s@%s%s%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, opt_decay_id, neg_flags, pos_flags)
def from_id(self, id: str):
id_splited = id.split('@')
......@@ -66,9 +79,6 @@ class SphericalViewSynConfig(object):
self.name = id_splited[0]
segs = id_splited[-1].split('_')
for i, seg in enumerate(segs):
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:])
continue
if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
int(str) for str in seg[2:].split('x'))
......@@ -77,6 +87,12 @@ class SphericalViewSynConfig(object):
self.FC_PARAMS['skips'] = [int(str)
for str in seg[4:].split(',')]
continue
if seg.startswith('decay'):
self.OPT_DECAY = float(seg[5:])
continue
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:])
continue
if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple(
float(str) for str in seg[1:].split('-'))
......@@ -85,9 +101,18 @@ class SphericalViewSynConfig(object):
self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
continue
if seg.startswith('~'): # Negative flags
self.SAMPLE_PARAMS['perturb_sample'] = (seg.find('p') < 0)
self.SAMPLE_PARAMS['lindisp'] = (seg.find('l') < 0)
self.SAMPLE_PARAMS['inverse_r'] = (seg.find('i') < 0)
if seg.find('p') >= 0:
self.SAMPLE_PARAMS['perturb_sample'] = False
if seg.find('l') >= 0:
self.SAMPLE_PARAMS['lindisp'] = False
if seg.find('i') >= 0:
self.SAMPLE_PARAMS['inverse_r'] = False
continue
if seg.startswith('+'): # Positive flags
if seg.find('n') >= 0:
self.NORMALIZE = True
if seg.find('d') >= 0:
self.DIR_AS_INPUT = True
continue
if i == 0: # NetType
self.NET_TYPE, color_str = seg.split('-')
......@@ -98,6 +123,41 @@ class SphericalViewSynConfig(object):
print('==== Config %s ====' % self.name)
print('Net type: ', self.NET_TYPE)
print('Encode dim: ', self.N_ENCODE_DIM)
print('Optimizer decay: ', self.OPT_DECAY)
print('Normalize: ', self.NORMALIZE)
print('Direction as input: ', self.DIR_AS_INPUT)
print('Full-connected network parameters:', self.FC_PARAMS)
print('Sample parameters', self.SAMPLE_PARAMS)
print('==========================')
def create_net(self):
return net_builder[self.NET_TYPE](self)
net_builder = {
'msl': lambda config: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
normalize_coord=config.NORMALIZE,
dir_as_input=config.DIR_AS_INPUT,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'nmsl': lambda config: NewMslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
normalize_coord=config.NORMALIZE,
dir_as_input=config.DIR_AS_INPUT,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'nnmsl': lambda config: NewMslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
normalize_coord=config.NORMALIZE,
dir_as_input=config.DIR_AS_INPUT,
not_same_net=True,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM)
}
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
})
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
from ..my import color_mode
def update_config(config):
# Dataset settings
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'nnmsl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 4
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
......@@ -85,7 +85,7 @@ class SphericalViewSynDataset(object):
if calculate_rays:
# rays_o & rays_d are both (N, H, W, 3)
self.rays_o, self.rays_d = self.cam_params.get_global_rays(
self.view_centers, self.view_rots)
view.Trans(self.view_centers, self.view_rots))
self.patched_rays_o = self.rays_o
self.patched_rays_d = self.rays_d
......
import torch
import torch.nn.functional as nn_f
from . import view
def get_warp(rays_o, rays_d, depthmap, tgt_o, tgt_r, tgt_cam):
print(rays_o.size(), rays_d.size(), depthmap.size())
pcloud = rays_o + rays_d * depthmap[..., None]
print(rays_o.size(), rays_d.size(), depthmap.size(), pcloud.size())
pcloud_in_tgt = view.trans_point(
pcloud, tgt_o, tgt_r, inverse=True)
print(pcloud_in_tgt.size())
pixel_positions = tgt_cam.proj(pcloud_in_tgt)
pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5
pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5
pixel_positions -= 1
return pixel_positions
def refine(image, depthmap, rays_o, rays_d, bounds_img, bounds_o,
bounds_r, ref_cam: view.CameraParam, net, is_lr):
if is_lr:
image = nn_f.upsample(
image[None, ...], scale_factor=2, mode='bicubic')[0]
depthmap = nn_f.upsample(
depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]
bounds_rays_o, bounds_rays_d = ref_cam.get_global_rays(
bounds_o, bounds_r, flatten=True)
bounds_inferred = torch.stack([
net(bounds_rays_o[i], bounds_rays_d[i]).view(
ref_cam.res[0], ref_cam.res[1], -1).permute(2, 0, 1)
for i in range(bounds_img.size(0))
], 0)
bounds_diff = (bounds_img - bounds_inferred) / (bounds_inferred + 1e-5)
bounds_warp = get_warp(rays_o, rays_d, depthmap,
bounds_o, bounds_r, ref_cam)
warped_diff = nn_f.grid_sample(bounds_diff, bounds_warp)
print(bounds_warp.size(), warped_diff.size())
avg_diff = torch.mean(warped_diff, 0)
return image * (1 + avg_diff)
......@@ -26,7 +26,7 @@ class Foveation(object):
return self
def synthesis(self, layers: List[torch.Tensor],
normalized_fovea_center: Tuple[float, float]) -> torch.Tensor:
fovea_center: Tuple[float, float]) -> torch.Tensor:
"""
Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers**
......@@ -37,8 +37,8 @@ class Foveation(object):
output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
mode='bilinear', align_corners=False)
c = torch.tensor([
normalized_fovea_center[0] * self.out_res[1],
normalized_fovea_center[1] * self.out_res[0]
fovea_center[0] + self.out_res[1] / 2,
fovea_center[1] + self.out_res[0] / 2
], device=self.coords.device)
for i in range(self.n_layers - 2, -1, -1):
if layers[i] == None:
......@@ -61,24 +61,6 @@ class Foveation(object):
k = length_i / length
return int(math.ceil(self.out_res[0] * k))
def get_layer_region_in_final_image(self, i: int,
normalized_fovea_center: Tuple[float, float]) -> Tuple[slice, slice]:
"""
Get region of fovea layer i in final image
:param i: index of layer
:return: tuple of slice objects stores the start and end of region in horizontal and vertical
"""
roi_size = self.get_layer_size_in_final_image(i)
roi_center = (int(self.out_res[1] * normalized_fovea_center[0]),
int(self.out_res[0] * normalized_fovea_center[1]))
roi_offset_y = roi_center[1] - roi_size // 2
roi_offset_x = roi_center[0] - roi_size // 2
return (...,
slice(roi_offset_y, roi_offset_y + roi_size),
slice(roi_offset_x, roi_offset_x + roi_size)
)
def _gen_layer_blendmap(self, i: int) -> torch.Tensor:
"""
Generate blend map for fovea layer i
......
import torch
from torch import nn
from typing import List, Mapping, Tuple
from . import view
from . import refine
from .foveation import Foveation
from .simple_perf import SimplePerf
class GenFinal(object):
def __init__(self, layers_fov: List[float],
layers_res: List[Tuple[int, int]],
full_res: Tuple[int, int],
fovea_net: nn.Module,
periph_net: nn.Module,
device: torch.device = None) -> None:
super().__init__()
self.layer_cams = [
view.CameraParam({
'fov': layers_fov[i],
'cx': 0.5,
'cy': 0.5,
'normalized': True
}, layers_res[i], device=device)
for i in range(len(layers_fov))
]
self.full_cam = view.CameraParam({
'fov': layers_fov[-1],
'cx': 0.5,
'cy': 0.5,
'normalized': True
}, full_res, device=device)
self.fovea_net = fovea_net.to(device)
self.periph_net = periph_net.to(device)
self.foveation = Foveation(
layers_fov, full_res, device=device)
self.device = device
def to(self, device: torch.device):
self.fovea_net.to(device)
self.periph_net.to(device)
self.foveation.to(device)
self.full_cam.to(device)
for cam in self.layer_cams:
cam.to(device)
self.device = device
return self
def gen(self, gaze, trans, ret_raw=False, perf_time=False) -> Mapping[str, torch.Tensor]:
fovea_cam = self._adjust_cam(self.layer_cams[0], self.full_cam, gaze)
mid_cam = self._adjust_cam(self.layer_cams[1], self.full_cam, gaze)
periph_cam = self.layer_cams[2]
perf = SimplePerf(True, True) if perf_time else None
# x_rays_o, x_rays_d: (Hx, Wx, 3)
fovea_rays_o, fovea_rays_d = fovea_cam.get_global_rays(trans, True)
mid_rays_o, mid_rays_d = mid_cam.get_global_rays(trans, True)
periph_rays_o, periph_rays_d = periph_cam.get_global_rays(trans, True)
mid_periph_rays_o = torch.cat([mid_rays_o, periph_rays_o], 1)
mid_periph_rays_d = torch.cat([mid_rays_d, periph_rays_d], 1)
if perf_time:
perf.Checkpoint('Get rays')
perf1 = SimplePerf(True, True) if perf_time else None
fovea_inferred = self.fovea_net(fovea_rays_o[0], fovea_rays_d[0]).view(
1, fovea_cam.res[0], fovea_cam.res[1], -1).permute(0, 3, 1, 2) # (1, C, H_fovea, W_fovea)
if perf_time:
perf1.Checkpoint('Infer fovea')
periph_mid_inferred = self.periph_net(mid_periph_rays_o[0], mid_periph_rays_d[0])
mid_inferred = periph_mid_inferred[:mid_cam.res[0] * mid_cam.res[1], :].view(
1, mid_cam.res[0], mid_cam.res[1], -1).permute(0, 3, 1, 2)
periph_inferred = periph_mid_inferred[mid_cam.res[0] * mid_cam.res[1]:, :].view(
1, periph_cam.res[0], periph_cam.res[1], -1).permute(0, 3, 1, 2)
if perf_time:
perf1.Checkpoint('Infer mid & periph')
perf.Checkpoint('Infer')
fovea_refined = refine.constrast_enhance(fovea_inferred, 3, 0.2)
mid_refined = refine.constrast_enhance(mid_inferred, 5, 0.2)
periph_refined = refine.constrast_enhance(periph_inferred, 5, 0.2)
if perf_time:
perf.Checkpoint('Refine')
blended = self.foveation.synthesis([
fovea_refined,
mid_refined,
periph_refined
], (gaze[0], gaze[1]))
if perf_time:
perf.Checkpoint('Blend')
if ret_raw:
return {
'fovea': fovea_refined,
'mid': mid_refined,
'periph': periph_refined,
'blended': blended,
'fovea_raw': fovea_inferred,
'mid_raw': mid_inferred,
'periph_raw': periph_inferred,
'blended_raw': self.foveation.synthesis([
fovea_inferred,
mid_inferred,
periph_inferred
], (gaze[0], gaze[1]))
}
return {
'fovea': fovea_refined,
'mid': mid_refined,
'periph': periph_refined,
'blended': blended
}
def _adjust_cam(self, cam: view.CameraParam, full_cam: view.CameraParam,
gaze: Tuple[float, float]) -> view.CameraParam:
fovea_offset = (
(gaze[0]) / full_cam.f[0].item() * cam.f[0].item(),
(gaze[1]) / full_cam.f[1].item() * cam.f[1].item()
)
return view.CameraParam({
'fx': cam.f[0].item(),
'fy': cam.f[1].item(),
'cx': cam.c[0].item() - fovea_offset[0],
'cy': cam.c[1].item() - fovea_offset[1]
}, cam.res, device=self.device)
import sys
import time
TOTAL_BAR_LENGTH = 80
TOTAL_BAR_LENGTH = 50
LAST_T = time.time()
BEGIN_T = LAST_T
def progress_bar(current, total, msg=None):
def progress_bar(current, total, msg=None, premsg=None):
global LAST_T, BEGIN_T
if current == 0:
BEGIN_T = time.time() # Reset for new bar.
......@@ -14,21 +14,25 @@ def progress_bar(current, total, msg=None):
current_len = int(TOTAL_BAR_LENGTH * (current + 1) / total)
rest_len = int(TOTAL_BAR_LENGTH - current_len) - 1
sys.stdout.write(' %d/%d' % (current + 1, total))
sys.stdout.write(' [')
if premsg:
sys.stdout.write(premsg)
sys.stdout.write(' ')
sys.stdout.write('[')
for i in range(current_len):
sys.stdout.write('=')
sys.stdout.write('>')
if current_len < TOTAL_BAR_LENGTH:
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
sys.stdout.write(' %d/%d' % (current + 1, total))
current_time = time.time()
step_time = current_time - LAST_T
LAST_T = current_time
total_time = current_time - BEGIN_T
time_used = ' Step: %s' % format_time(step_time)
time_used = ' | Step: %s' % format_time(step_time)
time_used += ' | Tot: %s' % format_time(total_time)
if msg:
time_used += ' | ' + msg
......@@ -37,9 +41,9 @@ def progress_bar(current, total, msg=None):
sys.stdout.write(msg)
if current < total - 1:
sys.stdout.write('\r')
sys.stdout.write(' \r')
else:
sys.stdout.write('\n')
sys.stdout.write(' \n')
sys.stdout.flush()
......@@ -67,10 +71,10 @@ def format_time(seconds):
output += str(minutes) + 'm'
time_index += 1
if seconds_final > 0 and time_index <= 2:
output += str(seconds_final) + 's'
output += '%02ds' % seconds_final
time_index += 1
if millis > 0 and time_index <= 2:
output += str(millis) + 'ms'
output += '%03dms' % millis
time_index += 1
if output == '':
output = '0ms'
......
import torch
import numpy as np
import torch.nn.functional as nn_f
from . import view
class GuideRefinement(object):
def __init__(self, guides_image, guides_view: view.Trans,
guides_cam: view.CameraParam, net) -> None:
rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True)
guides_inferred = torch.stack([
net(rays_o[i], rays_d[i]).view(
guides_cam.res[0], guides_cam.res[1], -1).permute(2, 0, 1)
for i in range(guides_image.size(0))
], 0)
self.guides_diff = (guides_image - guides_inferred) / \
(guides_inferred + 1e-5)
self.guides_view = guides_view
self.guides_cam = guides_cam
def get_warp(self, rays_o, rays_d, depthmap, tgt_trans: view.Trans, tgt_cam):
rays_size = list(depthmap.size()) + [3]
rays_o = rays_o.view(rays_size)
rays_d = rays_d.view(rays_size)
#print(rays_o.size(), rays_d.size(), depthmap.size())
pcloud = rays_o + rays_d * depthmap[..., None]
#print('pcloud', pcloud.size())
pcloud_in_tgt = tgt_trans.trans_point(pcloud, inverse=True)
#print(pcloud_in_tgt.size())
pixel_positions = tgt_cam.proj(pcloud_in_tgt)
pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5
pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5
pixel_positions -= 1
return pixel_positions
def refine_by_guide(self, image, depthmap, rays_o, rays_d, is_lr):
if is_lr:
image = nn_f.upsample(
image[None, ...], scale_factor=2, mode='bicubic')[0]
depthmap = nn_f.upsample(
depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]
warp = self.get_warp(rays_o, rays_d, depthmap,
self.guides_view, self.guides_cam)
warped_diff = nn_f.grid_sample(self.guides_diff, warp)
print(warp.size(), warped_diff.size())
avg_diff = torch.mean(warped_diff, 0)
return image * (1 + avg_diff)
def constrast_enhance(image, sigma, fe):
kernel = torch.ones(1, 1, 3, 3, device=image.device) / 9
mean = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
cScale = 1.0 + sigma * fe
return torch.clamp(mean + (image - mean) * cScale, 0, 1)
def get_grad(image):
kernel = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=image.device, dtype=torch.float32).view(1, 1, 3, 3)
x_grad = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
kernel = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=image.device, dtype=torch.float32).view(1, 1, 3, 3)
y_grad = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
return (x_grad ** 2 + y_grad ** 2).sqrt().clamp(0, 1)
def getGaussianKernel(ksize, sigma=0):
if sigma <= 0:
# 根据 kernelsize 计算默认的 sigma,和 opencv 保持一致
sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8
center = ksize // 2
xs = (np.arange(ksize, dtype=np.float32) - center) # 元素与矩阵中心的横向距离
kernel1d = np.exp(-(xs ** 2) / (2 * sigma ** 2)) # 计算一维卷积核
# 根据指数函数性质,利用矩阵乘法快速计算二维卷积核
kernel = kernel1d[..., None] @ kernel1d[None, ...]
kernel = torch.from_numpy(kernel)
kernel = kernel / kernel.sum() # 归一化
return kernel.view(1, 1, 3, 3)
def grad_aware_gaussian(image, ksize, sigma=0):
kernel = getGaussianKernel(ksize, sigma).to(image.device)
print(kernel.size())
blur = torch.cat([
nn_f.conv2d(image[:, 0:1], kernel, padding=1),
nn_f.conv2d(image[:, 1:2], kernel, padding=1),
nn_f.conv2d(image[:, 2:3], kernel, padding=1)
], 1)
grad = get_grad(image)
return image * grad + blur * (1 - grad)
def bilateral_filter(batch_img, ksize, sigmaColor=None, sigmaSpace=None):
device = batch_img.device
if sigmaSpace is None:
sigmaSpace = 0.15 * ksize + 0.35 # 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8
if sigmaColor is None:
sigmaColor = sigmaSpace
pad = (ksize - 1) // 2
batch_img_pad = nn_f.pad(batch_img, pad=[pad, pad, pad, pad], mode='reflect')
# batch_img 的维度为 BxcxHxW, 因此要沿着第 二、三维度 unfold
# patches.shape: B x C x H x W x ksize x ksize
patches = batch_img_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
patch_dim = patches.dim() # 6
# 求出像素亮度差
diff_color = patches - batch_img.unsqueeze(-1).unsqueeze(-1)
# 根据像素亮度差,计算权重矩阵
weights_color = torch.exp(-(diff_color ** 2) / (2 * sigmaColor ** 2))
# 归一化权重矩阵
weights_color = weights_color / weights_color.sum(dim=(-1, -2), keepdim=True)
# 获取 gaussian kernel 并将其复制成和 weight_color 形状相同的 tensor
weights_space = getGaussianKernel(ksize, sigmaSpace).to(device)
weights_space_dim = (patch_dim - 2) * (1,) + (ksize, ksize)
weights_space = weights_space.view(*weights_space_dim).expand_as(weights_color)
# 两个权重矩阵相乘得到总的权重矩阵
weights = weights_space * weights_color
# 总权重矩阵的归一化参数
weights_sum = weights.sum(dim=(-1, -2))
# 加权平均
weighted_pix = (weights * patches).sum(dim=(-1, -2)) / weights_sum
return weighted_pix
\ No newline at end of file
from typing import Mapping, Tuple, Union
from typing import List, Mapping, Tuple, Union
import torch
from . import util
......@@ -65,8 +65,7 @@ class CameraParam(object):
rays = rays.flatten(0, 1)
return rays
def get_global_rays(self, t: torch.Tensor, r: torch.Tensor,
flatten=False, norm=True) -> torch.Tensor:
def get_global_rays(self, trans, flatten=False, norm=True) -> torch.Tensor:
"""
[summary]
......@@ -77,9 +76,9 @@ class CameraParam(object):
:return: [description]
"""
rays = self.get_local_rays(flatten, norm) # (M.., 3)
rays_o, _ = torch.broadcast_tensors(t[..., None, :], rays) if flatten \
else torch.broadcast_tensors(t[..., None, None, :], rays) # (N.., M.., 3)
rays_d = trans_vector(rays, r)
rays_o, _ = torch.broadcast_tensors(trans.t[..., None, :], rays) if flatten \
else torch.broadcast_tensors(trans.t[..., None, None, :], rays) # (N.., M.., 3)
rays_d = trans.trans_vector(rays)
return rays_o, rays_d
def _convert_camera_params(self, input_camera_params: Mapping[str, Union[float, bool]],
......@@ -114,6 +113,62 @@ class CameraParam(object):
return camera_params
class Trans(object):
def __init__(self, t: torch.Tensor, r: torch.Tensor) -> None:
self.t = t
self.r = r
if len(self.t.size()) == 1:
self.t = self.t[None, :]
self.r = self.r[None, :, :]
def trans_point(self, p: torch.Tensor, inverse=False) -> torch.Tensor:
"""
Transform points by given translation vectors and rotation matrices
:param p ```Tensor(N.., 3)```: points to transform
:param t ```Tensor(M.., 3)```: translation vectors
:param r ```Tensor(M.., 3, 3)```: rotation matrices
:param inverse: whether perform inverse transform
:return ```Tensor(M.., N.., 3)```: transformed points
"""
size_N = list(p.size())[:-1]
size_M = list(self.r.size())[:-2]
out_size = size_M + size_N + [3]
t_size = size_M + [1 for _ in range(len(size_N))] + [3]
t = self.t.view(t_size) # (M.., 1.., 3)
if inverse:
p = (p - t).view(size_M + [-1, 3])
r = self.r
else:
p = p.view(-1, 3)
r = self.r.movedim(-1, -2) # Transpose rotation matrices
out = torch.matmul(p, r).view(out_size)
if not inverse:
out = out + t
return out
def trans_vector(self, v: torch.Tensor, inverse=False) -> torch.Tensor:
"""
Transform vectors by given translation vectors and rotation matrices
:param v ```Tensor(N.., 3)```: vectors to transform
:param r ```Tensor(M.., 3, 3)```: rotation matrices
:param inverse: whether perform inverse transform
:return ```Tensor(M.., N.., 3)```: transformed vectors
"""
out_size = list(self.r.size())[:-2] + list(v.size())[:-1] + [3]
r = self.r if inverse else self.r.movedim(-1, -2) # Transpose rotation matrices
out = torch.matmul(v.view(-1, 3), r).view(out_size)
return out
def size(self) -> List[int]:
return list(self.t.size()[:-1])
def get(self, *index):
return Trans(self.t[index], self.r[index])
def trans_point(p: torch.Tensor, t: torch.Tensor, r: torch.Tensor, inverse=False) -> torch.Tensor:
"""
Transform points by given translation vectors and rotation matrices
......
from typing import Tuple
import math
import torch
import torch.nn as nn
from .my import net_modules
from .my import util
from .my import device
from .my import color_mode
from ..my import net_modules
from ..my import util
from ..my import device
from ..my import color_mode
rand_gen = torch.Generator(device=device.GetDevice())
......@@ -112,7 +113,7 @@ class Rendering(nn.Module):
# dists: (N_rays, N_samples)
dists = z_vals[..., 1:] - z_vals[..., :-1]
last_dist = z_vals[..., 0:1] * 0 + 1e10
dists = torch.cat([
dists, last_dist
], -1)
......@@ -145,19 +146,17 @@ class Sampler(nn.Module):
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super().__init__()
if lindisp:
self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1],
n_samples, device=device.GetDevice())
else:
self.r = torch.linspace(depth_range[0], depth_range[1],
n_samples, device=device.GetDevice())
self.lindisp = lindisp
if self.lindisp:
depth_range = (1 / depth_range[0], 1 / depth_range[1])
self.r = torch.linspace(depth_range[0], depth_range[1],
n_samples, device=device.GetDevice())
step = (depth_range[1] - depth_range[0]) / (n_samples - 1)
self.perturb_sample = perturb_sample
self.spherical = spherical
self.inverse_r = inverse_r
if perturb_sample:
mids = .5 * (self.r[1:] + self.r[:-1])
self.upper = torch.cat([mids, self.r[-1:]], -1)
self.lower = torch.cat([self.r[:1], mids], -1)
self.upper = torch.clamp_min(self.r + step / 2, 0)
self.lower = torch.clamp_min(self.r - step / 2, 0)
def forward(self, rays_o, rays_d):
"""
......@@ -177,18 +176,26 @@ class Sampler(nn.Module):
r = self.lower + (self.upper - self.lower) * t_rand
else:
r = self.r
if self.lindisp:
r = torch.reciprocal(r)
if self.spherical:
pts, depths = RaySphereIntersect(rays_o, rays_d, r)
sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
return sphers, depths
return sphers, pts, depths
else:
return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
# x>0, y>0 -> (y, -x)
# x<0, y>0 -> (-y, x)
# x<0, y<0 -> (y, -x)
# x>0, y<0 -> (-y, x)
class MslNet(nn.Module):
def __init__(self, fc_params, sampler_params,
normalize_coord: bool,
dir_as_input: bool,
color: int = color_mode.RGB,
encode_to_dim: int = 0,
export_mode: bool = False):
......@@ -197,7 +204,8 @@ class MslNet(nn.Module):
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param gray: is grayscale mode
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param color: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
......@@ -209,7 +217,10 @@ class MslNet(nn.Module):
self.sampler = Sampler(**sampler_params)
self.rendering = Rendering()
self.export_mode = export_mode
if color == color_mode.YCbCr:
self.normalize_coord = normalize_coord
self.dir_as_input = dir_as_input
self.color = color
if self.color == color_mode.YCbCr:
self.net1 = net_modules.FcNet(
in_chns=fc_params['in_chns'],
out_chns=fc_params['nf'] + 2,
......@@ -221,9 +232,52 @@ class MslNet(nn.Module):
nf=fc_params['nf'],
n_layers=1)
self.net = None
elif self.dir_as_input:
self.input_encoder2 = net_modules.InputEncoder.Get(4, 2)
self.net1 = net_modules.FcNet(
in_chns=fc_params['in_chns'],
out_chns=fc_params['nf'],
nf=fc_params['nf'],
n_layers=fc_params['n_layers'])
self.net2 = net_modules.FcNet(
in_chns=fc_params['nf'] + self.input_encoder2.out_dim,
out_chns=fc_params['out_chns'],
nf=fc_params['nf'],
n_layers=1)
self.net = None
else:
self.net = net_modules.FcNet(**fc_params)
if self.normalize_coord:
self.register_buffer('angle_range', torch.tensor(
[[1e5, 1e5], [-1e5, -1e5]]))
self.register_buffer('depth_range', torch.tensor([
self.sampler.lower[0], self.sampler.upper[-1]
]))
def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
coords, _, _ = self.sampler(rays_o, rays_d)
coords = coords[..., 1:].view(-1, 2)
self.angle_range = torch.stack([
torch.cat([coords, self.angle_range[0:1]]).amin(0),
torch.cat([coords, self.angle_range[1:2]]).amax(0)
])
def calc_local_dir(self, rays_d, coords, pts: torch.Tensor):
"""
[summary]
:param rays_d ```Tensor(B, 3)```:
:param coords ```Tensor(B, N, 3)```:
:param pts ```Tensor(B, N, 3)```:
:return ```Tensor(B, N, 2)```
"""
local_z = pts / pts.norm(dim=-1, keepdim=True)
local_x = util.SphericalToCartesian(coords + torch.tensor([0, 0.1 / 180 * math.pi, 0], device=coords.device)) - pts
local_x = local_x / local_x.norm(dim=-1, keepdim=True)
local_y = torch.cross(local_x, local_z, -1)
local_rot = torch.stack([local_x, local_y, local_z], dim=-2) # (B, N, 3, 3)
return util.CartesianToSpherical(torch.matmul(rays_d[:, None, None, :], local_rot)).squeeze(-2)[..., 1:3]
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor:
"""
......@@ -233,19 +287,30 @@ class MslNet(nn.Module):
:param rays_d ```Tensor(B, 3)```: rays' direction
:return: ```Tensor(B, C)``, inferred images/pixels
"""
coords, depths = self.sampler(rays_o, rays_d)
coords, pts, depths = self.sampler(rays_o, rays_d)
if self.dir_as_input:
dirs = self.calc_local_dir(rays_d, coords, pts)
if self.normalize_coord: # Normalize coords to [0, 2pi]
range = torch.cat([self.depth_range.view(2, 1), self.angle_range], 1)
coords = (coords - range[0]) / (range[1] - range[0]) * 2 * math.pi
encoded = self.input_encoder(coords)
if not self.net:
if self.color == color_mode.YCbCr:
mid_output = self.net1(encoded)
net2_output = self.net2(mid_output[..., :-2])
raw = torch.cat([
mid_output[..., -2:],
net2_output
], -1)
elif self.dir_as_input:
encoded_dirs = self.input_encoder2(dirs)
#print(encoded.size(), self.net1(encoded).size(), encoded_dirs.size())
raw = self.net2(torch.cat([self.net1(encoded), encoded_dirs], -1))
else:
raw = self.net(encoded)
if self.export_mode:
colors, alphas = self.rendering.raw2color(raw, depths)
return torch.cat([colors, alphas[..., None]], -1)
......@@ -254,5 +319,5 @@ class MslNet(nn.Module):
color_map, _, _, _, depth_map = self.rendering(
raw, depths, ret_extra=True)
return color_map, depth_map
return self.rendering(raw, depths)
from typing import Tuple
import math
import torch
import torch.nn as nn
from .my import net_modules
from .my import util
from .my import device
from ..my import net_modules
from ..my import util
from ..my import device
from ..my import color_mode
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())
......@@ -110,7 +113,7 @@ class Rendering(nn.Module):
# dists: (N_rays, N_samples)
dists = z_vals[..., 1:] - z_vals[..., :-1]
last_dist = z_vals[..., 0:1] * 0 + 1e10
dists = torch.cat([
dists, last_dist
], -1)
......@@ -143,19 +146,17 @@ class Sampler(nn.Module):
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super().__init__()
if lindisp:
self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1],
n_samples, device=device.GetDevice())
else:
self.r = torch.linspace(depth_range[0], depth_range[1],
n_samples, device=device.GetDevice())
self.lindisp = lindisp
if self.lindisp:
depth_range = (1 / depth_range[0], 1 / depth_range[1])
self.r = torch.linspace(depth_range[0], depth_range[1],
n_samples, device=device.GetDevice())
step = (depth_range[1] - depth_range[0]) / (n_samples - 1)
self.perturb_sample = perturb_sample
self.spherical = spherical
self.inverse_r = inverse_r
if perturb_sample:
mids = .5 * (self.r[1:] + self.r[:-1])
self.upper = torch.cat([mids, self.r[-1:]], -1)
self.lower = torch.cat([self.r[:1], mids], -1)
self.upper = torch.clamp_min(self.r + step / 2, 0)
self.lower = torch.clamp_min(self.r - step / 2, 0)
def forward(self, rays_o, rays_d):
"""
......@@ -175,19 +176,24 @@ class Sampler(nn.Module):
r = self.lower + (self.upper - self.lower) * t_rand
else:
r = self.r
if self.lindisp:
r = torch.reciprocal(r)
if self.spherical:
pts, depths = RaySphereIntersect(rays_o, rays_d, r)
sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
return sphers, depths
return sphers, pts, depths
else:
return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
class MslNet(nn.Module):
class NewMslNet(nn.Module):
def __init__(self, fc_params, sampler_params,
gray=False,
normalize_coord: bool,
dir_as_input: bool,
not_same_net: bool = False,
color: int = color_mode.RGB,
encode_to_dim: int = 0,
export_mode: bool = False):
"""
......@@ -195,7 +201,8 @@ class MslNet(nn.Module):
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param gray: is grayscale mode
:param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
:param color: color mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
......@@ -203,14 +210,36 @@ class MslNet(nn.Module):
self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4
fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering()
self.export_mode = export_mode
def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor, local_rays: torch.Tensor,
self.normalize_coord = normalize_coord
self.nets = nn.ModuleList([
net_modules.FcNet(**fc_params),
net_modules.FcNet(in_chns=fc_params['in_chns'],
out_chns=fc_params['out_chns'],
nf=128, n_layers=4) if not_same_net
else net_modules.FcNet(**fc_params)
])
self.n_samples = sampler_params['n_samples']
if self.normalize_coord:
self.register_buffer('angle_range', torch.tensor(
[[1e5, 1e5], [-1e5, -1e5]]))
self.register_buffer('depth_range', torch.tensor([
[self.sampler.lower[0], self.sampler.lower[self.n_samples // 2]],
[self.sampler.upper[self.n_samples // 2 - 1], self.sampler.upper[-1]]
]))
def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
coords, _, _ = self.sampler(rays_o, rays_d)
coords = coords[..., 1:].view(-1, 2)
self.angle_range = torch.stack([
torch.cat([coords, self.angle_range[0:1]]).amin(0),
torch.cat([coords, self.angle_range[1:2]]).amax(0)
])
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor:
"""
rays -> colors
......@@ -219,18 +248,28 @@ class MslNet(nn.Module):
:param rays_d ```Tensor(B, 3)```: rays' direction
:return: ```Tensor(B, C)``, inferred images/pixels
"""
rays_o = local_rays * 0 + view_centers
rays_d = torch.matmul(local_rays.flatten(0, -2), r).view(out_size)
coords, depths = self.sampler(rays_o, rays_d)
coords, pts, depths = self.sampler(rays_o, rays_d)
if self.normalize_coord: # Normalize coords to [0, 2pi]
range = torch.cat([self.depth_range[:, 0:1], self.angle_range], 1)
coords[:, :self.n_samples // 2] = (coords[:, :self.n_samples // 2] - range[0]) / (
range[1] - range[0]) * 2 * math.pi
range = torch.cat([self.depth_range[:, 1:2], self.angle_range], 1)
coords[:, self.n_samples // 2:] = (coords[:, self.n_samples // 2:] - range[0]) / (
range[1] - range[0]) * 2 * math.pi
encoded = self.input_encoder(coords)
raw = torch.cat([
self.nets[0](encoded[:, :self.n_samples // 2]),
self.nets[1](encoded[:, self.n_samples // 2:]),
], 1)
if self.export_mode:
colors, alphas = self.rendering.raw2color(self.net(encoded), depths)
colors, alphas = self.rendering.raw2color(raw, depths)
return torch.cat([colors, alphas[..., None]], -1)
if ret_depth:
color_map, _, _, _, depth_map = self.rendering(
self.net(encoded), depths, ret_extra=True)
raw, depths, ret_extra=True)
return color_map, depth_map
return self.rendering(self.net(encoded), depths)
return self.rendering(raw, depths)
import torch
import torch.nn as nn
from .my import net_modules
from .my import util
from ..my import net_modules
from ..my import util
class SpherNet(nn.Module):
......
from typing import List
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import util
from .my import device
from ..pytorch_prototyping.pytorch_prototyping import *
from ..my import util
from ..my import device
class Encoder(nn.Module):
......
......@@ -18,28 +18,29 @@
"from ..my import util\n",
"\n",
"path_patts = [\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/gt/view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/fovea_rgb@msl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_500/train/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/input/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/output/view_%04d.png'\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/test/view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/fovea_rgb@msl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_500/test/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/new_fovea_rgb@nmsl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_300/test/out_view_%04d.png',\n",
" # '/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_mid_trans_2021.01.11/new_fovea_rgb@nmsl-rgb_e10_fc128x4_d1-50_s32/output/model-epoch_200/test/out_view_%04d.png'\n",
"]\n",
"titles = [ 'Ground truth', 'Normal', 'Low Res', 'Upsampling']\n",
"show_range = range(5)\n",
"titles = ['Ground truth', 'Baseline', 'NMSL', 'With Dir', 'NMSL_32']\n",
"#show_range = range(20)\n",
"show_range = [ 2, 4, 5, 10, 12, 13, 14 ]\n",
"\n",
"#os.chdir('/home/dengnc/deep_view_syn/data/')\n",
"# os.chdir('/home/dengnc/deep_view_syn/data/')\n",
"image_seqs = [\n",
" util.ReadImageTensor([path_patt % i for i in show_range])\n",
" for path_patt in path_patts\n",
"]\n",
"\n",
"for i in show_range:\n",
" plt.figure(facecolor='white', figsize=(12, 4))\n",
" plt.figure(facecolor='white', figsize=(4 * len(image_seqs), 4))\n",
" plt.suptitle('View %d' % i)\n",
" for j in range(len(image_seqs)):\n",
" plt.subplot(1, len(image_seqs), j + 1)\n",
" plt.title(titles[j])\n",
" util.PlotImageTensor(image_seqs[j][i])\n",
" \n"
" util.PlotImageTensor(image_seqs[j][i])\n"
]
},
{
......@@ -53,11 +54,19 @@
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9-final"
},
"orig_nbformat": 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment