import torch.nn as nn import torch.nn.init as init class Net(nn.Module): def __init__(self, upscale_factor): super(Net, self).__init__() self.relu = nn.ReLU() self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, stride=1, padding=1) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self._initialize_weights() def _initialize_weights(self): init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv4.weight) def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.relu(x) x = self.conv3(x) x = self.relu(x) x = self.conv4(x) x = self.pixel_shuffle(x) return x