Commit 408738c8 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 3554ba52
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"torch.cuda.set_device(2)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n",
"from ..data.spherical_view_syn import *\n",
"from ..configs.spherical_view_syn import SphericalViewSynConfig\n",
"from ..my import netio\n",
"from ..my import util\n",
"from ..my import device\n",
"from ..my import view\n",
"from ..my.gen_final import GenFinal\n",
"\n",
"\n",
"def load_net(path):\n",
" config = SphericalViewSynConfig()\n",
" config.from_id(path[:-4])\n",
" config.SAMPLE_PARAMS['perturb_sample'] = False\n",
" config.print()\n",
" net = config.create_net().to(device.GetDevice())\n",
" netio.LoadNet(path, net)\n",
" return net\n",
"\n",
"\n",
"def find_file(prefix):\n",
" for path in os.listdir():\n",
" if path.startswith(prefix):\n",
" return path\n",
" return None\n",
"\n",
"\n",
"def load_views(data_desc_file) -> view.Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n",
" view_centers = torch.tensor(\n",
" data_desc['view_centers'], device=device.GetDevice()).view(-1, 3)\n",
" view_rots = torch.tensor(\n",
" data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)\n",
" return view.Trans(view_centers, view_rots)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Change working directory to /home/dengnc/deep_view_syn/data/bedroom_all_in_one\n",
"==== Config fovea ====\n",
"Net type: nmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 256, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 32, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from fovea@nmsl-rgb_e10_fc256x4_d1.00-50.00_s32.pth ...\n",
"==== Config periph ====\n",
"Net type: nnmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 64, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from periph@nnmsl-rgb_e10_fc64x4_d1.00-50.00_s16.pth ...\n",
"Dataset loaded.\n",
"views: [13]\n"
]
}
],
"source": [
"#os.chdir(sys.path[0] + '/../data/__0_user_study/us_gas_all_in_one')\n",
"os.chdir(sys.path[0] + '/../data/bedroom_all_in_one')\n",
"print('Change working directory to ', os.getcwd())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n",
"\n",
"# Load Dataset\n",
"views = load_views('nerf_views.json')\n",
"print('Dataset loaded.')\n",
"\n",
"print('views:', views.size())\n",
"#print('ref views:', ref_dataset.samples)\n",
"\n",
"fov_list = [20, 45, 110]\n",
"res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)]\n",
"res_full = (1600, 1440)\n",
"gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n",
" device=device.GetDevice())\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"for view_idx in range(8):\n",
" center = (0, 0)\n",
" test_view = views.get(view_idx)\n",
" images = gen.gen(center, test_view, True)\n",
" #plot_figures(images, center)\n",
"\n",
" util.CreateDirIfNeed('output/eval')\n",
" util.WriteImageTensor(images['blended'], 'output/eval/view%04d.png' % view_idx)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import torchvision.transforms.functional as trans_f\n",
"\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"__package__ = \"deep_view_syn.notebook\"\n",
"torch.cuda.set_device(2)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n",
"from ..data.spherical_view_syn import *\n",
"from ..configs.spherical_view_syn import SphericalViewSynConfig\n",
"from ..my import netio\n",
"from ..my import util\n",
"from ..my import device\n",
"from ..my import view\n",
"from ..my.gen_final import GenFinal\n",
"\n",
"\n",
"def load_net(path):\n",
" config = SphericalViewSynConfig()\n",
" config.from_id(path[:-4])\n",
" config.SAMPLE_PARAMS['perturb_sample'] = False\n",
" config.print()\n",
" net = config.create_net().to(device.GetDevice())\n",
" netio.LoadNet(path, net)\n",
" return net\n",
"\n",
"\n",
"def find_file(prefix):\n",
" for path in os.listdir():\n",
" if path.startswith(prefix):\n",
" return path\n",
" return None\n",
"\n",
"\n",
"def load_views(data_desc_file) -> view.Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n",
" view_centers = torch.tensor(\n",
" data_desc['view_centers'], device=device.GetDevice()).view(-1, 3)\n",
" view_rots = torch.tensor(\n",
" data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)\n",
" return view.Trans(view_centers, view_rots)\n",
"\n",
"\n",
"def plot_figures(images, center):\n",
" plt.figure(figsize=(8, 4))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['fovea_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['fovea'])\n",
"\n",
" plt.figure(figsize=(8, 4))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['mid_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['mid'])\n",
"\n",
" plt.figure(figsize=(8, 4))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['periph_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['periph'])\n",
"\n",
" # Plot Blended\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" util.PlotImageTensor(images['blended_raw'])\n",
" plt.subplot(122)\n",
" util.PlotImageTensor(images['blended'])\n",
" plt.plot([(res_full[1] - 1) / 2 + center[0] - 5, (res_full[1] - 1) / 2 + center[0] + 5],\n",
" [(res_full[0] - 1) / 2 + center[1],\n",
" (res_full[0] - 1) / 2 + center[1]],\n",
" color=[0, 1, 0])\n",
" plt.plot([(res_full[1] - 1) / 2 + center[0], (res_full[1] - 1) / 2 + center[0]],\n",
" [(res_full[0] - 1) / 2 + center[1] - 5,\n",
" (res_full[0] - 1) / 2 + center[1] + 5],\n",
" color=[0, 1, 0])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Change working directory to /home/dengnc/deep_view_syn/data/__0_user_study/us_gas_all_in_one\n",
"==== Config fovea ====\n",
"Net type: nmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 128, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 32, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from fovea@nmsl-rgb_e10_fc128x4_d1-50_s32.pth ...\n",
"==== Config periph ====\n",
"Net type: nnmsl\n",
"Encode dim: 10\n",
"Optimizer decay: 0\n",
"Normalize: False\n",
"Direction as input: False\n",
"Full-connected network parameters: {'nf': 64, 'n_layers': 4, 'skips': []}\n",
"Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n",
"==========================\n",
"Load net from periph@nnmsl-rgb_e10_fc64x4_d1-50_s16.pth ...\n",
"Dataset loaded.\n",
"views: [13]\n"
]
}
],
"source": [
"os.chdir(sys.path[0] + '/../data/__0_user_study/us_gas_all_in_one')\n",
"#os.chdir(sys.path[0] + '/../data/__0_user_study/us_mc_all_in_one')\n",
"print('Change working directory to ', os.getcwd())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n",
"\n",
"# Load Dataset\n",
"views = load_views('for_teaser.json')\n",
"print('Dataset loaded.')\n",
"\n",
"print('views:', views.size())\n",
"#print('ref views:', ref_dataset.samples)\n",
"\n",
"fov_list = [20, 45, 110]\n",
"res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)]\n",
"res_full = (1600, 1440)\n",
"gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n",
" device=device.GetDevice())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"center_x_list = [-100, -100, 10, 10, 100]\n",
"for view_idx in range(1):\n",
" center = (center_x_list[view_idx], -center_x_list[view_idx] * 1600 / 1440)\n",
" test_view = views.get(view_idx)\n",
" images = gen.gen(center, test_view, True)\n",
" plot_figures(images, center)\n",
"\n",
" #util.CreateDirIfNeed('output/teasers')\n",
" #for key in images:\n",
" # util.WriteImageTensor(\n",
" # images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from __future__ import print_function
import sys
import torch
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn as nn
from math import log10
from my.progress_bar import progress_bar
from my import color_mode
class Net(torch.nn.Module):
def __init__(self, color, base_filter):
super(Net, self).__init__()
self.color = color
if color == color_mode.GRAY:
self.layers = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
else:
self.net_1 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_2 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
self.net_3 = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=1, kernel_size=5, stride=1, padding=2, bias=True),
#nn.PixelShuffle(upscale_factor)
)
def forward(self, x):
if self.color == color_mode.GRAY:
out = self.layers(x)
else:
out = torch.cat([
self.net_1(x[:, 0:1]),
self.net_2(x[:, 1:2]),
self.net_3(x[:, 2:3])
], dim=1)
return out
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
class Solver(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(Solver, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
self.color = config.color
def build_model(self):
self.model = Net(color=self.color, base_filter=64).to(self.device)
self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters, channels = None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data =data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
return iters
\ No newline at end of file
from __future__ import print_function
import argparse
import os
import sys
import torch
import torch.nn.functional as nn_f
from tensorboardX.writer import SummaryWriter
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deep_view_syn"
# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
help='training batch size')
parser.add_argument('--testBatchSize', type=int,
default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
parser.add_argument('--color', type=str, default='rgb',
help='color')
# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')
args = parser.parse_args()
# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import util
from .my import netio
from .my import device
from .my import color_mode
from .refine_net import *
from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader
os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
args.color = color_mode.from_str(args.color)
def train():
util.CreateDirIfNeed(run_dir)
train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.batchSize,
shuffle=True,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
iters = 0
for epoch in range(1, args.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
iters = trainer.train(epoch, iters,
channels=slice(2, 3) if args.color == color_mode.YCbCr
else None)
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
def test():
util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.testBatchSize,
shuffle=False,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
netio.LoadNet(args.test, trainer.model)
for idx, input, _ in training_data_loader:
if args.color == color_mode.YCbCr:
output_y = trainer.model(input[:, -1:])
output_cbcr = input[:, 0:2]
output = util.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
else:
output = trainer.model(input)
util.WriteImageTensor(output, args.testOutPatt % idx)
def main():
if (args.test):
test()
else:
train()
if __name__ == '__main__':
main()
import math
import sys
import os
import argparse
import torch
import torch.optim
import torchvision
import numpy as np
from tensorboardX import SummaryWriter
from torch import nn
......@@ -21,6 +18,8 @@ parser.add_argument('--config-id', type=str,
help='Net config id')
parser.add_argument('--dataset', type=str, required=True,
help='Dataset description file')
parser.add_argument('--cont', type=str,
help='Continue train on model file')
parser.add_argument('--epochs', type=int,
help='Max epochs for train')
parser.add_argument('--test', type=str,
......@@ -37,6 +36,8 @@ parser.add_argument('--output-video', action='store_true',
help='Output test results as video')
parser.add_argument('--perf', action='store_true',
help='Test performance')
parser.add_argument('--simple-log', action='store_true', help='Simple log')
opt = parser.parse_args()
if opt.res:
opt.res = tuple(int(s) for s in opt.res.split('x'))
......@@ -49,11 +50,10 @@ from .my import netio
from .my import util
from .my import device
from .my import loss
from .my.progress_bar import progress_bar
from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import *
from .data.loader import FastDataLoader
from .msl_net import MslNet
from .spher_net import SpherNet
from .configs.spherical_view_syn import SphericalViewSynConfig
......@@ -68,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False
# Train
BATCH_SIZE = 4096
EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 500)
SAVE_INTERVAL = 50
EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 300)
SAVE_INTERVAL = 10
# Test
TEST_BATCH_SIZE = 1
......@@ -92,13 +92,20 @@ if opt.test:
output_dir = run_dir + 'output/%s/%s_s%d/' % \
(test_net_name, data_desc_name, opt.test_samples)
else:
if opt.config:
config.load(opt.config)
if opt.config_id:
config.from_id(opt.config_id)
data_dir = os.path.dirname(data_desc_path) + '/'
run_id = config.to_id()
run_dir = data_dir + run_id + '/'
if opt.cont:
train_net_name = os.path.splitext(os.path.basename(opt.cont))[0]
EPOCH_RANGE = range(int(train_net_name[12:]), EPOCH_RANGE.stop)
run_dir = os.path.dirname(opt.cont) + '/'
run_id = os.path.basename(run_dir[:-1])
config.from_id(run_id)
else:
if opt.config:
config.load(opt.config)
if opt.config_id:
config.from_id(opt.config_id)
run_id = config.to_id()
run_dir = data_dir + run_id + '/'
log_dir = run_dir + 'log/'
output_dir = None
train_mode = True
......@@ -113,26 +120,6 @@ if not train_mode:
config.SAMPLE_PARAMS['perturb_sample'] = \
config.SAMPLE_PARAMS['perturb_sample'] and train_mode
NETS = {
'msl': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'nerf': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': False}), config.SAMPLE_PARAMS)[1],
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'spher': lambda: SpherNet(
fc_params=config.FC_PARAMS,
color=config.COLOR,
translation=not ROT_ONLY,
encode_to_dim=config.N_ENCODE_DIM)
}
LOSSES = {
'mse': lambda: nn.MSELoss(),
'mse_grad': lambda: loss.CombinedLoss(
......@@ -140,7 +127,7 @@ LOSSES = {
}
# Initialize model
model = NETS[config.NET_TYPE]().to(device.GetDevice())
model = config.create_net().to(device.GetDevice())
loss_mse = nn.MSELoss().to(device.GetDevice())
loss_grad = loss.GradLoss().to(device.GetDevice())
......@@ -148,6 +135,10 @@ loss_grad = loss.GradLoss().to(device.GetDevice())
def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
sub_iters = 0
iters_in_epoch = len(data_loader)
loss_min = 1e5
loss_max = 0
loss_avg = 0
perf = SimplePerf(opt.simple_log)
for _, gt, rays_o, rays_d in data_loader:
patch = (len(gt.size()) == 4)
gt = gt.to(device.GetDevice())
......@@ -175,25 +166,26 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
optimizer.step()
perf.Checkpoint("Update")
if patch:
print("Epoch: %d, Iter: %d(%d/%d), Loss MSE: %f, Loss Grad: %f" %
(epoch, iters, sub_iters, iters_in_epoch,
loss_mse_value.item(), loss_grad_value.item()))
else:
print("Epoch: %d, Iter: %d(%d/%d), Loss MSE: %f" %
(epoch, iters, sub_iters, iters_in_epoch, loss_mse_value.item()))
loss_value = loss_value.item()
loss_min = min(loss_min, loss_value)
loss_max = max(loss_max, loss_value)
loss_avg = (loss_avg * sub_iters + loss_value) / (sub_iters + 1)
if not opt.simple_log:
progress_bar(sub_iters, iters_in_epoch,
"Loss: %.2e (%.2e/%.2e/%.2e)" % (loss_value, loss_min, loss_avg, loss_max),
"Epoch {:<3d}".format(epoch))
# Write tensorboard logs.
writer.add_scalar("loss mse", loss_mse_value, iters)
if patch:
writer.add_scalar("loss grad", loss_grad_value, iters)
if patch and iters % 100 == 0:
output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt, nrow=4).cpu().numpy(), iters)
writer.add_scalar("loss mse", loss_value, iters)
# if patch and iters % 100 == 0:
# output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
# writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
# output_vs_gt, nrow=4).cpu().numpy(), iters)
iters += 1
sub_iters += 1
if opt.simple_log:
perf.Checkpoint('Epoch %d (%.2e/%.2e/%.2e)' % (epoch, loss_min, loss_avg, loss_max), True)
return iters
......@@ -211,13 +203,18 @@ def train():
pin_memory=True)
# 2. Initialize components
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=config.OPT_DECAY)
loss = 0 # LOSSES[config.LOSS]().to(device.GetDevice())
if EPOCH_RANGE.start > 0:
iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
model, solver=optimizer)
else:
if config.NORMALIZE:
for _, _, rays_o, rays_d in train_data_loader:
model.update_normalize_range(rays_o, rays_d)
print('Depth/diopter range: ', model.depth_range)
print('Angle range: ', model.angle_range / 3.14159 * 180)
iters = 0
epoch = None
......@@ -228,12 +225,10 @@ def train():
util.CreateDirIfNeed(log_dir)
perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
perf_epoch = SimplePerf(True, start=True)
writer = SummaryWriter(log_dir)
print("Begin training...")
for epoch in EPOCH_RANGE:
perf_epoch.Checkpoint("Epoch")
iters = train_loop(train_data_loader, optimizer, loss,
perf, writer, epoch, iters)
# Save checkpoint
......
......@@ -5,8 +5,8 @@ import torch
import torch.optim
from torch import onnx
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deep_view_syn"
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
__package__ = "deep_view_syn.tools"
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
......@@ -23,11 +23,10 @@ opt = parser.parse_args()
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .msl_net import MslNet
from .configs.spherical_view_syn import SphericalViewSynConfig
from .my import device
from .my import netio
from .my import util
from ..configs.spherical_view_syn import SphericalViewSynConfig
from ..my import device
from ..my import netio
from ..my import util
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
......@@ -42,8 +41,7 @@ def load_net(path):
config.SAMPLE_PARAMS['perturb_sample'] = False
config.SAMPLE_PARAMS['n_samples'] = 4
config.print()
net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, config.GRAY,
config.N_ENCODE_DIM, export_mode=True).to(device.GetDevice())
net = config.create_net().to(device.GetDevice())
netio.LoadNet(path, net)
return net, name
......
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deep_view_syn"
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
__package__ = "deep_view_syn.tools"
import argparse
from PIL import Image
from .my import util
from ..my import util
def batch_scale(src, target, size):
......
from typing import Tuple
import torch
import torch.nn as nn
from .my import net_modules
from .my import util
from .my import device
class UpsamplingNet(nn.Module):
def __init__(self, inner_chns, gray=False,
encode_to_dim: int = 0):
"""
Initialize a multi-sphere-layer net
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.in_chns = 3
self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4
self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering()
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
"""
rays -> colors
:param rays_o ```Tensor(B, ..., 3)```: rays' origin
:param rays_d ```Tensor(B, ..., 3)```: rays' direction
:return: Tensor(B, 1|3, ...), inferred images/pixels
"""
p = rays_o.view(-1, 3)
v = rays_d.view(-1, 3)
coords, depths = self.sampler(p, v)
encoded = self.input_encoder(coords)
color_map = self.rendering(self.net(encoded), depths)
# Unflatten according to input shape
out_shape = list(rays_d.size())
out_shape[-1] = -1
return color_map.view(out_shape).movedim(-1, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment