Commit 72636b3e authored by BobYeah's avatar BobYeah
Browse files

sync

parent b7fae973
## Fast Super Resolution CNN
![pic](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN/img/framework.png)
### note
this model has high possibility to diverge after 20 epochs.
\ No newline at end of file
import torch
import torch.nn as nn
class Net(torch.nn.Module):
def __init__(self, num_channels, upscale_factor, d=64, s=12, m=4):
super(Net, self).__init__()
self.first_part = nn.Sequential(
nn.Conv2d(in_channels=num_channels, out_channels=d,
kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.layers = []
self.layers += [
nn.Conv2d(in_channels=d, out_channels=s,
kernel_size=1, stride=1, padding=0),
nn.PReLU()
]
for _ in range(m):
self.layers += [
nn.Conv2d(in_channels=s, out_channels=s,
kernel_size=3, stride=1, padding=1),
nn.PReLU()
]
self.layers += [
nn.Conv2d(in_channels=s, out_channels=d,
kernel_size=1, stride=1, padding=0),
nn.PReLU()
]
self.mid_part = nn.Sequential(*self.layers)
# Deconvolution
if upscale_factor % 2:
self.last_part = nn.ConvTranspose2d(
in_channels=d, out_channels=num_channels, kernel_size=9,
stride=upscale_factor, padding=5 - (upscale_factor + 1) // 2)
else:
self.last_part = nn.ConvTranspose2d(
in_channels=d, out_channels=num_channels, kernel_size=9,
stride=upscale_factor, padding=5 - upscale_factor // 2,
output_padding=1)
def forward(self, x):
out = self.first_part(x)
out = self.mid_part(out)
out = self.last_part(out)
return out
def weight_init(self, mean=0.0, std=0.02):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0.0, 0.0001)
if m.bias is not None:
m.bias.data.zero_()
from __future__ import print_function
from math import log10
import sys
import torch
import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from ..my.progress_bar import progress_bar
class FSRCNNTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(FSRCNNTrainer, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
self.model = Net(
num_channels=1, upscale_factor=self.upscale_factor).to(self.device)
self.model.weight_init(mean=0.0, std=0.2)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader),
'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(
train_loss / len(self.training_loader)))
return iters
def test(self):
self.model.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.model(data)
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader),
'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(
avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save_model()
## Super Resolution CNN
The authors of the SRCNN describe their network, pointing out the equivalence of their method to the sparse-coding method, which is a widely used learning method for image SR. This is an important and educational aspect of their work, because it shows how example-based learning methods can be adapted and generalized to CNN models.
The SRCNN consists of the following operations:
1. **Preprocessing**: Up-scales LR image to desired HR size.
2. **Feature extraction**: Extracts a set of feature maps from the up-scaled LR image.
3. **Non-linear mapping**: Maps the feature maps representing LR to HR patches.
4. **Reconstruction**: Produces the HR image from HR patches.
Operations 2–4 above can be cast as a convolutional layer in a CNN that accepts as input the preprocessed images from step 1 above, and outputs the HR image
import torch
import torch.nn as nn
class Net(torch.nn.Module):
def __init__(self, num_channels, base_filter, upscale_factor=2):
super(Net, self).__init__()
self.layers = torch.nn.Sequential(
nn.Conv2d(in_channels=num_channels, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=num_channels * (upscale_factor ** 2), kernel_size=5, stride=1, padding=2, bias=True),
nn.PixelShuffle(upscale_factor)
)
def forward(self, x):
out = self.layers(x)
return out
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
from __future__ import print_function
from math import log10
import sys
import torch
import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from ..my.progress_bar import progress_bar
class SRCNNTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(SRCNNTrainer, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
self.model = Net(num_channels=1, base_filter=64, upscale_factor=self.upscale_factor).to(self.device)
self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
return iters
def test(self):
self.model.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.model(data)
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save_model()
# SRGAN: Super-Resolution using GANs
This is a complete Pytorch implementation of [Christian Ledig et al: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"](https://arxiv.org/abs/1609.04802),
reproducing their results.
This paper's main result is that through using an adversarial and a content loss, a convolutional neural network is able to produce sharp, almost photo-realistic upsamplings of images.
The implementation tries to be as faithful as possible to the original paper.
See [implementation details](#method-and-implementation-details) for a closer look.
## Method and Implementation Details
Architecture diagram of the super-resolution and discriminator networks by Ledig et al:
<p align='center'>
<img src='https://github.com/mseitzer/srgan/blob/master/images/architecture.png' width=580>
</p>
The implementation tries to stay as close as possible to the details given in the paper.
As such, the pretrained SRGAN is also trained with 1e6 and 1e5 update steps.
The high amount of update steps proved to be essential for performance, which pretty much monotonically increases with training time.
Some further implementation choices where the paper does not give any details:
- Initialization: orthogonal for the super-resolution network, randomly from a normal distribution with std=0.02 for the discriminator network
- Padding: reflection padding (instead of the more commonly used zero padding)
## Batch-size
batch size of 2 is recommended if GPU has only 8G RAM.
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
def swish(x):
return x * torch.sigmoid(x)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, kernel, out_channels, stride):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
y = swish(self.bn1(self.conv1(x)))
return self.bn2(self.conv2(y)) + x
class UpsampleBlock(nn.Module):
# Implements resize-convolution
def __init__(self, in_channels):
super(UpsampleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1)
self.shuffler = nn.PixelShuffle(2)
def forward(self, x):
return swish(self.shuffler(self.conv(x)))
class Generator(nn.Module):
def __init__(self, n_residual_blocks, upsample_factor, num_channel=1, base_filter=64):
super(Generator, self).__init__()
self.n_residual_blocks = n_residual_blocks
self.upsample_factor = upsample_factor
self.conv1 = nn.Conv2d(num_channel, base_filter, kernel_size=9, stride=1, padding=4)
for i in range(self.n_residual_blocks):
self.add_module('residual_block' + str(i + 1), ResidualBlock(in_channels=base_filter, out_channels=base_filter, kernel=3, stride=1))
self.conv2 = nn.Conv2d(base_filter, base_filter, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(base_filter)
for i in range(self.upsample_factor // 2):
self.add_module('upsample' + str(i + 1), UpsampleBlock(base_filter))
self.conv3 = nn.Conv2d(base_filter, num_channel, kernel_size=9, stride=1, padding=4)
def forward(self, x):
x = swish(self.conv1(x))
y = x.clone()
for i in range(self.n_residual_blocks):
y = self.__getattr__('residual_block' + str(i + 1))(y)
x = self.bn2(self.conv2(y)) + x
for i in range(self.upsample_factor // 2):
x = self.__getattr__('upsample' + str(i + 1))(x)
return self.conv3(x)
def weight_init(self, mean=0.0, std=0.02):
for m in self._modules:
normal_init(self._modules[m], mean, std)
class Discriminator(nn.Module):
def __init__(self, num_channel=1, base_filter=64):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(num_channel, base_filter, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(base_filter, base_filter, kernel_size=3, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(base_filter)
self.conv3 = nn.Conv2d(base_filter, base_filter * 2, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(base_filter * 2)
self.conv4 = nn.Conv2d(base_filter * 2, base_filter * 2, kernel_size=3, stride=2, padding=1)
self.bn4 = nn.BatchNorm2d(base_filter * 2)
self.conv5 = nn.Conv2d(base_filter * 2, base_filter * 4, kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2d(base_filter * 4)
self.conv6 = nn.Conv2d(base_filter * 4, base_filter * 4, kernel_size=3, stride=2, padding=1)
self.bn6 = nn.BatchNorm2d(base_filter * 4)
self.conv7 = nn.Conv2d(base_filter * 4, base_filter * 8, kernel_size=3, stride=1, padding=1)
self.bn7 = nn.BatchNorm2d(base_filter * 8)
self.conv8 = nn.Conv2d(base_filter * 8, base_filter * 8, kernel_size=3, stride=2, padding=1)
self.bn8 = nn.BatchNorm2d(base_filter * 8)
# Replaced original paper FC layers with FCN
self.conv9 = nn.Conv2d(base_filter * 8, num_channel, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = swish(self.conv1(x))
x = swish(self.bn2(self.conv2(x)))
x = swish(self.bn3(self.conv3(x)))
x = swish(self.bn4(self.conv4(x)))
x = swish(self.bn5(self.conv5(x)))
x = swish(self.bn6(self.conv6(x)))
x = swish(self.bn7(self.conv7(x)))
x = swish(self.bn8(self.conv8(x)))
x = self.conv9(x)
return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)
def weight_init(self, mean=0.0, std=0.02):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
from __future__ import print_function
from math import log10
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
from torchvision.models.vgg import vgg16
from .model import Generator, Discriminator
from ..my.progress_bar import progress_bar
class SRGANTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer):
super(SRGANTrainer, self).__init__()
self.GPU_IN_USE = torch.cuda.is_available()
self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
self.netG = None
self.netD = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.epoch_pretrain = 10
self.criterionG = None
self.criterionD = None
self.optimizerG = None
self.optimizerD = None
self.feature_extractor = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.num_residuals = 16
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
self.netG = Generator(n_residual_blocks=self.num_residuals, upsample_factor=self.upscale_factor, base_filter=64, num_channel=1).to(self.device)
self.netD = Discriminator(base_filter=64, num_channel=1).to(self.device)
self.feature_extractor = vgg16(pretrained=True)
self.netG.weight_init(mean=0.0, std=0.2)
self.netD.weight_init(mean=0.0, std=0.2)
self.criterionG = nn.MSELoss()
self.criterionD = nn.BCELoss()
torch.manual_seed(self.seed)
if self.GPU_IN_USE:
torch.cuda.manual_seed(self.seed)
self.feature_extractor.cuda()
cudnn.benchmark = True
self.criterionG.cuda()
self.criterionD.cuda()
self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.9, 0.999))
self.optimizerD = optim.SGD(self.netD.parameters(), lr=self.lr / 100, momentum=0.9, nesterov=True)
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizerG, milestones=[50, 75, 100], gamma=0.5) # lr decay
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizerD, milestones=[50, 75, 100], gamma=0.5) # lr decay
@staticmethod
def to_data(x):
if torch.cuda.is_available():
x = x.cpu()
return x.data
def save(self):
g_model_out_path = "SRGAN_Generator_model_path.pth"
d_model_out_path = "SRGAN_Discriminator_model_path.pth"
torch.save(self.netG, g_model_out_path)
torch.save(self.netD, d_model_out_path)
print("Checkpoint saved to {}".format(g_model_out_path))
print("Checkpoint saved to {}".format(d_model_out_path))
def pretrain(self):
self.netG.train()
for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
self.netG.zero_grad()
loss = self.criterionG(self.netG(data), target)
loss.backward()
self.optimizerG.step()
def train(self, epoch, iters):
# models setup
self.netG.train()
self.netD.train()
g_train_loss = 0
d_train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
# setup noise
real_label = torch.ones(data.size(0), data.size(1)).to(self.device)
fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device)
data, target = data.to(self.device), target.to(self.device)
# Train Discriminator
self.optimizerD.zero_grad()
d_real = self.netD(target)
d_real_loss = self.criterionD(d_real, real_label)
d_fake = self.netD(self.netG(data))
d_fake_loss = self.criterionD(d_fake, fake_label)
d_total = d_real_loss + d_fake_loss
d_train_loss += d_total.item()
d_total.backward()
self.optimizerD.step()
# Train generator
self.optimizerG.zero_grad()
g_real = self.netG(data)
g_fake = self.netD(g_real)
gan_loss = self.criterionD(g_fake, real_label)
mse_loss = self.criterionG(g_real, target)
g_total = mse_loss + 1e-3 * gan_loss
g_train_loss += g_total.item()
g_total.backward()
self.optimizerG.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'G_Loss: %.4f | D_Loss: %.4f' % (g_train_loss / (batch_num + 1), d_train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("G_Loss", g_train_loss / (batch_num + 1), iters)
self.writer.add_scalar("D_Loss", d_train_loss / (batch_num + 1), iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([g_real, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average G_Loss: {:.4f}".format(g_train_loss / len(self.training_loader)))
return iters
def test(self):
self.netG.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.netG(data)
mse = self.criterionG(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.epoch_pretrain + 1):
self.pretrain()
print("{}/{} pretrained".format(epoch, self.epoch_pretrain))
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save()
import torch.nn as nn
import torch.nn.init as init
class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.pixel_shuffle(x)
return x
from __future__ import print_function
from math import log10
import sys
import torch
import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from ..my.progress_bar import progress_bar
class SubPixelTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(SubPixelTrainer, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
self.model = Net(upscale_factor=self.upscale_factor).to(self.device)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
def save(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
return iters
def test(self):
self.model.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.model(data)
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save()
...@@ -5,13 +5,12 @@ def update_config(config): ...@@ -5,13 +5,12 @@ def update_config(config):
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10 config.N_ENCODE_DIM = 10
config.FC_PARAMS = { config.FC_PARAMS.update({
'nf': 128, 'nf': 128,
'n_layers': 8, 'n_layers': 8,
'skips': [4] 'skips': [4]
} })
config.SAMPLE_PARAMS = { config.SAMPLE_PARAMS.update({
'depth_range': (1, 50), 'depth_range': (1, 50),
'n_samples': 8, 'n_samples': 8
'perturb_sample': True })
} \ No newline at end of file
\ No newline at end of file
...@@ -5,13 +5,11 @@ def update_config(config): ...@@ -5,13 +5,11 @@ def update_config(config):
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10 config.N_ENCODE_DIM = 10
config.FC_PARAMS = { config.FC_PARAMS.update({
'nf': 64, 'nf': 64,
'n_layers': 12, 'n_layers': 12
'skips': [] })
} config.SAMPLE_PARAMS.update({
config.SAMPLE_PARAMS = {
'depth_range': (1, 20), 'depth_range': (1, 20),
'n_samples': 16, 'n_samples': 16
'perturb_sample': True })
} \ No newline at end of file
\ No newline at end of file
...@@ -5,13 +5,11 @@ def update_config(config): ...@@ -5,13 +5,11 @@ def update_config(config):
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10 config.N_ENCODE_DIM = 10
config.FC_PARAMS = { config.FC_PARAMS.update({
'nf': 64, 'nf': 64,
'n_layers': 12, 'n_layers': 12
'skips': [] })
} config.SAMPLE_PARAMS.update({
config.SAMPLE_PARAMS = {
'depth_range': (1, 20), 'depth_range': (1, 20),
'n_samples': 16, 'n_samples': 16
'perturb_sample': True })
} \ No newline at end of file
\ No newline at end of file
...@@ -5,13 +5,12 @@ def update_config(config): ...@@ -5,13 +5,12 @@ def update_config(config):
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10 config.N_ENCODE_DIM = 10
config.FC_PARAMS = { config.FC_PARAMS.update({
'nf': 256, 'nf': 256,
'n_layers': 8, 'n_layers': 8,
'skips': [4] 'skips': [4]
} })
config.SAMPLE_PARAMS = { config.SAMPLE_PARAMS.update({
'depth_range': (1, 50), 'depth_range': (1, 50),
'n_samples': 32, 'n_samples': 32
'perturb_sample': True })
} \ No newline at end of file
\ No newline at end of file
...@@ -5,13 +5,12 @@ def update_config(config): ...@@ -5,13 +5,12 @@ def update_config(config):
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10 config.N_ENCODE_DIM = 10
config.FC_PARAMS = { config.FC_PARAMS.update({
'nf': 256, 'nf': 256,
'n_layers': 8, 'n_layers': 8,
'skips': [4] 'skips': [4]
} })
config.SAMPLE_PARAMS = { config.SAMPLE_PARAMS.update({
'depth_range': (1, 20), 'depth_range': (1, 20),
'n_samples': 16, 'n_samples': 16
'perturb_sample': True })
} \ No newline at end of file
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
config.GRAY = False config.GRAY = True
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10 config.N_ENCODE_DIM = 20
config.FC_PARAMS = { config.FC_PARAMS.update({
'nf': 64, 'nf': 64,
'n_layers': 12, 'n_layers': 12,
'skips': [] })
} config.SAMPLE_PARAMS.update({
config.SAMPLE_PARAMS = {
'depth_range': (1, 20), 'depth_range': (1, 20),
'n_samples': 16, 'n_samples': 16
'perturb_sample': True })
} \ No newline at end of file
config.LOSS = 'mse_grad'
\ No newline at end of file
...@@ -5,13 +5,12 @@ def update_config(config): ...@@ -5,13 +5,12 @@ def update_config(config):
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10 config.N_ENCODE_DIM = 10
config.FC_PARAMS = { config.FC_PARAMS.update({
'nf': 64, 'nf': 64,
'n_layers': 8, 'n_layers': 8,
'skips': [4] 'skips': [4]
} })
config.SAMPLE_PARAMS = { config.SAMPLE_PARAMS.update({
'depth_range': (1, 50), 'depth_range': (1, 50),
'n_samples': 4, 'n_samples': 4
'perturb_sample': True })
} \ No newline at end of file
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 6
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment