Commit 408738c8 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 3554ba52
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"torch.cuda.set_device(2)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n",
"from ..data.spherical_view_syn import *\n",
"from ..configs.spherical_view_syn import SphericalViewSynConfig\n",
"from ..my import netio\n",
"from ..my import util\n",
"from ..my import device\n",
"from ..my import view\n",
"from ..my.gen_final import GenFinal\n",
"\n",
"\n",
"def load_net(path):\n",
" config = SphericalViewSynConfig()\n",
" config.from_id(path[:-4])\n",
" config.SAMPLE_PARAMS['perturb_sample'] = False\n",
" config.print()\n",
" net = config.create_net().to(device.GetDevice())\n",
" netio.LoadNet(path, net)\n",
" return net\n",
"\n",
"\n",
"def find_file(prefix):\n",
" for path in os.listdir():\n",
" if path.startswith(prefix):\n",
" return path\n",
" return None\n",
"\n",
"\n",
"def load_views(data_desc_file) -> view.Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n",
" view_centers = torch.tensor(\n",
" data_desc['view_centers'], device=device.GetDevice()).view(-1, 3)\n",
" view_rots = torch.tensor(\n",
" data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)\n",
" return view.Trans(view_centers, view_rots)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Change working directory to /home/dengnc/deep_view_syn/data/bedroom_all_in_one\n",
"==== Config fovea ====\n",
"Net type: nmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 256, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 32, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from fovea@nmsl-rgb_e10_fc256x4_d1.00-50.00_s32.pth ...\n",
"==== Config periph ====\n",
"Net type: nnmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 64, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from periph@nnmsl-rgb_e10_fc64x4_d1.00-50.00_s16.pth ...\n",
"Dataset loaded.\n",
"views: [13]\n"
]
}
],
"source": [
"#os.chdir(sys.path[0] + '/../data/__0_user_study/us_gas_all_in_one')\n",
"os.chdir(sys.path[0] + '/../data/bedroom_all_in_one')\n",
"print('Change working directory to ', os.getcwd())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n",
"\n",
"# Load Dataset\n",
"views = load_views('nerf_views.json')\n",
"print('Dataset loaded.')\n",
"\n",
"print('views:', views.size())\n",
"#print('ref views:', ref_dataset.samples)\n",
"\n",
"fov_list = [20, 45, 110]\n",
"res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)]\n",
"res_full = (1600, 1440)\n",
"gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n",
" device=device.GetDevice())\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"for view_idx in range(8):\n",
" center = (0, 0)\n",
" test_view = views.get(view_idx)\n",
" images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n",
"\n",
" util.CreateDirIfNeed('output/eval')\n",
" util.WriteImageTensor(images['blended'], 'output/eval/view%04d.png' % view_idx)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import torchvision.transforms.functional as trans_f\n",
"\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"torch.cuda.set_device(2)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n",
"from ..data.spherical_view_syn import *\n",
"from ..configs.spherical_view_syn import SphericalViewSynConfig\n",
"from ..my import netio\n",
"from ..my import util\n",
"from ..my import device\n",
"from ..my import view\n",
"from ..my.gen_final import GenFinal\n",
"\n",
"\n",
"def load_net(path):\n",
" config = SphericalViewSynConfig()\n",
" config.from_id(path[:-4])\n",
" config.SAMPLE_PARAMS['perturb_sample'] = False\n",
" config.print()\n",
" net = config.create_net().to(device.GetDevice())\n",
" netio.LoadNet(path, net)\n",
" return net\n",
"\n",
"\n",
"def find_file(prefix):\n",
" for path in os.listdir():\n",
" if path.startswith(prefix):\n",
" return path\n",
" return None\n",
"\n",
"\n",
"def load_views(data_desc_file) -> view.Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n",
" view_centers = torch.tensor(\n",
" data_desc['view_centers'], device=device.GetDevice()).view(-1, 3)\n",
" view_rots = torch.tensor(\n",
" data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)\n",
" return view.Trans(view_centers, view_rots)\n",
"\n",
"\n",
"def plot_figures(images, center):\n",
" plt.figure(figsize=(8, 4))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['fovea_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['fovea'])\n",
"\n",
" plt.figure(figsize=(8, 4))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['mid_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['mid'])\n",
"\n",
" plt.figure(figsize=(8, 4))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['periph_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['periph'])\n",
"\n",
" # Plot Blended\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['blended_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['blended'])\n",
" plt.plot([(res_full[1] - 1) / 2 + center[0] - 5, (res_full[1] - 1) / 2 + center[0] + 5],\n",
" [(res_full[0] - 1) / 2 + center[1],\n",
" (res_full[0] - 1) / 2 + center[1]],\n",
" color=[0, 1, 0])\n",
" plt.plot([(res_full[1] - 1) / 2 + center[0], (res_full[1] - 1) / 2 + center[0]],\n",
" [(res_full[0] - 1) / 2 + center[1] - 5,\n",
" (res_full[0] - 1) / 2 + center[1] + 5],\n",
" color=[0, 1, 0])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Change working directory to /home/dengnc/deep_view_syn/data/__0_user_study/us_gas_all_in_one\n",
"==== Config fovea ====\n",
"Net type: nmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 128, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 32, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from fovea@nmsl-rgb_e10_fc128x4_d1-50_s32.pth ...\n",
"==== Config periph ====\n",
"Net type: nnmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 64, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from periph@nnmsl-rgb_e10_fc64x4_d1-50_s16.pth ...\n",
"Dataset loaded.\n",
"views: [13]\n"
]
}
],
"source": [
"os.chdir(sys.path[0] + '/../data/__0_user_study/us_gas_all_in_one')\n",
"#os.chdir(sys.path[0] + '/../data/__0_user_study/us_mc_all_in_one')\n",
"print('Change working directory to ', os.getcwd())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n",
"\n",
"# Load Dataset\n",
"views = load_views('for_teaser.json')\n",
"print('Dataset loaded.')\n",
"\n",
"print('views:', views.size())\n",
"#print('ref views:', ref_dataset.samples)\n",
"\n",
"fov_list = [20, 45, 110]\n",
"res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)]\n",
"res_full = (1600, 1440)\n",
"gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n",
" device=device.GetDevice())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"center_x_list = [-100, -100, 10, 10, 100]\n",
"for view_idx in range(1):\n",
" center = (center_x_list[view_idx], -center_x_list[view_idx] * 1600 / 1440)\n",
" test_view = views.get(view_idx)\n",
" images = gen.gen(center, test_view, True)\n",
" plot_figures(images, center)\n",
"\n",
" #util.CreateDirIfNeed('output/teasers')\n",
" #for key in images:\n",
" # util.WriteImageTensor(\n",
" # images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from __future__ import print_function
import sys
import torch
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn as nn
from math import log10
from my.progress_bar import progress_bar
from my import color_mode
class Net(torch.nn.Module):
def __init__(self, color, base_filter):
super(Net, self).__init__()
self.color = color
if color == color_mode.GRAY:
self.layers = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
else:
self.net_1 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_2 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_3 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
def forward(self, x):
if self.color == color_mode.GRAY:
out = self.layers(x)
else:
out = torch.cat([
self.net_1(x[:, 0:1]),
self.net_2(x[:, 1:2]),
self.net_3(x[:, 2:3])
], dim=1)
return out
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
class Solver(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(Solver, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
self.color = config.color
def build_model(self):
self.model = Net(color=self.color, base_filter=64).to(self.device)
self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters, channels = None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data =data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
return iters
\ No newline at end of file
from __future__ import print_function
import argparse
import os
import sys
import torch
import torch.nn.functional as nn_f
from tensorboardX.writer import SummaryWriter
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deep_view_syn"
# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
help='training batch size')
parser.add_argument('--testBatchSize', type=int,
default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
parser.add_argument('--color', type=str, default='rgb',
help='color')
# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')
args = parser.parse_args()
# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import util
from .my import netio
from .my import device
from .my import color_mode
from .refine_net import *
from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader
os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
args.color = color_mode.from_str(args.color)
def train():
util.CreateDirIfNeed(run_dir)
train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.batchSize,
shuffle=True,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
iters = 0
for epoch in range(1, args.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
iters = trainer.train(epoch, iters,
channels=slice(2, 3) if args.color == color_mode.YCbCr
else None)
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
def test():
util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.testBatchSize,
shuffle=False,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
netio.LoadNet(args.test, trainer.model)
for idx, input, _ in training_data_loader:
if args.color == color_mode.YCbCr:
output_y = trainer.model(input[:, -1:])
output_cbcr = input[:, 0:2]
output = util.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
else:
output = trainer.model(input)
util.WriteImageTensor(output, args.testOutPatt % idx)
def main():
if (args.test):
test()
else:
train()
if __name__ == '__main__':
main()
import math
import sys import sys
import os import os
import argparse import argparse
import torch import torch
import torch.optim import torch.optim
import torchvision
import numpy as np
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
...@@ -21,6 +18,8 @@ parser.add_argument('--config-id', type=str, ...@@ -21,6 +18,8 @@ parser.add_argument('--config-id', type=str,
help='Net config id') help='Net config id')
parser.add_argument('--dataset', type=str, required=True, parser.add_argument('--dataset', type=str, required=True,
help='Dataset description file') help='Dataset description file')
parser.add_argument('--cont', type=str,
help='Continue train on model file')
parser.add_argument('--epochs', type=int, parser.add_argument('--epochs', type=int,
help='Max epochs for train') help='Max epochs for train')
parser.add_argument('--test', type=str, parser.add_argument('--test', type=str,
...@@ -37,6 +36,8 @@ parser.add_argument('--output-video', action='store_true', ...@@ -37,6 +36,8 @@ parser.add_argument('--output-video', action='store_true',
help='Output test results as video') help='Output test results as video')
parser.add_argument('--perf', action='store_true', parser.add_argument('--perf', action='store_true',
help='Test performance') help='Test performance')
parser.add_argument('--simple-log', action='store_true', help='Simple log')
opt = parser.parse_args() opt = parser.parse_args()
if opt.res: if opt.res:
opt.res = tuple(int(s) for s in opt.res.split('x')) opt.res = tuple(int(s) for s in opt.res.split('x'))
...@@ -49,11 +50,10 @@ from .my import netio ...@@ -49,11 +50,10 @@ from .my import netio
from .my import util from .my import util
from .my import device from .my import device
from .my import loss from .my import loss
from .my.progress_bar import progress_bar
from .my.simple_perf import SimplePerf from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import * from .data.spherical_view_syn import *
from .data.loader import FastDataLoader from .data.loader import FastDataLoader
from .msl_net import MslNet
from .spher_net import SpherNet
from .configs.spherical_view_syn import SphericalViewSynConfig from .configs.spherical_view_syn import SphericalViewSynConfig
...@@ -68,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False ...@@ -68,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False
# Train # Train
BATCH_SIZE = 4096 BATCH_SIZE = 4096
EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 500) EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 300)
SAVE_INTERVAL = 50 SAVE_INTERVAL = 10
# Test # Test
TEST_BATCH_SIZE = 1 TEST_BATCH_SIZE = 1
...@@ -92,11 +92,18 @@ if opt.test: ...@@ -92,11 +92,18 @@ if opt.test:
output_dir = run_dir + 'output/%s/%s_s%d/' % \ output_dir = run_dir + 'output/%s/%s_s%d/' % \
(test_net_name, data_desc_name, opt.test_samples) (test_net_name, data_desc_name, opt.test_samples)
else: else:
data_dir = os.path.dirname(data_desc_path) + '/'
if opt.cont:
train_net_name = os.path.splitext(os.path.basename(opt.cont))[0]
EPOCH_RANGE = range(int(train_net_name[12:]), EPOCH_RANGE.stop)
run_dir = os.path.dirname(opt.cont) + '/'
run_id = os.path.basename(run_dir[:-1])
config.from_id(run_id)
else:
if opt.config: if opt.config:
config.load(opt.config) config.load(opt.config)
if opt.config_id: if opt.config_id:
config.from_id(opt.config_id) config.from_id(opt.config_id)
data_dir = os.path.dirname(data_desc_path) + '/'
run_id = config.to_id() run_id = config.to_id()
run_dir = data_dir + run_id + '/' run_dir = data_dir + run_id + '/'
log_dir = run_dir + 'log/' log_dir = run_dir + 'log/'
...@@ -113,26 +120,6 @@ if not train_mode: ...@@ -113,26 +120,6 @@ if not train_mode:
config.SAMPLE_PARAMS['perturb_sample'] = \ config.SAMPLE_PARAMS['perturb_sample'] = \
config.SAMPLE_PARAMS['perturb_sample'] and train_mode config.SAMPLE_PARAMS['perturb_sample'] and train_mode
NETS = {
'msl': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'nerf': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': False}), config.SAMPLE_PARAMS)[1],
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'spher': lambda: SpherNet(
fc_params=config.FC_PARAMS,
color=config.COLOR,
translation=not ROT_ONLY,
encode_to_dim=config.N_ENCODE_DIM)
}
LOSSES = { LOSSES = {
'mse': lambda: nn.MSELoss(), 'mse': lambda: nn.MSELoss(),
'mse_grad': lambda: loss.CombinedLoss( 'mse_grad': lambda: loss.CombinedLoss(
...@@ -140,7 +127,7 @@ LOSSES = { ...@@ -140,7 +127,7 @@ LOSSES = {
} }
# Initialize model # Initialize model
model = NETS[config.NET_TYPE]().to(device.GetDevice()) model = config.create_net().to(device.GetDevice())
loss_mse = nn.MSELoss().to(device.GetDevice()) loss_mse = nn.MSELoss().to(device.GetDevice())
loss_grad = loss.GradLoss().to(device.GetDevice()) loss_grad = loss.GradLoss().to(device.GetDevice())
...@@ -148,6 +135,10 @@ loss_grad = loss.GradLoss().to(device.GetDevice()) ...@@ -148,6 +135,10 @@ loss_grad = loss.GradLoss().to(device.GetDevice())
def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters): def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
sub_iters = 0 sub_iters = 0
iters_in_epoch = len(data_loader) iters_in_epoch = len(data_loader)
loss_min = 1e5
loss_max = 0
loss_avg = 0
perf = SimplePerf(opt.simple_log)
for _, gt, rays_o, rays_d in data_loader: for _, gt, rays_o, rays_d in data_loader:
patch = (len(gt.size()) == 4) patch = (len(gt.size()) == 4)
gt = gt.to(device.GetDevice()) gt = gt.to(device.GetDevice())
...@@ -175,25 +166,26 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters): ...@@ -175,25 +166,26 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
optimizer.step() optimizer.step()
perf.Checkpoint("Update") perf.Checkpoint("Update")
if patch: loss_value = loss_value.item()
print("Epoch: %d, Iter: %d(%d/%d), Loss MSE: %f, Loss Grad: %f" % loss_min = min(loss_min, loss_value)
(epoch, iters, sub_iters, iters_in_epoch, loss_max = max(loss_max, loss_value)
loss_mse_value.item(), loss_grad_value.item())) loss_avg = (loss_avg * sub_iters + loss_value) / (sub_iters + 1)
else: if not opt.simple_log:
print("Epoch: %d, Iter: %d(%d/%d), Loss MSE: %f" % progress_bar(sub_iters, iters_in_epoch,
(epoch, iters, sub_iters, iters_in_epoch, loss_mse_value.item())) "Loss: %.2e (%.2e/%.2e/%.2e)" % (loss_value, loss_min, loss_avg, loss_max),
"Epoch {:<3d}".format(epoch))
# Write tensorboard logs. # Write tensorboard logs.
writer.add_scalar("loss mse", loss_mse_value, iters) writer.add_scalar("loss mse", loss_value, iters)
if patch: # if patch and iters % 100 == 0:
writer.add_scalar("loss grad", loss_grad_value, iters) # output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
if patch and iters % 100 == 0: # writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach() # output_vs_gt, nrow=4).cpu().numpy(), iters)
writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt, nrow=4).cpu().numpy(), iters)
iters += 1 iters += 1
sub_iters += 1 sub_iters += 1
if opt.simple_log:
perf.Checkpoint('Epoch %d (%.2e/%.2e/%.2e)' % (epoch, loss_min, loss_avg, loss_max), True)
return iters return iters
...@@ -211,13 +203,18 @@ def train(): ...@@ -211,13 +203,18 @@ def train():
pin_memory=True) pin_memory=True)
# 2. Initialize components # 2. Initialize components
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=config.OPT_DECAY)
loss = 0 # LOSSES[config.LOSS]().to(device.GetDevice()) loss = 0 # LOSSES[config.LOSS]().to(device.GetDevice())
if EPOCH_RANGE.start > 0: if EPOCH_RANGE.start > 0:
iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start), iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
model, solver=optimizer) model, solver=optimizer)
else: else:
if config.NORMALIZE:
for _, _, rays_o, rays_d in train_data_loader:
model.update_normalize_range(rays_o, rays_d)
print('Depth/diopter range: ', model.depth_range)
print('Angle range: ', model.angle_range / 3.14159 * 180)
iters = 0 iters = 0
epoch = None epoch = None
...@@ -228,12 +225,10 @@ def train(): ...@@ -228,12 +225,10 @@ def train():
util.CreateDirIfNeed(log_dir) util.CreateDirIfNeed(log_dir)
perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True) perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
perf_epoch = SimplePerf(True, start=True)
writer = SummaryWriter(log_dir) writer = SummaryWriter(log_dir)
print("Begin training...") print("Begin training...")
for epoch in EPOCH_RANGE: for epoch in EPOCH_RANGE:
perf_epoch.Checkpoint("Epoch")
iters = train_loop(train_data_loader, optimizer, loss, iters = train_loop(train_data_loader, optimizer, loss,
perf, writer, epoch, iters) perf, writer, epoch, iters)
# Save checkpoint # Save checkpoint
......
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
import torch.optim import torch.optim
from torch import onnx from torch import onnx
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
__package__ = "deep_view_syn" __package__ = "deep_view_syn.tools"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0, parser.add_argument('--device', type=int, default=0,
...@@ -23,11 +23,10 @@ opt = parser.parse_args() ...@@ -23,11 +23,10 @@ opt = parser.parse_args()
torch.cuda.set_device(opt.device) torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device()) print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .msl_net import MslNet from ..configs.spherical_view_syn import SphericalViewSynConfig
from .configs.spherical_view_syn import SphericalViewSynConfig from ..my import device
from .my import device from ..my import netio
from .my import netio from ..my import util
from .my import util
dir_path, model_file = os.path.split(opt.model) dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size) batch_size = eval(opt.batch_size)
...@@ -42,8 +41,7 @@ def load_net(path): ...@@ -42,8 +41,7 @@ def load_net(path):
config.SAMPLE_PARAMS['perturb_sample'] = False config.SAMPLE_PARAMS['perturb_sample'] = False
config.SAMPLE_PARAMS['n_samples'] = 4 config.SAMPLE_PARAMS['n_samples'] = 4
config.print() config.print()
net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, config.GRAY, net = config.create_net().to(device.GetDevice())
config.N_ENCODE_DIM, export_mode=True).to(device.GetDevice())
netio.LoadNet(path, net) netio.LoadNet(path, net)
return net, name return net, name
......
import sys import sys
import os import os
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
__package__ = "deep_view_syn" __package__ = "deep_view_syn.tools"
import argparse import argparse
from PIL import Image from PIL import Image
from .my import util from ..my import util
def batch_scale(src, target, size): def batch_scale(src, target, size):
......
from typing import Tuple
import torch
import torch.nn as nn
from .my import net_modules
from .my import util
from .my import device
class UpsamplingNet(nn.Module):
def __init__(self, inner_chns, gray=False,
encode_to_dim: int = 0):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.in_chns = 3
self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4
self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
"""
rays -> colors
:param rays_o ```Tensor(B, ..., 3)```: rays' origin
:param rays_d ```Tensor(B, ..., 3)```: rays' direction
:return: Tensor(B, 1|3, ...), inferred images/pixels
"""
p = rays_o.view(-1, 3)
v = rays_d.view(-1, 3)
coords, depths = self.sampler(p, v)
encoded = self.input_encoder(coords)
color_map = self.rendering(self.net(encoded), depths)
# Unflatten according to input shape
out_shape = list(rays_d.size())
out_shape[-1] = -1
return color_map.view(out_shape).movedim(-1, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment