import math import sys import os import argparse import torch import torch.optim import torchvision import importlib from tensorboardX import SummaryWriter from torch import nn sys.path.append(os.path.abspath(sys.path[0] + '/../')) __package__ = "deeplightfield" parser = argparse.ArgumentParser() parser.add_argument('--device', type=int, default=3, help='Which CUDA device to use.') parser.add_argument('--config', type=str, help='Net config files') parser.add_argument('--dataset', type=str, required=True, help='Dataset description file') parser.add_argument('--test', type=str, help='Test net file') parser.add_argument('--test-samples', type=int, help='Samples used for test') parser.add_argument('--output-gt', action='store_true', help='Output ground truth images if exist') parser.add_argument('--output-alongside', action='store_true', help='Output generated image alongside ground truth image') parser.add_argument('--output-video', action='store_true', help='Output test results as video') opt = parser.parse_args() # Select device torch.cuda.set_device(opt.device) print("Set CUDA:%d as current device." % torch.cuda.current_device()) from .my import netio from .my import util from .my import device from .my.simple_perf import SimplePerf from .data.spherical_view_syn import * from .msl_net import MslNet from .spher_net import SpherNet from .my import loss class Config(object): def __init__(self): self.name = 'default' self.GRAY = False # Net parameters self.NET_TYPE = 'msl' self.N_ENCODE_DIM = 10 self.FC_PARAMS = { 'nf': 256, 'n_layers': 8, 'skips': [4] } self.SAMPLE_PARAMS = { 'depth_range': (1, 50), 'n_samples': 32, 'perturb_sample': True } self.LOSS = 'mse' def load(self, path): module_name = os.path.splitext(path)[0].replace('/', '.') config_module = importlib.import_module( 'deeplightfield.' + module_name) config_module.update_config(config) self.name = module_name.split('.')[-1] def load_by_name(self, name): config_module = importlib.import_module( 'deeplightfield.configs.' + name) config_module.update_config(config) self.name = name def print(self): print('==== Config %s ====' % self.name) print('Net type: ', self.NET_TYPE) print('Encode dim: ', self.N_ENCODE_DIM) print('Full-connected network parameters:', self.FC_PARAMS) print('Sample parameters', self.SAMPLE_PARAMS) print('Loss', self.LOSS) print('==========================') config = Config() # Toggles ROT_ONLY = False EVAL_TIME_PERFORMANCE = False # ======== #ROT_ONLY = True #EVAL_TIME_PERFORMANCE = True # Train PATCH_SIZE = 1 BATCH_SIZE = 4096 // (PATCH_SIZE * PATCH_SIZE) EPOCH_RANGE = range(0, 500) SAVE_INTERVAL = 20 # Test TEST_BATCH_SIZE = 1 TEST_CHUNKS = 1 # Paths data_desc_path = opt.dataset data_desc_name = os.path.split(data_desc_path)[1] if opt.test: test_net_path = opt.test test_net_name = os.path.splitext(os.path.basename(test_net_path))[0] run_dir = os.path.dirname(test_net_path) + '/' run_id = os.path.basename(run_dir[:-1]) config_name = run_id.split('_b')[0] output_dir = run_dir + 'output/%s/%s/' % (test_net_name, data_desc_name) config.load_by_name(config_name) train_mode = False if opt.test_samples: config.SAMPLE_PARAMS['n_samples'] = opt.test_samples output_dir = run_dir + 'output/%s/%s_s%d/' % \ (test_net_name, data_desc_name, opt.test_samples) else: if opt.config: config.load(opt.config) data_dir = os.path.dirname(data_desc_path) + '/' run_id = '%s_b%d[%d]' % (config.name, BATCH_SIZE, PATCH_SIZE) run_dir = data_dir + run_id + '/' log_dir = run_dir + 'log/' output_dir = None train_mode = True config.print() print("dataset: ", data_desc_path) print("train_mode: ", train_mode) print("run_dir: ", run_dir) if not train_mode: print("output_dir", output_dir) config.SAMPLE_PARAMS['perturb_sample'] = \ config.SAMPLE_PARAMS['perturb_sample'] and train_mode NETS = { 'msl': lambda: MslNet( fc_params=config.FC_PARAMS, sampler_params=(config.SAMPLE_PARAMS.update( {'spherical': True}), config.SAMPLE_PARAMS)[1], gray=config.GRAY, encode_to_dim=config.N_ENCODE_DIM), 'nerf': lambda: MslNet( fc_params=config.FC_PARAMS, sampler_params=(config.SAMPLE_PARAMS.update( {'spherical': False}), config.SAMPLE_PARAMS)[1], gray=config.GRAY, encode_to_dim=config.N_ENCODE_DIM), 'spher': lambda: SpherNet( fc_params=config.FC_PARAMS, gray=config.GRAY, translation=not ROT_ONLY, encode_to_dim=config.N_ENCODE_DIM) } LOSSES = { 'mse': lambda: nn.MSELoss(), 'mse_grad': lambda: loss.CombinedLoss( [nn.MSELoss(), loss.GradLoss()], [1.0, 0.5]) } # Initialize model model = NETS[config.NET_TYPE]().to(device.GetDevice()) def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters): sub_iters = 0 iters_in_epoch = len(data_loader) for _, gt, rays_o, rays_d in data_loader: gt = gt.to(device.GetDevice()) rays_o = rays_o.to(device.GetDevice()) rays_d = rays_d.to(device.GetDevice()) perf.Checkpoint("Load") out = model(rays_o, rays_d) perf.Checkpoint("Forward") optimizer.zero_grad() loss_value = loss(out, gt) perf.Checkpoint("Compute loss") loss_value.backward() perf.Checkpoint("Backward") optimizer.step() perf.Checkpoint("Update") print("Epoch: %d, Iter: %d(%d/%d), Loss: %f" % (epoch, iters, sub_iters, iters_in_epoch, loss_value.item())) # Write tensorboard logs. writer.add_scalar("loss", loss_value, iters) if len(gt.size()) == 4 and iters % 100 == 0: output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach() writer.add_image("Output_vs_gt", torchvision.utils.make_grid( output_vs_gt, nrow=4).cpu().numpy(), iters) iters += 1 sub_iters += 1 return iters def train(): # 1. Initialize data loader print("Load dataset: " + data_desc_path) train_dataset = FastSphericalViewSynDataset(data_desc_path, gray=config.GRAY) train_dataset.set_patch_size((PATCH_SIZE, PATCH_SIZE)) train_data_loader = FastDataLoader( dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, pin_memory=True) # 2. Initialize components optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) loss = LOSSES[config.LOSS]().to(device.GetDevice()) if EPOCH_RANGE.start > 0: iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start), model, solver=optimizer) else: iters = 0 epoch = None # 3. Train model.train() util.CreateDirIfNeed(run_dir) util.CreateDirIfNeed(log_dir) perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True) perf_epoch = SimplePerf(True, start=True) writer = SummaryWriter(log_dir) print("Begin training...") for epoch in EPOCH_RANGE: perf_epoch.Checkpoint("Epoch") iters = train_loop(train_data_loader, optimizer, loss, perf, writer, epoch, iters) # Save checkpoint if ((epoch + 1) % SAVE_INTERVAL == 0): netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model, solver=optimizer, iters=iters) print("Train finished") def test(): torch.autograd.set_grad_enabled(False) # 1. Load train dataset print("Load dataset: " + data_desc_path) test_dataset = SphericalViewSynDataset(data_desc_path, load_images=opt.output_gt, gray=config.GRAY) test_data_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=1, shuffle=False, drop_last=False, pin_memory=True) # 2. Load trained model netio.LoadNet(test_net_path, model) # 3. Test on train dataset print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE) util.CreateDirIfNeed(output_dir) perf = SimplePerf(True, start=True) i = 0 n = test_dataset.view_rots.size(0) chns = 1 if config.GRAY else 3 out_view_images = torch.empty(n, chns, test_dataset.view_res[0], test_dataset.view_res[1], device=device.GetDevice()) print(out_view_images.size()) for view_idxs, _, rays_o, rays_d in test_data_loader: perf.Checkpoint("%d - Load" % i) rays_o = rays_o.to(device.GetDevice()).view(-1, 3) rays_d = rays_d.to(device.GetDevice()).view(-1, 3) n_rays = rays_o.size(0) chunk_size = n_rays // TEST_CHUNKS out_pixels = torch.empty(n_rays, chns, device=device.GetDevice()) for offset in range(0, n_rays, chunk_size): rays_o_ = rays_o[offset:offset + chunk_size] rays_d_ = rays_d[offset:offset + chunk_size] out_pixels[offset:offset + chunk_size] = \ model(rays_o_, rays_d_) out_view_images[view_idxs] = out_pixels.view( TEST_BATCH_SIZE, test_dataset.view_res[0], test_dataset.view_res[1], -1).permute(0, 3, 1, 2) perf.Checkpoint("%d - Infer" % i) i += 1 if opt.output_video: util.generate_video(out_view_images, output_dir + 'out.mp4', 24, 3, True) else: gt_paths = ['%sgt_view_%04d.png' % (output_dir, i) for i in range(n)] out_paths = ['%sout_view_%04d.png' % (output_dir, i) for i in range(n)] if test_dataset.load_images: if opt.output_alongside: util.WriteImageTensor( torch.cat([test_dataset.view_images, out_view_images.cpu()], 3), out_paths) else: util.WriteImageTensor(out_view_images, out_paths) util.WriteImageTensor(test_dataset.view_images, gt_paths) else: util.WriteImageTensor(out_view_images, out_paths) if __name__ == "__main__": if train_mode: train() else: test()