run_upsampling.py 3.82 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import print_function

import argparse
import os
import sys
import torch
from torch.utils.data import DataLoader
from tensorboardX.writer import SummaryWriter

sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"

# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
                    help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
                    help='training batch size')
parser.add_argument('--testBatchSize', type=int,
                    default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
                    help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
                    help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
                    help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
                    help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')

# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
                    default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')

args = parser.parse_args()

# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from .my import util
from .my import netio
from .my import device
from .SRGAN.solver import SRGANTrainer as Solver
from .data.upsampling import UpsamplingDataset
from .data.loader import FastDataLoader

os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'


def train():
    util.CreateDirIfNeed(run_dir)
    train_set = UpsamplingDataset('.', 'out_view_%04d.png',
                                  'gt_view_%04d.png', gray=True)
    training_data_loader = FastDataLoader(dataset=train_set,
                                          batch_size=args.batchSize,
                                          shuffle=True,
                                          drop_last=False)
    trainer = Solver(args, training_data_loader, training_data_loader,
                     SummaryWriter(run_dir))
    trainer.build_model()
    # ===
    for epoch in range(1, 20 + 1):
        trainer.pretrain()
        print("{}/{} pretrained".format(epoch, trainer.epoch_pretrain))
    # ===
    iters = 0
    for epoch in range(1, args.nEpochs + 1):
        print("\n===> Epoch {} starts:".format(epoch))
        iters = trainer.train(epoch, iters)
    netio.SaveNet(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.netG)


def test():
    util.CreateDirIfNeed(os.path.dirname(args.testOutPatt))
    train_set = UpsamplingDataset('.', 'out_view_%04d.png', None, gray=True)
    training_data_loader = FastDataLoader(dataset=train_set,
                                          batch_size=args.testBatchSize,
                                          shuffle=False,
                                          drop_last=False)
    trainer = Solver(args, training_data_loader, training_data_loader,
                     SummaryWriter(run_dir))
    trainer.build_model()
    netio.LoadNet(args.test, trainer.netG)
    for idx, input, _ in training_data_loader:
        output = trainer.netG(input)
        util.WriteImageTensor(output, args.testOutPatt % idx)


def main():
    if (args.test):
        test()
    else:
        train()


if __name__ == '__main__':
    main()