import sys import os sys.path.append(os.path.abspath(sys.path[0] + '/../')) __package__ = "deeplightfield" import argparse import torch import torch.optim import torchvision from tensorboardX import SummaryWriter from torch import nn from .my import netio from .my import util from .my import device from .my.simple_perf import SimplePerf from .data.spherical_view_syn import SphericalViewSynDataset from .msl_net import MslNet from .spher_net import SpherNet parser = argparse.ArgumentParser() parser.add_argument('--device', type=int, default=3, help='Which CUDA device to use.') opt = parser.parse_args() # Select device torch.cuda.set_device(opt.device) print("Set CUDA:%d as current device." % torch.cuda.current_device()) # Toggles GRAY = False ROT_ONLY = False TRAIN_MODE = True EVAL_TIME_PERFORMANCE = False RAY_AS_ITEM = True # ======== GRAY = True #ROT_ONLY = True #TRAIN_MODE = False #EVAL_TIME_PERFORMANCE = True #RAY_AS_ITEM = False # Net parameters DEPTH_RANGE = (1, 10) N_DEPTH_LAYERS = 10 N_ENCODE_DIM = 10 FC_PARAMS = { 'nf': 128, 'n_layers': 8, 'skips': [4] } # Train TRAIN_DATA_DESC_FILE = 'train.json' BATCH_SIZE = 2048 if RAY_AS_ITEM else 4 EPOCH_RANGE = range(0, 500) SAVE_INTERVAL = 20 # Test TEST_NET_NAME = 'model-epoch_500' TEST_DATA_DESC_FILE = 'test_fovea.json' TEST_BATCH_SIZE = 5 # Paths DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.28/' RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb', BATCH_SIZE, N_ENCODE_DIM, FC_PARAMS['nf'], FC_PARAMS['n_layers'], '_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '') RUN_DIR = DATA_DIR + RUN_ID + '/' OUTPUT_DIR = RUN_DIR + 'output/' LOG_DIR = RUN_DIR + 'log/' def train(): # 1. Initialize data loader print("Load dataset: " + DATA_DIR + TRAIN_DATA_DESC_FILE) train_dataset = SphericalViewSynDataset(DATA_DIR + TRAIN_DATA_DESC_FILE, gray=GRAY, ray_as_item=RAY_AS_ITEM) train_data_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=False) print('Data loaded. %d iters per epoch.' % len(train_data_loader)) # 2. Initialize components if ROT_ONLY: model = SpherNet(cam_params=train_dataset.cam_params, fc_params=FC_PARAMS, out_res=train_dataset.view_res, gray=GRAY, encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) else: model = MslNet(cam_params=train_dataset.cam_params, fc_params=FC_PARAMS, sphere_layers=util.GetDepthLayers( DEPTH_RANGE, N_DEPTH_LAYERS), out_res=train_dataset.view_res, gray=GRAY, encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) loss = nn.MSELoss() if EPOCH_RANGE.start > 0: netio.LoadNet('%smodel-epoch_%d.pth' % (RUN_DIR, EPOCH_RANGE.start), model, solver=optimizer) # 3. Train model.train() epoch = None iters = EPOCH_RANGE.start * len(train_data_loader) util.CreateDirIfNeed(RUN_DIR) util.CreateDirIfNeed(LOG_DIR) perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True) perf_epoch = SimplePerf(True, start=True) writer = SummaryWriter(LOG_DIR) print("Begin training...") for epoch in EPOCH_RANGE: for _, gt, ray_positions, ray_directions in train_data_loader: gt = gt.to(device.GetDevice()) ray_positions = ray_positions.to(device.GetDevice()) ray_directions = ray_directions.to(device.GetDevice()) perf.Checkpoint("Load") out = model(ray_positions, ray_directions) perf.Checkpoint("Forward") optimizer.zero_grad() loss_value = loss(out, gt) perf.Checkpoint("Compute loss") loss_value.backward() perf.Checkpoint("Backward") optimizer.step() perf.Checkpoint("Update") print("Epoch: ", epoch, ", Iter: ", iters, ", Loss: ", loss_value.item()) # Write tensorboard logs. writer.add_scalar("loss", loss_value, iters) if not RAY_AS_ITEM and iters % 100 == 0: output_vs_gt = torch.cat([out, gt], dim=0) writer.add_image("Output_vs_gt", torchvision.utils.make_grid( output_vs_gt, scale_each=True, normalize=False) .cpu().detach().numpy(), iters) iters += 1 perf_epoch.Checkpoint("Epoch") # Save checkpoint if ((epoch + 1) % SAVE_INTERVAL == 0): netio.SaveNet('%smodel-epoch_%d.pth' % (RUN_DIR, epoch + 1), model, solver=optimizer) print("Train finished") def test(net_file: str): # 1. Load train dataset print("Load dataset: " + DATA_DIR + TEST_DATA_DESC_FILE) test_dataset = SphericalViewSynDataset(DATA_DIR + TEST_DATA_DESC_FILE, load_images=True, gray=GRAY) test_data_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=TEST_BATCH_SIZE, pin_memory=True, shuffle=False, drop_last=False) # 2. Load trained model if ROT_ONLY: model = SpherNet(cam_params=test_dataset.cam_params, fc_params=FC_PARAMS, out_res=test_dataset.view_res, gray=GRAY, encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) else: model = MslNet(cam_params=test_dataset.cam_params, sphere_layers=util.GetDepthLayers( DEPTH_RANGE, N_DEPTH_LAYERS), out_res=test_dataset.view_res, gray=GRAY).to(device.GetDevice()) netio.LoadNet(net_file, model) # 3. Test on train dataset print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE) output_dir = '%s%s/%s/' % (OUTPUT_DIR, TEST_NET_NAME, TEST_DATA_DESC_FILE) util.CreateDirIfNeed(output_dir) perf = SimplePerf(True, start=True) i = 0 for view_idxs, view_images, ray_positions, ray_directions in test_data_loader: ray_positions = ray_positions.to(device.GetDevice()) ray_directions = ray_directions.to(device.GetDevice()) perf.Checkpoint("%d - Load" % i) out_view_images = model(ray_positions, ray_directions) perf.Checkpoint("%d - Infer" % i) if test_dataset.load_images: util.WriteImageTensor( view_images, ['%sgt_view_%04d.png' % (output_dir, i) for i in view_idxs]) util.WriteImageTensor( out_view_images, ['%sout_view_%04d.png' % (output_dir, i) for i in view_idxs]) perf.Checkpoint("%d - Write" % i) i += 1 if __name__ == "__main__": if TRAIN_MODE: train() else: test(RUN_DIR + TEST_NET_NAME + '.pth')