Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
c5434e97
Commit
c5434e97
authored
Dec 19, 2020
by
BobYeah
Browse files
Updatae CNN main
parent
95fb58a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
main_lf_syn.py
0 → 100644
View file @
c5434e97
import
torch
import
argparse
import
os
import
glob
import
numpy
as
np
import
torchvision.transforms
as
transforms
from
torchvision.utils
import
save_image
from
torchvision
import
datasets
from
torch.utils.data
import
DataLoader
from
torch.autograd
import
Variable
import
cv2
from
loss
import
*
import
json
from
baseline
import
*
from
data
import
*
import
torch.autograd.profiler
as
profiler
# param
BATCH_SIZE
=
2
NUM_EPOCH
=
1001
INTERLEAVE_RATE
=
2
IM_H
=
540
IM_W
=
376
Retinal_IM_H
=
540
Retinal_IM_W
=
376
N
=
4
# number of input light field stack
M
=
1
# number of display layers
DATA_FILE
=
"/home/yejiannan/Project/LightField/data/lf_syn"
DATA_JSON
=
"/home/yejiannan/Project/LightField/data/data_lf_syn_full.json"
# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR
=
"/home/yejiannan/Project/LightField/outputE/lf_syn_full_perc"
OUT_CHANNELS_RB
=
128
KERNEL_SIZE_RB
=
3
KERNEL_SIZE
=
3
LAST_LAYER_CHANNELS
=
3
*
INTERLEAVE_RATE
**
2
FIRSST_LAYER_CHANNELS
=
12
*
INTERLEAVE_RATE
**
2
from
weight_init
import
weight_init_normal
def
save_checkpoints
(
file_path
,
epoch_idx
,
model
,
model_solver
):
print
(
'[INFO] Saving checkpoint to %s ...'
%
(
file_path
))
checkpoint
=
{
'epoch_idx'
:
epoch_idx
,
'model_state_dict'
:
model
.
state_dict
(),
'model_solver_state_dict'
:
model_solver
.
state_dict
()
}
torch
.
save
(
checkpoint
,
file_path
)
mode
=
"Silence"
#"Perf"
w_frame
=
0.9
loss1
=
PerceptionReconstructionLoss
()
if
__name__
==
"__main__"
:
#train
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldSynDataLoader
(
DATA_FILE
,
DATA_JSON
),
batch_size
=
BATCH_SIZE
,
num_workers
=
8
,
pin_memory
=
True
,
shuffle
=
True
,
drop_last
=
False
)
#Data loader test
print
(
len
(
train_data_loader
))
lf_model
=
model
(
FIRSST_LAYER_CHANNELS
,
LAST_LAYER_CHANNELS
,
OUT_CHANNELS_RB
,
KERNEL_SIZE
,
KERNEL_SIZE_RB
,
INTERLEAVE_RATE
,
RNN
=
False
)
lf_model
.
apply
(
weight_init_normal
)
lf_model
.
train
()
epoch_begin
=
0
if
torch
.
cuda
.
is_available
():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model
=
lf_model
.
to
(
'cuda:3'
)
optimizer
=
torch
.
optim
.
Adam
(
lf_model
.
parameters
(),
lr
=
5e-3
,
betas
=
(
0.9
,
0.999
))
# lf_model.output_layer.register_backward_hook(hook_fn_back)
if
mode
==
"Perf"
:
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
.
record
()
print
(
"begin training...."
)
for
epoch
in
range
(
epoch_begin
,
NUM_EPOCH
):
for
batch_idx
,
(
image_set
,
gt
,
pos_row
,
pos_col
)
in
enumerate
(
train_data_loader
):
if
mode
==
"Perf"
:
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
"load:"
,
start
.
elapsed_time
(
end
))
start
.
record
()
#reshape for input
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N LFxC H W
image_set
=
var_or_cuda
(
image_set
)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
# BS C H W
gt
=
var_or_cuda
(
gt
)
if
mode
==
"Perf"
:
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
"data prepare:"
,
start
.
elapsed_time
(
end
))
start
.
record
()
output
=
lf_model
(
image_set
,
pos_row
,
pos_col
)
# 2 6 376 540
if
mode
==
"Perf"
:
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
"forward:"
,
start
.
elapsed_time
(
end
))
start
.
record
()
optimizer
.
zero_grad
()
# print("output:",output.shape," gt:",gt.shape)
loss1_value
=
loss1
(
output
,
gt
)
loss
=
(
w_frame
*
loss1_value
)
if
mode
==
"Perf"
:
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
"compute loss:"
,
start
.
elapsed_time
(
end
))
start
.
record
()
loss
.
backward
()
if
mode
==
"Perf"
:
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
"backward:"
,
start
.
elapsed_time
(
end
))
start
.
record
()
optimizer
.
step
()
if
mode
==
"Perf"
:
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
"update:"
,
start
.
elapsed_time
(
end
))
print
(
"Epoch:"
,
epoch
,
",Iter:"
,
batch_idx
,
",loss:"
,
loss
.
item
())
# exit(0)
########################### Save #####################
if
((
epoch
%
10
==
0
and
epoch
!=
0
)
or
epoch
==
2
):
# torch.Size([2, 5, 160, 160, 3])
for
i
in
range
(
gt
.
size
()[
0
]):
save_image
(
output
[
i
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"out_%.5f_%.5f.png"
%
(
pos_col
[
i
].
data
,
pos_row
[
i
].
data
)))
save_image
(
gt
[
i
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"gt_%.5f_%.5f.png"
%
(
pos_col
[
i
].
data
,
pos_row
[
i
].
data
)))
if
((
epoch
%
100
==
0
)
and
epoch
!=
0
and
batch_idx
==
len
(
train_data_loader
)
-
1
):
save_checkpoints
(
os
.
path
.
join
(
OUTPUT_DIR
,
'ckpt-epoch-%04d.pth'
%
(
epoch
)),
epoch
,
lf_model
,
optimizer
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment