run_lf_syn.py 4.65 KB
Newer Older
1
2
import sys
sys.path.append('/e/dengnc')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
__package__ = "deep_view_syn"
4
5
6
7
8
9

import os
import torch
import torch.optim
import torchvision
from tensorboardX import SummaryWriter
Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
11
12
13
14
15
16
17
from utils.loss import PerceptionReconstructionLoss
from utils import netio
from utils import misc
from utils import device
from utils import img
from utils.perf import Perf
from data.lf_syn import LightFieldSynDataset
from nets.trans_unet import TransUnet
18
19


BobYeah's avatar
sync    
BobYeah committed
20
21
22
torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

23
24
DATA_DIR = os.path.dirname(__file__) + '/data/lf_syn_2020.12.23'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'
BobYeah's avatar
sync    
BobYeah committed
25
26
27
OUTPUT_DIR = DATA_DIR + '/output_bat2'
RUN_DIR = DATA_DIR + '/run_bat2'
BATCH_SIZE = 8
28
29
30
TEST_BATCH_SIZE = 10
NUM_EPOCH = 1000
MODE = "Silence"  # "Perf"
BobYeah's avatar
sync    
BobYeah committed
31
EPOCH_BEGIN = 600
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


def train():
    # 1. Initialize data loader
    print("Load dataset: " + TRAIN_DATA_DESC_FILE)
    train_dataset = LightFieldSynDataset(TRAIN_DATA_DESC_FILE)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        pin_memory=True,
        shuffle=True,
        drop_last=False)
    print(len(train_data_loader))

    # 2. Initialize components
    model = TransUnet(cam_params=train_dataset.cam_params,
                      view_images=train_dataset.sparse_view_images,
                      view_depths=train_dataset.sparse_view_depths,
                      view_positions=train_dataset.sparse_view_positions,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
                      diopter_of_layers=train_dataset.diopter_of_layers).to(device.default())
52
53
54
55
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss = PerceptionReconstructionLoss()

    if EPOCH_BEGIN > 0:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
56
        netio.load('%s/model-epoch_%d.pth' % (RUN_DIR, EPOCH_BEGIN), model,
57
58
59
60
61
                      solver=optimizer)

    # 3. Train
    model.train()
    epoch = EPOCH_BEGIN
BobYeah's avatar
sync    
BobYeah committed
62
    iters = EPOCH_BEGIN * len(train_data_loader) * BATCH_SIZE
63

Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
    misc.create_dir(RUN_DIR)
65

Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
    perf = Perf(enable=(MODE == "Perf"), start=True)
67
68
69
70
71
72
    writer = SummaryWriter(RUN_DIR)

    print("Begin training...")
    for epoch in range(EPOCH_BEGIN, NUM_EPOCH):
        for _, view_images, _, view_positions in train_data_loader:

Nianchen Deng's avatar
sync    
Nianchen Deng committed
73
            view_images = view_images.to(device.default())
74

Nianchen Deng's avatar
sync    
Nianchen Deng committed
75
            perf.checkpoint("Load")
76
77
78

            out_view_images = model(view_positions)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
79
            perf.checkpoint("Forward")
80
81
82
83

            optimizer.zero_grad()
            loss_value = loss(out_view_images, view_images)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
84
            perf.checkpoint("Compute loss")
85
86
87

            loss_value.backward()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
88
            perf.checkpoint("Backward")
89
90
91

            optimizer.step()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
            perf.checkpoint("Update")
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

            print("Epoch: ", epoch, ", Iter: ", iters,
                  ", Loss: ", loss_value.item())

            iters = iters + BATCH_SIZE

            # Write tensorboard logs.
            writer.add_scalar("loss", loss_value, iters)
            if iters % len(train_data_loader) == 0:
                output_vs_gt = torch.cat([out_view_images, view_images], dim=0)
                writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
                    output_vs_gt, scale_each=True, normalize=False)
                    .cpu().detach().numpy(), iters)

        # Save checkpoint
        if ((epoch + 1) % 50 == 0):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
109
            netio.save('%s/model-epoch_%d.pth' % (RUN_DIR, epoch + 1), model, iters)
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

    print("Train finished")


def test(net_file: str):
    # 1. Load train dataset
    print("Load dataset: " + TRAIN_DATA_DESC_FILE)
    train_dataset = LightFieldSynDataset(TRAIN_DATA_DESC_FILE)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=TEST_BATCH_SIZE,
        pin_memory=True,
        shuffle=False,
        drop_last=False)

    # 2. Load trained model
    model = TransUnet(cam_params=train_dataset.cam_params,
                      view_images=train_dataset.sparse_view_images,
                      view_depths=train_dataset.sparse_view_depths,
                      view_positions=train_dataset.sparse_view_positions,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
130
131
                      diopter_of_layers=train_dataset.diopter_of_layers).to(device.default())
    netio.load(net_file, model)
132
133
134

    # 3. Test on train dataset
    print("Begin test on train dataset...")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
    misc.create_dir(OUTPUT_DIR)
136
137
    for view_idxs, view_images, _, view_positions in train_data_loader:
        out_view_images = model(view_positions)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
138
139
140
141
        img.save(view_images,
                 '%s/gt_view%02d.png' % (OUTPUT_DIR, i) for i in view_idxs)
        img.save(out_view_images,
                 '%s/out_view%02d.png' % (OUTPUT_DIR, i) for i in view_idxs)
142
143
144


if __name__ == "__main__":
Nianchen Deng's avatar
sync    
Nianchen Deng committed
145
    # train()
146
    test(RUN_DIR + '/model-epoch_1000.pth')