run_spherical_view_syn.py 14 KB
Newer Older
1
import sys
BobYeah's avatar
BobYeah committed
2
import os
3
4
5
6
7
8
import argparse
import torch
import torch.optim
from tensorboardX import SummaryWriter
from torch import nn

BobYeah's avatar
BobYeah committed
9
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
BobYeah's avatar
sync    
BobYeah committed
10
__package__ = "deeplightfield"
11
12
13
14

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=3,
                    help='Which CUDA device to use.')
BobYeah's avatar
BobYeah committed
15
16
parser.add_argument('--config', type=str,
                    help='Net config files')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
parser.add_argument('--config-id', type=str,
                    help='Net config id')
BobYeah's avatar
BobYeah committed
19
20
parser.add_argument('--dataset', type=str, required=True,
                    help='Dataset description file')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
21
22
parser.add_argument('--cont', type=str,
                    help='Continue train on model file')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
23
24
parser.add_argument('--epochs', type=int,
                    help='Max epochs for train')
BobYeah's avatar
BobYeah committed
25
26
27
28
parser.add_argument('--test', type=str,
                    help='Test net file')
parser.add_argument('--test-samples', type=int,
                    help='Samples used for test')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
30
parser.add_argument('--res', type=str,
                    help='Resolution')
BobYeah's avatar
BobYeah committed
31
32
33
34
35
36
parser.add_argument('--output-gt', action='store_true',
                    help='Output ground truth images if exist')
parser.add_argument('--output-alongside', action='store_true',
                    help='Output generated image alongside ground truth image')
parser.add_argument('--output-video', action='store_true',
                    help='Output test results as video')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
37
38
parser.add_argument('--perf', action='store_true',
                    help='Test performance')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
40
parser.add_argument('--simple-log', action='store_true', help='Simple log')

41
opt = parser.parse_args()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
43
if opt.res:
    opt.res = tuple(int(s) for s in opt.res.split('x'))
44
45
46
47
48

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

BobYeah's avatar
BobYeah committed
49
50
51
from .my import netio
from .my import util
from .my import device
BobYeah's avatar
sync    
BobYeah committed
52
from .my import loss
Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
from .my.progress_bar import progress_bar
BobYeah's avatar
BobYeah committed
54
55
from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import *
BobYeah's avatar
sync    
BobYeah committed
56
57
from .data.loader import FastDataLoader
from .configs.spherical_view_syn import SphericalViewSynConfig
BobYeah's avatar
BobYeah committed
58
59


BobYeah's avatar
sync    
BobYeah committed
60
config = SphericalViewSynConfig()
BobYeah's avatar
BobYeah committed
61

62
63
64
65
# Toggles
ROT_ONLY = False
EVAL_TIME_PERFORMANCE = False
# ========
BobYeah's avatar
BobYeah committed
66
#ROT_ONLY = True
67
68
69
#EVAL_TIME_PERFORMANCE = True

# Train
BobYeah's avatar
sync    
BobYeah committed
70
BATCH_SIZE = 4096
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71
72
EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 300)
SAVE_INTERVAL = 10
73

BobYeah's avatar
BobYeah committed
74
# Test
BobYeah's avatar
BobYeah committed
75
TEST_BATCH_SIZE = 1
BobYeah's avatar
sync    
BobYeah committed
76
TEST_MAX_RAYS = 32768
BobYeah's avatar
BobYeah committed
77

78
# Paths
BobYeah's avatar
BobYeah committed
79
data_desc_path = opt.dataset
Nianchen Deng's avatar
sync    
Nianchen Deng committed
80
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
BobYeah's avatar
BobYeah committed
81
82
83
84
85
if opt.test:
    test_net_path = opt.test
    test_net_name = os.path.splitext(os.path.basename(test_net_path))[0]
    run_dir = os.path.dirname(test_net_path) + '/'
    run_id = os.path.basename(run_dir[:-1])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
86
87
    output_dir = run_dir + 'output/%s/%s%s/' % (test_net_name, data_desc_name,
                                                '_%dx%d' % (opt.res[0], opt.res[1]) if opt.res else '')
BobYeah's avatar
sync    
BobYeah committed
88
    config.from_id(run_id)
BobYeah's avatar
BobYeah committed
89
90
91
92
93
94
95
    train_mode = False
    if opt.test_samples:
        config.SAMPLE_PARAMS['n_samples'] = opt.test_samples
        output_dir = run_dir + 'output/%s/%s_s%d/' % \
            (test_net_name, data_desc_name, opt.test_samples)
else:
    data_dir = os.path.dirname(data_desc_path) + '/'
Nianchen Deng's avatar
sync    
Nianchen Deng committed
96
97
98
99
100
101
102
103
104
105
106
107
108
    if opt.cont:
        train_net_name = os.path.splitext(os.path.basename(opt.cont))[0]
        EPOCH_RANGE = range(int(train_net_name[12:]), EPOCH_RANGE.stop)
        run_dir = os.path.dirname(opt.cont) + '/'
        run_id = os.path.basename(run_dir[:-1])
        config.from_id(run_id)
    else:
        if opt.config:
            config.load(opt.config)
        if opt.config_id:
            config.from_id(opt.config_id)
        run_id = config.to_id()
        run_dir = data_dir + run_id + '/'
BobYeah's avatar
BobYeah committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    log_dir = run_dir + 'log/'
    output_dir = None
    train_mode = True

config.print()
print("dataset: ", data_desc_path)
print("train_mode: ", train_mode)
print("run_dir: ", run_dir)
if not train_mode:
    print("output_dir", output_dir)

config.SAMPLE_PARAMS['perturb_sample'] = \
    config.SAMPLE_PARAMS['perturb_sample'] and train_mode

LOSSES = {
    'mse': lambda: nn.MSELoss(),
    'mse_grad': lambda: loss.CombinedLoss(
        [nn.MSELoss(), loss.GradLoss()], [1.0, 0.5])
}

# Initialize model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
130
model = config.create_net().to(device.GetDevice())
BobYeah's avatar
sync    
BobYeah committed
131
132
loss_mse = nn.MSELoss().to(device.GetDevice())
loss_grad = loss.GradLoss().to(device.GetDevice())
BobYeah's avatar
BobYeah committed
133
134
135
136
137


def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
    sub_iters = 0
    iters_in_epoch = len(data_loader)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
138
139
140
    loss_min = 1e5
    loss_max = 0
    loss_avg = 0
BobYeah's avatar
sync    
BobYeah committed
141
    perf1 = SimplePerf(opt.simple_log, True)
BobYeah's avatar
BobYeah committed
142
    for _, gt, rays_o, rays_d in data_loader:
BobYeah's avatar
sync    
BobYeah committed
143
        patch = (len(gt.size()) == 4)
BobYeah's avatar
BobYeah committed
144
145
146
147
148
149
150
151
152
        gt = gt.to(device.GetDevice())
        rays_o = rays_o.to(device.GetDevice())
        rays_d = rays_d.to(device.GetDevice())
        perf.Checkpoint("Load")

        out = model(rays_o, rays_d)
        perf.Checkpoint("Forward")

        optimizer.zero_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
153
154
155
156
157
        if config.COLOR == color_mode.YCbCr:
            loss_mse_value = 0.3 * loss_mse(out[..., 0:2], gt[..., 0:2]) + \
                0.7 * loss_mse(out[..., 2], gt[..., 2])
        else:
            loss_mse_value = loss_mse(out, gt)
BobYeah's avatar
sync    
BobYeah committed
158
159
160
        loss_grad_value = loss_grad(out, gt) if patch else None
        loss_value = loss_mse_value  # + 0.5 * loss_grad_value if patch \
        # else loss_mse_value
BobYeah's avatar
BobYeah committed
161
162
163
164
165
166
167
168
        perf.Checkpoint("Compute loss")

        loss_value.backward()
        perf.Checkpoint("Backward")

        optimizer.step()
        perf.Checkpoint("Update")

Nianchen Deng's avatar
sync    
Nianchen Deng committed
169
170
171
172
173
174
175
176
        loss_value = loss_value.item()
        loss_min = min(loss_min, loss_value)
        loss_max = max(loss_max, loss_value)
        loss_avg = (loss_avg * sub_iters + loss_value) / (sub_iters + 1)
        if not opt.simple_log:
            progress_bar(sub_iters, iters_in_epoch,
                        "Loss: %.2e (%.2e/%.2e/%.2e)" % (loss_value, loss_min, loss_avg, loss_max),
                        "Epoch {:<3d}".format(epoch))
BobYeah's avatar
BobYeah committed
177
178

        # Write tensorboard logs.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
179
180
181
182
183
        writer.add_scalar("loss mse", loss_value, iters)
        # if patch and iters % 100 == 0:
        #    output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
        #    writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
        #        output_vs_gt, nrow=4).cpu().numpy(), iters)
BobYeah's avatar
BobYeah committed
184
185
186

        iters += 1
        sub_iters += 1
Nianchen Deng's avatar
sync    
Nianchen Deng committed
187
    if opt.simple_log:
BobYeah's avatar
sync    
BobYeah committed
188
        perf1.Checkpoint('Epoch %d (%.2e/%.2e/%.2e)' % (epoch, loss_min, loss_avg, loss_max), True)
BobYeah's avatar
BobYeah committed
189
    return iters
190
191
192
193


def train():
    # 1. Initialize data loader
BobYeah's avatar
BobYeah committed
194
    print("Load dataset: " + data_desc_path)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
195
196
    train_dataset = SphericalViewSynDataset(
        data_desc_path, color=config.COLOR, res=opt.res)
BobYeah's avatar
sync    
BobYeah committed
197
    train_dataset.set_patch_size(1)
BobYeah's avatar
BobYeah committed
198
    train_data_loader = FastDataLoader(
199
200
201
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
BobYeah's avatar
BobYeah committed
202
203
        drop_last=False,
        pin_memory=True)
204
205

    # 2. Initialize components
Nianchen Deng's avatar
sync    
Nianchen Deng committed
206
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=config.OPT_DECAY)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
207
    loss = 0  # LOSSES[config.LOSS]().to(device.GetDevice())
208
209

    if EPOCH_RANGE.start > 0:
BobYeah's avatar
BobYeah committed
210
211
212
        iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
                              model, solver=optimizer)
    else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
213
214
215
216
217
        if config.NORMALIZE:
            for _, _, rays_o, rays_d in train_data_loader:
                model.update_normalize_range(rays_o, rays_d)
            print('Depth/diopter range: ', model.depth_range)
            print('Angle range: ', model.angle_range / 3.14159 * 180)
BobYeah's avatar
BobYeah committed
218
219
        iters = 0
    epoch = None
220
221
222
223

    # 3. Train
    model.train()

BobYeah's avatar
BobYeah committed
224
225
    util.CreateDirIfNeed(run_dir)
    util.CreateDirIfNeed(log_dir)
226
227

    perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
BobYeah's avatar
BobYeah committed
228
    writer = SummaryWriter(log_dir)
229
230
231

    print("Begin training...")
    for epoch in EPOCH_RANGE:
BobYeah's avatar
BobYeah committed
232
233
        iters = train_loop(train_data_loader, optimizer, loss,
                           perf, writer, epoch, iters)
234
235
        # Save checkpoint
        if ((epoch + 1) % SAVE_INTERVAL == 0):
BobYeah's avatar
BobYeah committed
236
237
            netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
                          solver=optimizer, iters=iters)
238
    print("Train finished")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
                  solver=optimizer, iters=iters)


def perf():
    with torch.no_grad():
        # 1. Load dataset
        print("Load dataset: " + data_desc_path)
        test_dataset = SphericalViewSynDataset(data_desc_path,
                                               load_images=True,
                                               color=config.COLOR, res=opt.res)
        test_data_loader = FastDataLoader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=False,
            drop_last=False,
            pin_memory=True)

        # 2. Load trained model
        netio.LoadNet(test_net_path, model)

        # 3. Test on dataset
        print("Begin perf, batch size is %d" % TEST_BATCH_SIZE)

        perf = SimplePerf(True, start=True)
        loss = nn.MSELoss()
        i = 0
        n = test_dataset.n_views
        chns = 1 if config.COLOR == color_mode.GRAY else 3
        out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
                                      test_dataset.view_res[1],
                                      device=device.GetDevice())
        perf_times = torch.empty(n)
        perf_errors = torch.empty(n)
        for view_idxs, gt, rays_o, rays_d in test_data_loader:
            perf.Checkpoint("%d - Load" % i)
            rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
            rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
            n_rays = rays_o.size(0)
            chunk_size = min(n_rays, TEST_MAX_RAYS)
            out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
            for offset in range(0, n_rays, chunk_size):
                idx = slice(offset, offset + chunk_size)
                out_pixels[idx] = model(rays_o[idx], rays_d[idx])
            if config.COLOR == color_mode.YCbCr:
                out_pixels = util.ycbcr2rgb(out_pixels)
            out_view_images[view_idxs] = out_pixels.view(
                TEST_BATCH_SIZE, test_dataset.view_res[0],
                test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
            perf_times[view_idxs] = perf.Checkpoint("%d - Infer" % i)
            if config.COLOR == color_mode.YCbCr:
                gt = util.ycbcr2rgb(gt)
            error = loss(out_view_images[view_idxs], gt).item()
            print("%d - Error: %f" % (i, error))
            perf_errors[view_idxs] = error
            i += 1

        # 4. Save results
        perf_mean_time = torch.mean(perf_times).item()
        perf_mean_error = torch.mean(perf_errors).item()
        with open(run_dir + 'perf_%s_%s_%.1fms_%.2e.txt' % (test_net_name, data_desc_name, perf_mean_time, perf_mean_error), 'w') as fp:
            fp.write('View, Time, Error\n')
            fp.writelines(['%d, %f, %f\n' % (
                i, perf_times[i].item(), perf_errors[i].item()) for i in range(n)])
303
304


BobYeah's avatar
BobYeah committed
305
def test():
BobYeah's avatar
sync    
BobYeah committed
306
    with torch.no_grad():
Nianchen Deng's avatar
sync    
Nianchen Deng committed
307
        # 1. Load dataset
BobYeah's avatar
sync    
BobYeah committed
308
309
310
        print("Load dataset: " + data_desc_path)
        test_dataset = SphericalViewSynDataset(data_desc_path,
                                               load_images=opt.output_gt or opt.output_alongside,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
311
312
                                               color=config.COLOR,
                                               res=opt.res)
BobYeah's avatar
sync    
BobYeah committed
313
314
315
316
317
318
319
320
321
322
        test_data_loader = FastDataLoader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=False,
            drop_last=False,
            pin_memory=True)

        # 2. Load trained model
        netio.LoadNet(test_net_path, model)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
323
324
        # 3. Test on dataset
        print("Begin test, batch size is %d" % TEST_BATCH_SIZE)
BobYeah's avatar
sync    
BobYeah committed
325
326
327
328
329
        util.CreateDirIfNeed(output_dir)

        perf = SimplePerf(True, start=True)
        i = 0
        n = test_dataset.n_views
Nianchen Deng's avatar
sync    
Nianchen Deng committed
330
        chns = 1 if config.COLOR == color_mode.GRAY else 3
BobYeah's avatar
sync    
BobYeah committed
331
332
333
334
335
336
337
338
339
340
341
342
343
        out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
                                      test_dataset.view_res[1],
                                      device=device.GetDevice())
        for view_idxs, _, rays_o, rays_d in test_data_loader:
            perf.Checkpoint("%d - Load" % i)
            rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
            rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
            n_rays = rays_o.size(0)
            chunk_size = min(n_rays, TEST_MAX_RAYS)
            out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
            for offset in range(0, n_rays, chunk_size):
                idx = slice(offset, offset + chunk_size)
                out_pixels[idx] = model(rays_o[idx], rays_d[idx])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
344
345
            if config.COLOR == color_mode.YCbCr:
                out_pixels = util.ycbcr2rgb(out_pixels)
BobYeah's avatar
sync    
BobYeah committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
            out_view_images[view_idxs] = out_pixels.view(
                TEST_BATCH_SIZE, test_dataset.view_res[0],
                test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
            perf.Checkpoint("%d - Infer" % i)
            i += 1

        # 4. Save results
        if opt.output_video:
            util.generate_video(out_view_images, output_dir +
                                'out.mp4', 24, 3, True)
        else:
            gt_paths = [
                '%sgt_view_%04d.png' % (output_dir, i) for i in range(n)
            ]
            out_paths = [
                '%sout_view_%04d.png' % (output_dir, i) for i in range(n)
            ]
            if test_dataset.load_images:
                if opt.output_alongside:
                    util.WriteImageTensor(
                        torch.cat([
                            test_dataset.view_images,
                            out_view_images
                        ], 3), out_paths)
                else:
                    util.WriteImageTensor(out_view_images, out_paths)
                    util.WriteImageTensor(test_dataset.view_images, gt_paths)
BobYeah's avatar
BobYeah committed
373
374
375
            else:
                util.WriteImageTensor(out_view_images, out_paths)

376
377

if __name__ == "__main__":
BobYeah's avatar
BobYeah committed
378
    if train_mode:
379
        train()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
380
381
    elif opt.perf:
        perf()
382
    else:
BobYeah's avatar
BobYeah committed
383
        test()