Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
3554ba52
Commit
3554ba52
authored
Jan 12, 2021
by
Nianchen Deng
Browse files
sync
parent
f7038e26
Changes
53
Hide whitespace changes
Inline
Side-by-side
run_spherical_view_syn.py
View file @
3554ba52
...
@@ -5,31 +5,41 @@ import argparse
...
@@ -5,31 +5,41 @@ import argparse
import
torch
import
torch
import
torch.optim
import
torch.optim
import
torchvision
import
torchvision
import
numpy
as
np
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
torch
import
nn
from
torch
import
nn
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deep
lightfield
"
__package__
=
"deep
_view_syn
"
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=
3
,
parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=
3
,
help
=
'Which CUDA device to use.'
)
help
=
'Which CUDA device to use.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Net config files'
)
help
=
'Net config files'
)
parser
.
add_argument
(
'--config-id'
,
type
=
str
,
help
=
'Net config id'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
required
=
True
,
help
=
'Dataset description file'
)
help
=
'Dataset description file'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
help
=
'Max epochs for train'
)
parser
.
add_argument
(
'--test'
,
type
=
str
,
parser
.
add_argument
(
'--test'
,
type
=
str
,
help
=
'Test net file'
)
help
=
'Test net file'
)
parser
.
add_argument
(
'--test-samples'
,
type
=
int
,
parser
.
add_argument
(
'--test-samples'
,
type
=
int
,
help
=
'Samples used for test'
)
help
=
'Samples used for test'
)
parser
.
add_argument
(
'--res'
,
type
=
str
,
help
=
'Resolution'
)
parser
.
add_argument
(
'--output-gt'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--output-gt'
,
action
=
'store_true'
,
help
=
'Output ground truth images if exist'
)
help
=
'Output ground truth images if exist'
)
parser
.
add_argument
(
'--output-alongside'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--output-alongside'
,
action
=
'store_true'
,
help
=
'Output generated image alongside ground truth image'
)
help
=
'Output generated image alongside ground truth image'
)
parser
.
add_argument
(
'--output-video'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--output-video'
,
action
=
'store_true'
,
help
=
'Output test results as video'
)
help
=
'Output test results as video'
)
parser
.
add_argument
(
'--perf'
,
action
=
'store_true'
,
help
=
'Test performance'
)
opt
=
parser
.
parse_args
()
opt
=
parser
.
parse_args
()
if
opt
.
res
:
opt
.
res
=
tuple
(
int
(
s
)
for
s
in
opt
.
res
.
split
(
'x'
))
# Select device
# Select device
torch
.
cuda
.
set_device
(
opt
.
device
)
torch
.
cuda
.
set_device
(
opt
.
device
)
...
@@ -58,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False
...
@@ -58,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False
# Train
# Train
BATCH_SIZE
=
4096
BATCH_SIZE
=
4096
EPOCH_RANGE
=
range
(
0
,
500
)
EPOCH_RANGE
=
range
(
0
,
opt
.
epochs
if
opt
.
epochs
else
500
)
SAVE_INTERVAL
=
2
0
SAVE_INTERVAL
=
5
0
# Test
# Test
TEST_BATCH_SIZE
=
1
TEST_BATCH_SIZE
=
1
...
@@ -67,13 +77,14 @@ TEST_MAX_RAYS = 32768
...
@@ -67,13 +77,14 @@ TEST_MAX_RAYS = 32768
# Paths
# Paths
data_desc_path
=
opt
.
dataset
data_desc_path
=
opt
.
dataset
data_desc_name
=
os
.
path
.
split
(
data_desc_path
)
[
1
]
data_desc_name
=
os
.
path
.
split
ext
(
os
.
path
.
basename
(
data_desc_path
)
)[
0
]
if
opt
.
test
:
if
opt
.
test
:
test_net_path
=
opt
.
test
test_net_path
=
opt
.
test
test_net_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
test_net_path
))[
0
]
test_net_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
test_net_path
))[
0
]
run_dir
=
os
.
path
.
dirname
(
test_net_path
)
+
'/'
run_dir
=
os
.
path
.
dirname
(
test_net_path
)
+
'/'
run_id
=
os
.
path
.
basename
(
run_dir
[:
-
1
])
run_id
=
os
.
path
.
basename
(
run_dir
[:
-
1
])
output_dir
=
run_dir
+
'output/%s/%s/'
%
(
test_net_name
,
data_desc_name
)
output_dir
=
run_dir
+
'output/%s/%s%s/'
%
(
test_net_name
,
data_desc_name
,
'_%dx%d'
%
(
opt
.
res
[
0
],
opt
.
res
[
1
])
if
opt
.
res
else
''
)
config
.
from_id
(
run_id
)
config
.
from_id
(
run_id
)
train_mode
=
False
train_mode
=
False
if
opt
.
test_samples
:
if
opt
.
test_samples
:
...
@@ -83,6 +94,8 @@ if opt.test:
...
@@ -83,6 +94,8 @@ if opt.test:
else
:
else
:
if
opt
.
config
:
if
opt
.
config
:
config
.
load
(
opt
.
config
)
config
.
load
(
opt
.
config
)
if
opt
.
config_id
:
config
.
from_id
(
opt
.
config_id
)
data_dir
=
os
.
path
.
dirname
(
data_desc_path
)
+
'/'
data_dir
=
os
.
path
.
dirname
(
data_desc_path
)
+
'/'
run_id
=
config
.
to_id
()
run_id
=
config
.
to_id
()
run_dir
=
data_dir
+
run_id
+
'/'
run_dir
=
data_dir
+
run_id
+
'/'
...
@@ -105,17 +118,17 @@ NETS = {
...
@@ -105,17 +118,17 @@ NETS = {
fc_params
=
config
.
FC_PARAMS
,
fc_params
=
config
.
FC_PARAMS
,
sampler_params
=
(
config
.
SAMPLE_PARAMS
.
update
(
sampler_params
=
(
config
.
SAMPLE_PARAMS
.
update
(
{
'spherical'
:
True
}),
config
.
SAMPLE_PARAMS
)[
1
],
{
'spherical'
:
True
}),
config
.
SAMPLE_PARAMS
)[
1
],
gray
=
config
.
GRAY
,
color
=
config
.
COLOR
,
encode_to_dim
=
config
.
N_ENCODE_DIM
),
encode_to_dim
=
config
.
N_ENCODE_DIM
),
'nerf'
:
lambda
:
MslNet
(
'nerf'
:
lambda
:
MslNet
(
fc_params
=
config
.
FC_PARAMS
,
fc_params
=
config
.
FC_PARAMS
,
sampler_params
=
(
config
.
SAMPLE_PARAMS
.
update
(
sampler_params
=
(
config
.
SAMPLE_PARAMS
.
update
(
{
'spherical'
:
False
}),
config
.
SAMPLE_PARAMS
)[
1
],
{
'spherical'
:
False
}),
config
.
SAMPLE_PARAMS
)[
1
],
gray
=
config
.
GRAY
,
color
=
config
.
COLOR
,
encode_to_dim
=
config
.
N_ENCODE_DIM
),
encode_to_dim
=
config
.
N_ENCODE_DIM
),
'spher'
:
lambda
:
SpherNet
(
'spher'
:
lambda
:
SpherNet
(
fc_params
=
config
.
FC_PARAMS
,
fc_params
=
config
.
FC_PARAMS
,
gray
=
config
.
GRAY
,
color
=
config
.
COLOR
,
translation
=
not
ROT_ONLY
,
translation
=
not
ROT_ONLY
,
encode_to_dim
=
config
.
N_ENCODE_DIM
)
encode_to_dim
=
config
.
N_ENCODE_DIM
)
}
}
...
@@ -146,7 +159,11 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
...
@@ -146,7 +159,11 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
perf
.
Checkpoint
(
"Forward"
)
perf
.
Checkpoint
(
"Forward"
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss_mse_value
=
loss_mse
(
out
,
gt
)
if
config
.
COLOR
==
color_mode
.
YCbCr
:
loss_mse_value
=
0.3
*
loss_mse
(
out
[...,
0
:
2
],
gt
[...,
0
:
2
])
+
\
0.7
*
loss_mse
(
out
[...,
2
],
gt
[...,
2
])
else
:
loss_mse_value
=
loss_mse
(
out
,
gt
)
loss_grad_value
=
loss_grad
(
out
,
gt
)
if
patch
else
None
loss_grad_value
=
loss_grad
(
out
,
gt
)
if
patch
else
None
loss_value
=
loss_mse_value
# + 0.5 * loss_grad_value if patch \
loss_value
=
loss_mse_value
# + 0.5 * loss_grad_value if patch \
# else loss_mse_value
# else loss_mse_value
...
@@ -183,7 +200,8 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
...
@@ -183,7 +200,8 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
def
train
():
def
train
():
# 1. Initialize data loader
# 1. Initialize data loader
print
(
"Load dataset: "
+
data_desc_path
)
print
(
"Load dataset: "
+
data_desc_path
)
train_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
gray
=
config
.
GRAY
)
train_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
color
=
config
.
COLOR
,
res
=
opt
.
res
)
train_dataset
.
set_patch_size
(
1
)
train_dataset
.
set_patch_size
(
1
)
train_data_loader
=
FastDataLoader
(
train_data_loader
=
FastDataLoader
(
dataset
=
train_dataset
,
dataset
=
train_dataset
,
...
@@ -194,7 +212,7 @@ def train():
...
@@ -194,7 +212,7 @@ def train():
# 2. Initialize components
# 2. Initialize components
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
loss
=
0
#
LOSSES[config.LOSS]().to(device.GetDevice())
loss
=
0
#
LOSSES[config.LOSS]().to(device.GetDevice())
if
EPOCH_RANGE
.
start
>
0
:
if
EPOCH_RANGE
.
start
>
0
:
iters
=
netio
.
LoadNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
EPOCH_RANGE
.
start
),
iters
=
netio
.
LoadNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
EPOCH_RANGE
.
start
),
...
@@ -223,15 +241,80 @@ def train():
...
@@ -223,15 +241,80 @@ def train():
netio
.
SaveNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
epoch
+
1
),
model
,
netio
.
SaveNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
epoch
+
1
),
model
,
solver
=
optimizer
,
iters
=
iters
)
solver
=
optimizer
,
iters
=
iters
)
print
(
"Train finished"
)
print
(
"Train finished"
)
netio
.
SaveNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
epoch
+
1
),
model
,
solver
=
optimizer
,
iters
=
iters
)
def
perf
():
with
torch
.
no_grad
():
# 1. Load dataset
print
(
"Load dataset: "
+
data_desc_path
)
test_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
load_images
=
True
,
color
=
config
.
COLOR
,
res
=
opt
.
res
)
test_data_loader
=
FastDataLoader
(
dataset
=
test_dataset
,
batch_size
=
1
,
shuffle
=
False
,
drop_last
=
False
,
pin_memory
=
True
)
# 2. Load trained model
netio
.
LoadNet
(
test_net_path
,
model
)
# 3. Test on dataset
print
(
"Begin perf, batch size is %d"
%
TEST_BATCH_SIZE
)
perf
=
SimplePerf
(
True
,
start
=
True
)
loss
=
nn
.
MSELoss
()
i
=
0
n
=
test_dataset
.
n_views
chns
=
1
if
config
.
COLOR
==
color_mode
.
GRAY
else
3
out_view_images
=
torch
.
empty
(
n
,
chns
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
device
=
device
.
GetDevice
())
perf_times
=
torch
.
empty
(
n
)
perf_errors
=
torch
.
empty
(
n
)
for
view_idxs
,
gt
,
rays_o
,
rays_d
in
test_data_loader
:
perf
.
Checkpoint
(
"%d - Load"
%
i
)
rays_o
=
rays_o
.
to
(
device
.
GetDevice
()).
view
(
-
1
,
3
)
rays_d
=
rays_d
.
to
(
device
.
GetDevice
()).
view
(
-
1
,
3
)
n_rays
=
rays_o
.
size
(
0
)
chunk_size
=
min
(
n_rays
,
TEST_MAX_RAYS
)
out_pixels
=
torch
.
empty
(
n_rays
,
chns
,
device
=
device
.
GetDevice
())
for
offset
in
range
(
0
,
n_rays
,
chunk_size
):
idx
=
slice
(
offset
,
offset
+
chunk_size
)
out_pixels
[
idx
]
=
model
(
rays_o
[
idx
],
rays_d
[
idx
])
if
config
.
COLOR
==
color_mode
.
YCbCr
:
out_pixels
=
util
.
ycbcr2rgb
(
out_pixels
)
out_view_images
[
view_idxs
]
=
out_pixels
.
view
(
TEST_BATCH_SIZE
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
perf_times
[
view_idxs
]
=
perf
.
Checkpoint
(
"%d - Infer"
%
i
)
if
config
.
COLOR
==
color_mode
.
YCbCr
:
gt
=
util
.
ycbcr2rgb
(
gt
)
error
=
loss
(
out_view_images
[
view_idxs
],
gt
).
item
()
print
(
"%d - Error: %f"
%
(
i
,
error
))
perf_errors
[
view_idxs
]
=
error
i
+=
1
# 4. Save results
perf_mean_time
=
torch
.
mean
(
perf_times
).
item
()
perf_mean_error
=
torch
.
mean
(
perf_errors
).
item
()
with
open
(
run_dir
+
'perf_%s_%s_%.1fms_%.2e.txt'
%
(
test_net_name
,
data_desc_name
,
perf_mean_time
,
perf_mean_error
),
'w'
)
as
fp
:
fp
.
write
(
'View, Time, Error
\n
'
)
fp
.
writelines
([
'%d, %f, %f
\n
'
%
(
i
,
perf_times
[
i
].
item
(),
perf_errors
[
i
].
item
())
for
i
in
range
(
n
)])
def
test
():
def
test
():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# 1. Load
train
dataset
# 1. Load dataset
print
(
"Load dataset: "
+
data_desc_path
)
print
(
"Load dataset: "
+
data_desc_path
)
test_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
test_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
load_images
=
opt
.
output_gt
or
opt
.
output_alongside
,
load_images
=
opt
.
output_gt
or
opt
.
output_alongside
,
gray
=
config
.
GRAY
)
color
=
config
.
COLOR
,
res
=
opt
.
res
)
test_data_loader
=
FastDataLoader
(
test_data_loader
=
FastDataLoader
(
dataset
=
test_dataset
,
dataset
=
test_dataset
,
batch_size
=
1
,
batch_size
=
1
,
...
@@ -242,14 +325,14 @@ def test():
...
@@ -242,14 +325,14 @@ def test():
# 2. Load trained model
# 2. Load trained model
netio
.
LoadNet
(
test_net_path
,
model
)
netio
.
LoadNet
(
test_net_path
,
model
)
# 3. Test on
train
dataset
# 3. Test on dataset
print
(
"Begin test
on train dataset
, batch size is %d"
%
TEST_BATCH_SIZE
)
print
(
"Begin test, batch size is %d"
%
TEST_BATCH_SIZE
)
util
.
CreateDirIfNeed
(
output_dir
)
util
.
CreateDirIfNeed
(
output_dir
)
perf
=
SimplePerf
(
True
,
start
=
True
)
perf
=
SimplePerf
(
True
,
start
=
True
)
i
=
0
i
=
0
n
=
test_dataset
.
n_views
n
=
test_dataset
.
n_views
chns
=
1
if
config
.
GRAY
else
3
chns
=
1
if
config
.
COLOR
==
color_mode
.
GRAY
else
3
out_view_images
=
torch
.
empty
(
n
,
chns
,
test_dataset
.
view_res
[
0
],
out_view_images
=
torch
.
empty
(
n
,
chns
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
test_dataset
.
view_res
[
1
],
device
=
device
.
GetDevice
())
device
=
device
.
GetDevice
())
...
@@ -263,6 +346,8 @@ def test():
...
@@ -263,6 +346,8 @@ def test():
for
offset
in
range
(
0
,
n_rays
,
chunk_size
):
for
offset
in
range
(
0
,
n_rays
,
chunk_size
):
idx
=
slice
(
offset
,
offset
+
chunk_size
)
idx
=
slice
(
offset
,
offset
+
chunk_size
)
out_pixels
[
idx
]
=
model
(
rays_o
[
idx
],
rays_d
[
idx
])
out_pixels
[
idx
]
=
model
(
rays_o
[
idx
],
rays_d
[
idx
])
if
config
.
COLOR
==
color_mode
.
YCbCr
:
out_pixels
=
util
.
ycbcr2rgb
(
out_pixels
)
out_view_images
[
view_idxs
]
=
out_pixels
.
view
(
out_view_images
[
view_idxs
]
=
out_pixels
.
view
(
TEST_BATCH_SIZE
,
test_dataset
.
view_res
[
0
],
TEST_BATCH_SIZE
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
test_dataset
.
view_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
...
@@ -297,5 +382,7 @@ def test():
...
@@ -297,5 +382,7 @@ def test():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
if
train_mode
:
if
train_mode
:
train
()
train
()
elif
opt
.
perf
:
perf
()
else
:
else
:
test
()
test
()
run_upsampling.py
View file @
3554ba52
...
@@ -4,11 +4,12 @@ import argparse
...
@@ -4,11 +4,12 @@ import argparse
import
os
import
os
import
sys
import
sys
import
torch
import
torch
import
torch.nn.functional
as
nn_f
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tensorboardX.writer
import
SummaryWriter
from
tensorboardX.writer
import
SummaryWriter
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deep
lightfield
"
__package__
=
"deep
_view_syn
"
# ===========================================================
# ===========================================================
# Training settings
# Training settings
...
@@ -31,6 +32,8 @@ parser.add_argument('--dataset', type=str, required=True,
...
@@ -31,6 +32,8 @@ parser.add_argument('--dataset', type=str, required=True,
help
=
'dataset directory'
)
help
=
'dataset directory'
)
parser
.
add_argument
(
'--test'
,
type
=
str
,
help
=
'path of model to test'
)
parser
.
add_argument
(
'--test'
,
type
=
str
,
help
=
'path of model to test'
)
parser
.
add_argument
(
'--testOutPatt'
,
type
=
str
,
help
=
'test output path pattern'
)
parser
.
add_argument
(
'--testOutPatt'
,
type
=
str
,
help
=
'test output path pattern'
)
parser
.
add_argument
(
'--color'
,
type
=
str
,
default
=
'rgb'
,
help
=
'color'
)
# model configuration
# model configuration
parser
.
add_argument
(
'--upscale_factor'
,
'-uf'
,
type
=
int
,
parser
.
add_argument
(
'--upscale_factor'
,
'-uf'
,
type
=
int
,
...
@@ -46,51 +49,57 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
...
@@ -46,51 +49,57 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from
.my
import
util
from
.my
import
util
from
.my
import
netio
from
.my
import
netio
from
.my
import
device
from
.my
import
device
from
.SRGAN.solver
import
SRGANTrainer
as
Solver
from
.my
import
color_mode
#from .upsampling.SubPixelCNN.solver import SubPixelTrainer as Solver
from
.upsampling.SRCNN.solver
import
SRCNNTrainer
as
Solver
from
.data.upsampling
import
UpsamplingDataset
from
.data.upsampling
import
UpsamplingDataset
from
.data.loader
import
FastDataLoader
from
.data.loader
import
FastDataLoader
os
.
chdir
(
args
.
dataset
)
os
.
chdir
(
args
.
dataset
)
print
(
'Change working directory to '
+
os
.
getcwd
())
print
(
'Change working directory to '
+
os
.
getcwd
())
run_dir
=
'run/'
run_dir
=
'run/'
args
.
color
=
color_mode
.
from_str
(
args
.
color
)
def
train
():
def
train
():
util
.
CreateDirIfNeed
(
run_dir
)
util
.
CreateDirIfNeed
(
run_dir
)
train_set
=
UpsamplingDataset
(
'.'
,
'out_view_%04d.png'
,
train_set
=
UpsamplingDataset
(
'.'
,
'
input/
out_view_%04d.png'
,
'gt
_
view_%04d.png'
,
gray
=
True
)
'gt
/
view_%04d.png'
,
color
=
args
.
color
)
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
batch_size
=
args
.
batchSize
,
batch_size
=
args
.
batchSize
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
)
drop_last
=
False
)
trainer
=
Solver
(
args
,
training_data_loader
,
training_data_loader
,
trainer
=
Solver
(
args
,
training_data_loader
,
training_data_loader
,
SummaryWriter
(
run_dir
))
SummaryWriter
(
run_dir
))
trainer
.
build_model
()
trainer
.
build_model
(
3
if
args
.
color
==
color_mode
.
RGB
else
1
)
# ===
for
epoch
in
range
(
1
,
20
+
1
):
trainer
.
pretrain
()
print
(
"{}/{} pretrained"
.
format
(
epoch
,
trainer
.
epoch_pretrain
))
# ===
iters
=
0
iters
=
0
for
epoch
in
range
(
1
,
args
.
nEpochs
+
1
):
for
epoch
in
range
(
1
,
args
.
nEpochs
+
1
):
print
(
"
\n
===> Epoch {} starts:"
.
format
(
epoch
))
print
(
"
\n
===> Epoch {} starts:"
.
format
(
epoch
))
iters
=
trainer
.
train
(
epoch
,
iters
)
iters
=
trainer
.
train
(
epoch
,
iters
,
netio
.
SaveNet
(
run_dir
+
'model-epoch_%d.pth'
%
args
.
nEpochs
,
trainer
.
netG
)
channels
=
slice
(
2
,
3
)
if
args
.
color
==
color_mode
.
YCbCr
else
None
)
netio
.
SaveNet
(
run_dir
+
'model-epoch_%d.pth'
%
args
.
nEpochs
,
trainer
.
model
)
def
test
():
def
test
():
util
.
CreateDirIfNeed
(
os
.
path
.
dirname
(
args
.
testOutPatt
))
util
.
CreateDirIfNeed
(
os
.
path
.
dirname
(
args
.
testOutPatt
))
train_set
=
UpsamplingDataset
(
'.'
,
'out_view_%04d.png'
,
None
,
gray
=
True
)
train_set
=
UpsamplingDataset
(
'.'
,
'input/out_view_%04d.png'
,
None
,
color
=
args
.
color
)
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
batch_size
=
args
.
testBatchSize
,
batch_size
=
args
.
testBatchSize
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
)
drop_last
=
False
)
trainer
=
Solver
(
args
,
training_data_loader
,
training_data_loader
,
trainer
=
Solver
(
args
,
training_data_loader
,
training_data_loader
,
SummaryWriter
(
run_dir
))
SummaryWriter
(
run_dir
))
trainer
.
build_model
()
trainer
.
build_model
(
3
if
args
.
color
==
color_mode
.
RGB
else
1
)
netio
.
LoadNet
(
args
.
test
,
trainer
.
netG
)
netio
.
LoadNet
(
args
.
test
,
trainer
.
model
)
for
idx
,
input
,
_
in
training_data_loader
:
for
idx
,
input
,
_
in
training_data_loader
:
output
=
trainer
.
netG
(
input
)
if
args
.
color
==
color_mode
.
YCbCr
:
output_y
=
trainer
.
model
(
input
[:,
-
1
:])
output_cbcr
=
nn_f
.
upsample
(
input
[:,
0
:
2
],
scale_factor
=
2
)
output
=
util
.
ycbcr2rgb
(
torch
.
cat
([
output_cbcr
,
output_y
],
-
3
))
else
:
output
=
trainer
.
model
(
input
)
util
.
WriteImageTensor
(
output
,
args
.
testOutPatt
%
idx
)
util
.
WriteImageTensor
(
output
,
args
.
testOutPatt
%
idx
)
...
...
FSRCNN/README.md
→
upsampling/
FSRCNN/README.md
View file @
3554ba52
File moved
FSRCNN/model.py
→
upsampling/
FSRCNN/model.py
View file @
3554ba52
File moved
FSRCNN/solver.py
→
upsampling/
FSRCNN/solver.py
View file @
3554ba52
File moved
SRCNN/README.md
→
upsampling/
SRCNN/README.md
View file @
3554ba52
File moved
SRCNN/model.py
→
upsampling/
SRCNN/model.py
View file @
3554ba52
File moved
SRCNN/solver.py
→
upsampling/
SRCNN/solver.py
View file @
3554ba52
...
@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
...
@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import
torchvision
import
torchvision
from
.model
import
Net
from
.model
import
Net
from
..
my.progress_bar
import
progress_bar
from
my.progress_bar
import
progress_bar
class
SRCNNTrainer
(
object
):
class
SRCNNTrainer
(
object
):
...
@@ -28,8 +28,8 @@ class SRCNNTrainer(object):
...
@@ -28,8 +28,8 @@ class SRCNNTrainer(object):
self
.
testing_loader
=
testing_loader
self
.
testing_loader
=
testing_loader
self
.
writer
=
writer
self
.
writer
=
writer
def
build_model
(
self
):
def
build_model
(
self
,
num_channels
):
self
.
model
=
Net
(
num_channels
=
1
,
base_filter
=
64
,
upscale_factor
=
self
.
upscale_factor
).
to
(
self
.
device
)
self
.
model
=
Net
(
num_channels
=
num_channels
,
base_filter
=
64
,
upscale_factor
=
self
.
upscale_factor
).
to
(
self
.
device
)
self
.
model
.
weight_init
(
mean
=
0.0
,
std
=
0.01
)
self
.
model
.
weight_init
(
mean
=
0.0
,
std
=
0.01
)
self
.
criterion
=
torch
.
nn
.
MSELoss
()
self
.
criterion
=
torch
.
nn
.
MSELoss
()
torch
.
manual_seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
...
@@ -47,11 +47,15 @@ class SRCNNTrainer(object):
...
@@ -47,11 +47,15 @@ class SRCNNTrainer(object):
torch
.
save
(
self
.
model
,
model_out_path
)
torch
.
save
(
self
.
model
,
model_out_path
)
print
(
"Checkpoint saved to {}"
.
format
(
model_out_path
))
print
(
"Checkpoint saved to {}"
.
format
(
model_out_path
))
def
train
(
self
,
epoch
,
iters
):
def
train
(
self
,
epoch
,
iters
,
channels
=
None
):
self
.
model
.
train
()
self
.
model
.
train
()
train_loss
=
0
train_loss
=
0
for
batch_num
,
(
_
,
data
,
target
)
in
enumerate
(
self
.
training_loader
):
for
batch_num
,
(
_
,
data
,
target
)
in
enumerate
(
self
.
training_loader
):
data
,
target
=
data
.
to
(
self
.
device
),
target
.
to
(
self
.
device
)
if
channels
:
data
=
data
[...,
channels
,
:,
:]
target
=
target
[...,
channels
,
:,
:]
data
=
data
.
to
(
self
.
device
)
target
=
target
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
out
=
self
.
model
(
data
)
out
=
self
.
model
(
data
)
loss
=
self
.
criterion
(
out
,
target
)
loss
=
self
.
criterion
(
out
,
target
)
...
...
SRGAN/README.md
→
upsampling/
SRGAN/README.md
View file @
3554ba52
File moved
SRGAN/model.py
→
upsampling/
SRGAN/model.py
View file @
3554ba52
File moved
SRGAN/solver.py
→
upsampling/
SRGAN/solver.py
View file @
3554ba52
File moved
SubPixelCNN/model.py
→
upsampling/
SubPixelCNN/model.py
View file @
3554ba52
File moved
SubPixelCNN/solver.py
→
upsampling/
SubPixelCNN/solver.py
View file @
3554ba52
...
@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
...
@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import
torchvision
import
torchvision
from
.model
import
Net
from
.model
import
Net
from
..
my.progress_bar
import
progress_bar
from
my.progress_bar
import
progress_bar
class
SubPixelTrainer
(
object
):
class
SubPixelTrainer
(
object
):
...
@@ -28,7 +28,9 @@ class SubPixelTrainer(object):
...
@@ -28,7 +28,9 @@ class SubPixelTrainer(object):
self
.
testing_loader
=
testing_loader
self
.
testing_loader
=
testing_loader
self
.
writer
=
writer
self
.
writer
=
writer
def
build_model
(
self
):
def
build_model
(
self
,
num_channels
):
if
num_channels
!=
1
:
raise
ValueError
(
'num_channels must be 1'
)
self
.
model
=
Net
(
upscale_factor
=
self
.
upscale_factor
).
to
(
self
.
device
)
self
.
model
=
Net
(
upscale_factor
=
self
.
upscale_factor
).
to
(
self
.
device
)
self
.
criterion
=
torch
.
nn
.
MSELoss
()
self
.
criterion
=
torch
.
nn
.
MSELoss
()
torch
.
manual_seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
...
@@ -39,17 +41,21 @@ class SubPixelTrainer(object):
...
@@ -39,17 +41,21 @@ class SubPixelTrainer(object):
self
.
criterion
.
cuda
()
self
.
criterion
.
cuda
()
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
self
.
lr
)
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
self
.
lr
)
self
.
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
self
.
optimizer
,
milestones
=
[
50
,
75
,
100
],
gamma
=
0.5
)
# lr decay
self
.
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
self
.
optimizer
,
milestones
=
[
50
,
75
,
100
],
gamma
=
0.5
)
# lr decay
def
save
(
self
):
def
save
(
self
):
model_out_path
=
"model_path.pth"
model_out_path
=
"model_path.pth"
torch
.
save
(
self
.
model
,
model_out_path
)
torch
.
save
(
self
.
model
,
model_out_path
)
print
(
"Checkpoint saved to {}"
.
format
(
model_out_path
))
print
(
"Checkpoint saved to {}"
.
format
(
model_out_path
))
def
train
(
self
,
epoch
,
iters
):
def
train
(
self
,
epoch
,
iters
,
channels
=
None
):
self
.
model
.
train
()
self
.
model
.
train
()
train_loss
=
0
train_loss
=
0
for
batch_num
,
(
_
,
data
,
target
)
in
enumerate
(
self
.
training_loader
):
for
batch_num
,
(
_
,
data
,
target
)
in
enumerate
(
self
.
training_loader
):
if
channels
:
data
=
data
[...,
channels
,
:,
:]
target
=
target
[...,
channels
,
:,
:]
data
,
target
=
data
.
to
(
self
.
device
),
target
.
to
(
self
.
device
)
data
,
target
=
data
.
to
(
self
.
device
),
target
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
out
=
self
.
model
(
data
)
out
=
self
.
model
(
data
)
...
@@ -58,7 +64,8 @@ class SubPixelTrainer(object):
...
@@ -58,7 +64,8 @@ class SubPixelTrainer(object):
loss
.
backward
()
loss
.
backward
()
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
sys
.
stdout
.
write
(
'Epoch %d: '
%
epoch
)
sys
.
stdout
.
write
(
'Epoch %d: '
%
epoch
)
progress_bar
(
batch_num
,
len
(
self
.
training_loader
),
'Loss: %.4f'
%
(
train_loss
/
(
batch_num
+
1
)))
progress_bar
(
batch_num
,
len
(
self
.
training_loader
),
'Loss: %.4f'
%
(
train_loss
/
(
batch_num
+
1
)))
if
self
.
writer
:
if
self
.
writer
:
self
.
writer
.
add_scalar
(
"loss"
,
loss
,
iters
)
self
.
writer
.
add_scalar
(
"loss"
,
loss
,
iters
)
if
iters
%
100
==
0
:
if
iters
%
100
==
0
:
...
@@ -66,11 +73,13 @@ class SubPixelTrainer(object):
...
@@ -66,11 +73,13 @@ class SubPixelTrainer(object):
.
flatten
(
0
,
1
).
detach
()
.
flatten
(
0
,
1
).
detach
()
self
.
writer
.
add_image
(
self
.
writer
.
add_image
(
"Output_vs_gt"
,
"Output_vs_gt"
,
torchvision
.
utils
.
make_grid
(
output_vs_gt
,
nrow
=
2
).
cpu
().
numpy
(),
torchvision
.
utils
.
make_grid
(
output_vs_gt
,
nrow
=
2
).
cpu
().
numpy
(),
iters
)
iters
)
iters
+=
1
iters
+=
1
print
(
" Average Loss: {:.4f}"
.
format
(
train_loss
/
len
(
self
.
training_loader
)))
print
(
" Average Loss: {:.4f}"
.
format
(
train_loss
/
len
(
self
.
training_loader
)))
return
iters
return
iters
def
test
(
self
):
def
test
(
self
):
...
@@ -84,9 +93,11 @@ class SubPixelTrainer(object):
...
@@ -84,9 +93,11 @@ class SubPixelTrainer(object):
mse
=
self
.
criterion
(
prediction
,
target
)
mse
=
self
.
criterion
(
prediction
,
target
)
psnr
=
10
*
log10
(
1
/
mse
.
item
())
psnr
=
10
*
log10
(
1
/
mse
.
item
())
avg_psnr
+=
psnr
avg_psnr
+=
psnr
progress_bar
(
batch_num
,
len
(
self
.
testing_loader
),
'PSNR: %.4f'
%
(
avg_psnr
/
(
batch_num
+
1
)))
progress_bar
(
batch_num
,
len
(
self
.
testing_loader
),
'PSNR: %.4f'
%
(
avg_psnr
/
(
batch_num
+
1
)))
print
(
" Average PSNR: {:.4f} dB"
.
format
(
avg_psnr
/
len
(
self
.
testing_loader
)))
print
(
" Average PSNR: {:.4f} dB"
.
format
(
avg_psnr
/
len
(
self
.
testing_loader
)))
def
run
(
self
):
def
run
(
self
):
self
.
build_model
()
self
.
build_model
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment