run_upsampling.py 4.43 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
from __future__ import print_function

import argparse
import os
import sys
import torch
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
import torch.nn.functional as nn_f
BobYeah's avatar
sync    
BobYeah committed
8
9
10
11
from torch.utils.data import DataLoader
from tensorboardX.writer import SummaryWriter

sys.path.append(os.path.abspath(sys.path[0] + '/../'))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
__package__ = "deep_view_syn"
BobYeah's avatar
sync    
BobYeah committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
                    help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
                    help='training batch size')
parser.add_argument('--testBatchSize', type=int,
                    default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
                    help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
                    help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
                    help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
                    help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
parser.add_argument('--color', type=str, default='rgb',
                    help='color')
BobYeah's avatar
sync    
BobYeah committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
                    default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')

args = parser.parse_args()

# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from .my import util
from .my import netio
from .my import device
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
54
from .my import color_mode
#from .upsampling.SubPixelCNN.solver import SubPixelTrainer as Solver
from .upsampling.SRCNN.solver import SRCNNTrainer as Solver
BobYeah's avatar
sync    
BobYeah committed
55
56
57
58
59
60
from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader

os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
Nianchen Deng's avatar
sync    
Nianchen Deng committed
61
args.color = color_mode.from_str(args.color)
BobYeah's avatar
sync    
BobYeah committed
62
63
64
65


def train():
    util.CreateDirIfNeed(run_dir)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
67
    train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
                                  'gt/view_%04d.png', color=args.color)
BobYeah's avatar
sync    
BobYeah committed
68
69
70
71
72
73
    training_data_loader = FastDataLoader(dataset=train_set,
                                          batch_size=args.batchSize,
                                          shuffle=True,
                                          drop_last=False)
    trainer = Solver(args, training_data_loader, training_data_loader,
                     SummaryWriter(run_dir))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
74
    trainer.build_model(3 if args.color == color_mode.RGB else 1)
BobYeah's avatar
sync    
BobYeah committed
75
76
77
    iters = 0
    for epoch in range(1, args.nEpochs + 1):
        print("\n===> Epoch {} starts:".format(epoch))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
78
79
80
81
        iters = trainer.train(epoch, iters,
                              channels=slice(2, 3) if args.color == color_mode.YCbCr
                              else None)
    netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
BobYeah's avatar
sync    
BobYeah committed
82
83
84
85


def test():
    util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
86
87
    train_set = UpsamplingDataset(
        '.', 'input/out_view_%04d.png', None, color=args.color)
BobYeah's avatar
sync    
BobYeah committed
88
89
90
91
92
93
    training_data_loader = FastDataLoader(dataset=train_set,
                                          batch_size=args.testBatchSize,
                                          shuffle=False,
                                          drop_last=False)
    trainer = Solver(args, training_data_loader, training_data_loader,
                     SummaryWriter(run_dir))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
94
95
    trainer.build_model(3 if args.color == color_mode.RGB else 1)
    netio.LoadNet(args.test, trainer.model)
BobYeah's avatar
sync    
BobYeah committed
96
    for idx, input, _ in training_data_loader:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
97
98
99
100
101
102
        if args.color == color_mode.YCbCr:
            output_y = trainer.model(input[:, -1:])
            output_cbcr = nn_f.upsample(input[:, 0:2], scale_factor=2)
            output = util.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
        else:
            output = trainer.model(input)
BobYeah's avatar
sync    
BobYeah committed
103
104
105
106
107
108
109
110
111
112
113
114
        util.WriteImageTensor(output, args.testOutPatt % idx)


def main():
    if (args.test):
        test()
    else:
        train()


if __name__ == '__main__':
    main()