Commit 69e1d015 authored by BobYeah's avatar BobYeah
Browse files

Update1205ForHPC

parent 5069f8ae
import torch
from gen_image import *
class Conf(object):
def __init__(self):
self.pupil_size = 0.02
self.retinal_res = torch.tensor([ 320, 320 ])
self.layer_res = torch.tensor([ 320, 320 ])
self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 85 # eye's horizontal FOV (ignored in foveated rendering)
self.eye_enable_fovea = True # enable foveated rendering
self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles
self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples
self.d_layer = [ 1, 3 ] # layers' distance
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetEyeViewportSize(self):
fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov
w = Fov2Length(fov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
\ No newline at end of file
......@@ -159,3 +159,36 @@ class RetinalGen(object):
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape)
return retinal
def GenFoveaLayers(self, retinal, retinal_mask):
'''
Generate foveated layers and corresponding masks
Parameters
--------
retinal - Retinal image generated by GenRetinalFromLayers()
retinal_mask - Mask of retinal image, also generated by GenRetinalFromLayers()
Returns
--------
fovea_layers - list of foveated layers
fovea_layer_masks - list of mask images, corresponding to foveated layers
'''
fovea_layers = []
fovea_layer_masks = []
fov = self.conf.eye_fovea_angles[-1]
retinal_res = int(self.conf.retinal_res[0])
for i in range(0, len(self.conf.eye_fovea_angles)):
angle = self.conf.eye_fovea_angles[i]
k = self.conf.eye_fovea_downsamples[i]
roi_size = int(np.ceil(retinal_res * angle / fov))
roi_offset = int((retinal_res - roi_size) / 2)
roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
if k == 1:
fovea_layers.append(roi_img)
fovea_layer_masks.append(roi_mask)
else:
fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
return [ fovea_layers, fovea_layer_masks ]
\ No newline at end of file
This diff is collapsed.
import torch
def var_or_cuda(x):
if torch.cuda.is_available():
# x = x.cuda(non_blocking=True)
x = x.to('cuda:1')
return x
class residual_block(torch.nn.Module):
def __init__(self, OUT_CHANNELS_RB, delta_channel_dim,KERNEL_SIZE_RB,RNN=False):
super(residual_block,self).__init__()
self.delta_channel_dim = delta_channel_dim
self.out_channels_rb = OUT_CHANNELS_RB
self.hidden = None
self.RNN = RNN
if self.RNN:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d((OUT_CHANNELS_RB+delta_channel_dim)*2,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
else:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
def forward(self,input):
if self.RNN:
# print("input:",input.shape,"hidden:",self.hidden.shape)
inp = torch.cat((input,self.hidden),dim=1)
# print(inp.shape)
output = self.layer1(inp)
output = self.layer2(output)
output = input+output
self.hidden = output
else:
output = self.layer1(input)
output = self.layer2(output)
output = input+output
return output
def reset_hidden(self, inp):
size = list(inp.size())
size[1] = self.delta_channel_dim + self.out_channels_rb
size[2] = size[2]//2
size[3] = size[3]//2
hidden = torch.zeros(*(size))
self.hidden = var_or_cuda(hidden)
class deinterleave(torch.nn.Module):
def __init__(self, block_size):
super(deinterleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / self.block_size_sq)
s_width = int(d_width * self.block_size)
s_height = int(d_height * self.block_size)
t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
spl = t_1.split(self.block_size, 3)
stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
class interleave(torch.nn.Module):
def __init__(self, block_size):
super(interleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output
class model(torch.nn.Module):
def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE):
super(model, self).__init__()
self.interleave = interleave(INTERLEAVE_RATE)
self.first_layer = torch.nn.Sequential(
torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
torch.nn.ELU()
)
self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False)
self.residual_block2 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,False)
self.residual_block3 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.residual_block4 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.residual_block5 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid()
)
self.deinterleave = deinterleave(INTERLEAVE_RATE)
def reset_hidden(self,inp):
self.residual_block3.reset_hidden(inp)
self.residual_block4.reset_hidden(inp)
self.residual_block5.reset_hidden(inp)
def forward(self, lightfield_images, focal_length, gazeX, gazeY):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
# print("lightfield_images:",lightfield_images.shape)
input_to_net = self.interleave(lightfield_images)
# print("after interleave:",input_to_net.shape)
input_to_rb = self.first_layer(input_to_net)
# print("input_to_rb1:",input_to_rb.shape)
output = self.residual_block1(input_to_rb)
depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
# print("depth_layer:",depth_layer.shape)
# print("focal_depth:",focal_length," gazeX:",gazeX," gazeY:",gazeY, " gazeX norm:",(gazeX[0] - (-3.333)) / (3.333*2))
for i in range(focal_length.shape[0]):
depth_layer[i] *= 1. / focal_length[i]
gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
# print(depth_layer.shape)
depth_layer = var_or_cuda(depth_layer)
gazeX_layer = var_or_cuda(gazeX_layer)
gazeY_layer = var_or_cuda(gazeY_layer)
output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
# output = torch.cat((output,depth_layer),dim=1)
# print("output to rb2:",output.shape)
output = self.residual_block2(output)
# print("output to rb3:",output.shape)
output = self.residual_block3(output)
# print("output to rb4:",output.shape)
output = self.residual_block4(output)
# print("output to rb5:",output.shape)
output = self.residual_block5(output)
# output = output + input_to_net
output = self.output_layer(output)
output = self.deinterleave(output)
return output
import torch, os, sys, cv2
import torch.nn as nn
from torch.nn import init
import functools
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as func
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch
class RecurrentBlock(nn.Module):
def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False):
super(RecurrentBlock, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.downsampling = downsampling
self.upsampling = upsampling
self.bottleneck = bottleneck
self.hidden = None
if self.downsampling:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.upsampling:
self.l1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(2 * input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.bottleneck:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
def forward(self, inp):
if self.downsampling:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
elif self.upsampling:
op1 = self.l1(inp)
return op1
elif self.bottleneck:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
def reset_hidden(self, inp, dfac):
size = list(inp.size())
size[1] = self.output_nc
size[2] /= dfac
size[3] /= dfac
self.hidden_size = size
self.hidden = torch.zeros(*(size)).to('cuda:0')
class RecurrentAE(nn.Module):
def __init__(self, input_nc):
super(RecurrentAE, self).__init__()
self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True)
self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True)
self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True)
self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True)
self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True)
self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True)
self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True)
self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True)
self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True)
self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True)
self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True)
def set_input(self, inp):
self.inp = inp['A']
def forward(self):
d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2)
d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2)
d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2)
d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2)
d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2)
b = self.bottleneck(d5)
u5 = self.u5(torch.cat((b, d5), dim=1))
u4 = self.u4(torch.cat((u5, d4), dim=1))
u3 = self.u3(torch.cat((u4, d3), dim=1))
u2 = self.u2(torch.cat((u3, d2), dim=1))
u1 = self.u1(torch.cat((u2, d1), dim=1))
return u1
def reset_hidden(self):
self.d1.reset_hidden(self.inp, dfac=1)
self.d2.reset_hidden(self.inp, dfac=2)
self.d3.reset_hidden(self.inp, dfac=4)
self.d4.reset_hidden(self.inp, dfac=8)
self.d5.reset_hidden(self.inp, dfac=16)
self.bottleneck.reset_hidden(self.inp, dfac=32)
self.u4.reset_hidden(self.inp, dfac=16)
self.u3.reset_hidden(self.inp, dfac=8)
self.u5.reset_hidden(self.inp, dfac=4)
self.u2.reset_hidden(self.inp, dfac=2)
self.u1.reset_hidden(self.inp, dfac=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment