Commit 3554ba52 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f7038e26
......@@ -5,31 +5,41 @@ import argparse
import torch
import torch.optim
import torchvision
import numpy as np
from tensorboardX import SummaryWriter
from torch import nn
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
__package__ = "deep_view_syn"
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
parser.add_argument('--config', type=str,
help='Net config files')
parser.add_argument('--config-id', type=str,
help='Net config id')
parser.add_argument('--dataset', type=str, required=True,
help='Dataset description file')
parser.add_argument('--epochs', type=int,
help='Max epochs for train')
parser.add_argument('--test', type=str,
help='Test net file')
parser.add_argument('--test-samples', type=int,
help='Samples used for test')
parser.add_argument('--res', type=str,
help='Resolution')
parser.add_argument('--output-gt', action='store_true',
help='Output ground truth images if exist')
parser.add_argument('--output-alongside', action='store_true',
help='Output generated image alongside ground truth image')
parser.add_argument('--output-video', action='store_true',
help='Output test results as video')
parser.add_argument('--perf', action='store_true',
help='Test performance')
opt = parser.parse_args()
if opt.res:
opt.res = tuple(int(s) for s in opt.res.split('x'))
# Select device
torch.cuda.set_device(opt.device)
......@@ -58,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False
# Train
BATCH_SIZE = 4096
EPOCH_RANGE = range(0, 500)
SAVE_INTERVAL = 20
EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 500)
SAVE_INTERVAL = 50
# Test
TEST_BATCH_SIZE = 1
......@@ -67,13 +77,14 @@ TEST_MAX_RAYS = 32768
# Paths
data_desc_path = opt.dataset
data_desc_name = os.path.split(data_desc_path)[1]
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
if opt.test:
test_net_path = opt.test
test_net_name = os.path.splitext(os.path.basename(test_net_path))[0]
run_dir = os.path.dirname(test_net_path) + '/'
run_id = os.path.basename(run_dir[:-1])
output_dir = run_dir + 'output/%s/%s/' % (test_net_name, data_desc_name)
output_dir = run_dir + 'output/%s/%s%s/' % (test_net_name, data_desc_name,
'_%dx%d' % (opt.res[0], opt.res[1]) if opt.res else '')
config.from_id(run_id)
train_mode = False
if opt.test_samples:
......@@ -83,6 +94,8 @@ if opt.test:
else:
if opt.config:
config.load(opt.config)
if opt.config_id:
config.from_id(opt.config_id)
data_dir = os.path.dirname(data_desc_path) + '/'
run_id = config.to_id()
run_dir = data_dir + run_id + '/'
......@@ -105,17 +118,17 @@ NETS = {
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'nerf': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': False}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY,
color=config.COLOR,
encode_to_dim=config.N_ENCODE_DIM),
'spher': lambda: SpherNet(
fc_params=config.FC_PARAMS,
gray=config.GRAY,
color=config.COLOR,
translation=not ROT_ONLY,
encode_to_dim=config.N_ENCODE_DIM)
}
......@@ -146,7 +159,11 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
perf.Checkpoint("Forward")
optimizer.zero_grad()
loss_mse_value = loss_mse(out, gt)
if config.COLOR == color_mode.YCbCr:
loss_mse_value = 0.3 * loss_mse(out[..., 0:2], gt[..., 0:2]) + \
0.7 * loss_mse(out[..., 2], gt[..., 2])
else:
loss_mse_value = loss_mse(out, gt)
loss_grad_value = loss_grad(out, gt) if patch else None
loss_value = loss_mse_value # + 0.5 * loss_grad_value if patch \
# else loss_mse_value
......@@ -183,7 +200,8 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
def train():
# 1. Initialize data loader
print("Load dataset: " + data_desc_path)
train_dataset = SphericalViewSynDataset(data_desc_path, gray=config.GRAY)
train_dataset = SphericalViewSynDataset(
data_desc_path, color=config.COLOR, res=opt.res)
train_dataset.set_patch_size(1)
train_data_loader = FastDataLoader(
dataset=train_dataset,
......@@ -194,7 +212,7 @@ def train():
# 2. Initialize components
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = 0#LOSSES[config.LOSS]().to(device.GetDevice())
loss = 0 # LOSSES[config.LOSS]().to(device.GetDevice())
if EPOCH_RANGE.start > 0:
iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
......@@ -223,15 +241,80 @@ def train():
netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
solver=optimizer, iters=iters)
print("Train finished")
netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
solver=optimizer, iters=iters)
def perf():
with torch.no_grad():
# 1. Load dataset
print("Load dataset: " + data_desc_path)
test_dataset = SphericalViewSynDataset(data_desc_path,
load_images=True,
color=config.COLOR, res=opt.res)
test_data_loader = FastDataLoader(
dataset=test_dataset,
batch_size=1,
shuffle=False,
drop_last=False,
pin_memory=True)
# 2. Load trained model
netio.LoadNet(test_net_path, model)
# 3. Test on dataset
print("Begin perf, batch size is %d" % TEST_BATCH_SIZE)
perf = SimplePerf(True, start=True)
loss = nn.MSELoss()
i = 0
n = test_dataset.n_views
chns = 1 if config.COLOR == color_mode.GRAY else 3
out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
test_dataset.view_res[1],
device=device.GetDevice())
perf_times = torch.empty(n)
perf_errors = torch.empty(n)
for view_idxs, gt, rays_o, rays_d in test_data_loader:
perf.Checkpoint("%d - Load" % i)
rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
n_rays = rays_o.size(0)
chunk_size = min(n_rays, TEST_MAX_RAYS)
out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
for offset in range(0, n_rays, chunk_size):
idx = slice(offset, offset + chunk_size)
out_pixels[idx] = model(rays_o[idx], rays_d[idx])
if config.COLOR == color_mode.YCbCr:
out_pixels = util.ycbcr2rgb(out_pixels)
out_view_images[view_idxs] = out_pixels.view(
TEST_BATCH_SIZE, test_dataset.view_res[0],
test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
perf_times[view_idxs] = perf.Checkpoint("%d - Infer" % i)
if config.COLOR == color_mode.YCbCr:
gt = util.ycbcr2rgb(gt)
error = loss(out_view_images[view_idxs], gt).item()
print("%d - Error: %f" % (i, error))
perf_errors[view_idxs] = error
i += 1
# 4. Save results
perf_mean_time = torch.mean(perf_times).item()
perf_mean_error = torch.mean(perf_errors).item()
with open(run_dir + 'perf_%s_%s_%.1fms_%.2e.txt' % (test_net_name, data_desc_name, perf_mean_time, perf_mean_error), 'w') as fp:
fp.write('View, Time, Error\n')
fp.writelines(['%d, %f, %f\n' % (
i, perf_times[i].item(), perf_errors[i].item()) for i in range(n)])
def test():
with torch.no_grad():
# 1. Load train dataset
# 1. Load dataset
print("Load dataset: " + data_desc_path)
test_dataset = SphericalViewSynDataset(data_desc_path,
load_images=opt.output_gt or opt.output_alongside,
gray=config.GRAY)
color=config.COLOR,
res=opt.res)
test_data_loader = FastDataLoader(
dataset=test_dataset,
batch_size=1,
......@@ -242,14 +325,14 @@ def test():
# 2. Load trained model
netio.LoadNet(test_net_path, model)
# 3. Test on train dataset
print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE)
# 3. Test on dataset
print("Begin test, batch size is %d" % TEST_BATCH_SIZE)
util.CreateDirIfNeed(output_dir)
perf = SimplePerf(True, start=True)
i = 0
n = test_dataset.n_views
chns = 1 if config.GRAY else 3
chns = 1 if config.COLOR == color_mode.GRAY else 3
out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
test_dataset.view_res[1],
device=device.GetDevice())
......@@ -263,6 +346,8 @@ def test():
for offset in range(0, n_rays, chunk_size):
idx = slice(offset, offset + chunk_size)
out_pixels[idx] = model(rays_o[idx], rays_d[idx])
if config.COLOR == color_mode.YCbCr:
out_pixels = util.ycbcr2rgb(out_pixels)
out_view_images[view_idxs] = out_pixels.view(
TEST_BATCH_SIZE, test_dataset.view_res[0],
test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
......@@ -297,5 +382,7 @@ def test():
if __name__ == "__main__":
if train_mode:
train()
elif opt.perf:
perf()
else:
test()
......@@ -4,11 +4,12 @@ import argparse
import os
import sys
import torch
import torch.nn.functional as nn_f
from torch.utils.data import DataLoader
from tensorboardX.writer import SummaryWriter
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
__package__ = "deep_view_syn"
# ===========================================================
# Training settings
......@@ -31,6 +32,8 @@ parser.add_argument('--dataset', type=str, required=True,
help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
parser.add_argument('--color', type=str, default='rgb',
help='color')
# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
......@@ -46,51 +49,57 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import util
from .my import netio
from .my import device
from .SRGAN.solver import SRGANTrainer as Solver
from .my import color_mode
#from .upsampling.SubPixelCNN.solver import SubPixelTrainer as Solver
from .upsampling.SRCNN.solver import SRCNNTrainer as Solver
from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader
os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
args.color = color_mode.from_str(args.color)
def train():
util.CreateDirIfNeed(run_dir)
train_set = UpsamplingDataset('.', 'out_view_%04d.png',
'gt_view_%04d.png', gray=True)
train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.batchSize,
shuffle=True,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
# ===
for epoch in range(1, 20 + 1):
trainer.pretrain()
print("{}/{} pretrained".format(epoch, trainer.epoch_pretrain))
# ===
trainer.build_model(3 if args.color == color_mode.RGB else 1)
iters = 0
for epoch in range(1, args.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
iters = trainer.train(epoch, iters)
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.netG)
iters = trainer.train(epoch, iters,
channels=slice(2, 3) if args.color == color_mode.YCbCr
else None)
netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
def test():
util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
train_set = UpsamplingDataset('.', 'out_view_%04d.png', None, gray=True)
train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.testBatchSize,
shuffle=False,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model()
netio.LoadNet(args.test, trainer.netG)
trainer.build_model(3 if args.color == color_mode.RGB else 1)
netio.LoadNet(args.test, trainer.model)
for idx, input, _ in training_data_loader:
output = trainer.netG(input)
if args.color == color_mode.YCbCr:
output_y = trainer.model(input[:, -1:])
output_cbcr = nn_f.upsample(input[:, 0:2], scale_factor=2)
output = util.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
else:
output = trainer.model(input)
util.WriteImageTensor(output, args.testOutPatt % idx)
......
......@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from ..my.progress_bar import progress_bar
from my.progress_bar import progress_bar
class SRCNNTrainer(object):
......@@ -28,8 +28,8 @@ class SRCNNTrainer(object):
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
self.model = Net(num_channels=1, base_filter=64, upscale_factor=self.upscale_factor).to(self.device)
def build_model(self, num_channels):
self.model = Net(num_channels=num_channels, base_filter=64, upscale_factor=self.upscale_factor).to(self.device)
self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
......@@ -47,11 +47,15 @@ class SRCNNTrainer(object):
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters):
def train(self, epoch, iters, channels = None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data =data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
......
......@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from ..my.progress_bar import progress_bar
from my.progress_bar import progress_bar
class SubPixelTrainer(object):
......@@ -28,7 +28,9 @@ class SubPixelTrainer(object):
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
def build_model(self, num_channels):
if num_channels != 1:
raise ValueError('num_channels must be 1')
self.model = Net(upscale_factor=self.upscale_factor).to(self.device)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
......@@ -39,17 +41,21 @@ class SubPixelTrainer(object):
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
def save(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters):
def train(self, epoch, iters, channels=None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
......@@ -58,7 +64,8 @@ class SubPixelTrainer(object):
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
progress_bar(batch_num, len(self.training_loader),
'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
......@@ -66,11 +73,13 @@ class SubPixelTrainer(object):
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
torchvision.utils.make_grid(
output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
print(" Average Loss: {:.4f}".format(
train_loss / len(self.training_loader)))
return iters
def test(self):
......@@ -84,9 +93,11 @@ class SubPixelTrainer(object):
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
progress_bar(batch_num, len(self.testing_loader),
'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
print(" Average PSNR: {:.4f} dB".format(
avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment