Commit 3554ba52 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f7038e26
import sys import sys
import os import os
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield" __package__ = "deep_view_syn"
import argparse import argparse
from PIL import Image from PIL import Image
......
...@@ -4,6 +4,8 @@ import torch.nn as nn ...@@ -4,6 +4,8 @@ import torch.nn as nn
from .my import net_modules from .my import net_modules
from .my import util from .my import util
from .my import device from .my import device
from .my import color_mode
rand_gen = torch.Generator(device=device.GetDevice()) rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed()) rand_gen.manual_seed(torch.seed())
...@@ -187,7 +189,7 @@ class Sampler(nn.Module): ...@@ -187,7 +189,7 @@ class Sampler(nn.Module):
class MslNet(nn.Module): class MslNet(nn.Module):
def __init__(self, fc_params, sampler_params, def __init__(self, fc_params, sampler_params,
gray=False, color: int = color_mode.RGB,
encode_to_dim: int = 0, encode_to_dim: int = 0,
export_mode: bool = False): export_mode: bool = False):
""" """
...@@ -203,11 +205,24 @@ class MslNet(nn.Module): ...@@ -203,11 +205,24 @@ class MslNet(nn.Module):
self.input_encoder = net_modules.InputEncoder.Get( self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns) encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4 fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
self.sampler = Sampler(**sampler_params) self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering() self.rendering = Rendering()
self.export_mode = export_mode self.export_mode = export_mode
if color == color_mode.YCbCr:
self.net1 = net_modules.FcNet(
in_chns=fc_params['in_chns'],
out_chns=fc_params['nf'] + 2,
nf=fc_params['nf'],
n_layers=fc_params['n_layers'] - 2)
self.net2 = net_modules.FcNet(
in_chns=fc_params['nf'],
out_chns=2,
nf=fc_params['nf'],
n_layers=1)
self.net = None
else:
self.net = net_modules.FcNet(**fc_params)
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor: ret_depth: bool = False) -> torch.Tensor:
...@@ -221,13 +236,23 @@ class MslNet(nn.Module): ...@@ -221,13 +236,23 @@ class MslNet(nn.Module):
coords, depths = self.sampler(rays_o, rays_d) coords, depths = self.sampler(rays_o, rays_d)
encoded = self.input_encoder(coords) encoded = self.input_encoder(coords)
if not self.net:
mid_output = self.net1(encoded)
net2_output = self.net2(mid_output[..., :-2])
raw = torch.cat([
mid_output[..., -2:],
net2_output
], -1)
else:
raw = self.net(encoded)
if self.export_mode: if self.export_mode:
colors, alphas = self.rendering.raw2color(self.net(encoded), depths) colors, alphas = self.rendering.raw2color(raw, depths)
return torch.cat([colors, alphas[..., None]], -1) return torch.cat([colors, alphas[..., None]], -1)
if ret_depth: if ret_depth:
color_map, _, _, _, depth_map = self.rendering( color_map, _, _, _, depth_map = self.rendering(
self.net(encoded), depths, ret_extra=True) raw, depths, ret_extra=True)
return color_map, depth_map return color_map, depth_map
return self.rendering(self.net(encoded), depths) return self.rendering(raw, depths)
...@@ -209,7 +209,8 @@ class MslNet(nn.Module): ...@@ -209,7 +209,8 @@ class MslNet(nn.Module):
self.rendering = Rendering() self.rendering = Rendering()
self.export_mode = export_mode self.export_mode = export_mode
def forward(self, fc_input: torch.Tensor,
def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor, local_rays: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor: ret_depth: bool = False) -> torch.Tensor:
""" """
rays -> colors rays -> colors
...@@ -218,10 +219,13 @@ class MslNet(nn.Module): ...@@ -218,10 +219,13 @@ class MslNet(nn.Module):
:param rays_d ```Tensor(B, 3)```: rays' direction :param rays_d ```Tensor(B, 3)```: rays' direction
:return: ```Tensor(B, C)``, inferred images/pixels :return: ```Tensor(B, C)``, inferred images/pixels
""" """
depths = torch.ones(4096, 16, device="cuda") rays_o = local_rays * 0 + view_centers
rays_d = torch.matmul(local_rays.flatten(0, -2), r).view(out_size)
coords, depths = self.sampler(rays_o, rays_d)
encoded = self.input_encoder(coords)
if self.export_mode: if self.export_mode:
colors, alphas = self.rendering.raw2color(self.net(fc_input), depths) colors, alphas = self.rendering.raw2color(self.net(encoded), depths)
return torch.cat([colors, alphas[..., None]], -1) return torch.cat([colors, alphas[..., None]], -1)
if ret_depth: if ret_depth:
......
RGB = 0
GRAY = 1
YCbCr = 2
def to_str(color_mode):
return "gray" if color_mode == GRAY \
else ("ybr" if color_mode == YCbCr
else "rgb")
def from_str(color_str):
return GRAY if color_str == 'gray' \
else (YCbCr if color_str == 'ybr'
else RGB)
import torch
import torch.nn.functional as nn_f
from . import view
def get_warp(rays_o, rays_d, depthmap, tgt_o, tgt_r, tgt_cam):
print(rays_o.size(), rays_d.size(), depthmap.size())
pcloud = rays_o + rays_d * depthmap[..., None]
print(rays_o.size(), rays_d.size(), depthmap.size(), pcloud.size())
pcloud_in_tgt = view.trans_point(
pcloud, tgt_o, tgt_r, inverse=True)
print(pcloud_in_tgt.size())
pixel_positions = tgt_cam.proj(pcloud_in_tgt)
pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5
pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5
pixel_positions -= 1
return pixel_positions
def refine(image, depthmap, rays_o, rays_d, bounds_img, bounds_o,
bounds_r, ref_cam: view.CameraParam, net, is_lr):
if is_lr:
image = nn_f.upsample(
image[None, ...], scale_factor=2, mode='bicubic')[0]
depthmap = nn_f.upsample(
depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]
bounds_rays_o, bounds_rays_d = ref_cam.get_global_rays(
bounds_o, bounds_r, flatten=True)
bounds_inferred = torch.stack([
net(bounds_rays_o[i], bounds_rays_d[i]).view(
ref_cam.res[0], ref_cam.res[1], -1).permute(2, 0, 1)
for i in range(bounds_img.size(0))
], 0)
bounds_diff = (bounds_img - bounds_inferred) / (bounds_inferred + 1e-5)
bounds_warp = get_warp(rays_o, rays_d, depthmap,
bounds_o, bounds_r, ref_cam)
warped_diff = nn_f.grid_sample(bounds_diff, bounds_warp)
print(bounds_warp.size(), warped_diff.size())
avg_diff = torch.mean(warped_diff, 0)
return image * (1 + avg_diff)
...@@ -17,12 +17,16 @@ class Foveation(object): ...@@ -17,12 +17,16 @@ class Foveation(object):
self._gen_layer_blendmap(i) self._gen_layer_blendmap(i)
for i in range(self.n_layers - 1) for i in range(self.n_layers - 1)
] # blend maps of fovea layers ] # blend maps of fovea layers
self.coords = util.MeshGrid(out_res).to(device=device)
def to(self, device): def to(self, device):
self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend] self.eye_fovea_blend = [x.to(device=device)
for x in self.eye_fovea_blend]
self.coords = self.coords.to(device=device)
return self return self
def synthesis(self, layers: List[torch.Tensor]) -> torch.Tensor: def synthesis(self, layers: List[torch.Tensor],
normalized_fovea_center: Tuple[float, float]) -> torch.Tensor:
""" """
Generate foveated retinal image by blending fovea layers Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers** **Note: current implementation only support two fovea layers**
...@@ -32,12 +36,17 @@ class Foveation(object): ...@@ -32,12 +36,17 @@ class Foveation(object):
""" """
output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res, output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
mode='bilinear', align_corners=False) mode='bilinear', align_corners=False)
c = torch.tensor([
normalized_fovea_center[0] * self.out_res[1],
normalized_fovea_center[1] * self.out_res[0]
], device=self.coords.device)
for i in range(self.n_layers - 2, -1, -1): for i in range(self.n_layers - 2, -1, -1):
output_roi = output[self.get_layer_region_in_final_image(i)] if layers[i] == None:
image = nn_f.interpolate(layers[i], output_roi.size()[-2:], continue
mode='bilinear', align_corners=False) R = self.get_layer_size_in_final_image(i) / 2
blend = self.eye_fovea_blend[i] grid = ((self.coords - c) / R)[None, ...]
output_roi.mul_(1 - blend).add_(image * blend) blend = nn_f.grid_sample(self.eye_fovea_blend[i][None, None, ...], grid) # (1, 1, H:out, W:out)
output.mul_(1 - blend).add_(nn_f.grid_sample(layers[i], grid) * blend)
return output return output
def get_layer_size_in_final_image(self, i: int) -> int: def get_layer_size_in_final_image(self, i: int) -> int:
...@@ -52,7 +61,8 @@ class Foveation(object): ...@@ -52,7 +61,8 @@ class Foveation(object):
k = length_i / length k = length_i / length
return int(math.ceil(self.out_res[0] * k)) return int(math.ceil(self.out_res[0] * k))
def get_layer_region_in_final_image(self, i: int) -> Tuple[slice, slice]: def get_layer_region_in_final_image(self, i: int,
normalized_fovea_center: Tuple[float, float]) -> Tuple[slice, slice]:
""" """
Get region of fovea layer i in final image Get region of fovea layer i in final image
...@@ -60,8 +70,10 @@ class Foveation(object): ...@@ -60,8 +70,10 @@ class Foveation(object):
:return: tuple of slice objects stores the start and end of region in horizontal and vertical :return: tuple of slice objects stores the start and end of region in horizontal and vertical
""" """
roi_size = self.get_layer_size_in_final_image(i) roi_size = self.get_layer_size_in_final_image(i)
roi_offset_y = (self.out_res[0] - roi_size) // 2 roi_center = (int(self.out_res[1] * normalized_fovea_center[0]),
roi_offset_x = (self.out_res[1] - roi_size) // 2 int(self.out_res[0] * normalized_fovea_center[1]))
roi_offset_y = roi_center[1] - roi_size // 2
roi_offset_x = roi_center[0] - roi_size // 2
return (..., return (...,
slice(roi_offset_y, roi_offset_y + roi_size), slice(roi_offset_y, roi_offset_y + roi_size),
slice(roi_offset_x, roi_offset_x + roi_size) slice(roi_offset_x, roi_offset_x + roi_size)
...@@ -76,6 +88,7 @@ class Foveation(object): ...@@ -76,6 +88,7 @@ class Foveation(object):
""" """
size = self.get_layer_size_in_final_image(i) size = self.get_layer_size_in_final_image(i)
R = size / 2 R = size / 2
p = util.MeshGrid((size, size)).to(device=self.device) # (size, size, 2) p = util.MeshGrid((size, size)).to(
device=self.device) # (size, size, 2)
r = torch.norm(p - R, dim=2) # (size, size, 2) r = torch.norm(p - R, dim=2) # (size, size, 2)
return util.SmoothStep(R, R * 0.6, r) return util.SmoothStep(R, R * 0.6, r)
...@@ -24,6 +24,8 @@ class SimplePerf(object): ...@@ -24,6 +24,8 @@ class SimplePerf(object):
return return
self.end_event.record() self.end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print('%s: %.1fms' % (name, self.start_event.elapsed_time(self.end_event))) duration = self.start_event.elapsed_time(self.end_event)
print('%s: %.1fms' % (name, duration))
if not end: if not end:
self.start_event.record() self.start_event.record()
return duration
\ No newline at end of file
...@@ -308,3 +308,74 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: ...@@ -308,3 +308,74 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
out_shape = list(ref.size()) out_shape = list(ref.size())
out_shape[-1] = -1 out_shape[-1] = -1
return input.view(out_shape) return input.view(out_shape)
def rgb2ycbcr(input: torch.Tensor) -> torch.Tensor:
"""
Convert input tensor from RGB to YCbCr
:param input ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
:return ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
"""
if input.size(-1) == 3:
r = input[..., 0:1]
g = input[..., 1:2]
b = input[..., 2:3]
dim_c = -1
else:
r = input[..., 0:1, :, :]
g = input[..., 1:2, :, :]
b = input[..., 2:3, :, :]
dim_c = -3
y = r * 0.25678824 + g * 0.50412941 + b * 0.09790588 + 0.0625
cb = r * -0.14822353 + g * -0.29099216 + b * 0.43921569 + 0.5
cr = r * 0.43921569 + g * -0.36778824 + b * -0.07142745 + 0.5
return torch.cat([y, cb, cr], dim_c)
def rgb2ycbcr(input: torch.Tensor) -> torch.Tensor:
"""
Convert input tensor from RGB to YCbCr
:param input ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
:return ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
"""
if input.size(-1) == 3:
r = input[..., 0:1]
g = input[..., 1:2]
b = input[..., 2:3]
dim_c = -1
else:
r = input[..., 0:1, :, :]
g = input[..., 1:2, :, :]
b = input[..., 2:3, :, :]
dim_c = -3
y = r * 0.257 + g * 0.504 + b * 0.098 + 0.0625
cb = r * -0.148 + g * -0.291 + b * 0.439 + 0.5
cr = r * 0.439 + g * -0.368 + b * -0.071 + 0.5
return torch.cat([cb, cr, y], dim_c)
def ycbcr2rgb(input: torch.Tensor) -> torch.Tensor:
"""
Convert input tensor from YCbCr to RGB
:param input ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
:return ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
"""
if input.size(-1) == 3:
cb = input[..., 0:1]
cr = input[..., 1:2]
y = input[..., 2:3]
dim_c = -1
else:
cb = input[..., 0:1, :, :]
cr = input[..., 1:2, :, :]
y = input[..., 2:3, :, :]
dim_c = -3
y = y - 0.0625
cb = cb - 0.5
cr = cr - 0.5
r = y * 1.164 + cr * 1.596
g = y * 1.164 + cb * -0.392 + cr * -0.813
b = y * 1.164 + cb * 2.017
return torch.cat([r, g, b], dim_c)
\ No newline at end of file
...@@ -18,6 +18,13 @@ class CameraParam(object): ...@@ -18,6 +18,13 @@ class CameraParam(object):
self.c = self.c.to(device) self.c = self.c.to(device)
return self return self
def resize(self, res: Tuple[int, int]):
self.f[0] = self.f[0] / self.res[1] * res[1]
self.f[1] = self.f[1] / self.res[0] * res[0]
self.c[0] = self.c[0] / self.res[1] * res[1]
self.c[1] = self.c[1] / self.res[0] * res[0]
self.res = res
def proj(self, p: torch.Tensor) -> torch.Tensor: def proj(self, p: torch.Tensor) -> torch.Tensor:
""" """
Project positions in local space to image plane Project positions in local space to image plane
...@@ -70,8 +77,8 @@ class CameraParam(object): ...@@ -70,8 +77,8 @@ class CameraParam(object):
:return: [description] :return: [description]
""" """
rays = self.get_local_rays(flatten, norm) # (M.., 3) rays = self.get_local_rays(flatten, norm) # (M.., 3)
rays_o, _ = torch.broadcast_tensors( rays_o, _ = torch.broadcast_tensors(t[..., None, :], rays) if flatten \
t[..., None, None, :], rays) # (N.., M.., 3) else torch.broadcast_tensors(t[..., None, None, :], rays) # (N.., M.., 3)
rays_d = trans_vector(rays, r) rays_d = trans_vector(rays, r)
return rays_o, rays_d return rays_o, rays_d
...@@ -87,8 +94,11 @@ class CameraParam(object): ...@@ -87,8 +94,11 @@ class CameraParam(object):
input_is_normalized = bool(input_camera_params.get('normalized')) input_is_normalized = bool(input_camera_params.get('normalized'))
camera_params = {} camera_params = {}
if 'fov' in input_camera_params: if 'fov' in input_camera_params:
camera_params['fx'] = camera_params['fy'] = \ if input_is_normalized:
(1 if input_is_normalized else view_res[0]) / \ camera_params['fy'] = 1 / util.Fov2Length(input_camera_params['fov'])
camera_params['fx'] = camera_params['fy'] / view_res[1] * view_res[0]
else:
camera_params['fx'] = camera_params['fy'] = view_res[0] / \
util.Fov2Length(input_camera_params['fov']) util.Fov2Length(input_camera_params['fov'])
camera_params['fy'] *= -1 camera_params['fy'] *= -1
else: else:
...@@ -114,15 +124,17 @@ def trans_point(p: torch.Tensor, t: torch.Tensor, r: torch.Tensor, inverse=False ...@@ -114,15 +124,17 @@ def trans_point(p: torch.Tensor, t: torch.Tensor, r: torch.Tensor, inverse=False
:param inverse: whether perform inverse transform :param inverse: whether perform inverse transform
:return ```Tensor(M.., N.., 3)```: transformed points :return ```Tensor(M.., N.., 3)```: transformed points
""" """
out_size = list(r.size())[0:-2] + list(p.size())[0:-1] + [3] size_N = list(p.size())[0:-1]
t_size = list(t.size()[0:-1]) + \ size_M = list(r.size())[0:-2]
[1 for _ in range(len(p.size()[0:-1]))] + [3] out_size = size_M + size_N + [3]
t_size = size_M + [1 for _ in range(len(size_N))] + [3]
t = t.view(t_size) t = t.view(t_size)
if not inverse: if not inverse:
r = r.movedim(-1, -2) # Transpose rotation matrices r = r.movedim(-1, -2) # Transpose rotation matrices
else: else:
p = p - t p = p - t
out = torch.matmul(p.flatten(0, -2), r).view(out_size) out = torch.matmul(p.view(size_M + [-1, 3]), r)
out = out.view(out_size)
if not inverse: if not inverse:
out = out + t out = out + t
return out return out
......
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import torchvision.transforms.functional as trans_f\n",
"\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"\n",
"from ..my import util\n",
"\n",
"path_patts = [\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/gt/view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/fovea_rgb@msl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_500/train/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/input/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/output/view_%04d.png'\n",
"]\n",
"titles = [ 'Ground truth', 'Normal', 'Low Res', 'Upsampling']\n",
"show_range = range(5)\n",
"\n",
"#os.chdir('/home/dengnc/deep_view_syn/data/')\n",
"image_seqs = [\n",
" util.ReadImageTensor([path_patt % i for i in show_range])\n",
" for path_patt in path_patts\n",
"]\n",
"\n",
"for i in show_range:\n",
" plt.figure(facecolor='white', figsize=(12, 4))\n",
" plt.suptitle('View %d' % i)\n",
" for j in range(len(image_seqs)):\n",
" plt.subplot(1, len(image_seqs), j + 1)\n",
" plt.title(titles[j])\n",
" util.PlotImageTensor(image_seqs[j][i])\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"nbconvert_exporter": "python",
"version": "3.7.9-final"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
"import torch\n", "import torch\n",
"from torch import nn\n", "from torch import nn\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"from deeplightfield.data.lf_syn import LightFieldSynDataset\n", "from deep_view_syn.data.lf_syn import LightFieldSynDataset\n",
"from deeplightfield.my import util\n", "from deep_view_syn.my import util\n",
"from deeplightfield.trans_unet import LatentSpaceTransformer\n", "from deep_view_syn.trans_unet import LatentSpaceTransformer\n",
"\n", "\n",
"device = torch.device(\"cuda:2\")\n" "device = torch.device(\"cuda:2\")\n"
] ]
......
This diff is collapsed.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -10,16 +10,14 @@ ...@@ -10,16 +10,14 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Set CUDA:2 as current device.\n", "Set CUDA:2 as current device.\n",
"Change working directory to /e/dengnc/deeplightfield/data/sp_view_syn_2020.12.31_fovea\n" "Change working directory to /home/dengnc/deep_view_syn/data/sp_view_syn_2020.12.31_fovea\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": "<torch.autograd.grad_mode.set_grad_enabled at 0x7f6824144910>"
"<torch.autograd.grad_mode.set_grad_enabled at 0x7fea6b9c2d50>"
]
}, },
"execution_count": 4, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -34,16 +32,17 @@ ...@@ -34,16 +32,17 @@
"\n", "\n",
"\n", "\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n", "sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"torch.cuda.set_device(2)\n", "torch.cuda.set_device(2)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n", "\n",
"from deeplightfield.data.spherical_view_syn import *\n", "from ..data.spherical_view_syn import *\n",
"from deeplightfield.msl_net import MslNet\n", "from ..msl_net import MslNet\n",
"from deeplightfield.configs.spherical_view_syn import SphericalViewSynConfig\n", "from ..configs.spherical_view_syn import SphericalViewSynConfig\n",
"from deeplightfield.my import netio\n", "from ..my import netio\n",
"from deeplightfield.my import util\n", "from ..my import util\n",
"from deeplightfield.my import device\n", "from ..my import device\n",
"from deeplightfield.my import view\n", "from ..my import view\n",
"\n", "\n",
"\n", "\n",
"os.chdir(sys.path[0] + '/../data/sp_view_syn_2020.12.31_fovea')\n", "os.chdir(sys.path[0] + '/../data/sp_view_syn_2020.12.31_fovea')\n",
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
"import math\n", "import math\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import numpy as np\n", "import numpy as np\n",
"from deeplightfield.my import util\n", "from deep_view_syn.my import util\n",
"from deeplightfield.msl_net import *\n", "from deep_view_syn.msl_net import *\n",
"\n", "\n",
"# Select device\n", "# Select device\n",
"torch.cuda.set_device(2)\n", "torch.cuda.set_device(2)\n",
...@@ -121,8 +121,8 @@ ...@@ -121,8 +121,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deeplightfield.data.spherical_view_syn import FastSphericalViewSynDataset\n", "from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset\n",
"from deeplightfield.data.spherical_view_syn import FastDataLoader\n", "from deep_view_syn.data.spherical_view_syn import FastDataLoader\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.28'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...@@ -149,7 +149,7 @@ ...@@ -149,7 +149,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n", "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.26'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...@@ -241,7 +241,7 @@ ...@@ -241,7 +241,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n", "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...@@ -304,7 +304,7 @@ ...@@ -304,7 +304,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n", "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...@@ -381,9 +381,9 @@ ...@@ -381,9 +381,9 @@
"source": [ "source": [
"import ipywidgets as widgets # 控件库\n", "import ipywidgets as widgets # 控件库\n",
"from IPython.display import display # 显示控件的方法\n", "from IPython.display import display # 显示控件的方法\n",
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n", "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from deeplightfield.spher_net import SpherNet\n", "from deep_view_syn.spher_net import SpherNet\n",
"from deeplightfield.my import netio\n", "from deep_view_syn.my import netio\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
"DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......
This diff is collapsed.
from __future__ import print_function
import sys
import torch
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn as nn
from math import log10
from my.progress_bar import progress_bar
from my import color_mode
class Net(torch.nn.Module):
def __init__(self, color, base_filter):
super(Net, self).__init__()
self.color = color
if color == color_mode.GRAY:
self.layers = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
else:
self.net_1 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_2 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_3 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
def forward(self, x):
if self.color == color_mode.GRAY:
out = self.layers(x)
else:
out = torch.cat([
self.net_1(x[:, 0:1]),
self.net_2(x[:, 1:2]),
self.net_3(x[:, 2:3])
], dim=1)
return out
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
class Solver(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(Solver, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
self.color = config.color
def build_model(self):
self.model = Net(color=self.color, base_filter=64).to(self.device)
self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters, channels = None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data =data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
return iters
\ No newline at end of file
import sys import sys
sys.path.append('/e/dengnc') sys.path.append('/e/dengnc')
__package__ = "deeplightfield" __package__ = "deep_view_syn"
import os import os
import torch import torch
......
from __future__ import print_function
import argparse
import os
import sys
import torch
import torch.nn.functional as nn_f
from tensorboardX.writer import SummaryWriter
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deep_view_syn"
# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
help='training batch size')
parser.add_argument('--testBatchSize', type=int,
default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
parser.add_argument('--color', type=str, default='rgb',
help='color')
# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')
args = parser.parse_args()
# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import util
from .my import netio
from .my import device
from .my import color_mode
from .refine_net import *
from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader
os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
args.color = color_mode.from_str(args.color)
def train():
util.CreateDirIfNeed(run_dir)
train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.batchSize,
shuffle=True,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
iters = 0
for epoch in range(1, args.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
iters = trainer.train(epoch, iters,
channels=slice(2, 3) if args.color == color_mode.YCbCr
else None)
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
def test():
util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.testBatchSize,
shuffle=False,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
netio.LoadNet(args.test, trainer.model)
for idx, input, _ in training_data_loader:
if args.color == color_mode.YCbCr:
output_y = trainer.model(input[:, -1:])
output_cbcr = input[:, 0:2]
output = util.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
else:
output = trainer.model(input)
util.WriteImageTensor(output, args.testOutPatt % idx)
def main():
if (args.test):
test()
else:
train()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment