Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
deeplightfield
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Nianchen Deng
deeplightfield
Commits
055dc0bb
Commit
055dc0bb
authored
4 years ago
by
BobYeah
Browse files
Options
Downloads
Patches
Plain Diff
First Stage
parent
648dfd2c
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
.gitignore
+112
-0
112 additions, 0 deletions
.gitignore
main.py
+115
-25
115 additions, 25 deletions
main.py
perc_loss.py
+38
-0
38 additions, 0 deletions
perc_loss.py
ssim.py
+72
-0
72 additions, 0 deletions
ssim.py
with
337 additions
and
25 deletions
.gitignore
0 → 100644
+
112
−
0
View file @
055dc0bb
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# log
*.txt
*.out
*.ipynb
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# macOS
.DS_Store
# Output
output/
\ No newline at end of file
This diff is collapsed.
Click to expand it.
main.py
+
115
−
25
View file @
055dc0bb
...
@@ -13,6 +13,8 @@ from torch.autograd import Variable
...
@@ -13,6 +13,8 @@ from torch.autograd import Variable
import
cv2
import
cv2
from
gen_image
import
*
from
gen_image
import
*
import
json
import
json
from
ssim
import
*
from
perc_loss
import
*
# param
# param
BATCH_SIZE
=
5
BATCH_SIZE
=
5
NUM_EPOCH
=
5000
NUM_EPOCH
=
5000
...
@@ -27,6 +29,7 @@ M = 2 # number of display layers
...
@@ -27,6 +29,7 @@ M = 2 # number of display layers
DATA_FILE
=
"
/home/yejiannan/Project/LightField/data/try
"
DATA_FILE
=
"
/home/yejiannan/Project/LightField/data/try
"
DATA_JSON
=
"
/home/yejiannan/Project/LightField/data/data.json
"
DATA_JSON
=
"
/home/yejiannan/Project/LightField/data/data.json
"
DATA_VAL_JSON
=
"
/home/yejiannan/Project/LightField/data/data_val.json
"
OUTPUT_DIR
=
"
/home/yejiannan/Project/LightField/output
"
OUTPUT_DIR
=
"
/home/yejiannan/Project/LightField/output
"
class
lightFieldDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
class
lightFieldDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
...
@@ -34,7 +37,7 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
...
@@ -34,7 +37,7 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
self
.
file_dir_path
=
file_dir_path
self
.
file_dir_path
=
file_dir_path
self
.
transforms
=
transforms
self
.
transforms
=
transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with
open
(
DATA_JSON
,
encoding
=
'
utf-8
'
)
as
file
:
with
open
(
file_json
,
encoding
=
'
utf-8
'
)
as
file
:
self
.
dastset_desc
=
json
.
loads
(
file
.
read
())
self
.
dastset_desc
=
json
.
loads
(
file
.
read
())
def
__len__
(
self
):
def
__len__
(
self
):
...
@@ -147,7 +150,7 @@ class model(torch.nn.Module):
...
@@ -147,7 +150,7 @@ class model(torch.nn.Module):
self
.
output_layer
=
torch
.
nn
.
Sequential
(
self
.
output_layer
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
1
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
Conv2d
(
OUT_CHANNELS_RB
+
1
,
LAST_LAYER_CHANNELS
,
KERNEL_SIZE
,
stride
=
1
,
padding
=
1
),
torch
.
nn
.
BatchNorm2d
(
LAST_LAYER_CHANNELS
),
torch
.
nn
.
BatchNorm2d
(
LAST_LAYER_CHANNELS
),
torch
.
nn
.
Tanh
()
torch
.
nn
.
Sigmoid
()
)
)
self
.
deinterleave
=
deinterleave
(
INTERLEAVE_RATE
)
self
.
deinterleave
=
deinterleave
(
INTERLEAVE_RATE
)
...
@@ -164,7 +167,7 @@ class model(torch.nn.Module):
...
@@ -164,7 +167,7 @@ class model(torch.nn.Module):
depth_layer
=
torch
.
ones
((
output
.
shape
[
0
],
1
,
output
.
shape
[
2
],
output
.
shape
[
3
]))
depth_layer
=
torch
.
ones
((
output
.
shape
[
0
],
1
,
output
.
shape
[
2
],
output
.
shape
[
3
]))
# print(df.shape[0])
# print(df.shape[0])
for
i
in
range
(
focal_length
.
shape
[
0
]):
for
i
in
range
(
focal_length
.
shape
[
0
]):
depth_layer
[
i
]
=
depth_layer
[
i
]
*
focal_length
[
i
]
depth_layer
[
i
]
=
1.
/
focal_length
[
i
]
# print(depth_layer.shape)
# print(depth_layer.shape)
depth_layer
=
var_or_cuda
(
depth_layer
)
depth_layer
=
var_or_cuda
(
depth_layer
)
output
=
torch
.
cat
((
output
,
depth_layer
),
dim
=
1
)
output
=
torch
.
cat
((
output
,
depth_layer
),
dim
=
1
)
...
@@ -182,8 +185,8 @@ class Conf(object):
...
@@ -182,8 +185,8 @@ class Conf(object):
self
.
retinal_res
=
torch
.
tensor
([
480
,
640
])
self
.
retinal_res
=
torch
.
tensor
([
480
,
640
])
self
.
layer_res
=
torch
.
tensor
([
480
,
640
])
self
.
layer_res
=
torch
.
tensor
([
480
,
640
])
self
.
n_layers
=
2
self
.
n_layers
=
2
self
.
d_layer
=
[
1.
75
,
3.
5
]
# layers' distance
self
.
d_layer
=
[
1.
,
3.
]
# layers' distance
self
.
h_layer
=
[
1.
,
2
.
]
# layers' height
self
.
h_layer
=
[
1.
*
480.
/
640.
,
3.
*
480.
/
640
.
]
# layers' height
#### Image Gen
#### Image Gen
conf
=
Conf
()
conf
=
Conf
()
...
@@ -223,14 +226,14 @@ def GenRetinalFromLayersBatch(layers, conf, df, v, u):
...
@@ -223,14 +226,14 @@ def GenRetinalFromLayersBatch(layers, conf, df, v, u):
torch
.
clamp_
(
pi
[:,
:,
:,
1
],
0
,
conf
.
layer_res
[
1
]
-
1
)
torch
.
clamp_
(
pi
[:,
:,
:,
1
],
0
,
conf
.
layer_res
[
1
]
-
1
)
Phi
[
bs
,
:,
:,
i
,
:,
:]
=
pi
Phi
[
bs
,
:,
:,
i
,
:,
:]
=
pi
# print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
# print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
retinal
=
torch
.
zero
s
(
BS
,
3
,
H_r
,
W_r
)
retinal
=
torch
.
one
s
(
BS
,
3
,
H_r
,
W_r
)
retinal
=
var_or_cuda
(
retinal
)
retinal
=
var_or_cuda
(
retinal
)
for
bs
in
range
(
BS
):
for
bs
in
range
(
BS
):
for
j
in
range
(
0
,
M
):
for
j
in
range
(
0
,
M
):
retinal_view
=
torch
.
zero
s
(
3
,
H_r
,
W_r
)
retinal_view
=
torch
.
one
s
(
3
,
H_r
,
W_r
)
retinal_view
=
var_or_cuda
(
retinal_view
)
retinal_view
=
var_or_cuda
(
retinal_view
)
for
i
in
range
(
0
,
N
):
for
i
in
range
(
0
,
N
):
retinal_view
.
add
_
(
layers
[
bs
,
(
i
*
3
)
:
(
i
*
3
+
3
),
Phi
[
bs
,
:,
:,
i
,
j
,
0
],
Phi
[
bs
,
:,
:,
i
,
j
,
1
]])
retinal_view
.
mul
_
(
layers
[
bs
,
(
i
*
3
)
:
(
i
*
3
+
3
),
Phi
[
bs
,
:,
:,
i
,
j
,
0
],
Phi
[
bs
,
:,
:,
i
,
j
,
1
]])
retinal
[
bs
,:,:,:].
add_
(
retinal_view
)
retinal
[
bs
,:,:,:].
add_
(
retinal_view
)
retinal
[
bs
,:,:,:].
div_
(
M
)
retinal
[
bs
,:,:,:].
div_
(
M
)
return
retinal
return
retinal
...
@@ -263,6 +266,42 @@ def var_or_cuda(x):
...
@@ -263,6 +266,42 @@ def var_or_cuda(x):
x
=
x
.
cuda
(
non_blocking
=
True
)
x
=
x
.
cuda
(
non_blocking
=
True
)
return
x
return
x
def
calImageGradients
(
images
):
# x is a 4-D tensor
dx
=
images
[:,
:,
1
:,
:]
-
images
[:,
:,
:
-
1
,
:]
dy
=
images
[:,
1
:,
:,
:]
-
images
[:,
:
-
1
,
:,
:]
return
dx
,
dy
perc_loss
=
VGGPerceptualLoss
()
perc_loss
=
perc_loss
.
to
(
"
cuda
"
)
def
loss_new
(
generated
,
gt
):
mse_loss
=
torch
.
nn
.
MSELoss
()
rmse_intensity
=
mse_loss
(
generated
,
gt
)
RENORM_SCALE
=
torch
.
tensor
(
0.9
)
RENORM_SCALE
=
var_or_cuda
(
RENORM_SCALE
)
psnr_intensity
=
torch
.
log10
(
rmse_intensity
)
ssim_intensity
=
ssim
(
generated
,
gt
)
labels_dx
,
labels_dy
=
calImageGradients
(
gt
)
preds_dx
,
preds_dy
=
calImageGradients
(
generated
)
rmse_grad_x
,
rmse_grad_y
=
mse_loss
(
labels_dx
,
preds_dx
),
mse_loss
(
labels_dy
,
preds_dy
)
psnr_grad_x
,
psnr_grad_y
=
torch
.
log10
(
rmse_grad_x
),
torch
.
log10
(
rmse_grad_y
)
p_loss
=
perc_loss
(
generated
,
gt
)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss
=
10
+
psnr_intensity
+
0.5
*
(
psnr_grad_x
+
psnr_grad_y
)
+
p_loss
return
total_loss
def
save_checkpoints
(
file_path
,
epoch_idx
,
model
,
model_solver
):
print
(
'
[INFO] Saving checkpoint to %s ...
'
%
(
file_path
))
checkpoint
=
{
'
epoch_idx
'
:
epoch_idx
,
'
model_state_dict
'
:
model
.
state_dict
(),
'
model_solver_state_dict
'
:
model_solver
.
state_dict
()
}
torch
.
save
(
checkpoint
,
file_path
)
mode
=
"
val
"
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
#test
#test
# train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
# train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
...
@@ -270,41 +309,92 @@ if __name__ == "__main__":
...
@@ -270,41 +309,92 @@ if __name__ == "__main__":
# cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
# cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end
#test end
#train
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_JSON
),
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_JSON
),
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
num_workers
=
0
,
num_workers
=
0
,
pin_memory
=
True
,
pin_memory
=
True
,
shuffle
=
Fals
e
,
shuffle
=
Tru
e
,
drop_last
=
False
)
drop_last
=
False
)
print
(
len
(
train_data_loader
))
print
(
len
(
train_data_loader
))
val_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
lightFieldDataLoader
(
DATA_FILE
,
DATA_VAL_JSON
),
batch_size
=
1
,
num_workers
=
0
,
pin_memory
=
True
,
shuffle
=
False
,
drop_last
=
False
)
print
(
len
(
val_data_loader
))
device
=
torch
.
device
(
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
)
device
=
torch
.
device
(
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
)
lf_model
=
model
()
lf_model
=
model
()
lf_model
.
apply
(
weight_init_normal
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
lf_model
=
torch
.
nn
.
DataParallel
(
lf_model
).
cuda
()
lf_model
=
torch
.
nn
.
DataParallel
(
lf_model
).
cuda
()
lf_model
.
train
()
optimizer
=
torch
.
optim
.
Adam
(
lf_model
.
parameters
(),
lr
=
5e-3
,
betas
=
(
0.9
,
0.999
))
#val
checkpoint
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
"
ckpt-epoch-3001.pth
"
))
for
epoch
in
range
(
NUM_EPOCH
):
lf_model
.
load_state_dict
(
checkpoint
[
"
model_state_dict
"
])
for
batch_idx
,
(
image_set
,
gt
,
df
)
in
enumerate
(
train_data_loader
):
lf_model
.
eval
()
print
(
"
Eval::
"
)
for
sample_idx
,
(
image_set
,
gt
,
df
)
in
enumerate
(
val_data_loader
):
print
(
"
sample_idx::
"
)
with
torch
.
no_grad
():
#reshape for input
#reshape for input
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
permute
(
0
,
1
,
4
,
2
,
3
)
# N LF C H W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N, LFxC, H, W
image_set
=
image_set
.
reshape
(
image_set
.
shape
[
0
],
-
1
,
image_set
.
shape
[
3
],
image_set
.
shape
[
4
])
# N, LFxC, H, W
image_set
=
var_or_cuda
(
image_set
)
image_set
=
var_or_cuda
(
image_set
)
# image_set.to(device)
# image_set.to(device)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
gt
=
gt
.
permute
(
0
,
3
,
1
,
2
)
gt
=
var_or_cuda
(
gt
)
gt
=
var_or_cuda
(
gt
)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
optimizer
.
zero_grad
()
output
=
lf_model
(
image_set
,
df
)
output
=
lf_model
(
image_set
,
df
)
# print("output:",output.shape," df:",df.shape)
print
(
"
output:
"
,
output
.
shape
,
"
df:
"
,
df
)
save_image
(
output
[
0
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
1113_interp_l1_%.3f.png
"
%
(
df
[
0
].
data
)))
save_image
(
output
[
0
][
3
:
6
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
1113_interp_l2_%.3f.png
"
%
(
df
[
0
].
data
)))
output
=
GenRetinalFromLayersBatch
(
output
,
conf
,
df
,
v
,
u
)
output
=
GenRetinalFromLayersBatch
(
output
,
conf
,
df
,
v
,
u
)
loss
=
loss_two_images
(
output
,
gt
)
save_image
(
output
[
0
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
1113_interp_o%.3f.png
"
%
(
df
[
0
].
data
)))
print
(
"
Epoch:
"
,
epoch
,
"
,Iter:
"
,
batch_idx
,
"
,loss:
"
,
loss
)
exit
()
loss
.
backward
()
optimizer
.
step
()
for
i
in
range
(
5
):
save_image
(
output
[
i
][
0
:
3
].
data
,
os
.
path
.
join
(
OUTPUT_DIR
,
"
cuda_lr_5e-3_insertmid_o%d_%d.png
"
%
(
epoch
,
i
)))
# train
# print(lf_model)
# exit()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# lf_model = model()
# lf_model.apply(weight_init_normal)
# if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
# lf_model.train()
# optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999))
# for epoch in range(NUM_EPOCH):
# for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
# #reshape for input
# image_set = image_set.permute(0,1,4,2,3) # N LF C H W
# image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
# image_set = var_or_cuda(image_set)
# # image_set.to(device)
# gt = gt.permute(0,3,1,2)
# gt = var_or_cuda(gt)
# # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# optimizer.zero_grad()
# output = lf_model(image_set,df)
# # print("output:",output.shape," df:",df.shape)
# output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# loss = loss_new(output,gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
# loss.backward()
# optimizer.step()
# if (epoch%100 == 0):
# for i in range(BATCH_SIZE):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i)))
# if (epoch%1000 == 0):
# save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)),
# epoch,lf_model,optimizer)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
perc_loss.py
0 → 100644
+
38
−
0
View file @
055dc0bb
import
torch
import
torchvision
class
VGGPerceptualLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
resize
=
True
):
super
(
VGGPerceptualLoss
,
self
).
__init__
()
blocks
=
[]
blocks
.
append
(
torchvision
.
models
.
vgg16
(
pretrained
=
True
).
features
[:
4
].
eval
())
blocks
.
append
(
torchvision
.
models
.
vgg16
(
pretrained
=
True
).
features
[
4
:
9
].
eval
())
blocks
.
append
(
torchvision
.
models
.
vgg16
(
pretrained
=
True
).
features
[
9
:
16
].
eval
())
blocks
.
append
(
torchvision
.
models
.
vgg16
(
pretrained
=
True
).
features
[
16
:
23
].
eval
())
for
bl
in
blocks
:
for
p
in
bl
:
p
.
requires_grad
=
False
self
.
blocks
=
torch
.
nn
.
ModuleList
(
blocks
)
self
.
transform
=
torch
.
nn
.
functional
.
interpolate
self
.
mean
=
torch
.
nn
.
Parameter
(
torch
.
tensor
([
0.485
,
0.456
,
0.406
]).
view
(
1
,
3
,
1
,
1
))
self
.
std
=
torch
.
nn
.
Parameter
(
torch
.
tensor
([
0.229
,
0.224
,
0.225
]).
view
(
1
,
3
,
1
,
1
))
self
.
resize
=
resize
def
forward
(
self
,
input
,
target
):
if
input
.
shape
[
1
]
!=
3
:
input
=
input
.
repeat
(
1
,
3
,
1
,
1
)
target
=
target
.
repeat
(
1
,
3
,
1
,
1
)
input
=
(
input
-
self
.
mean
)
/
self
.
std
target
=
(
target
-
self
.
mean
)
/
self
.
std
if
self
.
resize
:
input
=
self
.
transform
(
input
,
mode
=
'
bilinear
'
,
size
=
(
224
,
224
),
align_corners
=
False
)
target
=
self
.
transform
(
target
,
mode
=
'
bilinear
'
,
size
=
(
224
,
224
),
align_corners
=
False
)
loss
=
0.0
x
=
input
y
=
target
for
block
in
self
.
blocks
:
x
=
block
(
x
)
y
=
block
(
y
)
loss
+=
torch
.
nn
.
functional
.
l1_loss
(
x
,
y
)
return
loss
\ No newline at end of file
This diff is collapsed.
Click to expand it.
ssim.py
0 → 100644
+
72
−
0
View file @
055dc0bb
import
torch
import
torch.nn.functional
as
F
from
torch.autograd
import
Variable
import
numpy
as
np
from
math
import
exp
def
gaussian
(
window_size
,
sigma
):
gauss
=
torch
.
Tensor
([
exp
(
-
(
x
-
window_size
//
2
)
**
2
/
float
(
2
*
sigma
**
2
))
for
x
in
range
(
window_size
)])
return
gauss
/
gauss
.
sum
()
def
create_window
(
window_size
,
channel
):
_1D_window
=
gaussian
(
window_size
,
1.5
).
unsqueeze
(
1
)
_2D_window
=
_1D_window
.
mm
(
_1D_window
.
t
()).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
window
=
Variable
(
_2D_window
.
expand
(
channel
,
1
,
window_size
,
window_size
).
contiguous
())
return
window
def
_ssim
(
img1
,
img2
,
window
,
window_size
,
channel
,
size_average
=
True
):
mu1
=
F
.
conv2d
(
img1
,
window
,
padding
=
window_size
//
2
,
groups
=
channel
)
mu2
=
F
.
conv2d
(
img2
,
window
,
padding
=
window_size
//
2
,
groups
=
channel
)
mu1_sq
=
mu1
.
pow
(
2
)
mu2_sq
=
mu2
.
pow
(
2
)
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
F
.
conv2d
(
img1
*
img1
,
window
,
padding
=
window_size
//
2
,
groups
=
channel
)
-
mu1_sq
sigma2_sq
=
F
.
conv2d
(
img2
*
img2
,
window
,
padding
=
window_size
//
2
,
groups
=
channel
)
-
mu2_sq
sigma12
=
F
.
conv2d
(
img1
*
img2
,
window
,
padding
=
window_size
//
2
,
groups
=
channel
)
-
mu1_mu2
C1
=
0.01
**
2
C2
=
0.03
**
2
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
(
2
*
sigma12
+
C2
))
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
(
sigma1_sq
+
sigma2_sq
+
C2
))
if
size_average
:
return
ssim_map
.
mean
()
else
:
return
ssim_map
.
mean
(
1
).
mean
(
1
).
mean
(
1
)
class
SSIM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
window_size
=
11
,
size_average
=
True
):
super
(
SSIM
,
self
).
__init__
()
self
.
window_size
=
window_size
self
.
size_average
=
size_average
self
.
channel
=
1
self
.
window
=
create_window
(
window_size
,
self
.
channel
)
def
forward
(
self
,
img1
,
img2
):
(
_
,
channel
,
_
,
_
)
=
img1
.
size
()
if
channel
==
self
.
channel
and
self
.
window
.
data
.
type
()
==
img1
.
data
.
type
():
window
=
self
.
window
else
:
window
=
create_window
(
self
.
window_size
,
channel
)
if
img1
.
is_cuda
:
window
=
window
.
cuda
(
img1
.
get_device
())
window
=
window
.
type_as
(
img1
)
self
.
window
=
window
self
.
channel
=
channel
return
_ssim
(
img1
,
img2
,
window
,
self
.
window_size
,
channel
,
self
.
size_average
)
def
ssim
(
img1
,
img2
,
window_size
=
11
,
size_average
=
True
):
(
_
,
channel
,
_
,
_
)
=
img1
.
size
()
window
=
create_window
(
window_size
,
channel
)
if
img1
.
is_cuda
:
window
=
window
.
cuda
(
img1
.
get_device
())
window
=
window
.
type_as
(
img1
)
return
_ssim
(
img1
,
img2
,
window
,
window_size
,
channel
,
size_average
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment