run_upsampling.py 4.32 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
from __future__ import print_function

import argparse
import os
import sys
import torch
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
import torch.nn.functional as nn_f
BobYeah's avatar
sync    
BobYeah committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from tensorboardX.writer import SummaryWriter

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
                    help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
                    help='training batch size')
parser.add_argument('--testBatchSize', type=int,
                    default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
                    help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
                    help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
                    help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
                    help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
33
34
parser.add_argument('--color', type=str, default='rgb',
                    help='color')
BobYeah's avatar
sync    
BobYeah committed
35
36
37
38
39
40
41
42
43
44
45
46

# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
                    default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')

args = parser.parse_args()

# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
48
49
50
from utils import misc
from utils import netio
from utils import img
from utils import color
Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
#from .upsampling.SubPixelCNN.solver import SubPixelTrainer as Solver
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
54
from upsampling.SRCNN.solver import SRCNNTrainer as Solver
from upsampling.upsampling_dataset import UpsamplingDataset
from data.loader import FastDataLoader
BobYeah's avatar
sync    
BobYeah committed
55
56
57
58

os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
Nianchen Deng's avatar
sync    
Nianchen Deng committed
59
args.color = color.from_str(args.color)
BobYeah's avatar
sync    
BobYeah committed
60
61
62


def train():
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
    misc.create_dir(run_dir)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
65
    train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
                                  'gt/view_%04d.png', color=args.color)
BobYeah's avatar
sync    
BobYeah committed
66
67
68
69
70
71
    training_data_loader = FastDataLoader(dataset=train_set,
                                          batch_size=args.batchSize,
                                          shuffle=True,
                                          drop_last=False)
    trainer = Solver(args, training_data_loader, training_data_loader,
                     SummaryWriter(run_dir))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
72
    trainer.build_model(3 if args.color == color.RGB else 1)
BobYeah's avatar
sync    
BobYeah committed
73
74
75
    iters = 0
    for epoch in range(1, args.nEpochs + 1):
        print("\n===> Epoch {} starts:".format(epoch))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
        iters = trainer.train(epoch, iters,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
                              channels=slice(2, 3) if args.color == color.YCbCr
Nianchen Deng's avatar
sync    
Nianchen Deng committed
78
                              else None)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
79
    netio.save(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
BobYeah's avatar
sync    
BobYeah committed
80
81
82


def test():
Nianchen Deng's avatar
sync    
Nianchen Deng committed
83
    misc.create_dir(os.path.dirname(args.testOutPatt))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
84
85
    train_set = UpsamplingDataset(
        '.', 'input/out_view_%04d.png', None, color=args.color)
BobYeah's avatar
sync    
BobYeah committed
86
87
88
89
90
91
    training_data_loader = FastDataLoader(dataset=train_set,
                                          batch_size=args.testBatchSize,
                                          shuffle=False,
                                          drop_last=False)
    trainer = Solver(args, training_data_loader, training_data_loader,
                     SummaryWriter(run_dir))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
93
    trainer.build_model(3 if args.color == color.RGB else 1)
    netio.load(args.test, trainer.model)
BobYeah's avatar
sync    
BobYeah committed
94
    for idx, input, _ in training_data_loader:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
95
        if args.color == color.YCbCr:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
96
97
            output_y = trainer.model(input[:, -1:])
            output_cbcr = nn_f.upsample(input[:, 0:2], scale_factor=2)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
98
            output = color.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
99
100
        else:
            output = trainer.model(input)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
101
        img.save(output, args.testOutPatt % idx)
BobYeah's avatar
sync    
BobYeah committed
102
103
104
105
106
107
108
109
110
111
112


def main():
    if (args.test):
        test()
    else:
        train()


if __name__ == '__main__':
    main()