From 67c4de9e492225912f4c1120181c76abfb1fcf82 Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Fri, 25 Dec 2020 14:41:17 +0800 Subject: [PATCH] sync --- baseline.py | 159 ------------- gen_image.py | 331 --------------------------- main.py | 592 ------------------------------------------------- main_lf_syn.py | 147 ------------ my/device.py | 7 + run_lf_syn.py | 24 +- trans_unet.py | 17 +- weight_init.py | 12 - 8 files changed, 28 insertions(+), 1261 deletions(-) delete mode 100644 baseline.py delete mode 100644 gen_image.py delete mode 100644 main.py delete mode 100644 main_lf_syn.py create mode 100644 my/device.py delete mode 100644 weight_init.py diff --git a/baseline.py b/baseline.py deleted file mode 100644 index 0ba91ff..0000000 --- a/baseline.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -def var_or_cuda(x): - if torch.cuda.is_available(): - # x = x.cuda(non_blocking=True) - x = x.to('cuda:1') - return x - -class residual_block(torch.nn.Module): - def __init__(self, OUT_CHANNELS_RB, delta_channel_dim,KERNEL_SIZE_RB,RNN=False): - super(residual_block,self).__init__() - self.delta_channel_dim = delta_channel_dim - self.out_channels_rb = OUT_CHANNELS_RB - self.hidden = None - self.RNN = RNN - if self.RNN: - self.layer1 = torch.nn.Sequential( - torch.nn.Conv2d((OUT_CHANNELS_RB+delta_channel_dim)*2,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), - torch.nn.ELU() - ) - self.layer2 = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), - torch.nn.ELU() - ) - else: - self.layer1 = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), - torch.nn.ELU() - ) - self.layer2 = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), - torch.nn.ELU() - ) - - def forward(self,input): - if self.RNN: - # print("input:",input.shape,"hidden:",self.hidden.shape) - inp = torch.cat((input,self.hidden.detach()),dim=1) - # print(inp.shape) - output = self.layer1(inp) - output = self.layer2(output) - output = input+output - self.hidden = output - else: - output = self.layer1(input) - output = self.layer2(output) - output = input+output - return output - - def reset_hidden(self, inp): - size = list(inp.size()) - size[1] = self.delta_channel_dim + self.out_channels_rb - size[2] = size[2]//2 - size[3] = size[3]//2 - hidden = torch.zeros(*(size)) - self.hidden = var_or_cuda(hidden) - -class deinterleave(torch.nn.Module): - def __init__(self, block_size): - super(deinterleave, self).__init__() - self.block_size = block_size - self.block_size_sq = block_size*block_size - - def forward(self, input): - output = input.permute(0, 2, 3, 1) - (batch_size, d_height, d_width, d_depth) = output.size() - s_depth = int(d_depth / self.block_size_sq) - s_width = int(d_width * self.block_size) - s_height = int(d_height * self.block_size) - t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth) - spl = t_1.split(self.block_size, 3) - stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl] - output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth) - output = output.permute(0, 3, 1, 2) - return output - -class interleave(torch.nn.Module): - def __init__(self, block_size): - super(interleave, self).__init__() - self.block_size = block_size - self.block_size_sq = block_size*block_size - - def forward(self, input): - output = input.permute(0, 2, 3, 1) - (batch_size, s_height, s_width, s_depth) = output.size() - d_depth = s_depth * self.block_size_sq - d_width = int(s_width / self.block_size) - d_height = int(s_height / self.block_size) - t_1 = output.split(self.block_size, 2) - stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1] - output = torch.stack(stack, 1) - output = output.permute(0, 2, 1, 3) - output = output.permute(0, 3, 1, 2) - return output - -class model(torch.nn.Module): - def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False): - super(model, self).__init__() - self.interleave = interleave(INTERLEAVE_RATE) - - self.first_layer = torch.nn.Sequential( - torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB), - torch.nn.ELU() - ) - - self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False) - self.residual_block2 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False) - self.residual_block3 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False) - # if RNN: - # self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True) - # self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True) - # self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True) - # else: - # self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False) - # self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False) - # self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False) - - self.output_layer = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+2,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), - torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), - torch.nn.Sigmoid() - ) - self.deinterleave = deinterleave(INTERLEAVE_RATE) - - def reset_hidden(self,inp): - self.residual_block3.reset_hidden(inp) - self.residual_block4.reset_hidden(inp) - self.residual_block5.reset_hidden(inp) - - def forward(self, lightfield_images, pos_row, pos_col): - # lightfield_images: torch.Size([batch_size, channels * D, H, W]) - # channels : RGB*D: 3*9, H:256, W:256 - # print("lightfield_images:",lightfield_images.shape) - input_to_net = self.interleave(lightfield_images) - # print("after interleave:",input_to_net.shape) - input_to_rb = self.first_layer(input_to_net) - - # print("input_to_rb1:",input_to_rb.shape) - output = self.residual_block1(input_to_rb) - - pos_row_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - pos_col_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - for i in range(pos_row.shape[0]): - pos_row_layer[i] *= pos_row[i] - pos_col_layer[i] *= pos_col[i] - # print(depth_layer.shape) - pos_row_layer = var_or_cuda(pos_row_layer) - pos_col_layer = var_or_cuda(pos_col_layer) - - output = torch.cat((output,pos_row_layer,pos_col_layer),dim=1) - output = self.residual_block2(output) - output = self.residual_block3(output) - output = self.output_layer(output) - output = self.deinterleave(output) - return output \ No newline at end of file diff --git a/gen_image.py b/gen_image.py deleted file mode 100644 index d592396..0000000 --- a/gen_image.py +++ /dev/null @@ -1,331 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -import torch -import glm -import time -from .my import util -from .my import sample_in_pupil - -class RetinalGen(object): - ''' - Class for retinal generation process - - Properties - -------- - conf - multi-layers' parameters configuration - u - M x 3 tensor, M sample positions in pupil - p_r - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution - Phi - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers - mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping - - Methods - -------- - ''' - def __init__(self, conf): - ''' - Initialize retinal generator instance - - Parameters - -------- - conf - multi-layers' parameters configuration - u - a M x 3 tensor stores M sample positions in pupil - ''' - self.conf = conf - self.u = sample_in_pupil.CircleGen(conf.pupil_size, 5) - # self.u = u.to(cuda_dev) - # self.u = u # M x 3 M sample positions - self.D_r = conf.retinal_res # retinal res 480 x 640 - self.N = conf.GetNLayers() # 2 - self.M = self.u.size()[0] # samples - # p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])), - # torch.tensor(range(0, self.D_r[1]))) - # self.p_r = torch.cat([ - # ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野 - # torch.ones(self.D_r[0], self.D_r[1], 1) - # ], 2) - - self.p_r = torch.cat([ - ((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), - torch.ones(self.D_r[0], self.D_r[1], 1) - ], 2) - - # self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long) - # self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2 - - def CalculateRetinal2LayerMappings(self, position, gaze_dir, df): - ''' - Calculate the mapping matrix from retinal to layers. - - Parameters - -------- - position - 1 x 3 tensor, eye's position - gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized) - df - focus distance - - Returns - -------- - phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers - phi_invalid - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping - retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image - ''' - D = self.conf.layer_res - c = torch.tensor([ D[1] / 2, D[0] / 2 ]) # c: Center of layers (pixel) - - D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640 - V = self.conf.GetEyeViewportSize() # V: Viewport size of eye - p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane - - # Calculate transformation from eye to display - gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1) - gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0))) - eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye)) - eye_center = torch.tensor([ position[0], -position[1], position[2] ]) - - u_rot = torch.mm(self.u, eye_rot) - v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand( - -1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector - u_rot.add_(eye_center) # translate by eye's center - v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot - - phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) - - for i in range(0, self.N): - dp_i = self.conf.GetPixelSizeOfLayer(i) # dp_i: Pixel size of layer i - d_i = self.conf.d_layer[i] # d_i: Distance of layer i - k = (d_i - u_rot[:, 2]).unsqueeze(1) - pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i - phi[i, :, :, :, :] = torch.floor(pi_r + c) - - # Calculate invalid mask (out-of-range elements in phi) and reduced to retinal - phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \ - (phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0]) - phi_invalid = phi_invalid.unsqueeze(4) - # print("phi_invalid:",phi_invalid.shape) - retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0) - # print("retinal_invalid:",retinal_invalid.shape) - # Fix invalid elements in phi - phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0 - - return [ phi, phi_invalid, retinal_invalid ] - - - def GenRetinalFromLayers(self, layers, Phi): - ''' - Generate retinal image from layers, using precalculated mapping matrix - - Parameters - -------- - layers - 3N x H x W, stacked layer images, with 3 channels in each layer - phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers - - Returns - -------- - 3 x H_r x W_r, 3 channels retinal image - ''' - # FOR GRAYSCALE 1 FOR RGB 3 - mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41 - # print("mapped_layers:",mapped_layers.shape) - for i in range(0, Phi.size()[0]): - # torch.Size([3, 2, 320, 320, 2]) - # print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape) - mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3), - Phi[i, :, :, :, 1], - Phi[i, :, :, :, 0]] - # print("mapped_layers:",mapped_layers.shape) - retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3]) - # print("retinal:",retinal.shape) - return retinal - - def GenRetinalFromLayersBatch(self, layers, Phi): - ''' - Generate retinal image from layers, using precalculated mapping matrix - - Parameters - -------- - layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer - - Returns - -------- - 3 x H_r x W_r tensor, 3 channels retinal image - H_r x W_r tensor, retinal image mask, indicates pixels valid or not - - ''' - mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample - - # truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) - # layers_truth = layers.clone() - # Phi_truth = Phi.clone() - layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320]) - - # Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1) - # print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41]) - # print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320]) - # print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2]) - - # #没优化 - - # for i in range(0, Phi_truth.size()[0]): - # for j in range(0, Phi_truth.size()[1]): - # truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3), - # Phi_truth[i, j, :, :, :, 0], - # Phi_truth[i, j, :, :, :, 1]] - - #优化2 - # start = time.time() - mapped_layers_op1 = mapped_layers.reshape(-1, - mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) - # BatchSizexLayer Channel 3 320 320 41 - layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320 - Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2 - x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41 - y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41 - # print("reshape:",time.time() - start) - - # start = time.time() - mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 切换 - #2x2 320 320 41 3 - # print("mapping one step:",time.time() - start) - - # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41]) - # start = time.time() - mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3) - mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1], - mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) - # print("reshape end:",time.time() - start) - - # print("test:") - # print((truth.cpu() == mapped_layers.cpu()).all()) - #优化1 - # start = time.time() - # mapped_layers_op1 = mapped_layers.reshape(-1, - # mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) - # layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) - # Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) - # print("reshape:",time.time() - start) - - - # for i in range(0, Phi_op1.size()[0]): - # start = time.time() - # mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:, - # Phi_op1[i, :, :, :, 0], - # Phi_op1[i, :, :, :, 1]] - # print("mapping one step:",time.time() - start) - # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41]) - # start = time.time() - # mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1], - # mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) - # print("reshape end:",time.time() - start) - - # print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41]) - retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4]) - # print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320]) - return retinal - - ## TO BE CHECK - def GenFoveaLayers(self, b_retinal, is_mask): - ''' - Generate foveated layers for retinal images or masks - - Parameters - -------- - b_retinal - B x C x H_r x W_r, Batch of retinal images/masks - is_mask - Whether b_retinal is masks or images - - Returns - -------- - b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers - ''' - b_fovea_layers = [] - for i in range(0, len(self.conf.eye_fovea_angles)): - k = self.conf.eye_fovea_downsamples[i] - region = self.conf.GetRegionOfFoveaLayer(i) - b_roi = b_retinal[:, :, region, region] - if k == 1: - b_fovea_layers.append(b_roi) - elif is_mask: - b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool)) - else: - b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k)) - return b_fovea_layers - # fovea_layers = [] - # fovea_layer_masks = [] - # fov = self.conf.eye_fovea_angles[-1] - # retinal_res = int(self.conf.retinal_res[0]) - # for i in range(0, len(self.conf.eye_fovea_angles)): - # angle = self.conf.eye_fovea_angles[i] - # k = self.conf.eye_fovea_downsamples[i] - # roi_size = int(np.ceil(retinal_res * angle / fov)) - # roi_offset = int((retinal_res - roi_size) / 2) - # roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] - # roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] - # if k == 1: - # fovea_layers.append(roi_img) - # fovea_layer_masks.append(roi_mask) - # else: - # fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0)) - # fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0)) - # return [ fovea_layers, fovea_layer_masks ] - - ## TO BE CHECK - def GenFoveaLayersBatch(self, retinal, retinal_mask): - ''' - Generate foveated layers and corresponding masks - - Parameters - -------- - retinal - Retinal image generated by GenRetinalFromLayers() - retinal_mask - Mask of retinal image, also generated by GenRetinalFromLayers() - - Returns - -------- - fovea_layers - list of foveated layers - fovea_layer_masks - list of mask images, corresponding to foveated layers - ''' - fovea_layers = [] - fovea_layer_masks = [] - fov = self.conf.eye_fovea_angles[-1] - # print("fov:",fov) - retinal_res = int(self.conf.retinal_res[0]) - # print("retinal_res:",retinal_res) - # print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles)) - for i in range(0, len(self.conf.eye_fovea_angles)): - angle = self.conf.eye_fovea_angles[i] - k = self.conf.eye_fovea_downsamples[i] - roi_size = int(np.ceil(retinal_res * angle / fov)) - roi_offset = int((retinal_res - roi_size) / 2) - # [2, 3, 320, 320] - roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] - # print("roi_img:",roi_img.shape) - # [2, 320, 320] - roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] - # print("roi_mask:",roi_mask.shape) - if k == 1: - fovea_layers.append(roi_img) - fovea_layer_masks.append(roi_mask) - else: - fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k)) - fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k)) - return [ fovea_layers, fovea_layer_masks ] - - ## TO BE CHECK - def GenFoveaRetinal(self, b_fovea_layers): - ''' - Generate foveated retinal image by blending fovea layers - **Note: current implementation only support two fovea layers** - - Parameters - -------- - b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers - - Returns - -------- - B x 3 x H_r x W_r, batch of foveated retinal images - ''' - b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1], - scale_factor=self.conf.eye_fovea_downsamples[1], - mode='bilinear', align_corners=False) - region = self.conf.GetRegionOfFoveaLayer(0) - blend = self.conf.eye_fovea_blend[0] - b_roi = b_fovea_retinal[:, :, region, region] - b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend) - return b_fovea_retinal diff --git a/main.py b/main.py deleted file mode 100644 index 46d3800..0000000 --- a/main.py +++ /dev/null @@ -1,592 +0,0 @@ -import torch -import argparse -import os -import glob -import numpy as np -import torchvision.transforms as transforms -from torchvision.utils import save_image - -from torchvision import datasets -from torch.utils.data import DataLoader -from torch.autograd import Variable - -import cv2 -from .gen_image import * -from .loss import * -import json -from .conf import Conf - -from .baseline import * -from .data import * - -import torch.autograd.profiler as profiler -# param -BATCH_SIZE = 1 -NUM_EPOCH = 1001 -INTERLEAVE_RATE = 2 -IM_H = 320 -IM_W = 320 -Retinal_IM_H = 320 -Retinal_IM_W = 320 -N = 25 # number of input light field stack -M = 2 # number of display layers -DATA_FILE = "/home/yejiannan/Project/LightField/data/FlowRPG1211" -DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq_flow_RPG.json" -# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" -OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq_flow_RPG_seq5_same_loss" -OUT_CHANNELS_RB = 128 -KERNEL_SIZE_RB = 3 -KERNEL_SIZE = 3 -LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 -FIRSST_LAYER_CHANNELS = 75 * INTERLEAVE_RATE**2 - -def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): - # layers: batchsize, 2*color, height, width - # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) - # df : batchsize,.. - - # retinal bs x color x height x width - retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W) - mask = [] # mask shape 480 x 640 - for i in range(0, layers.size()[0]): - phi = phi_dict[int(sample_idx[i].data)] - # print("phi_i:",phi.shape) - phi = var_or_cuda(phi) - phi.requires_grad = False - # print("layers[i]:",layers[i].shape) - # print("retinal[i]:",retinal[i].shape) - retinal[i] = gen.GenRetinalFromLayers(layers[i],phi) - mask.append(mask_dict[int(sample_idx[i].data)]) - retinal = var_or_cuda(retinal) - mask = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width - return retinal, mask - -def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): - # layers: batchsize, 2*color, height, width - # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) - # df : batchsize,.. - - # retinal bs x color x height x width - retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160) - mask_fovea = torch.empty(layers.shape[0], 2, 160, 160) - - for i in range(0, layers.size()[0]): - phi = phi_dict[int(sample_idx[i].data)] - # print("phi_i:",phi.shape) - phi = var_or_cuda(phi) - phi.requires_grad = False - mask_i = var_or_cuda(mask_dict[int(sample_idx[i].data)]) - mask_i.requires_grad = False - # print("layers[i]:",layers[i].shape) - # print("retinal[i]:",retinal[i].shape) - retinal_i = gen.GenRetinalFromLayers(layers[i],phi) - fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i) - retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0) - mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0) - - - retinal_fovea = var_or_cuda(retinal_fovea) - mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width - # mask = torch.stack(mask,dim = 0).unsqueeze(1) - return retinal_fovea, mask_fovea - -import time -def GenRetinalGazeFromLayersBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid): - # layers: batchsize, 2*color, height, width - # Phi:torch.Size([batchsize, Layer, h, w, 41, 2]) - # df : batchsize,.. - # start1 = time.time() - # retinal bs x color x height x width - retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2") - mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2") - # start = time.time() - retinal = gen.GenRetinalFromLayersBatch(layers,phi_batch) - # print("retinal:",retinal.shape) #retinal: torch.Size([2, 3, 320, 320]) - # print("t2:",time.time() - start) - - # start = time.time() - fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch) - - mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1) - retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1) - # print("t3:",time.time() - start) - - retinal_fovea = var_or_cuda(retinal_fovea) - mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width - # mask = torch.stack(mask,dim = 0).unsqueeze(1) - return retinal_fovea, mask_fovea - -def MergeBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid): - # layers: batchsize, 2*color, height, width - # Phi:torch.Size([batchsize, Layer, h, w, 41, 2]) - # df : batchsize,.. - # start1 = time.time() - # retinal bs x color x height x width - # retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2") - # mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2") - # start = time.time() - retinal = gen.GenRetinalFromLayersBatch(layers,phi) #retinal: torch.Size([BatchSize , 3, 320, 320]) - retinal.mul_(~retinal_invalid.to("cuda:2")) - # print("retinal:",retinal.shape) - # print("t2:",time.time() - start) - - # start = time.time() - # fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch) - - # mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1) - # retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1) - # print("t3:",time.time() - start) - - # retinal_fovea = var_or_cuda(retinal_fovea) - # mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width - # mask = torch.stack(mask,dim = 0).unsqueeze(1) - return retinal - -def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask): - # layers: batchsize, 2*color, height, width - # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) - # df : batchsize,.. - - # retinal bs x color x height x width - # retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W) - # retinal = var_or_cuda(retinal) - phi = var_or_cuda(phi) - phi.requires_grad = False - retinal = gen.GenRetinalFromLayers(layers[0],phi) - retinal = var_or_cuda(retinal) - mask_out = mask.unsqueeze(0).unsqueeze(0) - # print("maskOUt:",mask_out.shape) # 1,1,240,320 - # mask_out = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width - return retinal.unsqueeze(0), mask_out -#### Image Gen End - -from weight_init import weight_init_normal - -def save_checkpoints(file_path, epoch_idx, model, model_solver): - print('[INFO] Saving checkpoint to %s ...' % ( file_path)) - checkpoint = { - 'epoch_idx': epoch_idx, - 'model_state_dict': model.state_dict(), - 'model_solver_state_dict': model_solver.state_dict() - } - torch.save(checkpoint, file_path) - -# import pickle -# def save_obj(obj, name ): -# # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f: -# # pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) -# torch.save(obj,'./outputF/dict/'+ name + '.pkl') -# def load_obj(name): -# # with open('./outputF/dict/' + name + '.pkl', 'rb') as f: -# # return pickle.load(f) -# return torch.load('./outputF/dict/'+ name + '.pkl') - -# def generatePhiMaskDict(data_json, generator): -# phi_dict = {} -# mask_dict = {} -# idx_info_dict = {} -# with open(data_json, encoding='utf-8') as file: -# dataset_desc = json.loads(file.read()) -# for i in range(len(dataset_desc["focaldepth"])): -# # if i == 2: -# # break -# idx = dataset_desc["idx"][i] -# focaldepth = dataset_desc["focaldepth"][i] -# gazeX = dataset_desc["gazeX"][i] -# gazeY = dataset_desc["gazeY"][i] -# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) -# phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) -# phi_dict[idx]=phi -# mask_dict[idx]=mask -# idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY] -# return phi_dict,mask_dict,idx_info_dict - -# def generatePhiMaskDictNew(data_json, generator): -# phi_dict = {} -# mask_dict = {} -# idx_info_dict = {} -# with open(data_json, encoding='utf-8') as file: -# dataset_desc = json.loads(file.read()) -# for i in range(len(dataset_desc["seq"])): -# for j in dataset_desc["seq"][i]: -# idx = dataset_desc["idx"][j] -# focaldepth = dataset_desc["focaldepth"][j] -# gazeX = dataset_desc["gazeX"][j] -# gazeY = dataset_desc["gazeY"][j] -# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) -# phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) -# phi_dict[idx]=phi -# mask_dict[idx]=mask -# idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY] -# return phi_dict,mask_dict,idx_info_dict - -mode = "Silence" #"Perf" -model_type = "RNN" #"RNN" -w_frame = 0.9 -w_inter_frame = 0.1 -batch_model = "NoSingle" -loss1 = ReconstructionLoss() -loss2 = ReconstructionLoss() -if __name__ == "__main__": - ############################## generate phi and mask in pre-training - - # print("generating phi and mask...") - # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDictNew(DATA_JSON,gen) - # # save_obj(phi_dict,"phi_1204") - # # save_obj(mask_dict,"mask_1204") - # # save_obj(idx_info_dict,"idx_info_1204") - # print("generating phi and mask end.") - # exit(0) - ############################# load phi and mask in pre-training - # print("loading phi and mask ...") - # phi_dict = load_obj("phi_1204") - # mask_dict = load_obj("mask_1204") - # idx_info_dict = load_obj("idx_info_1204") - # print(len(phi_dict)) - # print(len(mask_dict)) - # print("loading phi and mask end") - - #### Image Gen and conf - conf = Conf() - gen = RetinalGen(conf) - - #train - train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldFlowSeqDataLoader(DATA_FILE,DATA_JSON, gen, conf), - batch_size=BATCH_SIZE, - num_workers=8, - pin_memory=True, - shuffle=True, - drop_last=False) - #Data loader test - print(len(train_data_loader)) - - # # lightfield_images, gt, flow, fd, gazeX, gazeY, posX, posY, sample_idx, phi, mask - # # lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, fd, gazeX, gazeY, posX, posY, posZ, sample_idx - # for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader): - # print(image_set.shape,type(image_set)) - # print(gt.shape,type(gt)) - # print(phi.shape,type(phi)) - # print(phi_invalid.shape,type(phi_invalid)) - # print(retinal_invalid.shape,type(retinal_invalid)) - # print(flow.shape,type(flow)) - # print(df.shape,type(df)) - # print(gazeX.shape,type(gazeX)) - # print(posX.shape,type(posX)) - # print(sample_idx.shape,type(sample_idx)) - # print("test train dataloader.") - # exit(0) - #Data loader test end - - - - ################################################ val ######################################################### - # val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldValDataLoader(DATA_FILE,DATA_VAL_JSON), - # batch_size=1, - # num_workers=0, - # pin_memory=True, - # shuffle=False, - # drop_last=False) - - # print(len(val_data_loader)) - - # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # lf_model = baseline.model() - # if torch.cuda.is_available(): - # lf_model = torch.nn.DataParallel(lf_model).cuda() - - # checkpoint = torch.load(os.path.join(OUTPUT_DIR,"gaze-ckpt-epoch-0201.pth")) - # lf_model.load_state_dict(checkpoint["model_state_dict"]) - # lf_model.eval() - - # print("Eval::") - # for sample_idx, (image_set, df, gazeX, gazeY, sample_idx) in enumerate(val_data_loader): - # print("sample_idx::",sample_idx) - # with torch.no_grad(): - - # #reshape for input - # image_set = image_set.permute(0,1,4,2,3) # N LF C H W - # image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W - # image_set = var_or_cuda(image_set) - - # # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape) - # output = lf_model(image_set,df,gazeX,gazeY) - # output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict) - - # for i in range(0, 2): - # output1[:,i*3:i*3+3].mul_(mask[:,i:i+1]) - # output1[:,i*3:i*3+3].clamp_(0., 1.) - - # print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data) - # for i in range(output1.size()[0]): - # save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - # save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - # save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - # save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - - # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.5f.png"%(df[0].data))) - # # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.5f.png"%(df[0].data))) - # # output = GenRetinalFromLayersBatch(output,conf,df,v,u) - # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.5f.png"%(df[0].data))) - # exit() - - ################################################ train ######################################################### - if model_type == "RNN": - lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE) - else: - lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False) - lf_model.apply(weight_init_normal) - lf_model.train() - epoch_begin = 0 - - ################################ load model file - # WEIGHTS = os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (101)) - # print('[INFO] Recovering from %s ...' % (WEIGHTS)) - # checkpoint = torch.load(WEIGHTS) - # init_epoch = checkpoint['epoch_idx'] - # lf_model.load_state_dict(checkpoint['model_state_dict']) - # epoch_begin = init_epoch + 1 - # print(lf_model) - ############################################################ - - if torch.cuda.is_available(): - # lf_model = torch.nn.DataParallel(lf_model).cuda() - lf_model = lf_model.to('cuda:2') - - optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999)) - - # lf_model.output_layer.register_backward_hook(hook_fn_back) - if mode=="Perf": - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - print("begin training....") - for epoch in range(epoch_begin, NUM_EPOCH): - for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader): - # print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5]) - # print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3]) - # print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3]) - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("load:",start.elapsed_time(end)) - - start.record() - #reshape for input - image_set = image_set.permute(0,1,2,5,3,4) # N Seq 5 LF 25 C 3 H W - image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W - # N, Seq 5, LF 25 C 3, H, W - image_set = var_or_cuda(image_set) - # print(image_set.shape) #torch.Size([2, 5, 75, 320, 320]) - - gt = gt.permute(0,1,4,2,3) # BS Seq 5 C 3 H W - gt = var_or_cuda(gt) - - flow = var_or_cuda(flow) #BS,Seq-1,H,W,2 - - # gt2 = gt2.permute(0,1,4,2,3) - # gt2 = var_or_cuda(gt2) - - gen1 = torch.empty(gt.shape) # BS Seq C H W - gen1 = var_or_cuda(gen1) - # print(gen1.shape) #torch.Size([2, 5, 3, 320, 320]) - - # gen2 = torch.empty(gt2.shape) - # gen2 = var_or_cuda(gen2) - - #BS, Seq - 1, C, H, W - warped = torch.empty(gt.shape[0],gt.shape[1]-1,gt.shape[2],gt.shape[3],gt.shape[4]) - warped = var_or_cuda(warped) - gen_temp = torch.empty(warped.shape) - gen_temp = var_or_cuda(gen_temp) - # print("warped:",warped.shape) #warped: torch.Size([2, 4, 3, 320, 320]) - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("data prepare:",start.elapsed_time(end)) - - start.record() - if model_type == "RNN": - if batch_model != "Single": - for k in range(image_set.shape[1]): - if k == 0: - lf_model.reset_hidden(image_set[:,k]) - output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 - output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) - gen1[:,k] = output1[:,0:3] - gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) - - if ((epoch%10 == 0 and epoch != 0) or epoch == 2): - for i in range(output.shape[0]): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - - for i in range(1,gt.shape[1]): - # print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320]) - # print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320]) - warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) - gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) - else: - for k in range(image_set.shape[1]): - if k == 0: - lf_model.reset_hidden(image_set[:,k]) - output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 - output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) - gen1[:,k] = output1[:,0:3] - gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) - - if k != image_set.shape[1]-1: - warped[:,k] = FlowMap(output1.detach(),flow[:,k]).mul(~flow_invalid_mask[:,k].unsqueeze(1).to("cuda:2")) - loss1 = loss_without_perc(output1,gt[:,k]) - loss = (w_frame * loss1) - if k==0: - loss.backward(retain_graph=False) - optimizer.step() - lf_model.zero_grad() - optimizer.zero_grad() - print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item()) - else: - output1mask = output1.mul(~flow_invalid_mask[:,k-1].unsqueeze(1).to("cuda:2")) - loss2 = l1loss(output1mask,warped[:,k-1]) - loss += (w_inter_frame * loss2) - loss.backward(retain_graph=False) - optimizer.step() - lf_model.zero_grad() - optimizer.zero_grad() - # print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item()) - print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",frame loss:",loss1.item(),",inter loss:",w_inter_frame * loss2.item()) - - - if ((epoch%10 == 0 and epoch != 0) or epoch == 2): - for i in range(output.shape[0]): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - - # BSxSeq C H W - gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) - gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) - - if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) - for i in range(gt.size()[0]): - save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) - save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) - if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): - save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer) - else: - if batch_model != "Single": - for k in range(image_set.shape[1]): - output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 - output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) - gen1[:,k] = output1[:,0:3] - gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) - - if ((epoch%10 == 0 and epoch != 0) or epoch == 2): - for i in range(output.shape[0]): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - for i in range(1,gt.shape[1]): - # print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320]) - # print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320]) - warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) - gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) - else: - for k in range(image_set.shape[1]): - output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 - output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) - gen1[:,k] = output1[:,0:3] - gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) - - # print(output1.shape) #torch.Size([BS, 3, 320, 320]) - loss1 = loss_without_perc(output1,gt[:,k]) - loss = (w_frame * loss1) - # print("loss:",loss1.item()) - loss.backward(retain_graph=False) - optimizer.step() - lf_model.zero_grad() - optimizer.zero_grad() - - print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item()) - if ((epoch%10 == 0 and epoch != 0) or epoch == 2): - for i in range(output.shape[0]): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) - for i in range(1,gt.shape[1]): - warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) - gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) - - warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4]) - gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4]) - loss3 = l1loss(warped,gen_temp) - print("Epoch:",epoch,",Iter:",batch_idx,",inter-frame loss:",w_inter_frame *loss3.item()) - - # BSxSeq C H W - gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) - gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) - - if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) - for i in range(gt.size()[0]): - save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) - save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) - if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): - save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer) - - - if batch_model != "Single": - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("forward:",start.elapsed_time(end)) - - start.record() - optimizer.zero_grad() - - - # BSxSeq C H W - gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) - # gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4]) - # BSxSeq C H W - gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) - # gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) - - # BSx(Seq-1) C H W - warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4]) - gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4]) - - loss1_value = loss1(gen1,gt) - loss2_value = loss2(warped,gen_temp) - if model_type == "RNN": - loss = (w_frame * loss1_value)+ (w_inter_frame * loss2_value) - else: - loss = (w_frame * loss1_value) - - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("compute loss:",start.elapsed_time(end)) - - start.record() - loss.backward() - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("backward:",start.elapsed_time(end)) - - start.record() - optimizer.step() - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("update:",start.elapsed_time(end)) - - print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item(),",frame loss:",loss1_value.item(),",inter-frame loss:",loss2_value.item()) - - # exit(0) - ########################### Save ##################### - if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) - for i in range(gt.size()[0]): - # df 2,5 - save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) - # save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) - save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) - # save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) - if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): - save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer) \ No newline at end of file diff --git a/main_lf_syn.py b/main_lf_syn.py deleted file mode 100644 index 6a9e613..0000000 --- a/main_lf_syn.py +++ /dev/null @@ -1,147 +0,0 @@ -import torch -import argparse -import os -import glob -import numpy as np -import torchvision.transforms as transforms -from torchvision.utils import save_image - -from torchvision import datasets -from torch.utils.data import DataLoader -from torch.autograd import Variable - -import cv2 -from loss import * -import json - -from baseline import * -from data import * - -import torch.autograd.profiler as profiler -# param -BATCH_SIZE = 2 -NUM_EPOCH = 1001 -INTERLEAVE_RATE = 2 -IM_H = 540 -IM_W = 376 -Retinal_IM_H = 540 -Retinal_IM_W = 376 -N = 4 # number of input light field stack -M = 1 # number of display layers -DATA_FILE = "/home/yejiannan/Project/deeplightfield/data/lf_syn" -DATA_JSON = "/home/yejiannan/Project/deeplightfield/data/data_lf_syn_full.json" -# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" -OUTPUT_DIR = "/home/yejiannan/Project/deeplightfield/outputE/lf_syn_full1219" -OUT_CHANNELS_RB = 128 -KERNEL_SIZE_RB = 3 -KERNEL_SIZE = 3 -LAST_LAYER_CHANNELS = 3 * INTERLEAVE_RATE**2 -FIRSST_LAYER_CHANNELS = 12 * INTERLEAVE_RATE**2 - -from weight_init import weight_init_normal - -def save_checkpoints(file_path, epoch_idx, model, model_solver): - print('[INFO] Saving checkpoint to %s ...' % ( file_path)) - checkpoint = { - 'epoch_idx': epoch_idx, - 'model_state_dict': model.state_dict(), - 'model_solver_state_dict': model_solver.state_dict() - } - torch.save(checkpoint, file_path) - -mode = "Silence" #"Perf" -w_frame = 1.0 -loss1 = PerceptionReconstructionLoss() -if __name__ == "__main__": - #train - train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSynDataLoader(DATA_FILE,DATA_JSON), - batch_size=BATCH_SIZE, - num_workers=8, - pin_memory=True, - shuffle=True, - drop_last=False) - #Data loader test - print(len(train_data_loader)) - - lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False) - lf_model.apply(weight_init_normal) - lf_model.train() - epoch_begin = 0 - - if torch.cuda.is_available(): - # lf_model = torch.nn.DataParallel(lf_model).cuda() - lf_model = lf_model.to('cuda:1') - - optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999)) - - # lf_model.output_layer.register_backward_hook(hook_fn_back) - if mode=="Perf": - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - print("begin training....") - for epoch in range(epoch_begin, NUM_EPOCH): - for batch_idx, (image_set, gt, pos_row, pos_col) in enumerate(train_data_loader): - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("load:",start.elapsed_time(end)) - - start.record() - #reshape for input - image_set = image_set.permute(0,1,4,2,3) # N LF C H W - image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N LFxC H W - image_set = var_or_cuda(image_set) - - gt = gt.permute(0,3,1,2) # BS C H W - gt = var_or_cuda(gt) - - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("data prepare:",start.elapsed_time(end)) - - start.record() - - output = lf_model(image_set,pos_row, pos_col) # 2 6 376 540 - - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("forward:",start.elapsed_time(end)) - - start.record() - optimizer.zero_grad() - # print("output:",output.shape," gt:",gt.shape) - loss1_value = loss1(output,gt) - loss = (w_frame * loss1_value) - - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("compute loss:",start.elapsed_time(end)) - - start.record() - loss.backward() - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("backward:",start.elapsed_time(end)) - - start.record() - optimizer.step() - if mode=="Perf": - end.record() - torch.cuda.synchronize() - print("update:",start.elapsed_time(end)) - - print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item()) - - # exit(0) - ########################### Save ##################### - if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) - for i in range(gt.size()[0]): - save_image(output[i].data,os.path.join(OUTPUT_DIR,"out_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data))) - save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gt_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data))) - if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): - save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer) \ No newline at end of file diff --git a/my/device.py b/my/device.py new file mode 100644 index 0000000..c537a22 --- /dev/null +++ b/my/device.py @@ -0,0 +1,7 @@ +import torch + + +def GetDevice(): + if torch.cuda.is_available(): + return torch.device('cuda') + return torch.device('cpu') \ No newline at end of file diff --git a/run_lf_syn.py b/run_lf_syn.py index fc49bd2..b221d86 100644 --- a/run_lf_syn.py +++ b/run_lf_syn.py @@ -10,21 +10,24 @@ from tensorboardX import SummaryWriter from .loss.loss import PerceptionReconstructionLoss from .my import netio from .my import util +from .my import device from .my.simple_perf import SimplePerf from .data.lf_syn import LightFieldSynDataset from .trans_unet import TransUnet -device = torch.device("cuda:2") +torch.cuda.set_device(2) +print("Set CUDA:%d as current device." % torch.cuda.current_device()) + DATA_DIR = os.path.dirname(__file__) + '/data/lf_syn_2020.12.23' TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json' -OUTPUT_DIR = DATA_DIR + '/output_low_lr' -RUN_DIR = DATA_DIR + '/run_low_lr' -BATCH_SIZE = 1 +OUTPUT_DIR = DATA_DIR + '/output_bat2' +RUN_DIR = DATA_DIR + '/run_bat2' +BATCH_SIZE = 8 TEST_BATCH_SIZE = 10 NUM_EPOCH = 1000 MODE = "Silence" # "Perf" -EPOCH_BEGIN = 500 +EPOCH_BEGIN = 0 def train(): @@ -44,7 +47,7 @@ def train(): view_images=train_dataset.sparse_view_images, view_depths=train_dataset.sparse_view_depths, view_positions=train_dataset.sparse_view_positions, - diopter_of_layers=train_dataset.diopter_of_layers).to(device) + diopter_of_layers=train_dataset.diopter_of_layers).to(device.GetDevice()) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) loss = PerceptionReconstructionLoss() @@ -66,7 +69,7 @@ def train(): for epoch in range(EPOCH_BEGIN, NUM_EPOCH): for _, view_images, _, view_positions in train_data_loader: - view_images = view_images.to(device) + view_images = view_images.to(device.GetDevice()) perf.Checkpoint("Load") @@ -106,7 +109,6 @@ def train(): solver=optimizer) print("Train finished") - netio.SaveNet('%s/model-epoch_%d.pth' % (RUN_DIR, epoch + 1), model) def test(net_file: str): @@ -125,7 +127,7 @@ def test(net_file: str): view_images=train_dataset.sparse_view_images, view_depths=train_dataset.sparse_view_depths, view_positions=train_dataset.sparse_view_positions, - diopter_of_layers=train_dataset.diopter_of_layers).to(device) + diopter_of_layers=train_dataset.diopter_of_layers).to(device.GetDevice()) netio.LoadNet(net_file, model) # 3. Test on train dataset @@ -142,5 +144,5 @@ def test(net_file: str): if __name__ == "__main__": - #train() - test(RUN_DIR + '/model-epoch_1000.pth') + train() + #test(RUN_DIR + '/model-epoch_1000.pth') diff --git a/trans_unet.py b/trans_unet.py index 45211a8..db9c72d 100644 --- a/trans_unet.py +++ b/trans_unet.py @@ -3,8 +3,7 @@ import torch import torch.nn as nn from .pytorch_prototyping.pytorch_prototyping import * from .my import util - -device = torch.device("cuda:2") +from .my import device class Encoder(nn.Module): @@ -66,15 +65,15 @@ class LatentSpaceTransformer(nn.Module): self.n_views = view_positions.size()[0] self.diopter_of_layers = diopter_of_layers self.feat_coords = util.MeshGrid( - (feat_dim, feat_dim)).to(device=device) + (feat_dim, feat_dim)).to(device.GetDevice()) def forward(self, feats: torch.Tensor, feat_depths: torch.Tensor, novel_views: torch.Tensor) -> torch.Tensor: - trans_feats = torch.zeros(novel_views.size()[0], feats.size()[0], - feats.size()[1], feats.size()[ - 2], feats.size()[3], - device=device) + trans_feats = torch.zeros(novel_views.size()[0], + feats.size()[0], feats.size()[1], + feats.size()[2], feats.size()[3], + device=device.GetDevice()) for i in range(novel_views.size()[0]): for v in range(self.n_views): for l in range(len(self.diopter_of_layers)): @@ -151,8 +150,8 @@ class TransUnet(nn.Module): latent_sidelength = 64 # The dimensions of the latent space image_sidelength = view_images.size()[2] - self.view_images = view_images.to(device) - self.view_depths = view_depths.to(device) + self.view_images = view_images.to(device.GetDevice()) + self.view_depths = view_depths.to(device.GetDevice()) self.n_views = view_images.size()[0] self.encoder = Encoder(nf0=nf0, out_channels=nf, diff --git a/weight_init.py b/weight_init.py deleted file mode 100644 index 9c0a20a..0000000 --- a/weight_init.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -weightVarScale = 0.25 -bias_stddev = 0.01 - -def weight_init_normal(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - torch.nn.init.xavier_normal_(m.weight.data) - torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev) - elif classname.find("BatchNorm2d") != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) \ No newline at end of file -- GitLab