Commit 3554ba52 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f7038e26
...@@ -5,31 +5,41 @@ import argparse ...@@ -5,31 +5,41 @@ import argparse
import torch import torch
import torch.optim import torch.optim
import torchvision import torchvision
import numpy as np
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield" __package__ = "deep_view_syn"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=3, parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.') help='Which CUDA device to use.')
parser.add_argument('--config', type=str, parser.add_argument('--config', type=str,
help='Net config files') help='Net config files')
parser.add_argument('--config-id', type=str,
help='Net config id')
parser.add_argument('--dataset', type=str, required=True, parser.add_argument('--dataset', type=str, required=True,
help='Dataset description file') help='Dataset description file')
parser.add_argument('--epochs', type=int,
help='Max epochs for train')
parser.add_argument('--test', type=str, parser.add_argument('--test', type=str,
help='Test net file') help='Test net file')
parser.add_argument('--test-samples', type=int, parser.add_argument('--test-samples', type=int,
help='Samples used for test') help='Samples used for test')
parser.add_argument('--res', type=str,
help='Resolution')
parser.add_argument('--output-gt', action='store_true', parser.add_argument('--output-gt', action='store_true',
help='Output ground truth images if exist') help='Output ground truth images if exist')
parser.add_argument('--output-alongside', action='store_true', parser.add_argument('--output-alongside', action='store_true',
help='Output generated image alongside ground truth image') help='Output generated image alongside ground truth image')
parser.add_argument('--output-video', action='store_true', parser.add_argument('--output-video', action='store_true',
help='Output test results as video') help='Output test results as video')
parser.add_argument('--perf', action='store_true',
help='Test performance')
opt = parser.parse_args() opt = parser.parse_args()
if opt.res:
opt.res = tuple(int(s) for s in opt.res.split('x'))
# Select device # Select device
torch.cuda.set_device(opt.device) torch.cuda.set_device(opt.device)
...@@ -58,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False ...@@ -58,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False
# Train # Train
BATCH_SIZE = 4096 BATCH_SIZE = 4096
EPOCH_RANGE = range(0, 500) EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 500)
SAVE_INTERVAL = 20 SAVE_INTERVAL = 50
# Test # Test
TEST_BATCH_SIZE = 1 TEST_BATCH_SIZE = 1
...@@ -67,13 +77,14 @@ TEST_MAX_RAYS = 32768 ...@@ -67,13 +77,14 @@ TEST_MAX_RAYS = 32768
# Paths # Paths
data_desc_path = opt.dataset data_desc_path = opt.dataset
data_desc_name = os.path.split(data_desc_path)[1] data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
if opt.test: if opt.test:
test_net_path = opt.test test_net_path = opt.test
test_net_name = os.path.splitext(os.path.basename(test_net_path))[0] test_net_name = os.path.splitext(os.path.basename(test_net_path))[0]
run_dir = os.path.dirname(test_net_path) + '/' run_dir = os.path.dirname(test_net_path) + '/'
run_id = os.path.basename(run_dir[:-1]) run_id = os.path.basename(run_dir[:-1])
output_dir = run_dir + 'output/%s/%s/' % (test_net_name, data_desc_name) output_dir = run_dir + 'output/%s/%s%s/' % (test_net_name, data_desc_name,
'_%dx%d' % (opt.res[0], opt.res[1]) if opt.res else '')
config.from_id(run_id) config.from_id(run_id)
train_mode = False train_mode = False
if opt.test_samples: if opt.test_samples:
...@@ -83,6 +94,8 @@ if opt.test: ...@@ -83,6 +94,8 @@ if opt.test:
else: else:
if opt.config: if opt.config:
config.load(opt.config) config.load(opt.config)
if opt.config_id:
config.from_id(opt.config_id)
data_dir = os.path.dirname(data_desc_path) + '/' data_dir = os.path.dirname(data_desc_path) + '/'
run_id = config.to_id() run_id = config.to_id()
run_dir = data_dir + run_id + '/' run_dir = data_dir + run_id + '/'
...@@ -105,17 +118,17 @@ NETS = { ...@@ -105,17 +118,17 @@ NETS = {
fc_params=config.FC_PARAMS, fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update( sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1], {'spherical': True}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY, color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM), encode_to_dim=config.N_ENCODE_DIM),
'nerf': lambda: MslNet( 'nerf': lambda: MslNet(
fc_params=config.FC_PARAMS, fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update( sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': False}), config.SAMPLE_PARAMS)[1], {'spherical': False}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY, color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM), encode_to_dim=config.N_ENCODE_DIM),
'spher': lambda: SpherNet( 'spher': lambda: SpherNet(
fc_params=config.FC_PARAMS, fc_params=config.FC_PARAMS,
gray=config.GRAY, color=config.COLOR,
translation=not ROT_ONLY, translation=not ROT_ONLY,
encode_to_dim=config.N_ENCODE_DIM) encode_to_dim=config.N_ENCODE_DIM)
} }
...@@ -146,7 +159,11 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters): ...@@ -146,7 +159,11 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
perf.Checkpoint("Forward") perf.Checkpoint("Forward")
optimizer.zero_grad() optimizer.zero_grad()
loss_mse_value = loss_mse(out, gt) if config.COLOR == color_mode.YCbCr:
loss_mse_value = 0.3 * loss_mse(out[..., 0:2], gt[..., 0:2]) + \
0.7 * loss_mse(out[..., 2], gt[..., 2])
else:
loss_mse_value = loss_mse(out, gt)
loss_grad_value = loss_grad(out, gt) if patch else None loss_grad_value = loss_grad(out, gt) if patch else None
loss_value = loss_mse_value # + 0.5 * loss_grad_value if patch \ loss_value = loss_mse_value # + 0.5 * loss_grad_value if patch \
# else loss_mse_value # else loss_mse_value
...@@ -183,7 +200,8 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters): ...@@ -183,7 +200,8 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
def train(): def train():
# 1. Initialize data loader # 1. Initialize data loader
print("Load dataset: " + data_desc_path) print("Load dataset: " + data_desc_path)
train_dataset = SphericalViewSynDataset(data_desc_path, gray=config.GRAY) train_dataset = SphericalViewSynDataset(
data_desc_path, color=config.COLOR, res=opt.res)
train_dataset.set_patch_size(1) train_dataset.set_patch_size(1)
train_data_loader = FastDataLoader( train_data_loader = FastDataLoader(
dataset=train_dataset, dataset=train_dataset,
...@@ -194,7 +212,7 @@ def train(): ...@@ -194,7 +212,7 @@ def train():
# 2. Initialize components # 2. Initialize components
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = 0#LOSSES[config.LOSS]().to(device.GetDevice()) loss = 0 # LOSSES[config.LOSS]().to(device.GetDevice())
if EPOCH_RANGE.start > 0: if EPOCH_RANGE.start > 0:
iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start), iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
...@@ -223,15 +241,80 @@ def train(): ...@@ -223,15 +241,80 @@ def train():
netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model, netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
solver=optimizer, iters=iters) solver=optimizer, iters=iters)
print("Train finished") print("Train finished")
netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
solver=optimizer, iters=iters)
def perf():
with torch.no_grad():
# 1. Load dataset
print("Load dataset: " + data_desc_path)
test_dataset = SphericalViewSynDataset(data_desc_path,
load_images=True,
color=config.COLOR, res=opt.res)
test_data_loader = FastDataLoader(
dataset=test_dataset,
batch_size=1,
shuffle=False,
drop_last=False,
pin_memory=True)
# 2. Load trained model
netio.LoadNet(test_net_path, model)
# 3. Test on dataset
print("Begin perf, batch size is %d" % TEST_BATCH_SIZE)
perf = SimplePerf(True, start=True)
loss = nn.MSELoss()
i = 0
n = test_dataset.n_views
chns = 1 if config.COLOR == color_mode.GRAY else 3
out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
test_dataset.view_res[1],
device=device.GetDevice())
perf_times = torch.empty(n)
perf_errors = torch.empty(n)
for view_idxs, gt, rays_o, rays_d in test_data_loader:
perf.Checkpoint("%d - Load" % i)
rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
n_rays = rays_o.size(0)
chunk_size = min(n_rays, TEST_MAX_RAYS)
out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
for offset in range(0, n_rays, chunk_size):
idx = slice(offset, offset + chunk_size)
out_pixels[idx] = model(rays_o[idx], rays_d[idx])
if config.COLOR == color_mode.YCbCr:
out_pixels = util.ycbcr2rgb(out_pixels)
out_view_images[view_idxs] = out_pixels.view(
TEST_BATCH_SIZE, test_dataset.view_res[0],
test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
perf_times[view_idxs] = perf.Checkpoint("%d - Infer" % i)
if config.COLOR == color_mode.YCbCr:
gt = util.ycbcr2rgb(gt)
error = loss(out_view_images[view_idxs], gt).item()
print("%d - Error: %f" % (i, error))
perf_errors[view_idxs] = error
i += 1
# 4. Save results
perf_mean_time = torch.mean(perf_times).item()
perf_mean_error = torch.mean(perf_errors).item()
with open(run_dir + 'perf_%s_%s_%.1fms_%.2e.txt' % (test_net_name, data_desc_name, perf_mean_time, perf_mean_error), 'w') as fp:
fp.write('View, Time, Error\n')
fp.writelines(['%d, %f, %f\n' % (
i, perf_times[i].item(), perf_errors[i].item()) for i in range(n)])
def test(): def test():
with torch.no_grad(): with torch.no_grad():
# 1. Load train dataset # 1. Load dataset
print("Load dataset: " + data_desc_path) print("Load dataset: " + data_desc_path)
test_dataset = SphericalViewSynDataset(data_desc_path, test_dataset = SphericalViewSynDataset(data_desc_path,
load_images=opt.output_gt or opt.output_alongside, load_images=opt.output_gt or opt.output_alongside,
gray=config.GRAY) color=config.COLOR,
res=opt.res)
test_data_loader = FastDataLoader( test_data_loader = FastDataLoader(
dataset=test_dataset, dataset=test_dataset,
batch_size=1, batch_size=1,
...@@ -242,14 +325,14 @@ def test(): ...@@ -242,14 +325,14 @@ def test():
# 2. Load trained model # 2. Load trained model
netio.LoadNet(test_net_path, model) netio.LoadNet(test_net_path, model)
# 3. Test on train dataset # 3. Test on dataset
print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE) print("Begin test, batch size is %d" % TEST_BATCH_SIZE)
util.CreateDirIfNeed(output_dir) util.CreateDirIfNeed(output_dir)
perf = SimplePerf(True, start=True) perf = SimplePerf(True, start=True)
i = 0 i = 0
n = test_dataset.n_views n = test_dataset.n_views
chns = 1 if config.GRAY else 3 chns = 1 if config.COLOR == color_mode.GRAY else 3
out_view_images = torch.empty(n, chns, test_dataset.view_res[0], out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
test_dataset.view_res[1], test_dataset.view_res[1],
device=device.GetDevice()) device=device.GetDevice())
...@@ -263,6 +346,8 @@ def test(): ...@@ -263,6 +346,8 @@ def test():
for offset in range(0, n_rays, chunk_size): for offset in range(0, n_rays, chunk_size):
idx = slice(offset, offset + chunk_size) idx = slice(offset, offset + chunk_size)
out_pixels[idx] = model(rays_o[idx], rays_d[idx]) out_pixels[idx] = model(rays_o[idx], rays_d[idx])
if config.COLOR == color_mode.YCbCr:
out_pixels = util.ycbcr2rgb(out_pixels)
out_view_images[view_idxs] = out_pixels.view( out_view_images[view_idxs] = out_pixels.view(
TEST_BATCH_SIZE, test_dataset.view_res[0], TEST_BATCH_SIZE, test_dataset.view_res[0],
test_dataset.view_res[1], -1).permute(0, 3, 1, 2) test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
...@@ -297,5 +382,7 @@ def test(): ...@@ -297,5 +382,7 @@ def test():
if __name__ == "__main__": if __name__ == "__main__":
if train_mode: if train_mode:
train() train()
elif opt.perf:
perf()
else: else:
test() test()
...@@ -4,11 +4,12 @@ import argparse ...@@ -4,11 +4,12 @@ import argparse
import os import os
import sys import sys
import torch import torch
import torch.nn.functional as nn_f
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tensorboardX.writer import SummaryWriter from tensorboardX.writer import SummaryWriter
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield" __package__ = "deep_view_syn"
# =========================================================== # ===========================================================
# Training settings # Training settings
...@@ -31,6 +32,8 @@ parser.add_argument('--dataset', type=str, required=True, ...@@ -31,6 +32,8 @@ parser.add_argument('--dataset', type=str, required=True,
help='dataset directory') help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test') parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern') parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
parser.add_argument('--color', type=str, default='rgb',
help='color')
# model configuration # model configuration
parser.add_argument('--upscale_factor', '-uf', type=int, parser.add_argument('--upscale_factor', '-uf', type=int,
...@@ -46,51 +49,57 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device()) ...@@ -46,51 +49,57 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import util from .my import util
from .my import netio from .my import netio
from .my import device from .my import device
from .SRGAN.solver import SRGANTrainer as Solver from .my import color_mode
#from .upsampling.SubPixelCNN.solver import SubPixelTrainer as Solver
from .upsampling.SRCNN.solver import SRCNNTrainer as Solver
from .data.upsampling import UpsamplingDataset from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader from .data.loader import FastDataLoader
os.chdir(args.dataset) os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd()) print('Change working directory to ' + os.getcwd())
run_dir = 'run/' run_dir = 'run/'
args.color = color_mode.from_str(args.color)
def train(): def train():
util.CreateDirIfNeed(run_dir) util.CreateDirIfNeed(run_dir)
train_set = UpsamplingDataset('.', 'out_view_%04d.png', train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt_view_%04d.png', gray=True) 'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set, training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.batchSize, batch_size=args.batchSize,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader, trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir)) SummaryWriter(run_dir))
trainer.build_model() trainer.build_model(3 if args.color == color_mode.RGB else 1)
# ===
for epoch in range(1, 20 + 1):
trainer.pretrain()
print("{}/{} pretrained".format(epoch, trainer.epoch_pretrain))
# ===
iters = 0 iters = 0
for epoch in range(1, args.nEpochs + 1): for epoch in range(1, args.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch)) print("\n===> Epoch {} starts:".format(epoch))
iters = trainer.train(epoch, iters) iters = trainer.train(epoch, iters,
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.netG) channels=slice(2, 3) if args.color == color_mode.YCbCr
else None)
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
def test(): def test():
util.CreateDirIfNeed(os.path.dirname(args.testOutPatt)) util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
train_set = UpsamplingDataset('.', 'out_view_%04d.png', None, gray=True) train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set, training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.testBatchSize, batch_size=args.testBatchSize,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader, trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir)) SummaryWriter(run_dir))
trainer.build_model() trainer.build_model(3 if args.color == color_mode.RGB else 1)
netio.LoadNet(args.test, trainer.netG) netio.LoadNet(args.test, trainer.model)
for idx, input, _ in training_data_loader: for idx, input, _ in training_data_loader:
output = trainer.netG(input) if args.color == color_mode.YCbCr:
output_y = trainer.model(input[:, -1:])
output_cbcr = nn_f.upsample(input[:, 0:2], scale_factor=2)
output = util.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
else:
output = trainer.model(input)
util.WriteImageTensor(output, args.testOutPatt % idx) util.WriteImageTensor(output, args.testOutPatt % idx)
......
...@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn ...@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import torchvision import torchvision
from .model import Net from .model import Net
from ..my.progress_bar import progress_bar from my.progress_bar import progress_bar
class SRCNNTrainer(object): class SRCNNTrainer(object):
...@@ -28,8 +28,8 @@ class SRCNNTrainer(object): ...@@ -28,8 +28,8 @@ class SRCNNTrainer(object):
self.testing_loader = testing_loader self.testing_loader = testing_loader
self.writer = writer self.writer = writer
def build_model(self): def build_model(self, num_channels):
self.model = Net(num_channels=1, base_filter=64, upscale_factor=self.upscale_factor).to(self.device) self.model = Net(num_channels=num_channels, base_filter=64, upscale_factor=self.upscale_factor).to(self.device)
self.model.weight_init(mean=0.0, std=0.01) self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss() self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
...@@ -47,11 +47,15 @@ class SRCNNTrainer(object): ...@@ -47,11 +47,15 @@ class SRCNNTrainer(object):
torch.save(self.model, model_out_path) torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path)) print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters): def train(self, epoch, iters, channels = None):
self.model.train() self.model.train()
train_loss = 0 train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader): for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device) if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data =data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
out = self.model(data) out = self.model(data)
loss = self.criterion(out, target) loss = self.criterion(out, target)
......
...@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn ...@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import torchvision import torchvision
from .model import Net from .model import Net
from ..my.progress_bar import progress_bar from my.progress_bar import progress_bar
class SubPixelTrainer(object): class SubPixelTrainer(object):
...@@ -28,7 +28,9 @@ class SubPixelTrainer(object): ...@@ -28,7 +28,9 @@ class SubPixelTrainer(object):
self.testing_loader = testing_loader self.testing_loader = testing_loader
self.writer = writer self.writer = writer
def build_model(self): def build_model(self, num_channels):
if num_channels != 1:
raise ValueError('num_channels must be 1')
self.model = Net(upscale_factor=self.upscale_factor).to(self.device) self.model = Net(upscale_factor=self.upscale_factor).to(self.device)
self.criterion = torch.nn.MSELoss() self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
...@@ -39,17 +41,21 @@ class SubPixelTrainer(object): ...@@ -39,17 +41,21 @@ class SubPixelTrainer(object):
self.criterion.cuda() self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
def save(self): def save(self):
model_out_path = "model_path.pth" model_out_path = "model_path.pth"
torch.save(self.model, model_out_path) torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path)) print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters): def train(self, epoch, iters, channels=None):
self.model.train() self.model.train()
train_loss = 0 train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader): for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data, target = data.to(self.device), target.to(self.device) data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
out = self.model(data) out = self.model(data)
...@@ -58,7 +64,8 @@ class SubPixelTrainer(object): ...@@ -58,7 +64,8 @@ class SubPixelTrainer(object):
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch) sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1))) progress_bar(batch_num, len(self.training_loader),
'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer: if self.writer:
self.writer.add_scalar("loss", loss, iters) self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0: if iters % 100 == 0:
...@@ -66,11 +73,13 @@ class SubPixelTrainer(object): ...@@ -66,11 +73,13 @@ class SubPixelTrainer(object):
.flatten(0, 1).detach() .flatten(0, 1).detach()
self.writer.add_image( self.writer.add_image(
"Output_vs_gt", "Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(), torchvision.utils.make_grid(
output_vs_gt, nrow=2).cpu().numpy(),
iters) iters)
iters += 1 iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader))) print(" Average Loss: {:.4f}".format(
train_loss / len(self.training_loader)))
return iters return iters
def test(self): def test(self):
...@@ -84,9 +93,11 @@ class SubPixelTrainer(object): ...@@ -84,9 +93,11 @@ class SubPixelTrainer(object):
mse = self.criterion(prediction, target) mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item()) psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1))) progress_bar(batch_num, len(self.testing_loader),
'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader))) print(" Average PSNR: {:.4f} dB".format(
avg_psnr / len(self.testing_loader)))
def run(self): def run(self):
self.build_model() self.build_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment