Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
deeplightfield
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Nianchen Deng
deeplightfield
Commits
5069f8ae
Commit
5069f8ae
authored
4 years ago
by
BobYeah
Browse files
Options
Downloads
Patches
Plain Diff
Gaze
parent
055dc0bb
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
gen_image.py
+110
-54
110 additions, 54 deletions
gen_image.py
main.py
+179
-169
179 additions, 169 deletions
main.py
with
289 additions
and
223 deletions
gen_image.py
+
110
−
54
View file @
5069f8ae
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
glm
def
RandomGenSamplesInPupil
(
conf
,
n_samples
):
def
Fov2Length
(
angle
):
'''
'''
return
np
.
tan
(
angle
*
np
.
pi
/
360
)
*
2
def
RandomGenSamplesInPupil
(
pupil_size
,
n_samples
):
'''
'''
Random sample n_samples positions in pupil region
Random sample n_samples positions in pupil region
...
@@ -18,14 +26,14 @@ def RandomGenSamplesInPupil(conf, n_samples):
...
@@ -18,14 +26,14 @@ def RandomGenSamplesInPupil(conf, n_samples):
samples
=
torch
.
empty
(
n_samples
,
2
)
samples
=
torch
.
empty
(
n_samples
,
2
)
i
=
0
i
=
0
while
i
<
n_samples
:
while
i
<
n_samples
:
s
=
(
torch
.
rand
(
2
)
-
0.5
)
*
conf
.
pupil_size
s
=
(
torch
.
rand
(
2
)
-
0.5
)
*
pupil_size
if
np
.
linalg
.
norm
(
s
)
>
conf
.
pupil_size
/
2.
:
if
np
.
linalg
.
norm
(
s
)
>
pupil_size
/
2.
:
continue
continue
samples
[
i
,
:]
=
s
samples
[
i
,
:]
=
[
s
[
0
],
s
[
1
],
0
]
i
+=
1
i
+=
1
return
samples
return
samples
def
GenSamplesInPupil
(
conf
,
circles
):
def
GenSamplesInPupil
(
pupil_size
,
circles
):
'''
'''
Sample positions on circles in pupil region
Sample positions on circles in pupil region
...
@@ -38,68 +46,116 @@ def GenSamplesInPupil(conf, circles):
...
@@ -38,68 +46,116 @@ def GenSamplesInPupil(conf, circles):
--------
--------
a n_samples x 2 tensor with 2D sample position in each row
a n_samples x 2 tensor with 2D sample position in each row
'''
'''
samples
=
torch
.
tensor
([[
0.
,
0.
]]
)
samples
=
torch
.
zeros
(
1
,
3
)
for
i
in
range
(
1
,
circles
):
for
i
in
range
(
1
,
circles
):
r
=
conf
.
pupil_size
/
2.
/
(
circles
-
1
)
*
i
r
=
pupil_size
/
2.
/
(
circles
-
1
)
*
i
n
=
4
*
i
n
=
4
*
i
for
j
in
range
(
0
,
n
):
for
j
in
range
(
0
,
n
):
angle
=
2
*
np
.
pi
/
n
*
j
angle
=
2
*
np
.
pi
/
n
*
j
samples
=
torch
.
cat
(
(
samples
,
torch
.
tensor
([[
r
*
np
.
cos
(
angle
),
r
*
np
.
sin
(
angle
)
]])),
dim
=
0
)
samples
=
torch
.
cat
(
[
samples
,
torch
.
tensor
([[
r
*
np
.
cos
(
angle
),
r
*
np
.
sin
(
angle
)
,
0
]])
],
0
)
return
samples
return
samples
def
GenRetinal2LayerMappings
(
conf
,
df
,
v
,
u
):
class
RetinalGen
(
object
):
'''
'''
Generate the mapping matrix from retinal to layers.
Class for retinal generation process
Properties
--------
conf - multi-layers
'
parameters configuration
u - M x 3 tensor, M sample positions in pupil
p_r - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution
Phi - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers
mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping
Methods
--------
'''
def
__init__
(
self
,
conf
,
u
):
'''
Initialize retinal generator instance
Parameters
Parameters
--------
--------
conf - multi-layers
'
parameters configuration
conf - multi-layers
'
parameters configuration
df - focal distance
u - a M x 3 tensor stores M sample positions in pupil
v - a 1 x 2 tensor stores half viewport
'''
u - a M x 2 tensor stores M sample positions on pupil
self
.
conf
=
conf
# self.u = u.to(cuda_dev)
self
.
u
=
u
# M x 3 M sample positions
self
.
D_r
=
conf
.
retinal_res
# retinal res 480 x 640
self
.
N
=
conf
.
GetNLayers
()
# 2
self
.
M
=
u
.
size
()[
0
]
# samples
p_rx
,
p_ry
=
torch
.
meshgrid
(
torch
.
tensor
(
range
(
0
,
self
.
D_r
[
0
])),
torch
.
tensor
(
range
(
0
,
self
.
D_r
[
1
])))
self
.
p_r
=
torch
.
cat
([
((
torch
.
stack
([
p_rx
,
p_ry
],
2
)
+
0.5
)
/
self
.
D_r
-
0.5
)
*
conf
.
GetEyeViewportSize
(),
# 眼球视野
torch
.
ones
(
self
.
D_r
[
0
],
self
.
D_r
[
1
],
1
)
],
2
)
Returns
# self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def
CalculateRetinal2LayerMappings
(
self
,
df
,
gaze
):
'''
Calculate the mapping matrix from retinal to layers.
Parameters
--------
--------
The mapping matrix
df - focus distance
gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction
'''
'''
H_r
=
conf
.
retinal_res
[
0
]
Phi
=
torch
.
empty
(
self
.
N
,
self
.
D_r
[
0
],
self
.
D_r
[
1
],
self
.
M
,
2
,
dtype
=
torch
.
long
)
# 2 x 480 x 640 x 41 x 2
W_r
=
conf
.
retinal_res
[
1
]
mask
=
torch
.
empty
(
self
.
N
,
self
.
D_r
[
0
],
self
.
D_r
[
1
],
self
.
M
,
2
,
dtype
=
torch
.
float
)
D_r
=
conf
.
retinal_res
.
double
()
D_r
=
self
.
conf
.
retinal_res
# D_r: Resolution of retinal 480 640
N
=
conf
.
n_layers
V
=
self
.
conf
.
GetEyeViewportSize
()
# V: Viewport size of eye
M
=
u
.
size
()[
0
]
#41
c
=
(
self
.
conf
.
layer_res
/
2
)
# c: Center of layers (pixel)
Phi
=
torch
.
empty
(
H_r
,
W_r
,
N
,
M
,
2
,
dtype
=
torch
.
long
)
p_f
=
self
.
p_r
*
df
# p_f: H x W x 3, focus positions of retinal pixels on focus plane
p_rx
,
p_ry
=
torch
.
meshgrid
(
torch
.
tensor
(
range
(
0
,
H_r
)),
rot_forward
=
glm
.
dvec3
(
glm
.
tan
(
glm
.
radians
(
glm
.
dvec2
(
gaze
[
1
],
-
gaze
[
0
]))),
1
)
torch
.
tensor
(
range
(
0
,
W_r
)))
rot_mat
=
torch
.
from_numpy
(
np
.
array
(
p_r
=
torch
.
stack
([
p_rx
,
p_ry
],
2
).
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
M
,
-
1
)
glm
.
dmat3
(
glm
.
lookAtLH
(
glm
.
dvec3
(),
rot_forward
,
glm
.
dvec3
(
0
,
1
,
0
)))))
# print(p_r.shape) #torch.Size([480, 640, 41, 2])
rot_mat
=
rot_mat
.
float
()
for
i
in
range
(
0
,
N
):
u_rot
=
torch
.
mm
(
self
.
u
,
rot_mat
)
dpi
=
conf
.
h_layer
[
i
]
/
conf
.
layer_res
[
0
]
# 1 / 480
v_rot
=
torch
.
matmul
(
p_f
,
rot_mat
).
unsqueeze
(
2
).
expand
(
ci
=
conf
.
layer_res
/
2
# [240,320]
-
1
,
-
1
,
self
.
u
.
size
()[
0
],
-
1
)
-
u_rot
# v_rot: H x W x M x 3, rotated rays' direction vector
di
=
conf
.
d_layer
[
i
]
# 深度
v_rot
.
div_
(
v_rot
[:,
:,
:,
2
].
unsqueeze
(
3
))
# make z = 1 for each direction vector in v_rot
pi_r
=
di
*
v
*
(
1.
/
D_r
*
(
p_r
+
0.5
)
-
0.5
)
/
dpi
# [480, 640, 41, 2]
wi
=
(
1
-
di
/
df
)
/
dpi
# (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75
for
i
in
range
(
0
,
self
.
conf
.
GetNLayers
()):
pi
=
torch
.
floor
(
pi_r
+
ci
+
wi
*
u
)
dp_i
=
self
.
conf
.
GetLayerSize
(
i
)[
0
]
/
self
.
conf
.
layer_res
[
0
]
# dp_i: Pixel size of layer i
torch
.
clamp_
(
pi
[:,
:,
:,
0
],
0
,
conf
.
layer_res
[
0
]
-
1
)
d_i
=
self
.
conf
.
d_layer
[
i
]
# d_i: Distance of layer i
torch
.
clamp_
(
pi
[:,
:,
:,
1
],
0
,
conf
.
layer_res
[
1
]
-
1
)
k
=
(
d_i
-
u_rot
[:,
2
]).
unsqueeze
(
1
)
Phi
[:,
:,
i
,
:,
:]
=
pi
pi_r
=
(
u_rot
[:,
0
:
2
]
+
v_rot
[:,
:,
:,
0
:
2
]
*
k
)
/
dp_i
# pi_r: H x W x M x 2, rays' pixel coord on layer i
return
Phi
Phi
[
i
,
:,
:,
:,
:]
=
torch
.
floor
(
pi_r
+
c
)
mask
[:,
:,
:,
:,
0
]
=
((
Phi
[:,
:,
:,
:,
0
]
>=
0
)
&
(
Phi
[:,
:,
:,
:,
0
]
<
self
.
conf
.
layer_res
[
0
])).
float
()
def
GenRetinalFromLayers
(
layers
,
Phi
):
mask
[:,
:,
:,
:,
1
]
=
((
Phi
[:,
:,
:,
:,
1
]
>=
0
)
&
(
Phi
[:,
:,
:,
:,
1
]
<
self
.
conf
.
layer_res
[
1
])).
float
()
# layers: 2, color, height, width
Phi
[:,
:,
:,
:,
0
].
clamp_
(
0
,
self
.
conf
.
layer_res
[
0
]
-
1
)
# Phi:torch.Size([480, 640, 2, 41, 2])
Phi
[:,
:,
:,
:,
1
].
clamp_
(
0
,
self
.
conf
.
layer_res
[
1
]
-
1
)
M
=
Phi
.
size
()[
3
]
# 41
retinal_mask
=
mask
.
prod
(
0
).
prod
(
2
).
prod
(
2
)
N
=
Phi
.
size
()[
2
]
# 2
return
[
Phi
,
retinal_mask
]
# print(layers.shape)# torch.Size([2, 3, 480, 640])
# print(Phi.shape)# torch.Size([480, 640, 2, 41, 2])
def
GenRetinalFromLayers
(
self
,
layers
,
Phi
):
# retinal image: 3channels x retinal_size
'''
retinal
=
torch
.
zeros
(
3
,
Phi
.
size
()[
0
],
Phi
.
size
()[
1
])
Generate retinal image from layers, using precalculated mapping matrix
for
j
in
range
(
0
,
M
):
retinal_view
=
torch
.
zeros
(
3
,
Phi
.
size
()[
0
],
Phi
.
size
()[
1
])
for
i
in
range
(
0
,
N
):
retinal_view
.
add_
(
layers
[
i
,:,
Phi
[:,
:,
i
,
j
,
0
],
Phi
[:,
:,
i
,
j
,
1
]])
retinal
.
add_
(
retinal_view
)
retinal
.
div_
(
M
)
return
retinal
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
# FOR GRAYSCALE 1 FOR RGB 3
mapped_layers
=
torch
.
empty
(
self
.
N
,
3
,
self
.
D_r
[
0
],
self
.
D_r
[
1
],
self
.
M
)
# 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape)
for
i
in
range
(
0
,
Phi
.
size
()[
0
]):
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers
[
i
,
:,
:,
:,
:]
=
layers
[(
i
*
3
)
:
(
i
*
3
+
3
),
Phi
[
i
,
:,
:,
:,
0
],
Phi
[
i
,
:,
:,
:,
1
]]
# print("mapped_layers:",mapped_layers.shape)
retinal
=
mapped_layers
.
prod
(
0
).
sum
(
3
).
div
(
Phi
.
size
()[
3
])
# print("retinal:",retinal.shape)
return
retinal
\ No newline at end of file
This diff is collapsed.
Click to expand it.
main.py
+
179
−
169
View file @
5069f8ae
...
@@ -16,55 +16,64 @@ import json
...
@@ -16,55 +16,64 @@ import json
from
ssim
import
*
from
ssim
import
*
from
perc_loss
import
*
from
perc_loss
import
*
# param
# param
BATCH_SIZE
=
5
BATCH_SIZE
=
16
NUM_EPOCH
=
5
000
NUM_EPOCH
=
1
000
INTERLEAVE_RATE
=
2
INTERLEAVE_RATE
=
2
IM_H
=
480
IM_H
=
320
IM_W
=
640
IM_W
=
320
Retinal_IM_H
=
320
Retinal_IM_W
=
320
N
=
9
# number of input light field stack
N
=
9
# number of input light field stack
M
=
2
# number of display layers
M
=
2
# number of display layers
DATA_FILE
=
"
/home/yejiannan/Project/LightField/data/
try
"
DATA_FILE
=
"
/home/yejiannan/Project/LightField/data/
gaze_small_nar_new
"
DATA_JSON
=
"
/home/yejiannan/Project/LightField/data/data.json
"
DATA_JSON
=
"
/home/yejiannan/Project/LightField/data/data
_gaze_low_new
.json
"
DATA_VAL_JSON
=
"
/home/yejiannan/Project/LightField/data/data_val.json
"
DATA_VAL_JSON
=
"
/home/yejiannan/Project/LightField/data/data_val.json
"
OUTPUT_DIR
=
"
/home/yejiannan/Project/LightField/output
"
OUTPUT_DIR
=
"
/home/yejiannan/Project/LightField/output
/gaze_low_new_1125_minibatch
"
class
lightFieldDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
class
lightFieldDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
def
__init__
(
self
,
file_dir_path
,
file_json
,
transforms
=
None
):
def
__init__
(
self
,
file_dir_path
,
file_json
,
transforms
=
None
):
self
.
file_dir_path
=
file_dir_path
self
.
file_dir_path
=
file_dir_path
self
.
transforms
=
transforms
self
.
transforms
=
transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with
open
(
file_json
,
encoding
=
'
utf-8
'
)
as
file
:
with
open
(
file_json
,
encoding
=
'
utf-8
'
)
as
file
:
self
.
da
s
tset_desc
=
json
.
loads
(
file
.
read
())
self
.
dat
a
set_desc
=
json
.
loads
(
file
.
read
())
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
da
s
tset_desc
[
"
focaldepth
"
])
return
len
(
self
.
dat
a
set_desc
[
"
focaldepth
"
])
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
lightfield_images
,
gt
,
fd
=
self
.
get_datum
(
idx
)
lightfield_images
,
gt
,
fd
,
gazeX
,
gazeY
,
sample_idx
=
self
.
get_datum
(
idx
)
if
self
.
transforms
:
if
self
.
transforms
:
lightfield_images
=
self
.
transforms
(
lightfield_images
)
lightfield_images
=
self
.
transforms
(
lightfield_images
)
return
(
lightfield_images
,
gt
,
fd
)
return
(
lightfield_images
,
gt
,
fd
,
gazeX
,
gazeY
,
sample_idx
)
def
get_datum
(
self
,
idx
):
def
get_datum
(
self
,
idx
):
lf_image_paths
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dastset_desc
[
"
train
"
][
idx
])
lf_image_paths
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dataset_desc
[
"
train
"
][
idx
])
# print(lf_image_paths)
fd_gt_path
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dataset_desc
[
"
gt
"
][
idx
])
fd_gt_path
=
os
.
path
.
join
(
DATA_FILE
,
self
.
dastset_desc
[
"
gt
"
][
idx
])
# print(fd_gt_path)
lf_images
=
[]
lf_images
=
[]
lf_image_big
=
cv2
.
imread
(
lf_image_paths
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
lf_image_big
=
cv2
.
imread
(
lf_image_paths
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
lf_image_big
=
cv2
.
cvtColor
(
lf_image_big
,
cv2
.
COLOR_BGR2RGB
)
lf_image_big
=
cv2
.
cvtColor
(
lf_image_big
,
cv2
.
COLOR_BGR2RGB
)
for
i
in
range
(
9
):
for
i
in
range
(
9
):
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
# print(lf_image.shape)
lf_images
.
append
(
lf_image
)
lf_images
.
append
(
lf_image
)
gt
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
gt
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
gt
=
cv2
.
cvtColor
(
gt
,
cv2
.
COLOR_BGR2RGB
)
gt
=
cv2
.
cvtColor
(
gt
,
cv2
.
COLOR_BGR2RGB
)
fd
=
self
.
dastset_desc
[
"
focaldepth
"
][
idx
]
## IF GrayScale
return
(
np
.
asarray
(
lf_images
),
gt
,
fd
)
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd
=
self
.
dataset_desc
[
"
focaldepth
"
][
idx
]
gazeX
=
self
.
dataset_desc
[
"
gazeX
"
][
idx
]
gazeY
=
self
.
dataset_desc
[
"
gazeY
"
][
idx
]
sample_idx
=
self
.
dataset_desc
[
"
idx
"
][
idx
]
return
np
.
asarray
(
lf_images
),
gt
,
fd
,
gazeX
,
gazeY
,
sample_idx
OUT_CHANNELS_RB
=
128
OUT_CHANNELS_RB
=
128
KERNEL_SIZE_RB
=
3
KERNEL_SIZE_RB
=
3
...
@@ -128,7 +137,6 @@ class interleave(torch.nn.Module):
...
@@ -128,7 +137,6 @@ class interleave(torch.nn.Module):
output
=
output
.
permute
(
0
,
3
,
1
,
2
)
output
=
output
.
permute
(
0
,
3
,
1
,
2
)
return
output
return
output
LAST_LAYER_CHANNELS
=
6
*
INTERLEAVE_RATE
**
2
LAST_LAYER_CHANNELS
=
6
*
INTERLEAVE_RATE
**
2
FIRSST_LAYER_CHANNELS
=
27
*
INTERLEAVE_RATE
**
2
FIRSST_LAYER_CHANNELS
=
27
*
INTERLEAVE_RATE
**
2
...
@@ -144,37 +152,39 @@ class model(torch.nn.Module):
...
@@ -144,37 +152,39 @@ class model(torch.nn.Module):
)
)
self
.
residual_block1
=
residual_block
(
0
)
self
.
residual_block1
=
residual_block
(
0
)
self
.
residual_block2
=
residual_block
(
1
)
self
.
residual_block2
=
residual_block
(
3
)
self
.
residual_block3
=
residual_block
(
1
)
self
.
residual_block3
=
residual_block
(
3
)
self
.
residual_block4
=
residual_block
(
3
)
self
.
residual_block5
=
residual_block
(
3
)
self
.
output_layer
=
torch
.
nn
.
Sequential
(
self
.
output_layer
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
1
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
3
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
LAST_LAYER_CHANNELS
),
torch
.
nn
.
BatchNorm2d
(
LAST_LAYER_CHANNELS
),
torch
.
nn
.
Sigmoid
()
torch
.
nn
.
Sigmoid
()
)
)
self
.
deinterleave
=
deinterleave
(
INTERLEAVE_RATE
)
self
.
deinterleave
=
deinterleave
(
INTERLEAVE_RATE
)
def
forward
(
self
,
lightfield_images
,
focal_length
):
def
forward
(
self
,
lightfield_images
,
focal_length
,
gazeX
,
gazeY
):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
input_to_net
=
self
.
interleave
(
lightfield_images
)
input_to_net
=
self
.
interleave
(
lightfield_images
)
# print("after interleave:",input_to_net.shape)
input_to_rb
=
self
.
first_layer
(
input_to_net
)
input_to_rb
=
self
.
first_layer
(
input_to_net
)
output
=
self
.
residual_block1
(
input_to_rb
)
output
=
self
.
residual_block1
(
input_to_rb
)
# print("output1:",output.shape)
depth_layer
=
torch
.
ones
((
input_to_rb
.
shape
[
0
],
1
,
input_to_rb
.
shape
[
2
],
input_to_rb
.
shape
[
3
]))
gazeX_layer
=
torch
.
ones
((
input_to_rb
.
shape
[
0
],
1
,
input_to_rb
.
shape
[
2
],
input_to_rb
.
shape
[
3
]))
depth_layer
=
torch
.
ones
((
output
.
shape
[
0
],
1
,
output
.
shape
[
2
],
output
.
shape
[
3
]))
gazeY_layer
=
torch
.
ones
((
input_to_rb
.
shape
[
0
],
1
,
input_to_rb
.
shape
[
2
],
input_to_rb
.
shape
[
3
]))
# print(df.shape[0])
for
i
in
range
(
focal_length
.
shape
[
0
]):
for
i
in
range
(
focal_length
.
shape
[
0
]):
depth_layer
[
i
]
=
1.
/
focal_length
[
i
]
depth_layer
[
i
]
*=
1.
/
focal_length
[
i
]
# print(depth_layer.shape)
gazeX_layer
[
i
]
*=
(
gazeX
[
i
]
-
(
-
3.333
))
/
(
3.333
*
2
)
gazeY_layer
[
i
]
*=
(
gazeY
[
i
]
-
(
-
3.333
))
/
(
3.333
*
2
)
depth_layer
=
var_or_cuda
(
depth_layer
)
depth_layer
=
var_or_cuda
(
depth_layer
)
output
=
torch
.
cat
((
output
,
depth_layer
),
dim
=
1
)
gazeX_layer
=
var_or_cuda
(
gazeX_layer
)
gazeY_layer
=
var_or_cuda
(
gazeY_layer
)
output
=
torch
.
cat
((
output
,
depth_layer
,
gazeX_layer
,
gazeY_layer
),
dim
=
1
)
output
=
self
.
residual_block2
(
output
)
output
=
self
.
residual_block2
(
output
)
output
=
self
.
residual_block3
(
output
)
output
=
self
.
residual_block3
(
output
)
# output = output + input_to_net
output
=
self
.
residual_block4
(
output
)
output
=
self
.
residual_block5
(
output
)
output
=
self
.
output_layer
(
output
)
output
=
self
.
output_layer
(
output
)
output
=
self
.
deinterleave
(
output
)
output
=
self
.
deinterleave
(
output
)
return
output
return
output
...
@@ -182,72 +192,65 @@ class model(torch.nn.Module):
...
@@ -182,72 +192,65 @@ class model(torch.nn.Module):
class
Conf
(
object
):
class
Conf
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
pupil_size
=
0.02
# 2cm
self
.
pupil_size
=
0.02
# 2cm
self
.
retinal_res
=
torch
.
tensor
([
480
,
640
])
self
.
retinal_res
=
torch
.
tensor
([
Retinal_IM_H
,
Retinal_IM_W
])
self
.
layer_res
=
torch
.
tensor
([
480
,
640
])
self
.
layer_res
=
torch
.
tensor
([
IM_H
,
IM_W
])
self
.
n_layers
=
2
self
.
layer_hfov
=
90
# layers' horizontal FOV
self
.
d_layer
=
[
1.
,
3.
]
# layers' distance
self
.
eye_hfov
=
85
# eye's horizontal FOV
self
.
h_layer
=
[
1.
*
480.
/
640.
,
3.
*
480.
/
640.
]
# layers' height
self
.
d_layer
=
[
1
,
3
]
# layers' distance
def
GetNLayers
(
self
):
return
len
(
self
.
d_layer
)
def
GetLayerSize
(
self
,
i
):
w
=
Fov2Length
(
self
.
layer_hfov
)
h
=
w
*
self
.
layer_res
[
0
]
/
self
.
layer_res
[
1
]
return
torch
.
tensor
([
h
,
w
])
*
self
.
d_layer
[
i
]
def
GetEyeViewportSize
(
self
):
w
=
Fov2Length
(
self
.
eye_hfov
)
h
=
w
*
self
.
retinal_res
[
0
]
/
self
.
retinal_res
[
1
]
return
torch
.
tensor
([
h
,
w
])
#### Image Gen
#### Image Gen
conf
=
Conf
()
conf
=
Conf
()
v
=
torch
.
tensor
([
conf
.
h_layer
[
0
]
/
conf
.
d_layer
[
0
],
u
=
GenSamplesInPupil
(
conf
.
pupil_size
,
5
)
conf
.
h_layer
[
0
]
/
conf
.
d_layer
[
0
]
*
conf
.
layer_res
[
1
]
/
conf
.
layer_res
[
0
]])
u
=
GenSamplesInPupil
(
conf
,
5
)
gen
=
RetinalGen
(
conf
,
u
)
def
GenRetinalFromLayersBatch
(
layers
,
conf
,
df
,
v
,
u
):
def
GenRetinalFromLayersBatch
(
layers
,
gen
,
sample_idx
,
phi_dict
,
mask_dict
):
# layers: batchsize, 2
,
color, height, width
# layers: batchsize, 2
*
color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# df : batchsize,..
# df : batchsize,..
H_r
=
conf
.
retinal_res
[
0
]
W_r
=
conf
.
retinal_res
[
1
]
# retinal bs x color x height x width
D_r
=
conf
.
retinal_res
.
double
()
retinal
=
torch
.
zeros
(
layers
.
shape
[
0
],
3
,
Retinal_IM_H
,
Retinal_IM_W
)
N
=
conf
.
n_layers
mask
=
[]
# mask shape 480 x 640
M
=
u
.
size
()[
0
]
#41
for
i
in
range
(
0
,
layers
.
size
()[
0
]):
BS
=
df
.
shape
[
0
]
phi
=
phi_dict
[
int
(
sample_idx
[
i
].
data
)]
Phi
=
torch
.
empty
(
BS
,
H_r
,
W_r
,
N
,
M
,
2
,
dtype
=
torch
.
long
)
phi
=
var_or_cuda
(
phi
)
# print("Phi:",Phi.shape)
phi
.
requires_grad
=
False
retinal
[
i
]
=
gen
.
GenRetinalFromLayers
(
layers
[
i
],
phi
)
p_rx
,
p_ry
=
torch
.
meshgrid
(
torch
.
tensor
(
range
(
0
,
H_r
)),
mask
.
append
(
mask_dict
[
int
(
sample_idx
[
i
].
data
)])
torch
.
tensor
(
range
(
0
,
W_r
)))
p_r
=
torch
.
stack
([
p_rx
,
p_ry
],
2
).
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
M
,
-
1
)
# print("p_r:",p_r.shape) #torch.Size([480, 640, 41, 2])
for
bs
in
range
(
BS
):
for
i
in
range
(
0
,
N
):
dpi
=
conf
.
h_layer
[
i
]
/
float
(
conf
.
layer_res
[
0
])
# 1 / 480
# print("dpi:",dpi)
ci
=
conf
.
layer_res
/
2
# [240,320]
di
=
conf
.
d_layer
[
i
]
# 深度
pi_r
=
di
*
v
*
(
1.
/
D_r
*
(
p_r
+
0.5
)
-
0.5
)
/
dpi
# [480, 640, 41, 2]
wi
=
(
1
-
di
/
df
[
bs
])
/
dpi
# (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75
pi
=
torch
.
floor
(
pi_r
+
ci
+
wi
*
u
)
torch
.
clamp_
(
pi
[:,
:,
:,
0
],
0
,
conf
.
layer_res
[
0
]
-
1
)
torch
.
clamp_
(
pi
[:,
:,
:,
1
],
0
,
conf
.
layer_res
[
1
]
-
1
)
Phi
[
bs
,
:,
:,
i
,
:,
:]
=
pi
# print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
retinal
=
torch
.
ones
(
BS
,
3
,
H_r
,
W_r
)
retinal
=
var_or_cuda
(
retinal
)
retinal
=
var_or_cuda
(
retinal
)
for
bs
in
range
(
BS
):
mask
=
torch
.
stack
(
mask
,
dim
=
0
).
unsqueeze
(
1
)
# batch x 1 x height x width
for
j
in
range
(
0
,
M
):
return
retinal
,
mask
retinal_view
=
torch
.
ones
(
3
,
H_r
,
W_r
)
retinal_view
=
var_or_cuda
(
retinal_view
)
for
i
in
range
(
0
,
N
):
retinal_view
.
mul_
(
layers
[
bs
,
(
i
*
3
)
:
(
i
*
3
+
3
),
Phi
[
bs
,
:,
:,
i
,
j
,
0
],
Phi
[
bs
,
:,
:,
i
,
j
,
1
]])
retinal
[
bs
,:,:,:].
add_
(
retinal_view
)
retinal
[
bs
,:,:,:].
div_
(
M
)
return
retinal
#### Image Gen End
def
merge_two
(
near
,
far
):
def
GenRetinalFromLayersBatch_Online
(
layers
,
gen
,
phi
,
mask
):
df
=
conf
.
d_layer
[
0
]
+
(
conf
.
d_layer
[
1
]
-
conf
.
d_layer
[
0
])
/
2.
# layers: batchsize, 2*color, height, width
# Phi = GenRetinal2LayerMappings(conf, df, v, u)
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# retinal = GenRetinalFromLayers(layers, Phi)
# df : batchsize,..
return
near
[:,
0
:
3
,:,:]
+
far
[:,
3
:
6
,:,:]
/
2.0
def
loss_two_images
(
generated
,
gt
):
# retinal bs x color x height x width
l1_loss
=
torch
.
nn
.
L1Loss
()
# retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
return
l1_loss
(
generated
,
gt
)
# retinal = var_or_cuda(retinal)
phi
=
var_or_cuda
(
phi
)
phi
.
requires_grad
=
False
retinal
=
gen
.
GenRetinalFromLayers
(
layers
[
0
],
phi
)
retinal
=
var_or_cuda
(
retinal
)
mask_out
=
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
return
retinal
.
unsqueeze
(
0
),
mask_out
#### Image Gen End
weightVarScale
=
0.25
weightVarScale
=
0.25
bias_stddev
=
0.01
bias_stddev
=
0.01
...
@@ -269,7 +272,7 @@ def var_or_cuda(x):
...
@@ -269,7 +272,7 @@ def var_or_cuda(x):
def
calImageGradients
(
images
):
def
calImageGradients
(
images
):
# x is a 4-D tensor
# x is a 4-D tensor
dx
=
images
[:,
:,
1
:,
:]
-
images
[:,
:,
:
-
1
,
:]
dx
=
images
[:,
:,
1
:,
:]
-
images
[:,
:,
:
-
1
,
:]
dy
=
images
[:,
1
:,
:,
:]
-
images
[:,
:
-
1
,
:,
:]
dy
=
images
[:,
:,
:,
1
:]
-
images
[:,
:,
:,
:
-
1
]
return
dx
,
dy
return
dx
,
dy
...
@@ -279,16 +282,13 @@ perc_loss = perc_loss.to("cuda")
...
@@ -279,16 +282,13 @@ perc_loss = perc_loss.to("cuda")
def
loss_new
(
generated
,
gt
):
def
loss_new
(
generated
,
gt
):
mse_loss
=
torch
.
nn
.
MSELoss
()
mse_loss
=
torch
.
nn
.
MSELoss
()
rmse_intensity
=
mse_loss
(
generated
,
gt
)
rmse_intensity
=
mse_loss
(
generated
,
gt
)
RENORM_SCALE
=
torch
.
tensor
(
0.9
)
RENORM_SCALE
=
var_or_cuda
(
RENORM_SCALE
)
psnr_intensity
=
torch
.
log10
(
rmse_intensity
)
psnr_intensity
=
torch
.
log10
(
rmse_intensity
)
ssim_intensity
=
ssim
(
generated
,
gt
)
labels_dx
,
labels_dy
=
calImageGradients
(
gt
)
labels_dx
,
labels_dy
=
calImageGradients
(
gt
)
preds_dx
,
preds_dy
=
calImageGradients
(
generated
)
preds_dx
,
preds_dy
=
calImageGradients
(
generated
)
rmse_grad_x
,
rmse_grad_y
=
mse_loss
(
labels_dx
,
preds_dx
),
mse_loss
(
labels_dy
,
preds_dy
)
rmse_grad_x
,
rmse_grad_y
=
mse_loss
(
labels_dx
,
preds_dx
),
mse_loss
(
labels_dy
,
preds_dy
)
psnr_grad_x
,
psnr_grad_y
=
torch
.
log10
(
rmse_grad_x
),
torch
.
log10
(
rmse_grad_y
)
psnr_grad_x
,
psnr_grad_y
=
torch
.
log10
(
rmse_grad_x
),
torch
.
log10
(
rmse_grad_y
)
p_loss
=
perc_loss
(
generated
,
gt
)
p_loss
=
perc_loss
(
generated
,
gt
)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss
=
10
+
psnr_intensity
+
0.5
*
(
psnr_grad_x
+
psnr_grad_y
)
+
p_loss
total_loss
=
10
+
psnr_intensity
+
0.5
*
(
psnr_grad_x
+
psnr_grad_y
)
+
p_loss
return
total_loss
return
total_loss
...
@@ -301,15 +301,56 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
...
@@ -301,15 +301,56 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
}
}
torch
.
save
(
checkpoint
,
file_path
)
torch
.
save
(
checkpoint
,
file_path
)
mode
=
"
val
"
def
hook_fn_back
(
m
,
i
,
o
):
for
grad
in
i
:
try
:
print
(
"
Input Grad:
"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"
None found for Gradient
"
)
for
grad
in
o
:
try
:
print
(
"
Output Grad:
"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"
None found for Gradient
"
)
print
(
"
\n
"
)
def
hook_fn_for
(
m
,
i
,
o
):
for
grad
in
i
:
try
:
print
(
"
Input Feats:
"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"
None found for Gradient
"
)
for
grad
in
o
:
try
:
print
(
"
Output Feats:
"
,
m
,
grad
.
shape
,
grad
.
sum
())
except
AttributeError
:
print
(
"
None found for Gradient
"
)
print
(
"
\n
"
)
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
#test
# train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
# print(train_dataset[0][0].shape)
# cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end
############################## generate phi and mask in pre-training
phi_dict
=
{}
mask_dict
=
{}
idx_info_dict
=
{}
print
(
"
generating phi and mask...
"
)
with
open
(
DATA_JSON
,
encoding
=
'
utf-8
'
)
as
file
:
dataset_desc
=
json
.
loads
(
file
.
read
())
for
i
in
range
(
len
(
dataset_desc
[
"
focaldepth
"
])):
# if i == 2:
# break
idx
=
dataset_desc
[
"
idx
"
][
i
]
focaldepth
=
dataset_desc
[
"
focaldepth
"
][
i
]
gazeX
=
dataset_desc
[
"
gazeX
"
][
i
]
gazeY
=
dataset_desc
[
"
gazeY
"
][
i
]
# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
phi
,
mask
=
gen
.
CalculateRetinal2LayerMappings
(
focaldepth
,
torch
.
tensor
([
gazeX
,
gazeY
]))
phi_dict
[
idx
]
=
phi
mask_dict
[
idx
]
=
mask
idx_info_dict
[
idx
]
=
[
idx
,
focaldepth
,
gazeX
,
gazeY
]
print
(
"
generating phi and mask end.
"
)
# exit(0)
#train
#train
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_JSON
),
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_JSON
),
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
...
@@ -319,82 +360,51 @@ if __name__ == "__main__":
...
@@ -319,82 +360,51 @@ if __name__ == "__main__":
drop_last
=
False
)
drop_last
=
False
)
print
(
len
(
train_data_loader
))
print
(
len
(
train_data_loader
))
val_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_VAL_JSON
),
# exit(0)
batch_size
=
1
,
num_workers
=
0
,
pin_memory
=
True
,
shuffle
=
False
,
drop_last
=
False
)
print
(
len
(
val_data_loader
))
################################################ train #########################################################
device
=
torch
.
device
(
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
)
device
=
torch
.
device
(
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
)
lf_model
=
model
()
lf_model
=
model
()
lf_model
.
apply
(
weight_init_normal
)
epoch_begin
=
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
lf_model
=
torch
.
nn
.
DataParallel
(
lf_model
).
cuda
()
lf_model
=
torch
.
nn
.
DataParallel
(
lf_model
).
cuda
()
lf_model
.
train
()
optimizer
=
torch
.
optim
.
Adam
(
lf_model
.
parameters
(),
lr
=
1e-2
,
betas
=
(
0.9
,
0.999
))
#val
print
(
"
begin training....
"
)
checkpoint
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
"
ckpt-epoch-3001.pth
"
))
for
epoch
in
range
(
epoch_begin
,
NUM_EPOCH
):
lf_model
.
load_state_dict
(
checkpoint
[
"
model_state_dict
"
])
for
batch_idx
,
(
image_set
,
gt
,
df
,
gazeX
,
gazeY
,
sample_idx
)
in
enumerate
(
train_data_loader
):
lf_model
.
eval
()
print
(
"
Eval::
"
)
for
sample_idx
,
(
image_set
,
gt
,
df
)
in
enumerate
(
val_data_loader
):
print
(
"
sample_idx::
"
)
with
torch
.
no_grad
():
#reshape for input
#reshape for input
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N, LFxC, H, W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N, LFxC, H, W
image_set
=
var_or_cuda
(
image_set
)
image_set
=
var_or_cuda
(
image_set
)
# image_set.to(device)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
gt
=
var_or_cuda
(
gt
)
gt
=
var_or_cuda
(
gt
)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
output
=
lf_model
(
image_set
,
df
)
print
(
"
output:
"
,
output
.
shape
,
"
df:
"
,
df
)
save_image
(
output
[
0
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
1113_interp_l1_%.3f.png
"
%
(
df
[
0
].
data
)))
save_image
(
output
[
0
][
3
:
6
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
1113_interp_l2_%.3f.png
"
%
(
df
[
0
].
data
)))
output
=
GenRetinalFromLayersBatch
(
output
,
conf
,
df
,
v
,
u
)
save_image
(
output
[
0
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
1113_interp_o%.3f.png
"
%
(
df
[
0
].
data
)))
exit
()
# train
# print(lf_model)
# exit()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# lf_model = model()
# lf_model.apply(weight_init_normal)
# if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
# lf_model.train()
# optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999))
# for epoch in range(NUM_EPOCH):
# for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
# #reshape for input
# image_set = image_set.permute(0,1,4,2,3) # N LF C H W
# image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
# image_set = var_or_cuda(image_set)
# # image_set.to(device)
# gt = gt.permute(0,3,1,2)
# gt = var_or_cuda(gt)
# # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# optimizer.zero_grad()
# output = lf_model(image_set,df)
# # print("output:",output.shape," df:",df.shape)
# output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# loss = loss_new(output,gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
# loss.backward()
# optimizer.step()
# if (epoch%100 == 0):
# for i in range(BATCH_SIZE):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i)))
# if (epoch%1000 == 0):
# save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)),
# epoch,lf_model,optimizer)
optimizer
.
zero_grad
()
output
=
lf_model
(
image_set
,
df
,
gazeX
,
gazeY
)
########################### Use Pregen Phi and Mask ###################
output1
,
mask
=
GenRetinalFromLayersBatch
(
output
,
gen
,
sample_idx
,
phi_dict
,
mask_dict
)
mask
=
var_or_cuda
(
mask
)
mask
.
requires_grad
=
False
output_f
=
output1
*
mask
gt
=
gt
*
mask
loss
=
loss_new
(
output_f
,
gt
)
print
(
"
Epoch:
"
,
epoch
,
"
,Iter:
"
,
batch_idx
,
"
,loss:
"
,
loss
)
########################### Update ###################
loss
.
backward
()
optimizer
.
step
()
########################### Save #####################
if
((
epoch
%
50
==
0
)
or
epoch
==
5
):
for
i
in
range
(
output_f
.
size
()[
0
]):
save_image
(
output
[
i
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
gaze_fac1_o_%.3f_%.3f_%.3f.png
"
%
(
df
[
i
].
data
,
gazeX
[
i
].
data
,
gazeY
[
i
].
data
)))
save_image
(
output
[
i
][
3
:
6
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
gaze_fac2_o_%.3f_%.3f_%.3f.png
"
%
(
df
[
i
].
data
,
gazeX
[
i
].
data
,
gazeY
[
i
].
data
)))
save_image
(
output_f
[
i
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
gaze_test1_o_%.3f_%.3f_%.3f.png
"
%
(
df
[
i
].
data
,
gazeX
[
i
].
data
,
gazeY
[
i
].
data
)))
if
((
epoch
%
200
==
0
)
and
epoch
!=
0
and
batch_idx
==
len
(
train_data_loader
)
-
1
):
save_checkpoints
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
gaze-ckpt-epoch-%04d.pth
'
%
(
epoch
+
1
)),
epoch
,
lf_model
,
optimizer
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment