From a22cfc90006c0bb8fa7c165273ec461d804169da Mon Sep 17 00:00:00 2001
From: BobYeah <635596704@qq.com>
Date: Sat, 19 Dec 2020 15:32:04 +0800
Subject: [PATCH] Update CNN scripts

---
 Flow.py                          |  83 ++++
 model/baseline.py => baseline.py |  54 +-
 conf.py                          |  55 +-
 data.py                          | 305 ++++++++++++
 gen_image.py                     | 287 +++++++++--
 loss.py                          |  80 +++
 main.py                          | 826 +++++++++++++++----------------
 main_lf_syn.py                   |   6 +-
 model/__init__.py                |   0
 model/recurrent.py               | 146 ------
 util.py                          | 104 ++++
 weight_init.py                   |  12 +
 12 files changed, 1283 insertions(+), 675 deletions(-)
 create mode 100644 Flow.py
 rename model/baseline.py => baseline.py (75%)
 create mode 100644 loss.py
 delete mode 100644 model/__init__.py
 delete mode 100644 model/recurrent.py
 create mode 100644 util.py
 create mode 100644 weight_init.py

diff --git a/Flow.py b/Flow.py
new file mode 100644
index 0000000..83a6b6a
--- /dev/null
+++ b/Flow.py
@@ -0,0 +1,83 @@
+import matplotlib.pyplot as plt
+import torch
+import util
+import numpy as np
+def FlowMap(b_last_frame, b_map):
+    '''
+    Map images using the flow data.
+    
+    Parameters
+    --------
+    b_last_frame - B x 3 x H x W tensor, batch of images
+    b_map - B x H x W x 2, batch of map data records pixel coords in last frames
+    
+    Returns
+    --------
+    B x 3 x H x W tensor, batch of images mapped by flow data
+    '''
+    return torch.nn.functional.grid_sample(b_last_frame, b_map, align_corners=False)
+    
+class Flow(object):
+    '''
+    Class representating optical flow
+    
+    Properties
+    --------
+    b_data         - B x H x W x 2, batch of flow data
+    b_map          - B x H x W x 2, batch of map data records pixel coords in last frames
+    b_invalid_mask - B x H x W, batch of masks, indicate invalid elements in corresponding flow data
+    '''
+    def Load(paths):
+        '''
+        Create a Flow instance using a batch of encoded data images loaded from paths
+        
+        Parameters
+        --------
+        paths - list of encoded data image paths
+        
+        Returns
+        --------
+        Flow instance
+        '''
+        b_encoded_image = util.ReadImageTensor(paths, rgb_only=False, permute=False, batch_dim=True)
+        return Flow(b_encoded_image)
+
+    def __init__(self, b_encoded_image):
+        '''
+        Initialize a Flow instance from a batch of encoded data images
+        
+        Parameters
+        --------
+        b_encoded_image - batch of encoded data images
+        '''
+        b_encoded_image = b_encoded_image.mul(255)
+        # print("b_encoded_image:",b_encoded_image.shape)
+        self.b_invalid_mask = (b_encoded_image[:, :, :, 0] == 255)
+        self.b_data = (b_encoded_image[:, :, :, 0:2] / 254 + b_encoded_image[:, :, :, 2:4] - 127) / 127
+        self.b_data[:, :, :, 1] = -self.b_data[:, :, :, 1]
+        D = self.b_data.size()
+        grid = util.MeshGrid((D[1], D[2]), True)
+        self.b_map = (grid - self.b_data - 0.5) * 2
+        self.b_map[self.b_invalid_mask] = torch.tensor([ -2.0, -2.0 ])
+    
+    def getMap(self):
+        return self.b_map
+
+    def Visualize(self, scale_factor = 1):
+        '''
+        Visualize the flow data by "color wheel".
+        
+        Parameters
+        --------
+        scale_factor - scale factor of flow data to visualize, default is 1
+        
+        Returns
+        --------
+        B x 3 x H x W tensor, visualization of flow data
+        '''
+        try:
+            Flow.b_color_wheel
+        except AttributeError:
+            Flow.b_color_wheel = util.ReadImageTensor('color_wheel.png')
+        return torch.nn.functional.grid_sample(Flow.b_color_wheel.expand(self.b_data.size()[0], -1, -1, -1),
+            (self.b_data * scale_factor), align_corners=False)
\ No newline at end of file
diff --git a/model/baseline.py b/baseline.py
similarity index 75%
rename from model/baseline.py
rename to baseline.py
index d829464..0ba91ff 100644
--- a/model/baseline.py
+++ b/baseline.py
@@ -38,7 +38,7 @@ class residual_block(torch.nn.Module):
     def forward(self,input):
         if self.RNN:
             # print("input:",input.shape,"hidden:",self.hidden.shape)
-            inp = torch.cat((input,self.hidden),dim=1)
+            inp = torch.cat((input,self.hidden.detach()),dim=1)
             # print(inp.shape)
             output = self.layer1(inp)
             output = self.layer2(output)
@@ -97,7 +97,7 @@ class interleave(torch.nn.Module):
         return output
 
 class model(torch.nn.Module):
-    def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE):
+    def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False):
         super(model, self).__init__()
         self.interleave = interleave(INTERLEAVE_RATE)
 
@@ -108,13 +108,19 @@ class model(torch.nn.Module):
         )
         
         self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False)
-        self.residual_block2 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,False)
-        self.residual_block3 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
-        self.residual_block4 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
-        self.residual_block5 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
+        self.residual_block2 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
+        self.residual_block3 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
+        # if RNN:
+        #     self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
+        #     self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
+        #     self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
+        # else:
+        #     self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
+        #     self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
+        #     self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
 
         self.output_layer = torch.nn.Sequential(
-            torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
+            torch.nn.Conv2d(OUT_CHANNELS_RB+2,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
             torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
             torch.nn.Sigmoid()
         )
@@ -125,7 +131,7 @@ class model(torch.nn.Module):
         self.residual_block4.reset_hidden(inp)
         self.residual_block5.reset_hidden(inp)
 
-    def forward(self, lightfield_images, focal_length, gazeX, gazeY):
+    def forward(self, lightfield_images, pos_row, pos_col):
         # lightfield_images: torch.Size([batch_size, channels * D, H, W]) 
         # channels : RGB*D: 3*9, H:256, W:256
         # print("lightfield_images:",lightfield_images.shape)
@@ -136,32 +142,18 @@ class model(torch.nn.Module):
         # print("input_to_rb1:",input_to_rb.shape)
         output = self.residual_block1(input_to_rb)
 
-        depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
-        gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
-        gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
-        # print("depth_layer:",depth_layer.shape)
-        # print("focal_depth:",focal_length," gazeX:",gazeX," gazeY:",gazeY, " gazeX norm:",(gazeX[0] - (-3.333)) / (3.333*2))
-        for i in range(focal_length.shape[0]):
-            depth_layer[i] *= 1. / focal_length[i]
-            gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
-            gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
+        pos_row_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
+        pos_col_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
+        for i in range(pos_row.shape[0]):
+            pos_row_layer[i] *= pos_row[i]
+            pos_col_layer[i] *= pos_col[i]
             # print(depth_layer.shape)
-        depth_layer = var_or_cuda(depth_layer)
-        gazeX_layer = var_or_cuda(gazeX_layer)
-        gazeY_layer = var_or_cuda(gazeY_layer)
-
-        output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
-        # output = torch.cat((output,depth_layer),dim=1)
-        # print("output to rb2:",output.shape)
+        pos_row_layer = var_or_cuda(pos_row_layer)
+        pos_col_layer = var_or_cuda(pos_col_layer)
 
+        output = torch.cat((output,pos_row_layer,pos_col_layer),dim=1)
         output = self.residual_block2(output)
-        # print("output to rb3:",output.shape)
         output = self.residual_block3(output)
-        # print("output to rb4:",output.shape)
-        output = self.residual_block4(output)
-        # print("output to rb5:",output.shape)
-        output = self.residual_block5(output)
-        # output = output + input_to_net
         output = self.output_layer(output)
         output = self.deinterleave(output)
-        return output
+        return output
\ No newline at end of file
diff --git a/conf.py b/conf.py
index a6aea74..2bb9180 100644
--- a/conf.py
+++ b/conf.py
@@ -1,27 +1,68 @@
 import torch
-from gen_image import *
+import util
+import numpy as np
 class Conf(object):
     def __init__(self):
         self.pupil_size = 0.02
         self.retinal_res = torch.tensor([ 320, 320 ])
         self.layer_res = torch.tensor([ 320, 320 ])
         self.layer_hfov = 90                  # layers' horizontal FOV
-        self.eye_hfov = 85                    # eye's horizontal FOV (ignored in foveated rendering)
-        self.eye_enable_fovea = True          # enable foveated rendering
+        self.eye_hfov = 80                    # eye's horizontal FOV (ignored in foveated rendering)
+        self.eye_enable_fovea = False          # enable foveated rendering
         self.eye_fovea_angles = [ 40, 80 ]    # eye's foveation layers' angles
         self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples
         self.d_layer = [ 1, 3 ]               # layers' distance
-        
+        self.eye_fovea_blend = [ self._GenFoveaLayerBlend(0) ]
+                                                # blend maps of fovea layers
+        self.light_field_dim = 5
     def GetNLayers(self):
         return len(self.d_layer)
     
     def GetLayerSize(self, i):
-        w = Fov2Length(self.layer_hfov)
+        w = util.Fov2Length(self.layer_hfov)
         h = w * self.layer_res[0] / self.layer_res[1]
         return torch.tensor([ h, w ]) * self.d_layer[i]
 
+    def GetPixelSizeOfLayer(self, i):
+        '''
+        Get pixel size of layer i
+        '''
+        return util.Fov2Length(self.layer_hfov) * self.d_layer[i] / self.layer_res[0]
+
     def GetEyeViewportSize(self):
         fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov
-        w = Fov2Length(fov)
+        w = util.Fov2Length(fov)
         h = w * self.retinal_res[0] / self.retinal_res[1]
-        return torch.tensor([ h, w ])
\ No newline at end of file
+        return torch.tensor([ h, w ])
+
+    def GetRegionOfFoveaLayer(self, i):
+        '''
+        Get region of fovea layer i in retinal image
+        
+        Returns
+        --------
+        slice object stores the start and end of region
+        '''
+        roi_size = int(np.ceil(self.retinal_res[0] * self.eye_fovea_angles[i] / self.eye_fovea_angles[-1]))
+        roi_offset = int((self.retinal_res[0] - roi_size) / 2)
+        return slice(roi_offset, roi_offset + roi_size)
+    
+    def _GenFoveaLayerBlend(self, i):
+        '''
+        Generate blend map for fovea layer i
+        
+        Parameters
+        --------
+        i - index of fovea layer
+        
+        Returns
+        --------
+        H[i] x W[i], blend map
+        
+        '''
+        region = self.GetRegionOfFoveaLayer(i)
+        width = region.stop - region.start
+        R = width / 2
+        p = util.MeshGrid([ width, width ])
+        r = torch.linalg.norm(p - R, 2, dim=2, keepdim=False)
+        return util.SmoothStep(R, R * 0.6, r)
diff --git a/data.py b/data.py
index e69de29..14b5a2c 100644
--- a/data.py
+++ b/data.py
@@ -0,0 +1,305 @@
+import torch
+import os
+import glob
+import numpy as np
+import torchvision.transforms as transforms
+
+from torchvision import datasets
+from torch.utils.data import DataLoader 
+
+import cv2
+import json
+from Flow import *
+from gen_image import *
+import util
+
+import time
+
+
+class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset):
+    def __init__(self, file_dir_path,file_json):
+        self.file_dir_path = file_dir_path
+        with open(file_json, encoding='utf-8') as file:
+            self.dataset_desc = json.loads(file.read())
+        self.input_img = []
+        for i in self.dataset_desc["train"]:
+            lf_element = os.path.join(self.file_dir_path,i)
+            lf_element = cv2.imread(lf_element, -cv2.IMREAD_ANYDEPTH)[:, :, 0:3]
+            lf_element = cv2.cvtColor(lf_element, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
+            lf_element = lf_element[:,:-1,:]
+            self.input_img.append(lf_element)
+        self.input_img = np.asarray(self.input_img)
+    def __len__(self):
+        return len(self.dataset_desc["gt"])
+
+    def __getitem__(self, idx):
+        gt, pos_row, pos_col  = self.get_datum(idx)
+        return (self.input_img, gt, pos_row, pos_col)
+
+    def get_datum(self, idx):
+        fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
+        gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+        gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
+        gt = gt[:,:-1,:]
+        pos_col = self.dataset_desc["x"][idx]
+        pos_row = self.dataset_desc["y"][idx]
+        return gt, pos_row, pos_col
+
+class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
+    def __init__(self, file_dir_path, file_json, transforms=None):
+        self.file_dir_path = file_dir_path
+        self.transforms = transforms
+        # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
+        with open(file_json, encoding='utf-8') as file:
+            self.dataset_desc = json.loads(file.read())
+
+    def __len__(self):
+        return len(self.dataset_desc["focaldepth"])
+
+    def __getitem__(self, idx):
+        lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
+        if self.transforms:
+            lightfield_images = self.transforms(lightfield_images)
+        # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
+        return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
+
+    def get_datum(self, idx):
+        lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
+        # print(lf_image_paths)
+        fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
+        fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][idx])
+        # print(fd_gt_path)
+        lf_images = []
+        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
+
+        ## IF GrayScale
+        # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
+        # lf_image_big = np.expand_dims(lf_image_big, axis=-1)
+        # print(lf_image_big.shape)
+
+        for i in range(9):
+            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
+            ## IF GrayScale
+            # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
+            # print(lf_image.shape)
+            lf_images.append(lf_image)
+        gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+        gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
+        gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+        gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
+        ## IF GrayScale
+        # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
+        # gt = np.expand_dims(gt, axis=-1)
+
+        fd = self.dataset_desc["focaldepth"][idx]
+        gazeX = self.dataset_desc["gazeX"][idx]
+        gazeY = self.dataset_desc["gazeY"][idx]
+        sample_idx = self.dataset_desc["idx"][idx]
+        return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
+
+class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
+    def __init__(self, file_dir_path, file_json, transforms=None):
+        self.file_dir_path = file_dir_path
+        self.transforms = transforms
+        # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
+        with open(file_json, encoding='utf-8') as file:
+            self.dataset_desc = json.loads(file.read())
+
+    def __len__(self):
+        return len(self.dataset_desc["focaldepth"])
+
+    def __getitem__(self, idx):
+        lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
+        if self.transforms:
+            lightfield_images = self.transforms(lightfield_images)
+        # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
+        return (lightfield_images, fd, gazeX, gazeY, sample_idx)
+
+    def get_datum(self, idx):
+        lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
+        # print(fd_gt_path)
+        lf_images = []
+        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
+
+        ## IF GrayScale
+        # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
+        # lf_image_big = np.expand_dims(lf_image_big, axis=-1)
+        # print(lf_image_big.shape)
+
+        for i in range(9):
+            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
+            ## IF GrayScale
+            # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
+            # print(lf_image.shape)
+            lf_images.append(lf_image)
+        ## IF GrayScale
+        # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
+        # gt = np.expand_dims(gt, axis=-1)
+
+        fd = self.dataset_desc["focaldepth"][idx]
+        gazeX = self.dataset_desc["gazeX"][idx]
+        gazeY = self.dataset_desc["gazeY"][idx]
+        sample_idx = self.dataset_desc["idx"][idx]
+        return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
+
+class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
+    def __init__(self, file_dir_path, file_json, transforms=None):
+        self.file_dir_path = file_dir_path
+        self.transforms = transforms
+        with open(file_json, encoding='utf-8') as file:
+            self.dataset_desc = json.loads(file.read())
+
+    def __len__(self):
+        return len(self.dataset_desc["seq"])
+
+    def __getitem__(self, idx):
+        lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
+        fd = fd.astype(np.float32)
+        gazeX = gazeX.astype(np.float32)
+        gazeY = gazeY.astype(np.float32)
+        sample_idx = sample_idx.astype(np.int64)
+        # print(fd)
+        # print(gazeX)
+        # print(gazeY)
+        # print(sample_idx)
+
+        # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
+        # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
+        if self.transforms:
+            lightfield_images = self.transforms(lightfield_images)
+        return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
+
+    def get_datum(self, idx):
+        indices = self.dataset_desc["seq"][idx]
+        # print("indices:",indices)
+        lf_images = []
+        fd = []
+        gazeX = []
+        gazeY = []
+        sample_idx = []
+        gt = []
+        gt2 = []
+        for i in range(len(indices)):
+            lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
+            fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
+            fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
+            lf_image_one_sample = []
+            lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+            lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
+
+            for j in range(9):
+                lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
+                lf_image_one_sample.append(lf_image)
+
+            gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+            gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
+            gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+            gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
+
+            # print("indices[i]:",indices[i])
+            fd.append([self.dataset_desc["focaldepth"][indices[i]]])
+            gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
+            gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
+            sample_idx.append([self.dataset_desc["idx"][indices[i]]])
+            lf_images.append(lf_image_one_sample)
+        #lf_images: 5,9,320,320
+
+        return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
+
+class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
+    def __init__(self, file_dir_path, file_json, gen, conf, transforms=None):
+        self.file_dir_path = file_dir_path
+        self.file_json = file_json
+        self.gen = gen
+        self.conf = conf
+        self.transforms = transforms
+        with open(file_json, encoding='utf-8') as file:
+            self.dataset_desc = json.loads(file.read())
+
+    def __len__(self):
+        return len(self.dataset_desc["seq"])
+
+    def __getitem__(self, idx):
+        # start = time.time()
+        lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx = self.get_datum(idx)
+        fd = fd.astype(np.float32)
+        gazeX = gazeX.astype(np.float32)
+        gazeY = gazeY.astype(np.float32)
+        posX = posX.astype(np.float32)
+        posY = posY.astype(np.float32)
+        posZ = posZ.astype(np.float32)
+        sample_idx = sample_idx.astype(np.int64)
+
+        if self.transforms:
+            lightfield_images = self.transforms(lightfield_images)
+        # print("read once:",time.time() - start) # 8 ms
+        return (lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx)
+
+    def get_datum(self, idx):
+        IM_H = 320
+        IM_W = 320
+        indices = self.dataset_desc["seq"][idx]
+        # print("indices:",indices)
+        lf_images = []
+        fd = []
+        gazeX = []
+        gazeY = []
+        posX = []
+        posY = []
+        posZ = []
+        sample_idx = []
+        gt = []
+        # gt2 = []
+        phi = []
+        phi_invalid = []
+        retinal_invalid = []
+        for i in range(len(indices)): # 5
+            lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
+            fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
+            # fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
+            lf_image_one_sample = []
+            lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+            lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
+
+            lf_dim = int(self.conf.light_field_dim)
+            for j in range(lf_dim**2):
+                lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim*IM_H+IM_H,j%lf_dim*IM_W:j%lf_dim*IM_W+IM_W,0:3]
+                lf_image_one_sample.append(lf_image)
+
+            gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+            gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
+            # gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+            # gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
+
+            # print("indices[i]:",indices[i])
+            fd.append([self.dataset_desc["focaldepth"][indices[i]]])
+            gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
+            gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
+            posX.append([self.dataset_desc["x"][indices[i]]])
+            posY.append([self.dataset_desc["y"][indices[i]]])
+            posZ.append([0.0])
+            sample_idx.append([self.dataset_desc["idx"][indices[i]]])
+            lf_images.append(lf_image_one_sample)
+
+            idx_i = sample_idx[i][0]
+            focaldepth_i = fd[i][0]
+            gazeX_i = gazeX[i][0]
+            gazeY_i = gazeY[i][0]
+            posX_i = posX[i][0]
+            posY_i = posY[i][0]
+            posZ_i = posZ[i][0]
+            # print("x:%.3f,y:%.3f,z:%.3f;gaze:%.4f,%.4f,focaldepth:%.3f."%(posX_i,posY_i,posZ_i,gazeX_i,gazeY_i,focaldepth_i))
+            phi_i,phi_invalid_i,retinal_invalid_i =  self.gen.CalculateRetinal2LayerMappings(torch.tensor([posX_i, posY_i,posZ_i]),torch.tensor([gazeX_i, gazeY_i]),focaldepth_i)
+
+            phi.append(phi_i)
+            phi_invalid.append(phi_invalid_i)
+            retinal_invalid.append(retinal_invalid_i)
+        #lf_images: 5,9,320,320
+        flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"][indices[i-1]]) for i in range(1,len(indices))])
+        flow_map = flow.getMap()
+        flow_invalid_mask = flow.b_invalid_mask
+        # print("flow:",flow_map.shape)
+
+        return np.asarray(lf_images),np.asarray(gt), torch.stack(phi,dim=0), torch.stack(phi_invalid,dim=0),torch.stack(retinal_invalid,dim=0), flow_map, flow_invalid_mask, np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(posX),np.asarray(posY),np.asarray(posZ),np.asarray(sample_idx)
diff --git a/gen_image.py b/gen_image.py
index 7ae1de8..83e7f4d 100644
--- a/gen_image.py
+++ b/gen_image.py
@@ -2,13 +2,8 @@ import matplotlib.pyplot as plt
 import numpy as np
 import torch
 import glm
-
-def Fov2Length(angle):
-    '''
-
-    '''
-    return np.tan(angle * np.pi / 360) * 2
-
+import time
+import util
 
 def RandomGenSamplesInPupil(pupil_size, n_samples):
     '''
@@ -21,9 +16,9 @@ def RandomGenSamplesInPupil(pupil_size, n_samples):
     
     Returns
     --------
-    a n_samples x 2 tensor with 2D sample position in each row
+    a n_samples x 3 tensor with 3D sample position in each row
     '''
-    samples = torch.empty(n_samples, 2)
+    samples = torch.empty(n_samples, 3)
     i = 0
     while i < n_samples:
         s = (torch.rand(2) - 0.5) * pupil_size
@@ -44,7 +39,7 @@ def GenSamplesInPupil(pupil_size, circles):
     
     Returns
     --------
-    a n_samples x 2 tensor with 2D sample position in each row
+    a n_samples x 3 tensor with 3D sample position in each row
     '''
     samples = torch.zeros(1, 3)
     for i in range(1, circles):
@@ -70,7 +65,7 @@ class RetinalGen(object):
     Methods
     --------
     '''
-    def __init__(self, conf, u):
+    def __init__(self, conf):
         '''
         Initialize retinal generator instance
 
@@ -80,58 +75,83 @@ class RetinalGen(object):
         u    - a M x 3 tensor stores M sample positions in pupil
         '''
         self.conf = conf
+        self.u = GenSamplesInPupil(conf.pupil_size, 5)
         # self.u = u.to(cuda_dev)
-        self.u = u # M x 3 M sample positions 
+        # self.u = u # M x 3 M sample positions 
         self.D_r = conf.retinal_res # retinal res 480 x 640 
         self.N = conf.GetNLayers() # 2 
-        self.M = u.size()[0] # samples
-        p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
-                                    torch.tensor(range(0, self.D_r[1])))
+        self.M = self.u.size()[0] # samples
+        # p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
+        #                             torch.tensor(range(0, self.D_r[1])))
+        # self.p_r = torch.cat([
+        #     ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 鐪肩悆瑙嗛噹
+        #     torch.ones(self.D_r[0], self.D_r[1], 1)
+        # ], 2)
+
         self.p_r = torch.cat([
-            ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 鐪肩悆瑙嗛噹
+            ((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(),
             torch.ones(self.D_r[0], self.D_r[1], 1)
         ], 2)
 
         # self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
         # self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
         
-    def CalculateRetinal2LayerMappings(self, df, gaze):
+    def CalculateRetinal2LayerMappings(self, position, gaze_dir, df):
         '''
         Calculate the mapping matrix from retinal to layers.
 
         Parameters
         --------
-        df   - focus distance
-        gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction
+        position - 1 x 3 tensor, eye's position
+        gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized)
+        df       - focus distance
 
+        Returns
+        --------
+        phi             - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
+        phi_invalid     - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping
+        retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image
         '''
-        Phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) # 2 x 480 x 640 x 41 x 2
-        mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float)
+        D = self.conf.layer_res 
+        c = torch.tensor([ D[1] / 2, D[0] / 2 ])     # c: Center of layers (pixel)
+
         D_r = self.conf.retinal_res        # D_r: Resolution of retinal 480 640
         V = self.conf.GetEyeViewportSize() # V: Viewport size of eye 
-        c = (self.conf.layer_res / 2)      # c: Center of layers (pixel)
         p_f = self.p_r * df                # p_f: H x W x 3, focus positions of retinal pixels on focus plane
-        rot_forward = glm.dvec3(glm.tan(glm.radians(glm.dvec2(gaze[1], -gaze[0]))), 1)
-        rot_mat = torch.from_numpy(np.array(
-            glm.dmat3(glm.lookAtLH(glm.dvec3(), rot_forward, glm.dvec3(0, 1, 0)))))
-        rot_mat = rot_mat.float()
-        u_rot = torch.mm(self.u, rot_mat)
-        v_rot = torch.matmul(p_f, rot_mat).unsqueeze(2).expand(
-            -1, -1, self.u.size()[0], -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
-        v_rot.div_(v_rot[:, :, :, 2].unsqueeze(3))             # make z = 1 for each direction vector in v_rot
-        
-        for i in range(0, self.conf.GetNLayers()):
-            dp_i = self.conf.GetLayerSize(i)[0] / self.conf.layer_res[0] # dp_i: Pixel size of layer i
-            d_i = self.conf.d_layer[i]                                        # d_i: Distance of layer i
+        
+        # Calculate transformation from eye to display
+        gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1)
+        gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0)))
+        eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye))
+        eye_center = torch.tensor([ position[0], -position[1], position[2] ])
+
+        u_rot = torch.mm(self.u, eye_rot)
+        v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand(
+            -1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
+        u_rot.add_(eye_center)                            # translate by eye's center
+        v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
+        
+        phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long)
+
+        for i in range(0, self.N):
+            dp_i = self.conf.GetPixelSizeOfLayer(i)     # dp_i: Pixel size of layer i
+            d_i = self.conf.d_layer[i]                  # d_i: Distance of layer i
             k = (d_i - u_rot[:, 2]).unsqueeze(1)
             pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i      # pi_r: H x W x M x 2, rays' pixel coord on layer i
-            Phi[i, :, :, :, :] = torch.floor(pi_r + c)
-        mask[:, :, :, :, 0] = ((Phi[:, :, :, :, 0] >= 0) & (Phi[:, :, :, :, 0] < self.conf.layer_res[0])).float()
-        mask[:, :, :, :, 1] = ((Phi[:, :, :, :, 1] >= 0) & (Phi[:, :, :, :, 1] < self.conf.layer_res[1])).float()
-        Phi[:, :, :, :, 0].clamp_(0, self.conf.layer_res[0] - 1)
-        Phi[:, :, :, :, 1].clamp_(0, self.conf.layer_res[1] - 1)
-        retinal_mask = mask.prod(0).prod(2).prod(2)
-        return [ Phi, retinal_mask ]
+            phi[i, :, :, :, :] = torch.floor(pi_r + c)
+        
+        # Calculate invalid mask (out-of-range elements in phi) and reduced to retinal
+        phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \
+                       (phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0])
+        phi_invalid = phi_invalid.unsqueeze(4)
+        # print("phi_invalid:",phi_invalid.shape) 
+        retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0)
+        # print("retinal_invalid:",retinal_invalid.shape)
+        # Fix invalid elements in phi
+        phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0
+
+        return [ phi, phi_invalid, retinal_invalid  ]
+    
     
     def GenRetinalFromLayers(self, layers, Phi):
         '''
@@ -139,28 +159,159 @@ class RetinalGen(object):
         
         Parameters
         --------
-        layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
+        layers       - 3N x H x W, stacked layer images, with 3 channels in each layer
+        phi          - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
         
         Returns
         --------
-        3 x H_r x W_r tensor, 3 channels retinal image
-        H_r x W_r tensor, retinal image mask, indicates pixels valid or not
-        
+        3 x H_r x W_r, 3 channels retinal image
         '''
         # FOR GRAYSCALE 1 FOR RGB 3
         mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
         # print("mapped_layers:",mapped_layers.shape)
         for i in range(0, Phi.size()[0]):
+            # torch.Size([3, 2, 320, 320, 2])
             # print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
             mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
-                                                    Phi[i, :, :, :, 0],
-                                                    Phi[i, :, :, :, 1]]
+                                                    Phi[i, :, :, :, 1],
+                                                    Phi[i, :, :, :, 0]]
         # print("mapped_layers:",mapped_layers.shape)
         retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
         # print("retinal:",retinal.shape)
         return retinal
+
+    def GenRetinalFromLayersBatch(self, layers, Phi):
+        '''
+        Generate retinal image from layers, using precalculated mapping matrix
+        
+        Parameters
+        --------
+        layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
+        
+        Returns
+        --------
+        3 x H_r x W_r tensor, 3 channels retinal image
+        H_r x W_r tensor, retinal image mask, indicates pixels valid or not
+        
+        '''
+        mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample
+        
+        # truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M)
+        # layers_truth = layers.clone()
+        # Phi_truth = Phi.clone()
+        layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320])
+        
+        # Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1)
+        # print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41])
+        # print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320])
+        # print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2])
+        
+        # #娌′紭鍖�
+
+        # for i in range(0, Phi_truth.size()[0]):
+        #     for j in range(0, Phi_truth.size()[1]):
+        #         truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3),
+        #                                                 Phi_truth[i, j, :, :, :, 0],
+        #                                                 Phi_truth[i, j, :, :, :, 1]]
+
+        #浼樺寲2
+        # start = time.time() 
+        mapped_layers_op1 = mapped_layers.reshape(-1,
+                mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
+                # BatchSizexLayer Channel 3 320 320 41
+        layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320
+        Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2
+        x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41
+        y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41
+        # print("reshape:",time.time() - start)
+
+        # start = time.time()
+        mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 鍒囨崲
+        #2x2 320 320 41 3
+        # print("mapping one step:",time.time() - start)
+        
+        # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
+        # start = time.time()
+        mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3)
+        mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
+                    mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
+        # print("reshape end:",time.time() - start)
+
+        # print("test:")
+        # print((truth.cpu() == mapped_layers.cpu()).all())
+        #浼樺寲1
+        # start = time.time()
+        # mapped_layers_op1 = mapped_layers.reshape(-1,
+        #         mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
+        # layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4])
+        # Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5])
+        # print("reshape:",time.time() - start)
+
+
+        # for i in range(0, Phi_op1.size()[0]):
+        #     start = time.time()
+        #     mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:,
+        #                                             Phi_op1[i, :, :, :, 0],
+        #                                             Phi_op1[i, :, :, :, 1]]
+        #     print("mapping one step:",time.time() - start)
+        # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
+        # start = time.time()
+        # mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
+        #             mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
+        # print("reshape end:",time.time() - start)
+
+        # print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41])
+        retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4])
+        # print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320])
+        return retinal
+        
+    ## TO BE CHECK
+    def GenFoveaLayers(self, b_retinal, is_mask):
+        '''
+        Generate foveated layers for retinal images or masks
+        
+        Parameters
+        --------
+        b_retinal - B x C x H_r x W_r, Batch of retinal images/masks
+        is_mask   - Whether b_retinal is masks or images
         
-    def GenFoveaLayers(self, retinal, retinal_mask):
+        Returns
+        --------
+        b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers
+        '''
+        b_fovea_layers = []
+        for i in range(0, len(self.conf.eye_fovea_angles)):
+            k = self.conf.eye_fovea_downsamples[i]
+            region = self.conf.GetRegionOfFoveaLayer(i)
+            b_roi = b_retinal[:, :, region, region]
+            if k == 1:
+                b_fovea_layers.append(b_roi)
+            elif is_mask:
+                b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool))
+            else:
+                b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k))
+        return b_fovea_layers
+        # fovea_layers = []
+        # fovea_layer_masks = []
+        # fov = self.conf.eye_fovea_angles[-1]
+        # retinal_res = int(self.conf.retinal_res[0])
+        # for i in range(0, len(self.conf.eye_fovea_angles)):
+        #     angle = self.conf.eye_fovea_angles[i]
+        #     k = self.conf.eye_fovea_downsamples[i]
+        #     roi_size = int(np.ceil(retinal_res * angle / fov))
+        #     roi_offset = int((retinal_res - roi_size) / 2)
+        #     roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
+        #     roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
+        #     if k == 1:
+        #         fovea_layers.append(roi_img)
+        #         fovea_layer_masks.append(roi_mask)
+        #     else:
+        #         fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
+        #         fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
+        # return [ fovea_layers, fovea_layer_masks ]
+
+    ## TO BE CHECK
+    def GenFoveaLayersBatch(self, retinal, retinal_mask):
         '''
         Generate foveated layers and corresponding masks
         
@@ -177,18 +328,48 @@ class RetinalGen(object):
         fovea_layers = []
         fovea_layer_masks = []
         fov = self.conf.eye_fovea_angles[-1]
+        # print("fov:",fov)
         retinal_res = int(self.conf.retinal_res[0])
+        # print("retinal_res:",retinal_res)
+        # print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles))
         for i in range(0, len(self.conf.eye_fovea_angles)):
             angle = self.conf.eye_fovea_angles[i]
             k = self.conf.eye_fovea_downsamples[i]
             roi_size = int(np.ceil(retinal_res * angle / fov))
             roi_offset = int((retinal_res - roi_size) / 2)
-            roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
-            roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
+            # [2, 3, 320, 320]
+            roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
+            # print("roi_img:",roi_img.shape)
+            # [2, 320, 320]
+            roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
+            # print("roi_mask:",roi_mask.shape)
             if k == 1:
                 fovea_layers.append(roi_img)
                 fovea_layer_masks.append(roi_mask)
             else:
-                fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
-                fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
-        return [ fovea_layers, fovea_layer_masks ]
\ No newline at end of file
+                fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k))
+                fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k))
+        return [ fovea_layers, fovea_layer_masks ]
+    
+    ## TO BE CHECK
+    def GenFoveaRetinal(self, b_fovea_layers):
+        '''
+        Generate foveated retinal image by blending fovea layers
+        **Note: current implementation only support two fovea layers**
+        
+        Parameters
+        --------
+        b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers
+        
+        Returns
+        --------
+        B x 3 x H_r x W_r, batch of foveated retinal images
+        '''
+        b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1],
+            scale_factor=self.conf.eye_fovea_downsamples[1],
+            mode='bilinear', align_corners=False)
+        region = self.conf.GetRegionOfFoveaLayer(0)
+        blend = self.conf.eye_fovea_blend[0]
+        b_roi = b_fovea_retinal[:, :, region, region]
+        b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend)
+        return b_fovea_retinal
diff --git a/loss.py b/loss.py
new file mode 100644
index 0000000..7307a67
--- /dev/null
+++ b/loss.py
@@ -0,0 +1,80 @@
+import torch
+from ssim import *
+from perc_loss import * 
+
+l1loss = torch.nn.L1Loss()
+perc_loss = VGGPerceptualLoss() 
+perc_loss = perc_loss.to("cuda:1")
+
+##### LOSS #####
+def calImageGradients(images):
+    # x is a 4-D tensor
+    dx = images[:, :, 1:, :] - images[:, :, :-1, :]
+    dy = images[:, :, :, 1:] - images[:, :, :, :-1]
+    return dx, dy
+
+def loss_new(generated, gt):
+    mse_loss = torch.nn.MSELoss()
+    rmse_intensity = mse_loss(generated, gt)
+    psnr_intensity = torch.log10(rmse_intensity)
+    # print("psnr:",psnr_intensity)
+    # ssim_intensity = ssim(generated, gt)
+    labels_dx, labels_dy = calImageGradients(gt)
+    # print("generated:",generated.shape)
+    preds_dx, preds_dy = calImageGradients(generated)
+    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
+    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
+    # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
+    p_loss = perc_loss(generated,gt)
+    # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
+    total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
+    # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
+    return total_loss
+
+def loss_without_perc(generated, gt): 
+    mse_loss = torch.nn.MSELoss()
+    rmse_intensity = mse_loss(generated, gt)
+    psnr_intensity = torch.log10(rmse_intensity)
+    # print("psnr:",psnr_intensity)
+    # ssim_intensity = ssim(generated, gt)
+    labels_dx, labels_dy = calImageGradients(gt)
+    # print("generated:",generated.shape)
+    preds_dx, preds_dy = calImageGradients(generated)
+    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
+    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
+    # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
+    # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
+    total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
+    # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
+    return total_loss
+##### LOSS #####
+
+
+class ReconstructionLoss(torch.nn.Module):
+    def __init__(self):
+        super(ReconstructionLoss, self).__init__()
+
+    def forward(self, generated, gt):
+        rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
+        psnr_intensity = torch.log10(rmse_intensity)
+        labels_dx, labels_dy = calImageGradients(gt)
+        preds_dx, preds_dy = calImageGradients(generated)
+        rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
+        psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
+        total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
+        return total_loss
+
+class PerceptionReconstructionLoss(torch.nn.Module):
+    def __init__(self):
+        super(PerceptionReconstructionLoss, self).__init__()
+
+    def forward(self, generated, gt):
+        rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
+        psnr_intensity = torch.log10(rmse_intensity)
+        labels_dx, labels_dy = calImageGradients(gt)
+        preds_dx, preds_dy = calImageGradients(generated)
+        rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
+        psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
+        p_loss = perc_loss(generated,gt)
+        total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
+        return total_loss
diff --git a/main.py b/main.py
index 278ad74..104c174 100644
--- a/main.py
+++ b/main.py
@@ -12,212 +12,33 @@ from torch.autograd import Variable
 
 import cv2
 from gen_image import *
+from loss import *
 import json
-from ssim import *
-from perc_loss import * 
 from conf import Conf
 
-from model.baseline import *
+from baseline import *
+from data import * 
 
 import torch.autograd.profiler as profiler
 # param
-BATCH_SIZE = 2
-NUM_EPOCH = 300
-
+BATCH_SIZE = 1
+NUM_EPOCH = 1001
 INTERLEAVE_RATE = 2
-
 IM_H = 320
 IM_W = 320
-
 Retinal_IM_H = 320
 Retinal_IM_W = 320
-
-N = 9 # number of input light field stack
+N = 25 # number of input light field stack
 M = 2 # number of display layers
-
-DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_fovea"
-DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq.json"
-DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
-OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq"
-
-
+DATA_FILE = "/home/yejiannan/Project/LightField/data/FlowRPG1211"
+DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq_flow_RPG.json"
+# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
+OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq_flow_RPG_seq5_same_loss"
 OUT_CHANNELS_RB = 128
 KERNEL_SIZE_RB = 3
 KERNEL_SIZE = 3
-
 LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
-FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
-
-class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
-    def __init__(self, file_dir_path, file_json, transforms=None):
-        self.file_dir_path = file_dir_path
-        self.transforms = transforms
-        # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
-        with open(file_json, encoding='utf-8') as file:
-            self.dataset_desc = json.loads(file.read())
-
-    def __len__(self):
-        return len(self.dataset_desc["focaldepth"])
-
-    def __getitem__(self, idx):
-        lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
-        if self.transforms:
-            lightfield_images = self.transforms(lightfield_images)
-        # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
-        return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
-
-    def get_datum(self, idx):
-        lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
-        # print(lf_image_paths)
-        fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
-        fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][idx])
-        # print(fd_gt_path)
-        lf_images = []
-        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
-        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
-
-        ## IF GrayScale
-        # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
-        # lf_image_big = np.expand_dims(lf_image_big, axis=-1)
-        # print(lf_image_big.shape)
-
-        for i in range(9):
-            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
-            ## IF GrayScale
-            # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
-            # print(lf_image.shape)
-            lf_images.append(lf_image)
-        gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
-        gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
-        gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
-        gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
-        ## IF GrayScale
-        # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
-        # gt = np.expand_dims(gt, axis=-1)
-
-        fd = self.dataset_desc["focaldepth"][idx]
-        gazeX = self.dataset_desc["gazeX"][idx]
-        gazeY = self.dataset_desc["gazeY"][idx]
-        sample_idx = self.dataset_desc["idx"][idx]
-        return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
-
-class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
-    def __init__(self, file_dir_path, file_json, transforms=None):
-        self.file_dir_path = file_dir_path
-        self.transforms = transforms
-        # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
-        with open(file_json, encoding='utf-8') as file:
-            self.dataset_desc = json.loads(file.read())
-
-    def __len__(self):
-        return len(self.dataset_desc["focaldepth"])
-
-    def __getitem__(self, idx):
-        lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
-        if self.transforms:
-            lightfield_images = self.transforms(lightfield_images)
-        # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
-        return (lightfield_images, fd, gazeX, gazeY, sample_idx)
-
-    def get_datum(self, idx):
-        lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
-        # print(fd_gt_path)
-        lf_images = []
-        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
-        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
-
-        ## IF GrayScale
-        # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
-        # lf_image_big = np.expand_dims(lf_image_big, axis=-1)
-        # print(lf_image_big.shape)
-
-        for i in range(9):
-            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
-            ## IF GrayScale
-            # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
-            # print(lf_image.shape)
-            lf_images.append(lf_image)
-        ## IF GrayScale
-        # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
-        # gt = np.expand_dims(gt, axis=-1)
-
-        fd = self.dataset_desc["focaldepth"][idx]
-        gazeX = self.dataset_desc["gazeX"][idx]
-        gazeY = self.dataset_desc["gazeY"][idx]
-        sample_idx = self.dataset_desc["idx"][idx]
-        return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
-
-class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
-    def __init__(self, file_dir_path, file_json, transforms=None):
-        self.file_dir_path = file_dir_path
-        self.transforms = transforms
-        with open(file_json, encoding='utf-8') as file:
-            self.dataset_desc = json.loads(file.read())
-
-    def __len__(self):
-        return len(self.dataset_desc["seq"])
-
-    def __getitem__(self, idx):
-        lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
-        fd = fd.astype(np.float32)
-        gazeX = gazeX.astype(np.float32)
-        gazeY = gazeY.astype(np.float32)
-        sample_idx = sample_idx.astype(np.int64)
-        # print(fd)
-        # print(gazeX)
-        # print(gazeY)
-        # print(sample_idx)
-
-        # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
-        # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
-        if self.transforms:
-            lightfield_images = self.transforms(lightfield_images)
-        return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
-
-    def get_datum(self, idx):
-        indices = self.dataset_desc["seq"][idx]
-        # print("indices:",indices)
-        lf_images = []
-        fd = []
-        gazeX = []
-        gazeY = []
-        sample_idx = []
-        gt = []
-        gt2 = []
-        for i in range(len(indices)):
-            lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][indices[i]])
-            fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][indices[i]])
-            fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][indices[i]])
-            lf_image_one_sample = []
-            lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
-            lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
-
-            for j in range(9):
-                lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
-                ## IF GrayScale
-                # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
-                # print(lf_image.shape)
-                lf_image_one_sample.append(lf_image)
-
-            gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
-            gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
-            gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
-            gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
-
-            # print("indices[i]:",indices[i])
-            fd.append([self.dataset_desc["focaldepth"][indices[i]]])
-            gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
-            gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
-            sample_idx.append([self.dataset_desc["idx"][indices[i]]])
-            lf_images.append(lf_image_one_sample)
-        #lf_images: 5,9,320,320
-
-        return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
-
-#### Image Gen
-conf = Conf()
-u = GenSamplesInPupil(conf.pupil_size, 5)
-gen = RetinalGen(conf, u)
+FIRSST_LAYER_CHANNELS = 75 * INTERLEAVE_RATE**2
 
 def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
     # layers: batchsize, 2*color, height, width 
@@ -248,6 +69,7 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
     #  retinal bs x color x height x width
     retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160)
     mask_fovea = torch.empty(layers.shape[0], 2, 160, 160)
+
     for i in range(0, layers.size()[0]):
         phi = phi_dict[int(sample_idx[i].data)]
         # print("phi_i:",phi.shape)
@@ -261,12 +83,65 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
         fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i)
         retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0)
         mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0)
-        
+    
+
+    retinal_fovea = var_or_cuda(retinal_fovea)
+    mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
+    # mask = torch.stack(mask,dim = 0).unsqueeze(1) 
+    return retinal_fovea, mask_fovea
+
+import time
+def GenRetinalGazeFromLayersBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid):
+    # layers: batchsize, 2*color, height, width 
+    # Phi:torch.Size([batchsize, Layer, h, w,  41, 2])
+    # df : batchsize,..
+    # start1 = time.time()
+    #  retinal bs x color x height x width
+    retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2")
+    mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2")
+    # start = time.time()
+    retinal = gen.GenRetinalFromLayersBatch(layers,phi_batch)
+    # print("retinal:",retinal.shape) #retinal: torch.Size([2, 3, 320, 320])
+    # print("t2:",time.time() - start)
+
+    # start = time.time()
+    fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch)
+
+    mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1)
+    retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1)
+    # print("t3:",time.time() - start)
+
     retinal_fovea = var_or_cuda(retinal_fovea)
     mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
     # mask = torch.stack(mask,dim = 0).unsqueeze(1) 
     return retinal_fovea, mask_fovea
 
+def MergeBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid):
+    # layers: batchsize, 2*color, height, width 
+    # Phi:torch.Size([batchsize, Layer, h, w,  41, 2])
+    # df : batchsize,..
+    # start1 = time.time()
+    #  retinal bs x color x height x width
+    # retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2")
+    # mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2")
+    # start = time.time()
+    retinal = gen.GenRetinalFromLayersBatch(layers,phi) #retinal: torch.Size([BatchSize , 3, 320, 320])
+    retinal.mul_(~retinal_invalid.to("cuda:2"))
+    # print("retinal:",retinal.shape) 
+    # print("t2:",time.time() - start)
+    
+    # start = time.time()
+    # fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch)
+
+    # mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1)
+    # retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1)
+    # print("t3:",time.time() - start)
+
+    # retinal_fovea = var_or_cuda(retinal_fovea)
+    # mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
+    # mask = torch.stack(mask,dim = 0).unsqueeze(1) 
+    return retinal
+
 def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
     # layers: batchsize, 2*color, height, width 
     # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
@@ -285,47 +160,7 @@ def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
     return retinal.unsqueeze(0), mask_out
 #### Image Gen End
 
-weightVarScale = 0.25
-bias_stddev = 0.01
-
-def weight_init_normal(m):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        torch.nn.init.xavier_normal_(m.weight.data)
-        torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
-    elif classname.find("BatchNorm2d") != -1:
-        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
-        torch.nn.init.constant_(m.bias.data, 0.0)
-
-
-def calImageGradients(images):
-    # x is a 4-D tensor
-    dx = images[:, :, 1:, :] - images[:, :, :-1, :]
-    dy = images[:, :, :, 1:] - images[:, :, :, :-1]
-    return dx, dy
-
-
-perc_loss = VGGPerceptualLoss() 
-perc_loss = perc_loss.to("cuda:1")
-
-def loss_new(generated, gt):
-    mse_loss = torch.nn.MSELoss()
-    rmse_intensity = mse_loss(generated, gt)
-    
-    psnr_intensity = torch.log10(rmse_intensity)
-    # print("psnr:",psnr_intensity)
-    # ssim_intensity = ssim(generated, gt)
-    labels_dx, labels_dy = calImageGradients(gt)
-    # print("generated:",generated.shape)
-    preds_dx, preds_dy = calImageGradients(generated)
-    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
-    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
-    # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
-    p_loss = perc_loss(generated,gt)
-    # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
-    total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
-    # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
-    return total_loss
+from weight_init import weight_init_normal
 
 def save_checkpoints(file_path, epoch_idx, model, model_solver):
     print('[INFO] Saving checkpoint to %s ...' % ( file_path))
@@ -336,93 +171,112 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
     }
     torch.save(checkpoint, file_path)
 
-mode = "train"
-
-import pickle
-def save_obj(obj, name ):
-    # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f:
-    #     pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
-    torch.save(obj,'./outputF/dict/'+ name + '.pkl')
-def load_obj(name):
-    # with open('./outputF/dict/' + name + '.pkl', 'rb') as f:
-    #     return pickle.load(f)
-    return torch.load('./outputF/dict/'+ name + '.pkl')
-
-def hook_fn_back(m, i, o):
-  for grad in i:
-    try:
-      print("Input Grad:",m,grad.shape,grad.sum())
-    except AttributeError: 
-      print ("None found for Gradient")
-  for grad in o:  
-    try:
-      print("Output Grad:",m,grad.shape,grad.sum())
-    except AttributeError: 
-      print ("None found for Gradient")
-  print("\n")
-
-def hook_fn_for(m, i, o):
-  for grad in i:
-    try:
-      print("Input Feats:",m,grad.shape,grad.sum())
-    except AttributeError: 
-      print ("None found for Gradient")
-  for grad in o:  
-    try:
-      print("Output Feats:",m,grad.shape,grad.sum())
-    except AttributeError: 
-      print ("None found for Gradient")
-  print("\n")
-
-def generatePhiMaskDict(data_json, generator):
-    phi_dict = {}
-    mask_dict = {}
-    idx_info_dict = {}
-    with open(data_json, encoding='utf-8') as file:
-        dataset_desc = json.loads(file.read())
-        for i in range(len(dataset_desc["focaldepth"])):
-            # if i == 2:
-            #     break
-            idx = dataset_desc["idx"][i] 
-            focaldepth = dataset_desc["focaldepth"][i]
-            gazeX = dataset_desc["gazeX"][i]
-            gazeY = dataset_desc["gazeY"][i]
-            print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
-            phi,mask =  generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
-            phi_dict[idx]=phi
-            mask_dict[idx]=mask
-            idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
-    return phi_dict,mask_dict,idx_info_dict
-
+# import pickle
+# def save_obj(obj, name ):
+#     # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f:
+#     #     pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
+#     torch.save(obj,'./outputF/dict/'+ name + '.pkl')
+# def load_obj(name):
+#     # with open('./outputF/dict/' + name + '.pkl', 'rb') as f:
+#     #     return pickle.load(f)
+#     return torch.load('./outputF/dict/'+ name + '.pkl')
+
+# def generatePhiMaskDict(data_json, generator):
+#     phi_dict = {}
+#     mask_dict = {}
+#     idx_info_dict = {}
+#     with open(data_json, encoding='utf-8') as file:
+#         dataset_desc = json.loads(file.read())
+#         for i in range(len(dataset_desc["focaldepth"])):
+#             # if i == 2:
+#             #     break
+#             idx = dataset_desc["idx"][i] 
+#             focaldepth = dataset_desc["focaldepth"][i]
+#             gazeX = dataset_desc["gazeX"][i]
+#             gazeY = dataset_desc["gazeY"][i]
+#             print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
+#             phi,mask =  generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
+#             phi_dict[idx]=phi
+#             mask_dict[idx]=mask
+#             idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
+#     return phi_dict,mask_dict,idx_info_dict
+
+# def generatePhiMaskDictNew(data_json, generator):
+#     phi_dict = {}
+#     mask_dict = {}
+#     idx_info_dict = {}
+#     with open(data_json, encoding='utf-8') as file:
+#         dataset_desc = json.loads(file.read())
+#         for i in range(len(dataset_desc["seq"])):
+#             for j in dataset_desc["seq"][i]:
+#                 idx = dataset_desc["idx"][j] 
+#                 focaldepth = dataset_desc["focaldepth"][j]
+#                 gazeX = dataset_desc["gazeX"][j]
+#                 gazeY = dataset_desc["gazeY"][j]
+#                 print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
+#                 phi,mask =  generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
+#                 phi_dict[idx]=phi
+#                 mask_dict[idx]=mask
+#                 idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
+#     return phi_dict,mask_dict,idx_info_dict
+
+mode = "Silence" #"Perf"
+model_type = "RNN" #"RNN"
+w_frame = 0.9
+w_inter_frame = 0.1
+batch_model = "NoSingle"
+loss1 = ReconstructionLoss()
+loss2 = ReconstructionLoss()
 if __name__ == "__main__":
     ############################## generate phi and mask in pre-training
     
     # print("generating phi and mask...")
-    # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDict(DATA_JSON,gen)
-    # save_obj(phi_dict,"phi_1204")
-    # save_obj(mask_dict,"mask_1204")
-    # save_obj(idx_info_dict,"idx_info_1204")
+    # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDictNew(DATA_JSON,gen)
+    # # save_obj(phi_dict,"phi_1204")
+    # # save_obj(mask_dict,"mask_1204")
+    # # save_obj(idx_info_dict,"idx_info_1204")
     # print("generating phi and mask end.")
     # exit(0)
     ############################# load phi and mask in pre-training
-    print("loading phi and mask ...")
-    phi_dict = load_obj("phi_1204")
-    mask_dict = load_obj("mask_1204")
-    idx_info_dict = load_obj("idx_info_1204")
-    print(len(phi_dict))
-    print(len(mask_dict))
-    print("loading phi and mask end") 
+    # print("loading phi and mask ...")
+    # phi_dict = load_obj("phi_1204")
+    # mask_dict = load_obj("mask_1204")
+    # idx_info_dict = load_obj("idx_info_1204")
+    # print(len(phi_dict))
+    # print(len(mask_dict))
+    # print("loading phi and mask end") 
+
+    #### Image Gen and conf 
+    conf = Conf()
+    gen = RetinalGen(conf)
 
     #train
-    train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSeqDataLoader(DATA_FILE,DATA_JSON),
+    train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldFlowSeqDataLoader(DATA_FILE,DATA_JSON, gen, conf),
                                                     batch_size=BATCH_SIZE,
-                                                    num_workers=0,
+                                                    num_workers=8,
                                                     pin_memory=True,
                                                     shuffle=True,
                                                     drop_last=False)
+    #Data loader test
     print(len(train_data_loader))
 
-    # exit(0)
+    # # lightfield_images, gt, flow, fd, gazeX, gazeY, posX, posY, sample_idx, phi, mask
+    # # lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, fd, gazeX, gazeY, posX, posY, posZ, sample_idx
+    # for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader):
+    #     print(image_set.shape,type(image_set))
+    #     print(gt.shape,type(gt))
+    #     print(phi.shape,type(phi))
+    #     print(phi_invalid.shape,type(phi_invalid))
+    #     print(retinal_invalid.shape,type(retinal_invalid))
+    #     print(flow.shape,type(flow))
+    #     print(df.shape,type(df))
+    #     print(gazeX.shape,type(gazeX))
+    #     print(posX.shape,type(posX))
+    #     print(sample_idx.shape,type(sample_idx))
+    #     print("test train dataloader.")
+    #     exit(0)
+    #Data loader test end
+    
 
 
     ################################################ val #########################################################
@@ -464,21 +318,24 @@ if __name__ == "__main__":
             
     #         print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data)
     #         for i in range(output1.size()[0]):
-    #             save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
-    #             save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
-    #             save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
-    #             save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
+    #             save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
+    #             save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
+    #             save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
+    #             save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
 
-    #         # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.3f.png"%(df[0].data)))
-    #         # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.3f.png"%(df[0].data)))
+    #         # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.5f.png"%(df[0].data)))
+    #         # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.5f.png"%(df[0].data)))
     #         # output = GenRetinalFromLayersBatch(output,conf,df,v,u)
-    #         # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
+    #         # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.5f.png"%(df[0].data)))
     # exit()
 
     ################################################ train #########################################################
-    lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE)
+    if model_type == "RNN": 
+        lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE)
+    else:
+        lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False)
     lf_model.apply(weight_init_normal)
-
+    lf_model.train()
     epoch_begin = 0
 
     ################################ load model file
@@ -493,144 +350,243 @@ if __name__ == "__main__":
 
     if torch.cuda.is_available():
         # lf_model = torch.nn.DataParallel(lf_model).cuda()
-        lf_model = lf_model.to('cuda:1')
-    lf_model.train()
-    optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999))
-    l1loss = torch.nn.L1Loss()
+        lf_model = lf_model.to('cuda:2')
+    
+    optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
+    
     # lf_model.output_layer.register_backward_hook(hook_fn_back)
+    if mode=="Perf":
+        start = torch.cuda.Event(enable_timing=True)
+        end = torch.cuda.Event(enable_timing=True)
+        start.record()
     print("begin training....")
     for epoch in range(epoch_begin, NUM_EPOCH):
-        for batch_idx, (image_set, gt, gt2, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
+        for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader):
             # print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5])
             # print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3])
             # print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3])
-            
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("load:",start.elapsed_time(end))
+
+                start.record()
             #reshape for input
-            image_set = image_set.permute(0,1,2,5,3,4) # N S LF C H W
+            image_set = image_set.permute(0,1,2,5,3,4) # N Seq 5 LF 25 C 3 H W 
             image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W
+            # N, Seq 5, LF 25 C 3, H, W 
             image_set = var_or_cuda(image_set)
-            gt = gt.permute(0,1,4,2,3) # N S C H W
+            # print(image_set.shape) #torch.Size([2, 5, 75, 320, 320])
+
+            gt = gt.permute(0,1,4,2,3) # BS Seq 5 C 3 H W 
             gt = var_or_cuda(gt)
 
-            gt2 = gt2.permute(0,1,4,2,3)
-            gt2 = var_or_cuda(gt2)
+            flow = var_or_cuda(flow) #BS,Seq-1,H,W,2
+
+            # gt2 = gt2.permute(0,1,4,2,3)
+            # gt2 = var_or_cuda(gt2)
 
-            gen1 = torch.empty(gt.shape)
+            gen1 = torch.empty(gt.shape) # BS Seq C H W
             gen1 = var_or_cuda(gen1)
+            # print(gen1.shape) #torch.Size([2, 5, 3, 320, 320])
 
-            gen2 = torch.empty(gt2.shape)
-            gen2 = var_or_cuda(gen2)
+            # gen2 = torch.empty(gt2.shape)
+            # gen2 = var_or_cuda(gen2)
 
-            warped = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
+            #BS, Seq - 1, C, H, W
+            warped = torch.empty(gt.shape[0],gt.shape[1]-1,gt.shape[2],gt.shape[3],gt.shape[4])
             warped = var_or_cuda(warped)
+            gen_temp = torch.empty(warped.shape)
+            gen_temp = var_or_cuda(gen_temp)
+            # print("warped:",warped.shape) #warped: torch.Size([2, 4, 3, 320, 320])
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("data prepare:",start.elapsed_time(end))
+
+                start.record()
+            if model_type == "RNN": 
+                if batch_model != "Single":
+                    for k in range(image_set.shape[1]):
+                        if k == 0:
+                            lf_model.reset_hidden(image_set[:,k])
+                        output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
+                        output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
+                        gen1[:,k] = output1[:,0:3]
+                        gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
+
+                        if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
+                            for i in range(output.shape[0]):
+                                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                                save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                    
+                    for i in range(1,gt.shape[1]):
+                        # print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320])
+                        # print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320])
+                        warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
+                        gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
+                else:
+                    for k in range(image_set.shape[1]):
+                        if k == 0:
+                            lf_model.reset_hidden(image_set[:,k])
+                        output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
+                        output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
+                        gen1[:,k] = output1[:,0:3]
+                        gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
+
+                        if k != image_set.shape[1]-1:
+                            warped[:,k] = FlowMap(output1.detach(),flow[:,k]).mul(~flow_invalid_mask[:,k].unsqueeze(1).to("cuda:2"))
+                        loss1 = loss_without_perc(output1,gt[:,k])
+                        loss = (w_frame * loss1)
+                        if k==0:
+                            loss.backward(retain_graph=False)
+                            optimizer.step()
+                            lf_model.zero_grad()
+                            optimizer.zero_grad()
+                            print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
+                        else:
+                            output1mask = output1.mul(~flow_invalid_mask[:,k-1].unsqueeze(1).to("cuda:2"))
+                            loss2 = l1loss(output1mask,warped[:,k-1])
+                            loss += (w_inter_frame * loss2)
+                            loss.backward(retain_graph=False)
+                            optimizer.step()
+                            lf_model.zero_grad()
+                            optimizer.zero_grad()
+                            # print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
+                            print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",frame loss:",loss1.item(),",inter loss:",w_inter_frame * loss2.item())
+
+                        
+                        if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
+                            for i in range(output.shape[0]):
+                                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                                save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                    
+                    # BSxSeq C H W
+                    gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
+                    gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
+                    
+                    if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
+                        for i in range(gt.size()[0]):
+                            save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
+                            save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
+                    if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
+                        save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
+            else:
+                if batch_model != "Single":
+                    for k in range(image_set.shape[1]):
+                        output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
+                        output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
+                        gen1[:,k] = output1[:,0:3]
+                        gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
+                        
+                        if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
+                            for i in range(output.shape[0]):
+                                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                                save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                    for i in range(1,gt.shape[1]):
+                        # print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320])
+                        # print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320])
+                        warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
+                        gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
+                else:
+                    for k in range(image_set.shape[1]):
+                        output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
+                        output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
+                        gen1[:,k] = output1[:,0:3]
+                        gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
+                        
+                        # print(output1.shape) #torch.Size([BS, 3, 320, 320])
+                        loss1 = loss_without_perc(output1,gt[:,k])
+                        loss = (w_frame * loss1)
+                        # print("loss:",loss1.item())
+                        loss.backward(retain_graph=False)
+                        optimizer.step()
+                        lf_model.zero_grad()
+                        optimizer.zero_grad()
+
+                        print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
+                        if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
+                            for i in range(output.shape[0]):
+                                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                                save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
+                    for i in range(1,gt.shape[1]):
+                        warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
+                        gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
+
+                    warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
+                    gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4])
+                    loss3 = l1loss(warped,gen_temp)
+                    print("Epoch:",epoch,",Iter:",batch_idx,",inter-frame loss:",w_inter_frame *loss3.item())
+                    
+                    # BSxSeq C H W
+                    gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
+                    gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
+                    
+                    if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
+                        for i in range(gt.size()[0]):
+                            save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
+                            save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
+                    if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
+                        save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
+
+
+            if batch_model != "Single":
+                if mode=="Perf":
+                    end.record()
+                    torch.cuda.synchronize()
+                    print("forward:",start.elapsed_time(end))
+
+                    start.record()
+                optimizer.zero_grad()
 
-            delta = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
-            delta = var_or_cuda(delta)
-            
-            for k in range(image_set.shape[1]):
-                if k == 0:
-                    lf_model.reset_hidden(image_set[:,k])
                 
-                # start = torch.cuda.Event(enable_timing=True)
-                # end = torch.cuda.Event(enable_timing=True)
-                # start.record()
-                output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k])
-                # end.record()
-                # torch.cuda.synchronize()
-                # print("Model Forward:",start.elapsed_time(end))
-                # print("output:",output.shape) # [2, 6, 320, 320]
-                # exit()
-                ########################### Use Pregen Phi and Mask ###################
-                # start.record()
-                output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx[:,k], phi_dict, mask_dict)
-                # end.record()
-                # torch.cuda.synchronize()
-                # print("Merge:",start.elapsed_time(end))
-
-                # print("output1 shape:",output1.shape, "mask shape:",mask.shape)
-                # output1 shape: torch.Size([2, 6, 160, 160]) mask shape: torch.Size([2, 2, 160, 160])
-                for i in range(0, 2):
-                    output1[:,i*3:i*3+3].mul_(mask[:,i:i+1])
-                    if i == 0:
-                        gt[:,k].mul_(mask[:,i:i+1])
-                    if i == 1:
-                        gt2[:,k].mul_(mask[:,i:i+1])
+                # BSxSeq C H W
+                gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
+                # gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4])
+                # BSxSeq C H W
+                gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
+                # gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
+
+                # BSx(Seq-1) C H W
+                warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
+                gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4])
                 
-                gen1[:,k] = output1[:,0:3]
-                gen2[:,k] = output1[:,3:6]
-                if ((epoch%5== 0) or epoch == 2):
-                    for i in range(output.shape[0]):
-                        save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
-                        save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
-
-            ########################### Update ###################
-            for i in range(1,image_set.shape[1]):
-                delta[:,i-1] = gt2[:,i] - gt2[:,i]
-                warped[:,i-1] = gen2[:,i]-gen2[:,i-1]
-            
-            optimizer.zero_grad()
-
-            # # N S C H W
-            gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
-            gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4])
-            gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
-            gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
-            warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
-            delta = delta.reshape(-1,delta.shape[2],delta.shape[3],delta.shape[4])
-
-
-            # start = torch.cuda.Event(enable_timing=True)
-            # end = torch.cuda.Event(enable_timing=True)
-            # start.record()
-            loss1 = loss_new(gen1,gt)
-            loss2 = loss_new(gen2,gt2)
-            loss3 = l1loss(warped,delta)
-            loss = loss1+loss2+loss3
-            # end.record()
-            # torch.cuda.synchronize()
-            # print("loss comp:",start.elapsed_time(end))
-
-            
-            # start.record()
-            loss.backward()
-            # end.record()
-            # torch.cuda.synchronize()
-            # print("backward:",start.elapsed_time(end))
-
-            # start.record()
-            optimizer.step()
-            # end.record()
-            # torch.cuda.synchronize()
-            # print("optimizer step:",start.elapsed_time(end))
-            
-            ## Update Prev
-            print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
-            ########################### Save #####################
-            if ((epoch%5== 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
-                for i in range(gt.size()[0]):
-                    # df 2,5 
-                    save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
-                    save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
-                    save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
-                    save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
-            if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
-                save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),epoch,lf_model,optimizer)
-
-            ########################## test Phi and Mask ##########################
-            # phi,mask =  gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
-            # # print("gaze Online:",gazeX[0]," ,",gazeY[0])
-            # # print("df Online:",df[0])
-            # # print("idx:",int(sample_idx[0].data))
-            # phi_t = phi_dict[int(sample_idx[0].data)]
-            # mask_t = mask_dict[int(sample_idx[0].data)]
-            # # print("idx info:",idx_info_dict[int(sample_idx[0].data)])
-            # # print("phi online:", phi.shape, " phi_t:", phi_t.shape)
-            # # print("mask online:", mask.shape, " mask_t:", mask_t.shape)
-            # print("phi delta:", (phi-phi_t).sum()," mask delta:",(mask -mask_t).sum())
-            # exit(0)
-
-            ###########################Gen Batch 1 by 1###################
-            # phi,mask =  gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
-            # # print(phi.shape) # 2,240,320,41,2
-            # output1, mask = GenRetinalFromLayersBatch_Online(output, gen, phi, mask)
-            ###########################Gen Batch 1 by 1###################
\ No newline at end of file
+                loss1_value = loss1(gen1,gt)
+                loss2_value = loss2(warped,gen_temp)
+                if model_type == "RNN": 
+                    loss = (w_frame * loss1_value)+ (w_inter_frame * loss2_value)
+                else:
+                    loss = (w_frame * loss1_value)
+
+                if mode=="Perf":
+                    end.record()
+                    torch.cuda.synchronize()
+                    print("compute loss:",start.elapsed_time(end))
+
+                    start.record()
+                loss.backward()
+                if mode=="Perf":
+                    end.record()
+                    torch.cuda.synchronize()
+                    print("backward:",start.elapsed_time(end))
+
+                    start.record()
+                optimizer.step()
+                if mode=="Perf":
+                    end.record()
+                    torch.cuda.synchronize()
+                    print("update:",start.elapsed_time(end))
+
+                print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item(),",frame loss:",loss1_value.item(),",inter-frame loss:",loss2_value.item())
+                
+                # exit(0)
+                ########################### Save #####################
+                if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
+                    for i in range(gt.size()[0]):
+                        # df 2,5 
+                        save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
+                        # save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
+                        save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
+                        # save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
+                if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
+                    save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
\ No newline at end of file
diff --git a/main_lf_syn.py b/main_lf_syn.py
index f6de1e1..f84ab8f 100644
--- a/main_lf_syn.py
+++ b/main_lf_syn.py
@@ -31,7 +31,7 @@ M = 1 # number of display layers
 DATA_FILE = "/home/yejiannan/Project/LightField/data/lf_syn"
 DATA_JSON = "/home/yejiannan/Project/LightField/data/data_lf_syn_full.json"
 # DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
-OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full_perc"
+OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full"
 OUT_CHANNELS_RB = 128
 KERNEL_SIZE_RB = 3
 KERNEL_SIZE = 3
@@ -50,7 +50,7 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
     torch.save(checkpoint, file_path)
 
 mode = "Silence" #"Perf"
-w_frame = 0.9
+w_frame = 1.0
 loss1 = PerceptionReconstructionLoss()
 if __name__ == "__main__":
     #train
@@ -70,7 +70,7 @@ if __name__ == "__main__":
 
     if torch.cuda.is_available():
         # lf_model = torch.nn.DataParallel(lf_model).cuda()
-        lf_model = lf_model.to('cuda:3')
+        lf_model = lf_model.to('cuda:1')
     
     optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
     
diff --git a/model/__init__.py b/model/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/model/recurrent.py b/model/recurrent.py
deleted file mode 100644
index 3aeaf6a..0000000
--- a/model/recurrent.py
+++ /dev/null
@@ -1,146 +0,0 @@
-import torch, os, sys, cv2
-import torch.nn as nn
-from torch.nn import init
-import functools
-import torch.optim as optim
-
-from torch.utils.data import Dataset, DataLoader
-from torch.nn import functional as func
-from PIL import Image
-
-import torchvision.transforms as transforms
-import numpy as np 
-import torch
-
-
-class RecurrentBlock(nn.Module):
-
-	def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False):
-		super(RecurrentBlock, self).__init__()
-
-		self.input_nc = input_nc
-		self.output_nc = output_nc
-
-		self.downsampling = downsampling
-		self.upsampling = upsampling
-		self.bottleneck = bottleneck
-
-		self.hidden = None
-
-		if self.downsampling:
-			self.l1 = nn.Sequential(
-					nn.Conv2d(input_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1)
-				)
-			self.l2 = nn.Sequential(
-					nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1),
-					nn.Conv2d(output_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1),
-				)
-		elif self.upsampling:
-			self.l1 = nn.Sequential(
-					nn.Upsample(scale_factor=2, mode='nearest'),
-					nn.Conv2d(2 * input_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1),
-					nn.Conv2d(output_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1),
-				)
-		elif self.bottleneck:
-			self.l1 = nn.Sequential(
-					nn.Conv2d(input_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1)
-				)
-			self.l2 = nn.Sequential(
-					nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1),
-					nn.Conv2d(output_nc, output_nc, 3, padding=1),
-					nn.LeakyReLU(negative_slope=0.1),
-				)
-
-	def forward(self, inp):
-
-		if self.downsampling:
-			op1 = self.l1(inp)
-			op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
-
-			self.hidden = op2
-
-			return op2
-		elif self.upsampling:
-			op1 = self.l1(inp)
-
-			return op1
-		elif self.bottleneck:
-			op1 = self.l1(inp)
-			op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
-
-			self.hidden = op2
-
-			return op2
-
-	def reset_hidden(self, inp, dfac):
-		size = list(inp.size())
-		size[1] = self.output_nc
-		size[2] /= dfac
-		size[3] /= dfac
-
-		self.hidden_size = size
-		self.hidden = torch.zeros(*(size)).to('cuda:0')
-
-
-
-class RecurrentAE(nn.Module):
-
-	def __init__(self, input_nc):
-		super(RecurrentAE, self).__init__()
-
-		self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True)
-		self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True)
-		self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True)
-		self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True)
-		self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True)
-
-		self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True)
-
-		self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True)
-		self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True)
-		self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True)
-		self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True)
-		self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True)
-
-	def set_input(self, inp):
-		self.inp = inp['A']
-
-	def forward(self):
-		d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2)
-		d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2)
-		d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2)
-		d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2)
-		d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2)
-
-		b = self.bottleneck(d5)
-
-		u5 = self.u5(torch.cat((b, d5), dim=1))
-		u4 = self.u4(torch.cat((u5, d4), dim=1))
-		u3 = self.u3(torch.cat((u4, d3), dim=1))
-		u2 = self.u2(torch.cat((u3, d2), dim=1))
-		u1 = self.u1(torch.cat((u2, d1), dim=1))
-
-		return u1
-
-	def reset_hidden(self):
-		self.d1.reset_hidden(self.inp, dfac=1)
-		self.d2.reset_hidden(self.inp, dfac=2)
-		self.d3.reset_hidden(self.inp, dfac=4)
-		self.d4.reset_hidden(self.inp, dfac=8)
-		self.d5.reset_hidden(self.inp, dfac=16)
-
-		self.bottleneck.reset_hidden(self.inp, dfac=32)
-
-		self.u4.reset_hidden(self.inp, dfac=16)
-		self.u3.reset_hidden(self.inp, dfac=8)
-		self.u5.reset_hidden(self.inp, dfac=4)
-		self.u2.reset_hidden(self.inp, dfac=2)
-		self.u1.reset_hidden(self.inp, dfac=1)
-
diff --git a/util.py b/util.py
new file mode 100644
index 0000000..7e6846b
--- /dev/null
+++ b/util.py
@@ -0,0 +1,104 @@
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+import glm
+
+gvec_type = [ glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4 ]
+gmat_type = [ [ glm.dmat2, glm.dmat2x3, glm.dmat2x4 ], 
+              [ glm.dmat3x2, glm.dmat3, glm.dmat3x4 ],
+              [ glm.dmat4x2, glm.dmat4x3, glm.dmat4 ] ]
+
+def Fov2Length(angle):
+    return np.tan(angle * np.pi / 360) * 2
+
+def SmoothStep(x0, x1, x):
+    y = torch.clamp((x - x0) / (x1 - x0), 0, 1)
+    return y * y * (3 - 2 * y)
+
+def MatImg2Tensor(img, permute = True, batch_dim = True):
+    batch_input = len(img.shape) == 4
+    if permute:
+        t = torch.from_numpy(np.transpose(img,
+            [0, 3, 1, 2] if batch_input else [2, 0, 1]))
+    else:
+        t = torch.from_numpy(img)
+    if not batch_input and batch_dim:
+        t = t.unsqueeze(0)
+    return t
+
+def MatImg2Numpy(img, permute = True, batch_dim = True):
+    batch_input = len(img.shape) == 4
+    if permute:
+        t = np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1])
+    else:
+        t = img
+    if not batch_input and batch_dim:
+        t = t.unsqueeze(0)
+    return t
+    
+def Tensor2MatImg(t):
+    img = t.squeeze().cpu().numpy()
+    batch_input = len(img.shape) == 4
+    if t.size()[batch_input] <= 4:
+        return np.transpose(img, [0, 2, 3, 1] if batch_input else [1, 2, 0])
+    return img
+
+def ReadImageTensor(path, permute = True, rgb_only = True, batch_dim = True):
+    channels = 3 if rgb_only else 4
+    if isinstance(path,list):
+        first_image = plt.imread(path[0])[:, :, 0:channels]
+        b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
+        b_image[0] = first_image
+        for i in range(1, len(path)):
+            b_image[i] = plt.imread(path[i])[:, :, 0:channels]
+        return MatImg2Tensor(b_image, permute)
+    return MatImg2Tensor(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
+
+def ReadImageNumpyArray(path, permute = True, rgb_only = True, batch_dim = True):
+    channels = 3 if rgb_only else 4
+    if isinstance(path,list):
+        first_image = plt.imread(path[0])[:, :, 0:channels]
+        b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
+        b_image[0] = first_image
+        for i in range(1, len(path)):
+            b_image[i] = plt.imread(path[i])[:, :, 0:channels]
+        return MatImg2Numpy(b_image, permute)
+    return MatImg2Numpy(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
+
+def WriteImageTensor(t, path):
+    image = Tensor2MatImg(t)
+    if isinstance(path,list):
+        if len(image.shape) != 4 or image.shape[0] != len(path):
+            raise ValueError
+        for i in range(len(path)):
+            plt.imsave(path[i], image[i])
+    else:
+        if len(image.shape) == 4 and image.shape[0] != 1:
+            raise ValueError
+        plt.imsave(path, image)
+    
+def PlotImageTensor(t):
+    plt.imshow(Tensor2MatImg(t))
+    
+def Tensor2Glm(t):
+    t = t.squeeze()
+    size = t.size()
+    if len(size) == 1:
+        if size[0] <= 0 or size[0] > 4:
+            raise ValueError
+        return gvec_type[size[0] - 1](t.cpu().numpy())
+    if len(size) == 2:
+        if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4:
+            raise ValueError
+        return gmat_type[size[1] - 2][size[0] - 2](t.cpu().numpy())
+    raise ValueError
+
+def Glm2Tensor(val):
+    return torch.from_numpy(np.array(val))
+
+def MeshGrid(size, normalize=False):
+    y,x = torch.meshgrid(torch.tensor(range(size[0])),
+                          torch.tensor(range(size[1])))
+    if normalize:
+        return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2)
+    return torch.stack([x, y], 2)
\ No newline at end of file
diff --git a/weight_init.py b/weight_init.py
new file mode 100644
index 0000000..9c0a20a
--- /dev/null
+++ b/weight_init.py
@@ -0,0 +1,12 @@
+import torch
+weightVarScale = 0.25
+bias_stddev = 0.01
+
+def weight_init_normal(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        torch.nn.init.xavier_normal_(m.weight.data)
+        torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
+    elif classname.find("BatchNorm2d") != -1:
+        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
+        torch.nn.init.constant_(m.bias.data, 0.0)
\ No newline at end of file
-- 
GitLab