Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
648dfd2c
Commit
648dfd2c
authored
Nov 10, 2020
by
BobYeah
Browse files
focal depth insert middle
parent
4d8a98da
Changes
1
Hide whitespace changes
Inline
Side-by-side
main.py
View file @
648dfd2c
...
...
@@ -25,9 +25,9 @@ IM_W = 640
N
=
9
# number of input light field stack
M
=
2
# number of display layers
DATA_FILE
=
"/home/yejiannan/Project/
deepl
ight
f
ield/data/try"
DATA_JSON
=
"/home/yejiannan/Project/
deepl
ight
f
ield/data/data.json"
OUTPUT_DIR
=
"/home/yejiannan/Project/
deepl
ight
f
ield/output"
DATA_FILE
=
"/home/yejiannan/Project/
L
ight
F
ield/data/try"
DATA_JSON
=
"/home/yejiannan/Project/
L
ight
F
ield/data/data.json"
OUTPUT_DIR
=
"/home/yejiannan/Project/
L
ight
F
ield/output"
class
lightFieldDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
def
__init__
(
self
,
file_dir_path
,
file_json
,
transforms
=
None
):
...
...
@@ -68,16 +68,16 @@ KERNEL_SIZE_RB = 3
KERNEL_SIZE
=
3
class
residual_block
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
delta_channel_dim
):
super
(
residual_block
,
self
).
__init__
()
self
.
layer1
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
,
OUT_CHANNELS_RB
,
KERNEL_SIZE_RB
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
OUT_CHANNELS_RB
),
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
delta_channel_dim
,
OUT_CHANNELS_RB
+
delta_channel_dim
,
KERNEL_SIZE_RB
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
OUT_CHANNELS_RB
+
delta_channel_dim
),
torch
.
nn
.
ELU
()
)
self
.
layer2
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
,
OUT_CHANNELS_RB
,
KERNEL_SIZE_RB
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
OUT_CHANNELS_RB
,
OUT_CHANNELS_RB
),
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
delta_channel_dim
,
OUT_CHANNELS_RB
+
delta_channel_dim
,
KERNEL_SIZE_RB
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
OUT_CHANNELS_RB
+
delta_channel_dim
),
torch
.
nn
.
ELU
()
)
...
...
@@ -127,7 +127,7 @@ class interleave(torch.nn.Module):
LAST_LAYER_CHANNELS
=
6
*
INTERLEAVE_RATE
**
2
FIRSST_LAYER_CHANNELS
=
2
8
*
INTERLEAVE_RATE
**
2
FIRSST_LAYER_CHANNELS
=
2
7
*
INTERLEAVE_RATE
**
2
class
model
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -140,12 +140,12 @@ class model(torch.nn.Module):
torch
.
nn
.
ELU
()
)
self
.
residual_block1
=
residual_block
()
self
.
residual_block2
=
residual_block
()
self
.
residual_block3
=
residual_block
()
self
.
residual_block1
=
residual_block
(
0
)
self
.
residual_block2
=
residual_block
(
1
)
self
.
residual_block3
=
residual_block
(
1
)
self
.
output_layer
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
1
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
LAST_LAYER_CHANNELS
),
torch
.
nn
.
Tanh
()
)
...
...
@@ -160,8 +160,16 @@ class model(torch.nn.Module):
input_to_rb
=
self
.
first_layer
(
input_to_net
)
output
=
self
.
residual_block1
(
input_to_rb
)
# print("output1:",output.shape)
output
=
self
.
residual_block2
(
output
)
depth_layer
=
torch
.
ones
((
output
.
shape
[
0
],
1
,
output
.
shape
[
2
],
output
.
shape
[
3
]))
# print(df.shape[0])
for
i
in
range
(
focal_length
.
shape
[
0
]):
depth_layer
[
i
]
=
depth_layer
[
i
]
*
focal_length
[
i
]
# print(depth_layer.shape)
depth_layer
=
var_or_cuda
(
depth_layer
)
output
=
torch
.
cat
((
output
,
depth_layer
),
dim
=
1
)
output
=
self
.
residual_block2
(
output
)
output
=
self
.
residual_block3
(
output
)
# output = output + input_to_net
output
=
self
.
output_layer
(
output
)
...
...
@@ -263,8 +271,6 @@ if __name__ == "__main__":
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_JSON
),
batch_size
=
BATCH_SIZE
,
num_workers
=
0
,
...
...
@@ -285,19 +291,14 @@ if __name__ == "__main__":
#reshape for input
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N, LFxC, H, W
depth_layer
=
torch
.
ones
((
image_set
.
shape
[
0
],
1
,
image_set
.
shape
[
2
],
image_set
.
shape
[
3
]))
# print(df.shape[0])
for
i
in
range
(
df
.
shape
[
0
]):
depth_layer
[
i
]
=
depth_layer
[
i
]
*
df
[
i
]
# print(depth_layer.shape)
image_set
=
torch
.
cat
((
image_set
,
depth_layer
),
dim
=
1
)
image_set
=
var_or_cuda
(
image_set
)
# image_set.to(device)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
gt
=
var_or_cuda
(
gt
)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
optimizer
.
zero_grad
()
output
=
lf_model
(
image_set
,
0
)
output
=
lf_model
(
image_set
,
df
)
# print("output:",output.shape," df:",df.shape)
output
=
GenRetinalFromLayersBatch
(
output
,
conf
,
df
,
v
,
u
)
loss
=
loss_two_images
(
output
,
gt
)
...
...
@@ -305,5 +306,5 @@ if __name__ == "__main__":
loss
.
backward
()
optimizer
.
step
()
for
i
in
range
(
5
):
save_image
(
output
[
i
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"cuda_lr_5e-3_o%d_%d.png"
%
(
epoch
,
i
)))
save_image
(
output
[
i
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"cuda_lr_5e-3_
insertmid_
o%d_%d.png"
%
(
epoch
,
i
)))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment