Commit 3554ba52 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f7038e26
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
__package__ = "deep_view_syn"
import argparse
from PIL import Image
......
......@@ -4,6 +4,8 @@ import torch.nn as nn
from .my import net_modules
from .my import util
from .my import device
from .my import color_mode
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())
......@@ -187,7 +189,7 @@ class Sampler(nn.Module):
class MslNet(nn.Module):
def __init__(self, fc_params, sampler_params,
gray=False,
color: int = color_mode.RGB,
encode_to_dim: int = 0,
export_mode: bool = False):
"""
......@@ -203,11 +205,24 @@ class MslNet(nn.Module):
self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4
fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering()
self.export_mode = export_mode
if color == color_mode.YCbCr:
self.net1 = net_modules.FcNet(
in_chns=fc_params['in_chns'],
out_chns=fc_params['nf'] + 2,
nf=fc_params['nf'],
n_layers=fc_params['n_layers'] - 2)
self.net2 = net_modules.FcNet(
in_chns=fc_params['nf'],
out_chns=2,
nf=fc_params['nf'],
n_layers=1)
self.net = None
else:
self.net = net_modules.FcNet(**fc_params)
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor:
......@@ -221,13 +236,23 @@ class MslNet(nn.Module):
coords, depths = self.sampler(rays_o, rays_d)
encoded = self.input_encoder(coords)
if not self.net:
mid_output = self.net1(encoded)
net2_output = self.net2(mid_output[..., :-2])
raw = torch.cat([
mid_output[..., -2:],
net2_output
], -1)
else:
raw = self.net(encoded)
if self.export_mode:
colors, alphas = self.rendering.raw2color(self.net(encoded), depths)
colors, alphas = self.rendering.raw2color(raw, depths)
return torch.cat([colors, alphas[..., None]], -1)
if ret_depth:
color_map, _, _, _, depth_map = self.rendering(
self.net(encoded), depths, ret_extra=True)
raw, depths, ret_extra=True)
return color_map, depth_map
return self.rendering(self.net(encoded), depths)
return self.rendering(raw, depths)
......@@ -209,7 +209,8 @@ class MslNet(nn.Module):
self.rendering = Rendering()
self.export_mode = export_mode
def forward(self, fc_input: torch.Tensor,
def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor, local_rays: torch.Tensor,
ret_depth: bool = False) -> torch.Tensor:
"""
rays -> colors
......@@ -218,10 +219,13 @@ class MslNet(nn.Module):
:param rays_d ```Tensor(B, 3)```: rays' direction
:return: ```Tensor(B, C)``, inferred images/pixels
"""
depths = torch.ones(4096, 16, device="cuda")
rays_o = local_rays * 0 + view_centers
rays_d = torch.matmul(local_rays.flatten(0, -2), r).view(out_size)
coords, depths = self.sampler(rays_o, rays_d)
encoded = self.input_encoder(coords)
if self.export_mode:
colors, alphas = self.rendering.raw2color(self.net(fc_input), depths)
colors, alphas = self.rendering.raw2color(self.net(encoded), depths)
return torch.cat([colors, alphas[..., None]], -1)
if ret_depth:
......
RGB = 0
GRAY = 1
YCbCr = 2
def to_str(color_mode):
return "gray" if color_mode == GRAY \
else ("ybr" if color_mode == YCbCr
else "rgb")
def from_str(color_str):
return GRAY if color_str == 'gray' \
else (YCbCr if color_str == 'ybr'
else RGB)
import torch
import torch.nn.functional as nn_f
from . import view
def get_warp(rays_o, rays_d, depthmap, tgt_o, tgt_r, tgt_cam):
print(rays_o.size(), rays_d.size(), depthmap.size())
pcloud = rays_o + rays_d * depthmap[..., None]
print(rays_o.size(), rays_d.size(), depthmap.size(), pcloud.size())
pcloud_in_tgt = view.trans_point(
pcloud, tgt_o, tgt_r, inverse=True)
print(pcloud_in_tgt.size())
pixel_positions = tgt_cam.proj(pcloud_in_tgt)
pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5
pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5
pixel_positions -= 1
return pixel_positions
def refine(image, depthmap, rays_o, rays_d, bounds_img, bounds_o,
bounds_r, ref_cam: view.CameraParam, net, is_lr):
if is_lr:
image = nn_f.upsample(
image[None, ...], scale_factor=2, mode='bicubic')[0]
depthmap = nn_f.upsample(
depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]
bounds_rays_o, bounds_rays_d = ref_cam.get_global_rays(
bounds_o, bounds_r, flatten=True)
bounds_inferred = torch.stack([
net(bounds_rays_o[i], bounds_rays_d[i]).view(
ref_cam.res[0], ref_cam.res[1], -1).permute(2, 0, 1)
for i in range(bounds_img.size(0))
], 0)
bounds_diff = (bounds_img - bounds_inferred) / (bounds_inferred + 1e-5)
bounds_warp = get_warp(rays_o, rays_d, depthmap,
bounds_o, bounds_r, ref_cam)
warped_diff = nn_f.grid_sample(bounds_diff, bounds_warp)
print(bounds_warp.size(), warped_diff.size())
avg_diff = torch.mean(warped_diff, 0)
return image * (1 + avg_diff)
......@@ -17,12 +17,16 @@ class Foveation(object):
self._gen_layer_blendmap(i)
for i in range(self.n_layers - 1)
] # blend maps of fovea layers
self.coords = util.MeshGrid(out_res).to(device=device)
def to(self, device):
self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend]
self.eye_fovea_blend = [x.to(device=device)
for x in self.eye_fovea_blend]
self.coords = self.coords.to(device=device)
return self
def synthesis(self, layers: List[torch.Tensor]) -> torch.Tensor:
def synthesis(self, layers: List[torch.Tensor],
normalized_fovea_center: Tuple[float, float]) -> torch.Tensor:
"""
Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers**
......@@ -32,12 +36,17 @@ class Foveation(object):
"""
output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
mode='bilinear', align_corners=False)
c = torch.tensor([
normalized_fovea_center[0] * self.out_res[1],
normalized_fovea_center[1] * self.out_res[0]
], device=self.coords.device)
for i in range(self.n_layers - 2, -1, -1):
output_roi = output[self.get_layer_region_in_final_image(i)]
image = nn_f.interpolate(layers[i], output_roi.size()[-2:],
mode='bilinear', align_corners=False)
blend = self.eye_fovea_blend[i]
output_roi.mul_(1 - blend).add_(image * blend)
if layers[i] == None:
continue
R = self.get_layer_size_in_final_image(i) / 2
grid = ((self.coords - c) / R)[None, ...]
blend = nn_f.grid_sample(self.eye_fovea_blend[i][None, None, ...], grid) # (1, 1, H:out, W:out)
output.mul_(1 - blend).add_(nn_f.grid_sample(layers[i], grid) * blend)
return output
def get_layer_size_in_final_image(self, i: int) -> int:
......@@ -52,7 +61,8 @@ class Foveation(object):
k = length_i / length
return int(math.ceil(self.out_res[0] * k))
def get_layer_region_in_final_image(self, i: int) -> Tuple[slice, slice]:
def get_layer_region_in_final_image(self, i: int,
normalized_fovea_center: Tuple[float, float]) -> Tuple[slice, slice]:
"""
Get region of fovea layer i in final image
......@@ -60,8 +70,10 @@ class Foveation(object):
:return: tuple of slice objects stores the start and end of region in horizontal and vertical
"""
roi_size = self.get_layer_size_in_final_image(i)
roi_offset_y = (self.out_res[0] - roi_size) // 2
roi_offset_x = (self.out_res[1] - roi_size) // 2
roi_center = (int(self.out_res[1] * normalized_fovea_center[0]),
int(self.out_res[0] * normalized_fovea_center[1]))
roi_offset_y = roi_center[1] - roi_size // 2
roi_offset_x = roi_center[0] - roi_size // 2
return (...,
slice(roi_offset_y, roi_offset_y + roi_size),
slice(roi_offset_x, roi_offset_x + roi_size)
......@@ -76,6 +88,7 @@ class Foveation(object):
"""
size = self.get_layer_size_in_final_image(i)
R = size / 2
p = util.MeshGrid((size, size)).to(device=self.device) # (size, size, 2)
p = util.MeshGrid((size, size)).to(
device=self.device) # (size, size, 2)
r = torch.norm(p - R, dim=2) # (size, size, 2)
return util.SmoothStep(R, R * 0.6, r)
......@@ -24,6 +24,8 @@ class SimplePerf(object):
return
self.end_event.record()
torch.cuda.synchronize()
print('%s: %.1fms' % (name, self.start_event.elapsed_time(self.end_event)))
duration = self.start_event.elapsed_time(self.end_event)
print('%s: %.1fms' % (name, duration))
if not end:
self.start_event.record()
return duration
\ No newline at end of file
......@@ -308,3 +308,74 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
out_shape = list(ref.size())
out_shape[-1] = -1
return input.view(out_shape)
def rgb2ycbcr(input: torch.Tensor) -> torch.Tensor:
"""
Convert input tensor from RGB to YCbCr
:param input ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
:return ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
"""
if input.size(-1) == 3:
r = input[..., 0:1]
g = input[..., 1:2]
b = input[..., 2:3]
dim_c = -1
else:
r = input[..., 0:1, :, :]
g = input[..., 1:2, :, :]
b = input[..., 2:3, :, :]
dim_c = -3
y = r * 0.25678824 + g * 0.50412941 + b * 0.09790588 + 0.0625
cb = r * -0.14822353 + g * -0.29099216 + b * 0.43921569 + 0.5
cr = r * 0.43921569 + g * -0.36778824 + b * -0.07142745 + 0.5
return torch.cat([y, cb, cr], dim_c)
def rgb2ycbcr(input: torch.Tensor) -> torch.Tensor:
"""
Convert input tensor from RGB to YCbCr
:param input ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
:return ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
"""
if input.size(-1) == 3:
r = input[..., 0:1]
g = input[..., 1:2]
b = input[..., 2:3]
dim_c = -1
else:
r = input[..., 0:1, :, :]
g = input[..., 1:2, :, :]
b = input[..., 2:3, :, :]
dim_c = -3
y = r * 0.257 + g * 0.504 + b * 0.098 + 0.0625
cb = r * -0.148 + g * -0.291 + b * 0.439 + 0.5
cr = r * 0.439 + g * -0.368 + b * -0.071 + 0.5
return torch.cat([cb, cr, y], dim_c)
def ycbcr2rgb(input: torch.Tensor) -> torch.Tensor:
"""
Convert input tensor from YCbCr to RGB
:param input ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
:return ```Tensor(..., 3) | Tensor(..., 3, H, W)```:
"""
if input.size(-1) == 3:
cb = input[..., 0:1]
cr = input[..., 1:2]
y = input[..., 2:3]
dim_c = -1
else:
cb = input[..., 0:1, :, :]
cr = input[..., 1:2, :, :]
y = input[..., 2:3, :, :]
dim_c = -3
y = y - 0.0625
cb = cb - 0.5
cr = cr - 0.5
r = y * 1.164 + cr * 1.596
g = y * 1.164 + cb * -0.392 + cr * -0.813
b = y * 1.164 + cb * 2.017
return torch.cat([r, g, b], dim_c)
\ No newline at end of file
......@@ -18,6 +18,13 @@ class CameraParam(object):
self.c = self.c.to(device)
return self
def resize(self, res: Tuple[int, int]):
self.f[0] = self.f[0] / self.res[1] * res[1]
self.f[1] = self.f[1] / self.res[0] * res[0]
self.c[0] = self.c[0] / self.res[1] * res[1]
self.c[1] = self.c[1] / self.res[0] * res[0]
self.res = res
def proj(self, p: torch.Tensor) -> torch.Tensor:
"""
Project positions in local space to image plane
......@@ -70,8 +77,8 @@ class CameraParam(object):
:return: [description]
"""
rays = self.get_local_rays(flatten, norm) # (M.., 3)
rays_o, _ = torch.broadcast_tensors(
t[..., None, None, :], rays) # (N.., M.., 3)
rays_o, _ = torch.broadcast_tensors(t[..., None, :], rays) if flatten \
else torch.broadcast_tensors(t[..., None, None, :], rays) # (N.., M.., 3)
rays_d = trans_vector(rays, r)
return rays_o, rays_d
......@@ -87,8 +94,11 @@ class CameraParam(object):
input_is_normalized = bool(input_camera_params.get('normalized'))
camera_params = {}
if 'fov' in input_camera_params:
camera_params['fx'] = camera_params['fy'] = \
(1 if input_is_normalized else view_res[0]) / \
if input_is_normalized:
camera_params['fy'] = 1 / util.Fov2Length(input_camera_params['fov'])
camera_params['fx'] = camera_params['fy'] / view_res[1] * view_res[0]
else:
camera_params['fx'] = camera_params['fy'] = view_res[0] / \
util.Fov2Length(input_camera_params['fov'])
camera_params['fy'] *= -1
else:
......@@ -114,15 +124,17 @@ def trans_point(p: torch.Tensor, t: torch.Tensor, r: torch.Tensor, inverse=False
:param inverse: whether perform inverse transform
:return ```Tensor(M.., N.., 3)```: transformed points
"""
out_size = list(r.size())[0:-2] + list(p.size())[0:-1] + [3]
t_size = list(t.size()[0:-1]) + \
[1 for _ in range(len(p.size()[0:-1]))] + [3]
size_N = list(p.size())[0:-1]
size_M = list(r.size())[0:-2]
out_size = size_M + size_N + [3]
t_size = size_M + [1 for _ in range(len(size_N))] + [3]
t = t.view(t_size)
if not inverse:
r = r.movedim(-1, -2) # Transpose rotation matrices
else:
p = p - t
out = torch.matmul(p.flatten(0, -2), r).view(out_size)
out = torch.matmul(p.view(size_M + [-1, 3]), r)
out = out.view(out_size)
if not inverse:
out = out + t
return out
......
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import torchvision.transforms.functional as trans_f\n",
"\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"\n",
"from ..my import util\n",
"\n",
"path_patts = [\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/gt/view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/fovea_rgb@msl-rgb_e10_fc256x4_d1-50_s16/output/model-epoch_500/train/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/input/out_view_%04d.png',\n",
" '/home/dengnc/deep_view_syn/data/gas_fovea_2020.12.31/upsampling_test/output/view_%04d.png'\n",
"]\n",
"titles = [ 'Ground truth', 'Normal', 'Low Res', 'Upsampling']\n",
"show_range = range(5)\n",
"\n",
"#os.chdir('/home/dengnc/deep_view_syn/data/')\n",
"image_seqs = [\n",
" util.ReadImageTensor([path_patt % i for i in show_range])\n",
" for path_patt in path_patts\n",
"]\n",
"\n",
"for i in show_range:\n",
" plt.figure(facecolor='white', figsize=(12, 4))\n",
" plt.suptitle('View %d' % i)\n",
" for j in range(len(image_seqs)):\n",
" plt.subplot(1, len(image_seqs), j + 1)\n",
" plt.title(titles[j])\n",
" util.PlotImageTensor(image_seqs[j][i])\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"nbconvert_exporter": "python",
"version": "3.7.9-final"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
......@@ -13,9 +13,9 @@
"import torch\n",
"from torch import nn\n",
"import matplotlib.pyplot as plt\n",
"from deeplightfield.data.lf_syn import LightFieldSynDataset\n",
"from deeplightfield.my import util\n",
"from deeplightfield.trans_unet import LatentSpaceTransformer\n",
"from deep_view_syn.data.lf_syn import LightFieldSynDataset\n",
"from deep_view_syn.my import util\n",
"from deep_view_syn.trans_unet import LatentSpaceTransformer\n",
"\n",
"device = torch.device(\"cuda:2\")\n"
]
......
This diff is collapsed.
......@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
......@@ -10,16 +10,14 @@
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n",
"Change working directory to /e/dengnc/deeplightfield/data/sp_view_syn_2020.12.31_fovea\n"
"Change working directory to /home/dengnc/deep_view_syn/data/sp_view_syn_2020.12.31_fovea\n"
]
},
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x7fea6b9c2d50>"
]
"text/plain": "<torch.autograd.grad_mode.set_grad_enabled at 0x7f6824144910>"
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
......@@ -34,16 +32,17 @@
"\n",
"\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"torch.cuda.set_device(2)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n",
"from deeplightfield.data.spherical_view_syn import *\n",
"from deeplightfield.msl_net import MslNet\n",
"from deeplightfield.configs.spherical_view_syn import SphericalViewSynConfig\n",
"from deeplightfield.my import netio\n",
"from deeplightfield.my import util\n",
"from deeplightfield.my import device\n",
"from deeplightfield.my import view\n",
"from ..data.spherical_view_syn import *\n",
"from ..msl_net import MslNet\n",
"from ..configs.spherical_view_syn import SphericalViewSynConfig\n",
"from ..my import netio\n",
"from ..my import util\n",
"from ..my import device\n",
"from ..my import view\n",
"\n",
"\n",
"os.chdir(sys.path[0] + '/../data/sp_view_syn_2020.12.31_fovea')\n",
......
......@@ -14,8 +14,8 @@
"import math\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from deeplightfield.my import util\n",
"from deeplightfield.msl_net import *\n",
"from deep_view_syn.my import util\n",
"from deep_view_syn.msl_net import *\n",
"\n",
"# Select device\n",
"torch.cuda.set_device(2)\n",
......@@ -121,8 +121,8 @@
"metadata": {},
"outputs": [],
"source": [
"from deeplightfield.data.spherical_view_syn import FastSphericalViewSynDataset\n",
"from deeplightfield.data.spherical_view_syn import FastDataLoader\n",
"from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset\n",
"from deep_view_syn.data.spherical_view_syn import FastDataLoader\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......@@ -149,7 +149,7 @@
"metadata": {},
"outputs": [],
"source": [
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n",
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......@@ -241,7 +241,7 @@
"metadata": {},
"outputs": [],
"source": [
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n",
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......@@ -304,7 +304,7 @@
"metadata": {},
"outputs": [],
"source": [
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n",
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......@@ -381,9 +381,9 @@
"source": [
"import ipywidgets as widgets # 控件库\n",
"from IPython.display import display # 显示控件的方法\n",
"from deeplightfield.data.spherical_view_syn import SphericalViewSynDataset\n",
"from deeplightfield.spher_net import SpherNet\n",
"from deeplightfield.my import netio\n",
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from deep_view_syn.spher_net import SpherNet\n",
"from deep_view_syn.my import netio\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
"DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......
This diff is collapsed.
from __future__ import print_function
import sys
import torch
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn as nn
from math import log10
from my.progress_bar import progress_bar
from my import color_mode
class Net(torch.nn.Module):
def __init__(self, color, base_filter):
super(Net, self).__init__()
self.color = color
if color == color_mode.GRAY:
self.layers = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
else:
self.net_1 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_2 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_3 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
def forward(self, x):
if self.color == color_mode.GRAY:
out = self.layers(x)
else:
out = torch.cat([
self.net_1(x[:, 0:1]),
self.net_2(x[:, 1:2]),
self.net_3(x[:, 2:3])
], dim=1)
return out
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
class Solver(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(Solver, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
self.color = config.color
def build_model(self):
self.model = Net(color=self.color, base_filter=64).to(self.device)
self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters, channels = None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data =data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
return iters
\ No newline at end of file
import sys
sys.path.append('/e/dengnc')
__package__ = "deeplightfield"
__package__ = "deep_view_syn"
import os
import torch
......
from __future__ import print_function
import argparse
import os
import sys
import torch
import torch.nn.functional as nn_f
from tensorboardX.writer import SummaryWriter
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deep_view_syn"
# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
help='training batch size')
parser.add_argument('--testBatchSize', type=int,
default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
parser.add_argument('--color', type=str, default='rgb',
help='color')
# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')
args = parser.parse_args()
# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import util
from .my import netio
from .my import device
from .my import color_mode
from .refine_net import *
from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader
os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
args.color = color_mode.from_str(args.color)
def train():
util.CreateDirIfNeed(run_dir)
train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.batchSize,
shuffle=True,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
iters = 0
for epoch in range(1, args.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
iters = trainer.train(epoch, iters,
channels=slice(2, 3) if args.color == color_mode.YCbCr
else None)
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
def test():
util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.testBatchSize,
shuffle=False,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
netio.LoadNet(args.test, trainer.model)
for idx, input, _ in training_data_loader:
if args.color == color_mode.YCbCr:
output_y = trainer.model(input[:, -1:])
output_cbcr = input[:, 0:2]
output = util.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
else:
output = trainer.model(input)
util.WriteImageTensor(output, args.testOutPatt % idx)
def main():
if (args.test):
test()
else:
train()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment