solver.py 3.98 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
from __future__ import print_function
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2

BobYeah's avatar
sync    
BobYeah committed
3
4
5
6
7
8
from math import log10
import sys

import torch
import torch.backends.cudnn as cudnn
import torchvision
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9

BobYeah's avatar
sync    
BobYeah committed
10
from .model import Net
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
from my.progress_bar import progress_bar
BobYeah's avatar
sync    
BobYeah committed
12
13


Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
class SubPixelTrainer(object):
BobYeah's avatar
sync    
BobYeah committed
15
    def __init__(self, config, training_loader, testing_loader, writer=None):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
16
        super(SubPixelTrainer, self).__init__()
BobYeah's avatar
sync    
BobYeah committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
        self.CUDA = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.CUDA else 'cpu')
        self.model = None
        self.lr = config.lr
        self.nEpochs = config.nEpochs
        self.criterion = None
        self.optimizer = None
        self.scheduler = None
        self.seed = config.seed
        self.upscale_factor = config.upscale_factor
        self.training_loader = training_loader
        self.testing_loader = testing_loader
        self.writer = writer

Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
32
33
34
    def build_model(self, num_channels):
        if num_channels != 1:
            raise ValueError('num_channels must be 1')
        self.model = Net(upscale_factor=self.upscale_factor).to(self.device)
BobYeah's avatar
sync    
BobYeah committed
35
36
37
38
39
40
41
42
43
44
45
46
        self.criterion = torch.nn.MSELoss()
        torch.manual_seed(self.seed)

        if self.CUDA:
            torch.cuda.manual_seed(self.seed)
            cudnn.benchmark = True
            self.criterion.cuda()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[50, 75, 100], gamma=0.5)  # lr decay

Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
    def save(self):
BobYeah's avatar
sync    
BobYeah committed
48
49
50
51
        model_out_path = "model_path.pth"
        torch.save(self.model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
    def train(self, epoch, iters, channels=None):
BobYeah's avatar
sync    
BobYeah committed
53
54
55
        self.model.train()
        train_loss = 0
        for batch_num, (_, data, target) in enumerate(self.training_loader):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
56
57
58
            if channels:
                data = data[..., channels, :, :]
                target = target[..., channels, :, :]
BobYeah's avatar
sync    
BobYeah committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            out = self.model(data)
            loss = self.criterion(out, target)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
            sys.stdout.write('Epoch %d: ' % epoch)
            progress_bar(batch_num, len(self.training_loader),
                         'Loss: %.4f' % (train_loss / (batch_num + 1)))
            if self.writer:
                self.writer.add_scalar("loss", loss, iters)
                if iters % 100 == 0:
                    output_vs_gt = torch.stack([out, target], 1) \
                        .flatten(0, 1).detach()
                    self.writer.add_image(
                        "Output_vs_gt",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
                        torchvision.utils.make_grid(
                            output_vs_gt, nrow=2).cpu().numpy(),
BobYeah's avatar
sync    
BobYeah committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
                        iters)
            iters += 1

        print("    Average Loss: {:.4f}".format(
            train_loss / len(self.training_loader)))
        return iters

    def test(self):
        self.model.eval()
        avg_psnr = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.testing_loader):
                data, target = data.to(self.device), target.to(self.device)
                prediction = self.model(data)
                mse = self.criterion(prediction, target)
                psnr = 10 * log10(1 / mse.item())
                avg_psnr += psnr
                progress_bar(batch_num, len(self.testing_loader),
                             'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))

        print("    Average PSNR: {:.4f} dB".format(
            avg_psnr / len(self.testing_loader)))

    def run(self):
        self.build_model()
        for epoch in range(1, self.nEpochs + 1):
            print("\n===> Epoch {} starts:".format(epoch))
            self.train()
            self.test()
            self.scheduler.step(epoch)
            if epoch == self.nEpochs:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
110
                self.save()