Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
5069f8ae
Commit
5069f8ae
authored
Nov 26, 2020
by
BobYeah
Browse files
Gaze
parent
055dc0bb
Changes
2
Show whitespace changes
Inline
Side-by-side
gen_image.py
View file @
5069f8ae
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
glm
def
RandomGenSamplesInPupil
(
conf
,
n_samples
):
def
Fov2Length
(
angle
):
'''
'''
return
np
.
tan
(
angle
*
np
.
pi
/
360
)
*
2
def
RandomGenSamplesInPupil
(
pupil_size
,
n_samples
):
'''
'''
Random sample n_samples positions in pupil region
Random sample n_samples positions in pupil region
...
@@ -18,14 +26,14 @@ def RandomGenSamplesInPupil(conf, n_samples):
...
@@ -18,14 +26,14 @@ def RandomGenSamplesInPupil(conf, n_samples):
samples
=
torch
.
empty
(
n_samples
,
2
)
samples
=
torch
.
empty
(
n_samples
,
2
)
i
=
0
i
=
0
while
i
<
n_samples
:
while
i
<
n_samples
:
s
=
(
torch
.
rand
(
2
)
-
0.5
)
*
conf
.
pupil_size
s
=
(
torch
.
rand
(
2
)
-
0.5
)
*
pupil_size
if
np
.
linalg
.
norm
(
s
)
>
conf
.
pupil_size
/
2.
:
if
np
.
linalg
.
norm
(
s
)
>
pupil_size
/
2.
:
continue
continue
samples
[
i
,
:]
=
s
samples
[
i
,
:]
=
[
s
[
0
],
s
[
1
],
0
]
i
+=
1
i
+=
1
return
samples
return
samples
def
GenSamplesInPupil
(
conf
,
circles
):
def
GenSamplesInPupil
(
pupil_size
,
circles
):
'''
'''
Sample positions on circles in pupil region
Sample positions on circles in pupil region
...
@@ -38,68 +46,116 @@ def GenSamplesInPupil(conf, circles):
...
@@ -38,68 +46,116 @@ def GenSamplesInPupil(conf, circles):
--------
--------
a n_samples x 2 tensor with 2D sample position in each row
a n_samples x 2 tensor with 2D sample position in each row
'''
'''
samples
=
torch
.
tensor
([[
0.
,
0.
]]
)
samples
=
torch
.
zeros
(
1
,
3
)
for
i
in
range
(
1
,
circles
):
for
i
in
range
(
1
,
circles
):
r
=
conf
.
pupil_size
/
2.
/
(
circles
-
1
)
*
i
r
=
pupil_size
/
2.
/
(
circles
-
1
)
*
i
n
=
4
*
i
n
=
4
*
i
for
j
in
range
(
0
,
n
):
for
j
in
range
(
0
,
n
):
angle
=
2
*
np
.
pi
/
n
*
j
angle
=
2
*
np
.
pi
/
n
*
j
samples
=
torch
.
cat
(
(
samples
,
torch
.
tensor
([[
r
*
np
.
cos
(
angle
),
r
*
np
.
sin
(
angle
)
]])),
dim
=
0
)
samples
=
torch
.
cat
(
[
samples
,
torch
.
tensor
([[
r
*
np
.
cos
(
angle
),
r
*
np
.
sin
(
angle
)
,
0
]])
],
0
)
return
samples
return
samples
def
GenRetinal2LayerMappings
(
conf
,
df
,
v
,
u
):
class
RetinalGen
(
object
):
'''
'''
Generate the mapping matrix from retinal to layers.
Class for retinal generation process
Properties
--------
conf - multi-layers' parameters configuration
u - M x 3 tensor, M sample positions in pupil
p_r - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution
Phi - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers
mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping
Methods
--------
'''
def
__init__
(
self
,
conf
,
u
):
'''
Initialize retinal generator instance
Parameters
Parameters
--------
--------
conf - multi-layers' parameters configuration
conf - multi-layers' parameters configuration
df - focal distance
u - a M x 3 tensor stores M sample positions in pupil
v - a 1 x 2 tensor stores half viewport
'''
u - a M x 2 tensor stores M sample positions on pupil
self
.
conf
=
conf
# self.u = u.to(cuda_dev)
self
.
u
=
u
# M x 3 M sample positions
self
.
D_r
=
conf
.
retinal_res
# retinal res 480 x 640
self
.
N
=
conf
.
GetNLayers
()
# 2
self
.
M
=
u
.
size
()[
0
]
# samples
p_rx
,
p_ry
=
torch
.
meshgrid
(
torch
.
tensor
(
range
(
0
,
self
.
D_r
[
0
])),
torch
.
tensor
(
range
(
0
,
self
.
D_r
[
1
])))
self
.
p_r
=
torch
.
cat
([
((
torch
.
stack
([
p_rx
,
p_ry
],
2
)
+
0.5
)
/
self
.
D_r
-
0.5
)
*
conf
.
GetEyeViewportSize
(),
# 眼球视野
torch
.
ones
(
self
.
D_r
[
0
],
self
.
D_r
[
1
],
1
)
],
2
)
Returns
# self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def
CalculateRetinal2LayerMappings
(
self
,
df
,
gaze
):
'''
Calculate the mapping matrix from retinal to layers.
Parameters
--------
--------
The mapping matrix
df - focus distance
gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction
'''
'''
H_r
=
conf
.
retinal_res
[
0
]
Phi
=
torch
.
empty
(
self
.
N
,
self
.
D_r
[
0
],
self
.
D_r
[
1
],
self
.
M
,
2
,
dtype
=
torch
.
long
)
# 2 x 480 x 640 x 41 x 2
W_r
=
conf
.
retinal_res
[
1
]
mask
=
torch
.
empty
(
self
.
N
,
self
.
D_r
[
0
],
self
.
D_r
[
1
],
self
.
M
,
2
,
dtype
=
torch
.
float
)
D_r
=
conf
.
retinal_res
.
double
()
D_r
=
self
.
conf
.
retinal_res
# D_r: Resolution of retinal 480 640
N
=
conf
.
n_layers
V
=
self
.
conf
.
GetEyeViewportSize
()
# V: Viewport size of eye
M
=
u
.
size
()[
0
]
#41
c
=
(
self
.
conf
.
layer_res
/
2
)
# c: Center of layers (pixel)
Phi
=
torch
.
empty
(
H_r
,
W_r
,
N
,
M
,
2
,
dtype
=
torch
.
long
)
p_f
=
self
.
p_r
*
df
# p_f: H x W x 3, focus positions of retinal pixels on focus plane
p_rx
,
p_ry
=
torch
.
meshgrid
(
torch
.
tensor
(
range
(
0
,
H_r
)),
rot_forward
=
glm
.
dvec3
(
glm
.
tan
(
glm
.
radians
(
glm
.
dvec2
(
gaze
[
1
],
-
gaze
[
0
]))),
1
)
torch
.
tensor
(
range
(
0
,
W_r
)))
rot_mat
=
torch
.
from_numpy
(
np
.
array
(
p_r
=
torch
.
stack
([
p_rx
,
p_ry
],
2
).
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
M
,
-
1
)
glm
.
dmat3
(
glm
.
lookAtLH
(
glm
.
dvec3
(),
rot_forward
,
glm
.
dvec3
(
0
,
1
,
0
)))))
# print(p_r.shape) #torch.Size([480, 640, 41, 2])
rot_mat
=
rot_mat
.
float
()
for
i
in
range
(
0
,
N
):
u_rot
=
torch
.
mm
(
self
.
u
,
rot_mat
)
dpi
=
conf
.
h_layer
[
i
]
/
conf
.
layer_res
[
0
]
# 1 / 480
v_rot
=
torch
.
matmul
(
p_f
,
rot_mat
).
unsqueeze
(
2
).
expand
(
ci
=
conf
.
layer_res
/
2
# [240,320]
-
1
,
-
1
,
self
.
u
.
size
()[
0
],
-
1
)
-
u_rot
# v_rot: H x W x M x 3, rotated rays' direction vector
di
=
conf
.
d_layer
[
i
]
# 深度
v_rot
.
div_
(
v_rot
[:,
:,
:,
2
].
unsqueeze
(
3
))
# make z = 1 for each direction vector in v_rot
pi_r
=
di
*
v
*
(
1.
/
D_r
*
(
p_r
+
0.5
)
-
0.5
)
/
dpi
# [480, 640, 41, 2]
wi
=
(
1
-
di
/
df
)
/
dpi
# (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75
for
i
in
range
(
0
,
self
.
conf
.
GetNLayers
()):
pi
=
torch
.
floor
(
pi_r
+
ci
+
wi
*
u
)
dp_i
=
self
.
conf
.
GetLayerSize
(
i
)[
0
]
/
self
.
conf
.
layer_res
[
0
]
# dp_i: Pixel size of layer i
torch
.
clamp_
(
pi
[:,
:,
:,
0
],
0
,
conf
.
layer_res
[
0
]
-
1
)
d_i
=
self
.
conf
.
d_layer
[
i
]
# d_i: Distance of layer i
torch
.
clamp_
(
pi
[:,
:,
:,
1
],
0
,
conf
.
layer_res
[
1
]
-
1
)
k
=
(
d_i
-
u_rot
[:,
2
]).
unsqueeze
(
1
)
Phi
[:,
:,
i
,
:,
:]
=
pi
pi_r
=
(
u_rot
[:,
0
:
2
]
+
v_rot
[:,
:,
:,
0
:
2
]
*
k
)
/
dp_i
# pi_r: H x W x M x 2, rays' pixel coord on layer i
return
Phi
Phi
[
i
,
:,
:,
:,
:]
=
torch
.
floor
(
pi_r
+
c
)
mask
[:,
:,
:,
:,
0
]
=
((
Phi
[:,
:,
:,
:,
0
]
>=
0
)
&
(
Phi
[:,
:,
:,
:,
0
]
<
self
.
conf
.
layer_res
[
0
])).
float
()
def
GenRetinalFromLayers
(
layers
,
Phi
):
mask
[:,
:,
:,
:,
1
]
=
((
Phi
[:,
:,
:,
:,
1
]
>=
0
)
&
(
Phi
[:,
:,
:,
:,
1
]
<
self
.
conf
.
layer_res
[
1
])).
float
()
# layers: 2, color, height, width
Phi
[:,
:,
:,
:,
0
].
clamp_
(
0
,
self
.
conf
.
layer_res
[
0
]
-
1
)
# Phi:torch.Size([480, 640, 2, 41, 2])
Phi
[:,
:,
:,
:,
1
].
clamp_
(
0
,
self
.
conf
.
layer_res
[
1
]
-
1
)
M
=
Phi
.
size
()[
3
]
# 41
retinal_mask
=
mask
.
prod
(
0
).
prod
(
2
).
prod
(
2
)
N
=
Phi
.
size
()[
2
]
# 2
return
[
Phi
,
retinal_mask
]
# print(layers.shape)# torch.Size([2, 3, 480, 640])
# print(Phi.shape)# torch.Size([480, 640, 2, 41, 2])
def
GenRetinalFromLayers
(
self
,
layers
,
Phi
):
# retinal image: 3channels x retinal_size
'''
retinal
=
torch
.
zeros
(
3
,
Phi
.
size
()[
0
],
Phi
.
size
()[
1
])
Generate retinal image from layers, using precalculated mapping matrix
for
j
in
range
(
0
,
M
):
retinal_view
=
torch
.
zeros
(
3
,
Phi
.
size
()[
0
],
Phi
.
size
()[
1
])
for
i
in
range
(
0
,
N
):
retinal_view
.
add_
(
layers
[
i
,:,
Phi
[:,
:,
i
,
j
,
0
],
Phi
[:,
:,
i
,
j
,
1
]])
retinal
.
add_
(
retinal_view
)
retinal
.
div_
(
M
)
return
retinal
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
# FOR GRAYSCALE 1 FOR RGB 3
mapped_layers
=
torch
.
empty
(
self
.
N
,
3
,
self
.
D_r
[
0
],
self
.
D_r
[
1
],
self
.
M
)
# 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape)
for
i
in
range
(
0
,
Phi
.
size
()[
0
]):
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers
[
i
,
:,
:,
:,
:]
=
layers
[(
i
*
3
)
:
(
i
*
3
+
3
),
Phi
[
i
,
:,
:,
:,
0
],
Phi
[
i
,
:,
:,
:,
1
]]
# print("mapped_layers:",mapped_layers.shape)
retinal
=
mapped_layers
.
prod
(
0
).
sum
(
3
).
div
(
Phi
.
size
()[
3
])
# print("retinal:",retinal.shape)
return
retinal
\ No newline at end of file
main.py
View file @
5069f8ae
...
@@ -16,55 +16,64 @@ import json
...
@@ -16,55 +16,64 @@ import json
from
ssim
import
*
from
ssim
import
*
from
perc_loss
import
*
from
perc_loss
import
*
# param
# param
BATCH_SIZE
=
5
BATCH_SIZE
=
16
NUM_EPOCH
=
5
000
NUM_EPOCH
=
1
000
INTERLEAVE_RATE
=
2
INTERLEAVE_RATE
=
2
IM_H
=
480
IM_H
=
320
IM_W
=
640
IM_W
=
320
Retinal_IM_H
=
320
Retinal_IM_W
=
320
N
=
9
# number of input light field stack
N
=
9
# number of input light field stack
M
=
2
# number of display layers
M
=
2
# number of display layers
DATA_FILE
=
"/home/yejiannan/Project/LightField/data/
try
"
DATA_FILE
=
"/home/yejiannan/Project/LightField/data/
gaze_small_nar_new
"
DATA_JSON
=
"/home/yejiannan/Project/LightField/data/data.json"
DATA_JSON
=
"/home/yejiannan/Project/LightField/data/data
_gaze_low_new
.json"
DATA_VAL_JSON
=
"/home/yejiannan/Project/LightField/data/data_val.json"
DATA_VAL_JSON
=
"/home/yejiannan/Project/LightField/data/data_val.json"
OUTPUT_DIR
=
"/home/yejiannan/Project/LightField/output"
OUTPUT_DIR
=
"/home/yejiannan/Project/LightField/output
/gaze_low_new_1125_minibatch
"
class
lightFieldDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
class
lightFieldDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
def
__init__
(
self
,
file_dir_path
,
file_json
,
transforms
=
None
):
def
__init__
(
self
,
file_dir_path
,
file_json
,
transforms
=
None
):
self
.
file_dir_path
=
file_dir_path
self
.
file_dir_path
=
file_dir_path
self
.
transforms
=
transforms
self
.
transforms
=
transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with
open
(
file_json
,
encoding
=
'utf-8'
)
as
file
:
with
open
(
file_json
,
encoding
=
'utf-8'
)
as
file
:
self
.
da
s
tset_desc
=
json
.
loads
(
file
.
read
())
self
.
dat
a
set_desc
=
json
.
loads
(
file
.
read
())
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
da
s
tset_desc
[
"focaldepth"
])
return
len
(
self
.
dat
a
set_desc
[
"focaldepth"
])
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
lightfield_images
,
gt
,
fd
=
self
.
get_datum
(
idx
)
lightfield_images
,
gt
,
fd
,
gazeX
,
gazeY
,
sample_idx
=
self
.
get_datum
(
idx
)
if
self
.
transforms
:
if
self
.
transforms
:
lightfield_images
=
self
.
transforms
(
lightfield_images
)
lightfield_images
=
self
.
transforms
(
lightfield_images
)
return
(
lightfield_images
,
gt
,
fd
)
return
(
lightfield_images
,
gt
,
fd
,
gazeX
,
gazeY
,
sample_idx
)
def
get_datum
(
self
,
idx
):
def
get_datum
(
self
,
idx
):
lf_image_paths
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dastset_desc
[
"train"
][
idx
])
lf_image_paths
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dataset_desc
[
"train"
][
idx
])
# print(lf_image_paths)
fd_gt_path
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dataset_desc
[
"gt"
][
idx
])
fd_gt_path
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dastset_desc
[
"gt"
][
idx
])
# print(fd_gt_path)
lf_images
=
[]
lf_images
=
[]
lf_image_big
=
cv2
.
imread
(
lf_image_paths
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
lf_image_big
=
cv2
.
imread
(
lf_image_paths
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
lf_image_big
=
cv2
.
cvtColor
(
lf_image_big
,
cv2
.
COLOR_BGR2RGB
)
lf_image_big
=
cv2
.
cvtColor
(
lf_image_big
,
cv2
.
COLOR_BGR2RGB
)
for
i
in
range
(
9
):
for
i
in
range
(
9
):
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
# print(lf_image.shape)
lf_images
.
append
(
lf_image
)
lf_images
.
append
(
lf_image
)
gt
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
gt
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
gt
=
cv2
.
cvtColor
(
gt
,
cv2
.
COLOR_BGR2RGB
)
gt
=
cv2
.
cvtColor
(
gt
,
cv2
.
COLOR_BGR2RGB
)
fd
=
self
.
dastset_desc
[
"focaldepth"
][
idx
]
## IF GrayScale
return
(
np
.
asarray
(
lf_images
),
gt
,
fd
)
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd
=
self
.
dataset_desc
[
"focaldepth"
][
idx
]
gazeX
=
self
.
dataset_desc
[
"gazeX"
][
idx
]
gazeY
=
self
.
dataset_desc
[
"gazeY"
][
idx
]
sample_idx
=
self
.
dataset_desc
[
"idx"
][
idx
]
return
np
.
asarray
(
lf_images
),
gt
,
fd
,
gazeX
,
gazeY
,
sample_idx
OUT_CHANNELS_RB
=
128
OUT_CHANNELS_RB
=
128
KERNEL_SIZE_RB
=
3
KERNEL_SIZE_RB
=
3
...
@@ -128,7 +137,6 @@ class interleave(torch.nn.Module):
...
@@ -128,7 +137,6 @@ class interleave(torch.nn.Module):
output
=
output
.
permute
(
0
,
3
,
1
,
2
)
output
=
output
.
permute
(
0
,
3
,
1
,
2
)
return
output
return
output
LAST_LAYER_CHANNELS
=
6
*
INTERLEAVE_RATE
**
2
LAST_LAYER_CHANNELS
=
6
*
INTERLEAVE_RATE
**
2
FIRSST_LAYER_CHANNELS
=
27
*
INTERLEAVE_RATE
**
2
FIRSST_LAYER_CHANNELS
=
27
*
INTERLEAVE_RATE
**
2
...
@@ -144,37 +152,39 @@ class model(torch.nn.Module):
...
@@ -144,37 +152,39 @@ class model(torch.nn.Module):
)
)
self
.
residual_block1
=
residual_block
(
0
)
self
.
residual_block1
=
residual_block
(
0
)
self
.
residual_block2
=
residual_block
(
1
)
self
.
residual_block2
=
residual_block
(
3
)
self
.
residual_block3
=
residual_block
(
1
)
self
.
residual_block3
=
residual_block
(
3
)
self
.
residual_block4
=
residual_block
(
3
)
self
.
residual_block5
=
residual_block
(
3
)
self
.
output_layer
=
torch
.
nn
.
Sequential
(
self
.
output_layer
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
1
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
3
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
LAST_LAYER_CHANNELS
),
torch
.
nn
.
BatchNorm2d
(
LAST_LAYER_CHANNELS
),
torch
.
nn
.
Sigmoid
()
torch
.
nn
.
Sigmoid
()
)
)
self
.
deinterleave
=
deinterleave
(
INTERLEAVE_RATE
)
self
.
deinterleave
=
deinterleave
(
INTERLEAVE_RATE
)
def
forward
(
self
,
lightfield_images
,
focal_length
):
def
forward
(
self
,
lightfield_images
,
focal_length
,
gazeX
,
gazeY
):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
input_to_net
=
self
.
interleave
(
lightfield_images
)
input_to_net
=
self
.
interleave
(
lightfield_images
)
# print("after interleave:",input_to_net.shape)
input_to_rb
=
self
.
first_layer
(
input_to_net
)
input_to_rb
=
self
.
first_layer
(
input_to_net
)
output
=
self
.
residual_block1
(
input_to_rb
)
output
=
self
.
residual_block1
(
input_to_rb
)
# print("output1:",output.shape)
depth_layer
=
torch
.
ones
((
input_to_rb
.
shape
[
0
],
1
,
input_to_rb
.
shape
[
2
],
input_to_rb
.
shape
[
3
]))
gazeX_layer
=
torch
.
ones
((
input_to_rb
.
shape
[
0
],
1
,
input_to_rb
.
shape
[
2
],
input_to_rb
.
shape
[
3
]))
depth_layer
=
torch
.
ones
((
output
.
shape
[
0
],
1
,
output
.
shape
[
2
],
output
.
shape
[
3
]))
gazeY_layer
=
torch
.
ones
((
input_to_rb
.
shape
[
0
],
1
,
input_to_rb
.
shape
[
2
],
input_to_rb
.
shape
[
3
]))
# print(df.shape[0])
for
i
in
range
(
focal_length
.
shape
[
0
]):
for
i
in
range
(
focal_length
.
shape
[
0
]):
depth_layer
[
i
]
=
1.
/
focal_length
[
i
]
depth_layer
[
i
]
*=
1.
/
focal_length
[
i
]
# print(depth_layer.shape)
gazeX_layer
[
i
]
*=
(
gazeX
[
i
]
-
(
-
3.333
))
/
(
3.333
*
2
)
gazeY_layer
[
i
]
*=
(
gazeY
[
i
]
-
(
-
3.333
))
/
(
3.333
*
2
)
depth_layer
=
var_or_cuda
(
depth_layer
)
depth_layer
=
var_or_cuda
(
depth_layer
)
output
=
torch
.
cat
((
output
,
depth_layer
),
dim
=
1
)
gazeX_layer
=
var_or_cuda
(
gazeX_layer
)
gazeY_layer
=
var_or_cuda
(
gazeY_layer
)
output
=
torch
.
cat
((
output
,
depth_layer
,
gazeX_layer
,
gazeY_layer
),
dim
=
1
)
output
=
self
.
residual_block2
(
output
)
output
=
self
.
residual_block2
(
output
)
output
=
self
.
residual_block3
(
output
)
output
=
self
.
residual_block3
(
output
)
# output = output + input_to_net
output
=
self
.
residual_block4
(
output
)
output
=
self
.
residual_block5
(
output
)
output
=
self
.
output_layer
(
output
)
output
=
self
.
output_layer
(
output
)
output
=
self
.
deinterleave
(
output
)
output
=
self
.
deinterleave
(
output
)
return
output
return
output
...
@@ -182,72 +192,65 @@ class model(torch.nn.Module):
...
@@ -182,72 +192,65 @@ class model(torch.nn.Module):
class
Conf
(
object
):
class
Conf
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
pupil_size
=
0.02
# 2cm
self
.
pupil_size
=
0.02
# 2cm
self
.
retinal_res
=
torch
.
tensor
([
480
,
640
])
self
.
retinal_res
=
torch
.
tensor
([
Retinal_IM_H
,
Retinal_IM_W
])
self
.
layer_res
=
torch
.
tensor
([
480
,
640
])
self
.
layer_res
=
torch
.
tensor
([
IM_H
,
IM_W
])
self
.
n_layers
=
2
self
.
layer_hfov
=
90
# layers' horizontal FOV
self
.
d_layer
=
[
1.
,
3.
]
# layers' distance
self
.
eye_hfov
=
85
# eye's horizontal FOV
self
.
h_layer
=
[
1.
*
480.
/
640.
,
3.
*
480.
/
640.
]
# layers' height
self
.
d_layer
=
[
1
,
3
]
# layers' distance
def
GetNLayers
(
self
):
return
len
(
self
.
d_layer
)
def
GetLayerSize
(
self
,
i
):
w
=
Fov2Length
(
self
.
layer_hfov
)
h
=
w
*
self
.
layer_res
[
0
]
/
self
.
layer_res
[
1
]
return
torch
.
tensor
([
h
,
w
])
*
self
.
d_layer
[
i
]
def
GetEyeViewportSize
(
self
):
w
=
Fov2Length
(
self
.
eye_hfov
)
h
=
w
*
self
.
retinal_res
[
0
]
/
self
.
retinal_res
[
1
]
return
torch
.
tensor
([
h
,
w
])
#### Image Gen
#### Image Gen
conf
=
Conf
()
conf
=
Conf
()
v
=
torch
.
tensor
([
conf
.
h_layer
[
0
]
/
conf
.
d_layer
[
0
],
u
=
GenSamplesInPupil
(
conf
.
pupil_size
,
5
)
conf
.
h_layer
[
0
]
/
conf
.
d_layer
[
0
]
*
conf
.
layer_res
[
1
]
/
conf
.
layer_res
[
0
]])
u
=
GenSamplesInPupil
(
conf
,
5
)
gen
=
RetinalGen
(
conf
,
u
)
def
GenRetinalFromLayersBatch
(
layers
,
conf
,
df
,
v
,
u
):
def
GenRetinalFromLayersBatch
(
layers
,
gen
,
sample_idx
,
phi_dict
,
mask_dict
):
# layers: batchsize, 2
,
color, height, width
# layers: batchsize, 2
*
color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# df : batchsize,..
# df : batchsize,..
H_r
=
conf
.
retinal_res
[
0
]
W_r
=
conf
.
retinal_res
[
1
]
# retinal bs x color x height x width
D_r
=
conf
.
retinal_res
.
double
()
retinal
=
torch
.
zeros
(
layers
.
shape
[
0
],
3
,
Retinal_IM_H
,
Retinal_IM_W
)
N
=
conf
.
n_layers
mask
=
[]
# mask shape 480 x 640
M
=
u
.
size
()[
0
]
#41
for
i
in
range
(
0
,
layers
.
size
()[
0
]):
BS
=
df
.
shape
[
0
]
phi
=
phi_dict
[
int
(
sample_idx
[
i
].
data
)]
Phi
=
torch
.
empty
(
BS
,
H_r
,
W_r
,
N
,
M
,
2
,
dtype
=
torch
.
long
)
phi
=
var_or_cuda
(
phi
)
# print("Phi:",Phi.shape)
phi
.
requires_grad
=
False
retinal
[
i
]
=
gen
.
GenRetinalFromLayers
(
layers
[
i
],
phi
)
p_rx
,
p_ry
=
torch
.
meshgrid
(
torch
.
tensor
(
range
(
0
,
H_r
)),
mask
.
append
(
mask_dict
[
int
(
sample_idx
[
i
].
data
)])
torch
.
tensor
(
range
(
0
,
W_r
)))
p_r
=
torch
.
stack
([
p_rx
,
p_ry
],
2
).
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
M
,
-
1
)
# print("p_r:",p_r.shape) #torch.Size([480, 640, 41, 2])
for
bs
in
range
(
BS
):
for
i
in
range
(
0
,
N
):
dpi
=
conf
.
h_layer
[
i
]
/
float
(
conf
.
layer_res
[
0
])
# 1 / 480
# print("dpi:",dpi)
ci
=
conf
.
layer_res
/
2
# [240,320]
di
=
conf
.
d_layer
[
i
]
# 深度
pi_r
=
di
*
v
*
(
1.
/
D_r
*
(
p_r
+
0.5
)
-
0.5
)
/
dpi
# [480, 640, 41, 2]
wi
=
(
1
-
di
/
df
[
bs
])
/
dpi
# (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75
pi
=
torch
.
floor
(
pi_r
+
ci
+
wi
*
u
)
torch
.
clamp_
(
pi
[:,
:,
:,
0
],
0
,
conf
.
layer_res
[
0
]
-
1
)
torch
.
clamp_
(
pi
[:,
:,
:,
1
],
0
,
conf
.
layer_res
[
1
]
-
1
)
Phi
[
bs
,
:,
:,
i
,
:,
:]
=
pi
# print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
retinal
=
torch
.
ones
(
BS
,
3
,
H_r
,
W_r
)
retinal
=
var_or_cuda
(
retinal
)
retinal
=
var_or_cuda
(
retinal
)
for
bs
in
range
(
BS
):
mask
=
torch
.
stack
(
mask
,
dim
=
0
).
unsqueeze
(
1
)
# batch x 1 x height x width
for
j
in
range
(
0
,
M
):
return
retinal
,
mask
retinal_view
=
torch
.
ones
(
3
,
H_r
,
W_r
)
retinal_view
=
var_or_cuda
(
retinal_view
)
for
i
in
range
(
0
,
N
):
retinal_view
.
mul_
(
layers
[
bs
,
(
i
*
3
)
:
(
i
*
3
+
3
),
Phi
[
bs
,
:,
:,
i
,
j
,
0
],
Phi
[
bs
,
:,
:,
i
,
j
,
1
]])
retinal
[
bs
,:,:,:].
add_
(
retinal_view
)
retinal
[
bs
,:,:,:].
div_
(
M
)
return
retinal
#### Image Gen End
def
merge_two
(
near
,
far
):
def
GenRetinalFromLayersBatch_Online
(
layers
,
gen
,
phi
,
mask
):
df
=
conf
.
d_layer
[
0
]
+
(
conf
.
d_layer
[
1
]
-
conf
.
d_layer
[
0
])
/
2.
# layers: batchsize, 2*color, height, width
# Phi = GenRetinal2LayerMappings(conf, df, v, u)
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# retinal = GenRetinalFromLayers(layers, Phi)
# df : batchsize,..
return
near
[:,
0
:
3
,:,:]
+
far
[:,
3
:
6
,:,:]
/
2.0
def
loss_two_images
(
generated
,
gt
):
# retinal bs x color x height x width
l1_loss
=
torch
.
nn
.
L1Loss
()
# retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
return
l1_loss
(
generated
,
gt
)
# retinal = var_or_cuda(retinal)
phi
=
var_or_cuda
(
phi
)
phi
.
requires_grad
=
False
retinal
=
gen
.
GenRetinalFromLayers
(
layers
[
0
],
phi
)
retinal
=
var_or_cuda
(
retinal
)
mask_out
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
return
retinal
.
unsqueeze
(
0
),
mask_out
#### Image Gen End
weightVarScale
=
0.25
weightVarScale
=
0.25
bias_stddev
=
0.01
bias_stddev
=
0.01
...
@@ -269,7 +272,7 @@ def var_or_cuda(x):
...
@@ -269,7 +272,7 @@ def var_or_cuda(x):
def
calImageGradients
(
images
):
def
calImageGradients
(
images
):
# x is a 4-D tensor
# x is a 4-D tensor
dx
=
images
[:,
:,
1
:,
:]
-
images
[:,
:,
:
-
1
,
:]
dx
=
images
[:,
:,
1
:,
:]
-
images
[:,
:,
:
-
1
,
:]
dy
=
images
[:,
1
:,
:,
:]
-
images
[:,
:
-
1
,
:,
:]
dy
=
images
[:,
:,
:,
1
:]
-
images
[:,
:,
:,
:
-
1
]
return
dx
,
dy
return
dx
,
dy
...
@@ -279,16 +282,13 @@ perc_loss = perc_loss.to("cuda")
...
@@ -279,16 +282,13 @@ perc_loss = perc_loss.to("cuda")
def
loss_new
(
generated
,
gt
):
def
loss_new
(
generated
,
gt
):
mse_loss
=
torch
.
nn
.
MSELoss
()
mse_loss
=
torch
.
nn
.
MSELoss
()
rmse_intensity
=
mse_loss
(
generated
,
gt
)
rmse_intensity
=
mse_loss
(
generated
,
gt
)
RENORM_SCALE
=
torch
.
tensor
(
0.9
)
RENORM_SCALE
=
var_or_cuda
(
RENORM_SCALE
)
psnr_intensity
=
torch
.
log10
(
rmse_intensity
)
psnr_intensity
=
torch
.
log10
(
rmse_intensity
)
ssim_intensity
=
ssim
(
generated
,
gt
)
labels_dx
,
labels_dy
=
calImageGradients
(
gt
)
labels_dx
,
labels_dy
=
calImageGradients
(
gt
)
preds_dx
,
preds_dy
=
calImageGradients
(
generated
)
preds_dx
,
preds_dy
=
calImageGradients
(
generated
)
rmse_grad_x
,
rmse_grad_y
=
mse_loss
(
labels_dx
,
preds_dx
),
mse_loss
(
labels_dy
,
preds_dy
)
rmse_grad_x
,
rmse_grad_y
=
mse_loss
(
labels_dx
,
preds_dx
),
mse_loss
(
labels_dy
,
preds_dy
)
psnr_grad_x
,
psnr_grad_y
=
torch
.
log10
(
rmse_grad_x
),
torch
.
log10
(
rmse_grad_y
)
psnr_grad_x
,
psnr_grad_y
=
torch
.
log10
(
rmse_grad_x
),
torch
.
log10
(
rmse_grad_y
)
p_loss
=
perc_loss
(
generated
,
gt
)
p_loss
=
perc_loss
(
generated
,
gt
)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss
=
10
+
psnr_intensity
+
0.5
*
(
psnr_grad_x
+
psnr_grad_y
)
+
p_loss
total_loss
=
10
+
psnr_intensity
+
0.5
*
(
psnr_grad_x
+
psnr_grad_y
)
+
p_loss
return
total_loss
return
total_loss
...
@@ -301,15 +301,56 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
...
@@ -301,15 +301,56 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
}
}
torch
.
save
(
checkpoint
,
file_path
)
torch
.
save
(
checkpoint
,
file_path
)
mode
=
"val"
def
hook_fn_back
(
m
,
i
,
o
):
for
grad
in
i
:
try
:
print
(
"Input Grad:"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"None found for Gradient"
)
for
grad
in
o
:
try
:
print
(
"Output Grad:"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"None found for Gradient"
)
print
(
"
\n
"
)
def
hook_fn_for
(
m
,
i
,
o
):
for
grad
in
i
:
try
:
print
(
"Input Feats:"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"None found for Gradient"
)
for
grad
in
o
:
try
:
print
(
"Output Feats:"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"None found for Gradient"
)
print
(
"
\n
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
#test
# train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
# print(train_dataset[0][0].shape)
# cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end
############################## generate phi and mask in pre-training
phi_dict
=
{}
mask_dict
=
{}
idx_info_dict
=
{}
print
(
"generating phi and mask..."
)
with
open
(
DATA_JSON
,
encoding
=
'utf-8'
)
as
file
:
dataset_desc
=
json
.
loads
(
file
.
read
())
for
i
in
range
(
len
(
dataset_desc
[
"focaldepth"
])):
# if i == 2:
# break
idx
=
dataset_desc
[
"idx"
][
i
]
focaldepth
=
dataset_desc
[
"focaldepth"
][
i
]
gazeX
=
dataset_desc
[
"gazeX"
][
i
]
gazeY
=
dataset_desc
[
"gazeY"
][
i
]
# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
phi
,
mask
=
gen
.
CalculateRetinal2LayerMappings
(
focaldepth
,
torch
.
tensor
([
gazeX
,
gazeY
]))
phi_dict
[
idx
]
=
phi
mask_dict
[
idx
]
=
mask
idx_info_dict
[
idx
]
=
[
idx
,
focaldepth
,
gazeX
,
gazeY
]
print
(
"generating phi and mask end."
)
# exit(0)
#train
#train
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_JSON
),
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_JSON
),
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
...
@@ -319,82 +360,51 @@ if __name__ == "__main__":
...
@@ -319,82 +360,51 @@ if __name__ == "__main__":
drop_last
=
False
)
drop_last
=
False
)
print
(
len
(
train_data_loader
))
print
(
len
(
train_data_loader
))
val_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_VAL_JSON
),
# exit(0)
batch_size
=
1
,
num_workers
=
0
,
pin_memory
=
True
,
shuffle
=
False
,
drop_last
=
False
)
print
(
len
(
val_data_loader
))
################################################ train #########################################################
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
lf_model
=
model
()
lf_model
=
model
()
lf_model
.
apply
(
weight_init_normal
)
epoch_begin
=
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
lf_model
=
torch
.
nn
.
DataParallel
(
lf_model
).
cuda
()
lf_model
=
torch
.
nn
.
DataParallel
(
lf_model
).
cuda
()
lf_model
.
train
()
optimizer
=
torch
.
optim
.
Adam
(
lf_model
.
parameters
(),
lr
=
1e-2
,
betas
=
(
0.9
,
0.999
))
#val
print
(
"begin training...."
)
checkpoint
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
"ckpt-epoch-3001.pth"
))
for
epoch
in
range
(
epoch_begin
,
NUM_EPOCH
):
lf_model
.
load_state_dict
(
checkpoint
[
"model_state_dict"
])
for
batch_idx
,
(
image_set
,
gt
,
df
,
gazeX
,
gazeY
,
sample_idx
)
in
enumerate
(
train_data_loader
):
lf_model
.
eval
()
print
(
"Eval::"
)
for
sample_idx
,
(
image_set
,
gt
,
df
)
in
enumerate
(
val_data_loader
):
print
(
"sample_idx::"
)
with
torch
.
no_grad
():
#reshape for input
#reshape for input
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N, LFxC, H, W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N, LFxC, H, W
image_set
=
var_or_cuda
(
image_set
)
image_set
=
var_or_cuda
(
image_set
)
# image_set.to(device)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
gt
=
var_or_cuda
(
gt
)
gt
=
var_or_cuda
(
gt
)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
output
=
lf_model
(
image_set
,
df
)
print
(
"output:"
,
output
.
shape
,
" df:"
,
df
)
save_image
(
output
[
0
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"1113_interp_l1_%.3f.png"
%
(
df
[
0
].
data
)))
save_image
(
output
[
0
][
3
:
6
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"1113_interp_l2_%.3f.png"
%
(
df
[
0
].
data
)))
output
=
GenRetinalFromLayersBatch
(
output
,
conf
,
df
,
v
,
u
)
save_image
(
output
[
0
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"1113_interp_o%.3f.png"
%
(
df
[
0
].
data
)))
exit
()
# train
# print(lf_model)
# exit()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# lf_model = model()
# lf_model.apply(weight_init_normal)
# if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
# lf_model.train()
# optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999))
# for epoch in range(NUM_EPOCH):
# for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
# #reshape for input
# image_set = image_set.permute(0,1,4,2,3) # N LF C H W
# image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
# image_set = var_or_cuda(image_set)
# # image_set.to(device)
# gt = gt.permute(0,3,1,2)
# gt = var_or_cuda(gt)
# # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# optimizer.zero_grad()
# output = lf_model(image_set,df)
# # print("output:",output.shape," df:",df.shape)
# output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# loss = loss_new(output,gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
# loss.backward()
# optimizer.step()
# if (epoch%100 == 0):
# for i in range(BATCH_SIZE):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i)))
# if (epoch%1000 == 0):
# save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)),
# epoch,lf_model,optimizer)
optimizer
.
zero_grad
()
output
=
lf_model
(
image_set
,
df
,
gazeX
,
gazeY
)
########################### Use Pregen Phi and Mask ###################
output1
,
mask
=
GenRetinalFromLayersBatch
(
output
,
gen
,
sample_idx
,
phi_dict
,
mask_dict
)
mask
=
var_or_cuda
(
mask
)
mask
.
requires_grad
=
False
output_f
=
output1
*
mask
gt
=
gt
*
mask
loss
=
loss_new
(
output_f
,
gt
)
print
(
"Epoch:"
,
epoch
,
",Iter:"
,
batch_idx
,
",loss:"
,
loss
)
########################### Update ###################
loss
.
backward
()
optimizer
.
step
()
########################### Save #####################
if
((
epoch
%
50
==
0
)
or
epoch
==
5
):
for
i
in
range
(
output_f
.
size
()[
0
]):
save_image
(
output
[
i
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"gaze_fac1_o_%.3f_%.3f_%.3f.png"
%
(
df
[
i
].
data
,
gazeX
[
i
].
data
,
gazeY
[
i
].
data
)))
save_image
(
output
[
i
][
3
:
6
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"gaze_fac2_o_%.3f_%.3f_%.3f.png"
%
(
df
[
i
].
data
,
gazeX
[
i
].
data
,
gazeY
[
i
].
data
)))
save_image
(
output_f
[
i
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"gaze_test1_o_%.3f_%.3f_%.3f.png"
%
(
df
[
i
].
data
,
gazeX
[
i
].
data
,
gazeY
[
i
].
data
)))
if
((
epoch
%
200
==
0
)
and
epoch
!=
0
and
batch_idx
==
len
(
train_data_loader
)
-
1
):
save_checkpoints
(
os
.
path
.
join
(
OUTPUT_DIR
,
'gaze-ckpt-epoch-%04d.pth'
%
(
epoch
+
1
)),
epoch
,
lf_model
,
optimizer
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment