Commit c5434e97 authored by BobYeah's avatar BobYeah
Browse files

Updatae CNN main

parent 95fb58a7
import torch
import argparse
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable
import cv2
from loss import *
import json
from baseline import *
from data import *
import torch.autograd.profiler as profiler
# param
BATCH_SIZE = 2
NUM_EPOCH = 1001
INTERLEAVE_RATE = 2
IM_H = 540
IM_W = 376
Retinal_IM_H = 540
Retinal_IM_W = 376
N = 4 # number of input light field stack
M = 1 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/lf_syn"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_lf_syn_full.json"
# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full_perc"
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3
LAST_LAYER_CHANNELS = 3 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 12 * INTERLEAVE_RATE**2
from weight_init import weight_init_normal
def save_checkpoints(file_path, epoch_idx, model, model_solver):
print('[INFO] Saving checkpoint to %s ...' % ( file_path))
checkpoint = {
'epoch_idx': epoch_idx,
'model_state_dict': model.state_dict(),
'model_solver_state_dict': model_solver.state_dict()
}
torch.save(checkpoint, file_path)
mode = "Silence" #"Perf"
w_frame = 0.9
loss1 = PerceptionReconstructionLoss()
if __name__ == "__main__":
#train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSynDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE,
num_workers=8,
pin_memory=True,
shuffle=True,
drop_last=False)
#Data loader test
print(len(train_data_loader))
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False)
lf_model.apply(weight_init_normal)
lf_model.train()
epoch_begin = 0
if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:3')
optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
# lf_model.output_layer.register_backward_hook(hook_fn_back)
if mode=="Perf":
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
print("begin training....")
for epoch in range(epoch_begin, NUM_EPOCH):
for batch_idx, (image_set, gt, pos_row, pos_col) in enumerate(train_data_loader):
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("load:",start.elapsed_time(end))
start.record()
#reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N LFxC H W
image_set = var_or_cuda(image_set)
gt = gt.permute(0,3,1,2) # BS C H W
gt = var_or_cuda(gt)
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("data prepare:",start.elapsed_time(end))
start.record()
output = lf_model(image_set,pos_row, pos_col) # 2 6 376 540
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("forward:",start.elapsed_time(end))
start.record()
optimizer.zero_grad()
# print("output:",output.shape," gt:",gt.shape)
loss1_value = loss1(output,gt)
loss = (w_frame * loss1_value)
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("compute loss:",start.elapsed_time(end))
start.record()
loss.backward()
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("backward:",start.elapsed_time(end))
start.record()
optimizer.step()
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("update:",start.elapsed_time(end))
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item())
# exit(0)
########################### Save #####################
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
save_image(output[i].data,os.path.join(OUTPUT_DIR,"out_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gt_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment