Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
3554ba52
Commit
3554ba52
authored
Jan 12, 2021
by
Nianchen Deng
Browse files
sync
parent
f7038e26
Changes
53
Show whitespace changes
Inline
Side-by-side
run_spherical_view_syn.py
View file @
3554ba52
...
...
@@ -5,31 +5,41 @@ import argparse
import
torch
import
torch.optim
import
torchvision
import
numpy
as
np
from
tensorboardX
import
SummaryWriter
from
torch
import
nn
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deep
lightfield
"
__package__
=
"deep
_view_syn
"
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=
3
,
help
=
'Which CUDA device to use.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Net config files'
)
parser
.
add_argument
(
'--config-id'
,
type
=
str
,
help
=
'Net config id'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
required
=
True
,
help
=
'Dataset description file'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
help
=
'Max epochs for train'
)
parser
.
add_argument
(
'--test'
,
type
=
str
,
help
=
'Test net file'
)
parser
.
add_argument
(
'--test-samples'
,
type
=
int
,
help
=
'Samples used for test'
)
parser
.
add_argument
(
'--res'
,
type
=
str
,
help
=
'Resolution'
)
parser
.
add_argument
(
'--output-gt'
,
action
=
'store_true'
,
help
=
'Output ground truth images if exist'
)
parser
.
add_argument
(
'--output-alongside'
,
action
=
'store_true'
,
help
=
'Output generated image alongside ground truth image'
)
parser
.
add_argument
(
'--output-video'
,
action
=
'store_true'
,
help
=
'Output test results as video'
)
parser
.
add_argument
(
'--perf'
,
action
=
'store_true'
,
help
=
'Test performance'
)
opt
=
parser
.
parse_args
()
if
opt
.
res
:
opt
.
res
=
tuple
(
int
(
s
)
for
s
in
opt
.
res
.
split
(
'x'
))
# Select device
torch
.
cuda
.
set_device
(
opt
.
device
)
...
...
@@ -58,8 +68,8 @@ EVAL_TIME_PERFORMANCE = False
# Train
BATCH_SIZE
=
4096
EPOCH_RANGE
=
range
(
0
,
500
)
SAVE_INTERVAL
=
2
0
EPOCH_RANGE
=
range
(
0
,
opt
.
epochs
if
opt
.
epochs
else
500
)
SAVE_INTERVAL
=
5
0
# Test
TEST_BATCH_SIZE
=
1
...
...
@@ -67,13 +77,14 @@ TEST_MAX_RAYS = 32768
# Paths
data_desc_path
=
opt
.
dataset
data_desc_name
=
os
.
path
.
split
(
data_desc_path
)
[
1
]
data_desc_name
=
os
.
path
.
split
ext
(
os
.
path
.
basename
(
data_desc_path
)
)[
0
]
if
opt
.
test
:
test_net_path
=
opt
.
test
test_net_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
test_net_path
))[
0
]
run_dir
=
os
.
path
.
dirname
(
test_net_path
)
+
'/'
run_id
=
os
.
path
.
basename
(
run_dir
[:
-
1
])
output_dir
=
run_dir
+
'output/%s/%s/'
%
(
test_net_name
,
data_desc_name
)
output_dir
=
run_dir
+
'output/%s/%s%s/'
%
(
test_net_name
,
data_desc_name
,
'_%dx%d'
%
(
opt
.
res
[
0
],
opt
.
res
[
1
])
if
opt
.
res
else
''
)
config
.
from_id
(
run_id
)
train_mode
=
False
if
opt
.
test_samples
:
...
...
@@ -83,6 +94,8 @@ if opt.test:
else
:
if
opt
.
config
:
config
.
load
(
opt
.
config
)
if
opt
.
config_id
:
config
.
from_id
(
opt
.
config_id
)
data_dir
=
os
.
path
.
dirname
(
data_desc_path
)
+
'/'
run_id
=
config
.
to_id
()
run_dir
=
data_dir
+
run_id
+
'/'
...
...
@@ -105,17 +118,17 @@ NETS = {
fc_params
=
config
.
FC_PARAMS
,
sampler_params
=
(
config
.
SAMPLE_PARAMS
.
update
(
{
'spherical'
:
True
}),
config
.
SAMPLE_PARAMS
)[
1
],
gray
=
config
.
GRAY
,
color
=
config
.
COLOR
,
encode_to_dim
=
config
.
N_ENCODE_DIM
),
'nerf'
:
lambda
:
MslNet
(
fc_params
=
config
.
FC_PARAMS
,
sampler_params
=
(
config
.
SAMPLE_PARAMS
.
update
(
{
'spherical'
:
False
}),
config
.
SAMPLE_PARAMS
)[
1
],
gray
=
config
.
GRAY
,
color
=
config
.
COLOR
,
encode_to_dim
=
config
.
N_ENCODE_DIM
),
'spher'
:
lambda
:
SpherNet
(
fc_params
=
config
.
FC_PARAMS
,
gray
=
config
.
GRAY
,
color
=
config
.
COLOR
,
translation
=
not
ROT_ONLY
,
encode_to_dim
=
config
.
N_ENCODE_DIM
)
}
...
...
@@ -146,6 +159,10 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
perf
.
Checkpoint
(
"Forward"
)
optimizer
.
zero_grad
()
if
config
.
COLOR
==
color_mode
.
YCbCr
:
loss_mse_value
=
0.3
*
loss_mse
(
out
[...,
0
:
2
],
gt
[...,
0
:
2
])
+
\
0.7
*
loss_mse
(
out
[...,
2
],
gt
[...,
2
])
else
:
loss_mse_value
=
loss_mse
(
out
,
gt
)
loss_grad_value
=
loss_grad
(
out
,
gt
)
if
patch
else
None
loss_value
=
loss_mse_value
# + 0.5 * loss_grad_value if patch \
...
...
@@ -183,7 +200,8 @@ def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
def
train
():
# 1. Initialize data loader
print
(
"Load dataset: "
+
data_desc_path
)
train_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
gray
=
config
.
GRAY
)
train_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
color
=
config
.
COLOR
,
res
=
opt
.
res
)
train_dataset
.
set_patch_size
(
1
)
train_data_loader
=
FastDataLoader
(
dataset
=
train_dataset
,
...
...
@@ -194,7 +212,7 @@ def train():
# 2. Initialize components
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
loss
=
0
#
LOSSES[config.LOSS]().to(device.GetDevice())
loss
=
0
#
LOSSES[config.LOSS]().to(device.GetDevice())
if
EPOCH_RANGE
.
start
>
0
:
iters
=
netio
.
LoadNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
EPOCH_RANGE
.
start
),
...
...
@@ -223,15 +241,80 @@ def train():
netio
.
SaveNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
epoch
+
1
),
model
,
solver
=
optimizer
,
iters
=
iters
)
print
(
"Train finished"
)
netio
.
SaveNet
(
'%smodel-epoch_%d.pth'
%
(
run_dir
,
epoch
+
1
),
model
,
solver
=
optimizer
,
iters
=
iters
)
def
perf
():
with
torch
.
no_grad
():
# 1. Load dataset
print
(
"Load dataset: "
+
data_desc_path
)
test_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
load_images
=
True
,
color
=
config
.
COLOR
,
res
=
opt
.
res
)
test_data_loader
=
FastDataLoader
(
dataset
=
test_dataset
,
batch_size
=
1
,
shuffle
=
False
,
drop_last
=
False
,
pin_memory
=
True
)
# 2. Load trained model
netio
.
LoadNet
(
test_net_path
,
model
)
# 3. Test on dataset
print
(
"Begin perf, batch size is %d"
%
TEST_BATCH_SIZE
)
perf
=
SimplePerf
(
True
,
start
=
True
)
loss
=
nn
.
MSELoss
()
i
=
0
n
=
test_dataset
.
n_views
chns
=
1
if
config
.
COLOR
==
color_mode
.
GRAY
else
3
out_view_images
=
torch
.
empty
(
n
,
chns
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
device
=
device
.
GetDevice
())
perf_times
=
torch
.
empty
(
n
)
perf_errors
=
torch
.
empty
(
n
)
for
view_idxs
,
gt
,
rays_o
,
rays_d
in
test_data_loader
:
perf
.
Checkpoint
(
"%d - Load"
%
i
)
rays_o
=
rays_o
.
to
(
device
.
GetDevice
()).
view
(
-
1
,
3
)
rays_d
=
rays_d
.
to
(
device
.
GetDevice
()).
view
(
-
1
,
3
)
n_rays
=
rays_o
.
size
(
0
)
chunk_size
=
min
(
n_rays
,
TEST_MAX_RAYS
)
out_pixels
=
torch
.
empty
(
n_rays
,
chns
,
device
=
device
.
GetDevice
())
for
offset
in
range
(
0
,
n_rays
,
chunk_size
):
idx
=
slice
(
offset
,
offset
+
chunk_size
)
out_pixels
[
idx
]
=
model
(
rays_o
[
idx
],
rays_d
[
idx
])
if
config
.
COLOR
==
color_mode
.
YCbCr
:
out_pixels
=
util
.
ycbcr2rgb
(
out_pixels
)
out_view_images
[
view_idxs
]
=
out_pixels
.
view
(
TEST_BATCH_SIZE
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
perf_times
[
view_idxs
]
=
perf
.
Checkpoint
(
"%d - Infer"
%
i
)
if
config
.
COLOR
==
color_mode
.
YCbCr
:
gt
=
util
.
ycbcr2rgb
(
gt
)
error
=
loss
(
out_view_images
[
view_idxs
],
gt
).
item
()
print
(
"%d - Error: %f"
%
(
i
,
error
))
perf_errors
[
view_idxs
]
=
error
i
+=
1
# 4. Save results
perf_mean_time
=
torch
.
mean
(
perf_times
).
item
()
perf_mean_error
=
torch
.
mean
(
perf_errors
).
item
()
with
open
(
run_dir
+
'perf_%s_%s_%.1fms_%.2e.txt'
%
(
test_net_name
,
data_desc_name
,
perf_mean_time
,
perf_mean_error
),
'w'
)
as
fp
:
fp
.
write
(
'View, Time, Error
\n
'
)
fp
.
writelines
([
'%d, %f, %f
\n
'
%
(
i
,
perf_times
[
i
].
item
(),
perf_errors
[
i
].
item
())
for
i
in
range
(
n
)])
def
test
():
with
torch
.
no_grad
():
# 1. Load
train
dataset
# 1. Load dataset
print
(
"Load dataset: "
+
data_desc_path
)
test_dataset
=
SphericalViewSynDataset
(
data_desc_path
,
load_images
=
opt
.
output_gt
or
opt
.
output_alongside
,
gray
=
config
.
GRAY
)
color
=
config
.
COLOR
,
res
=
opt
.
res
)
test_data_loader
=
FastDataLoader
(
dataset
=
test_dataset
,
batch_size
=
1
,
...
...
@@ -242,14 +325,14 @@ def test():
# 2. Load trained model
netio
.
LoadNet
(
test_net_path
,
model
)
# 3. Test on
train
dataset
print
(
"Begin test
on train dataset
, batch size is %d"
%
TEST_BATCH_SIZE
)
# 3. Test on dataset
print
(
"Begin test, batch size is %d"
%
TEST_BATCH_SIZE
)
util
.
CreateDirIfNeed
(
output_dir
)
perf
=
SimplePerf
(
True
,
start
=
True
)
i
=
0
n
=
test_dataset
.
n_views
chns
=
1
if
config
.
GRAY
else
3
chns
=
1
if
config
.
COLOR
==
color_mode
.
GRAY
else
3
out_view_images
=
torch
.
empty
(
n
,
chns
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
device
=
device
.
GetDevice
())
...
...
@@ -263,6 +346,8 @@ def test():
for
offset
in
range
(
0
,
n_rays
,
chunk_size
):
idx
=
slice
(
offset
,
offset
+
chunk_size
)
out_pixels
[
idx
]
=
model
(
rays_o
[
idx
],
rays_d
[
idx
])
if
config
.
COLOR
==
color_mode
.
YCbCr
:
out_pixels
=
util
.
ycbcr2rgb
(
out_pixels
)
out_view_images
[
view_idxs
]
=
out_pixels
.
view
(
TEST_BATCH_SIZE
,
test_dataset
.
view_res
[
0
],
test_dataset
.
view_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
...
...
@@ -297,5 +382,7 @@ def test():
if
__name__
==
"__main__"
:
if
train_mode
:
train
()
elif
opt
.
perf
:
perf
()
else
:
test
()
run_upsampling.py
View file @
3554ba52
...
...
@@ -4,11 +4,12 @@ import argparse
import
os
import
sys
import
torch
import
torch.nn.functional
as
nn_f
from
torch.utils.data
import
DataLoader
from
tensorboardX.writer
import
SummaryWriter
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deep
lightfield
"
__package__
=
"deep
_view_syn
"
# ===========================================================
# Training settings
...
...
@@ -31,6 +32,8 @@ parser.add_argument('--dataset', type=str, required=True,
help
=
'dataset directory'
)
parser
.
add_argument
(
'--test'
,
type
=
str
,
help
=
'path of model to test'
)
parser
.
add_argument
(
'--testOutPatt'
,
type
=
str
,
help
=
'test output path pattern'
)
parser
.
add_argument
(
'--color'
,
type
=
str
,
default
=
'rgb'
,
help
=
'color'
)
# model configuration
parser
.
add_argument
(
'--upscale_factor'
,
'-uf'
,
type
=
int
,
...
...
@@ -46,51 +49,57 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from
.my
import
util
from
.my
import
netio
from
.my
import
device
from
.SRGAN.solver
import
SRGANTrainer
as
Solver
from
.my
import
color_mode
#from .upsampling.SubPixelCNN.solver import SubPixelTrainer as Solver
from
.upsampling.SRCNN.solver
import
SRCNNTrainer
as
Solver
from
.data.upsampling
import
UpsamplingDataset
from
.data.loader
import
FastDataLoader
os
.
chdir
(
args
.
dataset
)
print
(
'Change working directory to '
+
os
.
getcwd
())
run_dir
=
'run/'
args
.
color
=
color_mode
.
from_str
(
args
.
color
)
def
train
():
util
.
CreateDirIfNeed
(
run_dir
)
train_set
=
UpsamplingDataset
(
'.'
,
'out_view_%04d.png'
,
'gt
_
view_%04d.png'
,
gray
=
True
)
train_set
=
UpsamplingDataset
(
'.'
,
'
input/
out_view_%04d.png'
,
'gt
/
view_%04d.png'
,
color
=
args
.
color
)
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
batch_size
=
args
.
batchSize
,
shuffle
=
True
,
drop_last
=
False
)
trainer
=
Solver
(
args
,
training_data_loader
,
training_data_loader
,
SummaryWriter
(
run_dir
))
trainer
.
build_model
()
# ===
for
epoch
in
range
(
1
,
20
+
1
):
trainer
.
pretrain
()
print
(
"{}/{} pretrained"
.
format
(
epoch
,
trainer
.
epoch_pretrain
))
# ===
trainer
.
build_model
(
3
if
args
.
color
==
color_mode
.
RGB
else
1
)
iters
=
0
for
epoch
in
range
(
1
,
args
.
nEpochs
+
1
):
print
(
"
\n
===> Epoch {} starts:"
.
format
(
epoch
))
iters
=
trainer
.
train
(
epoch
,
iters
)
netio
.
SaveNet
(
run_dir
+
'model-epoch_%d.pth'
%
args
.
nEpochs
,
trainer
.
netG
)
iters
=
trainer
.
train
(
epoch
,
iters
,
channels
=
slice
(
2
,
3
)
if
args
.
color
==
color_mode
.
YCbCr
else
None
)
netio
.
SaveNet
(
run_dir
+
'model-epoch_%d.pth'
%
args
.
nEpochs
,
trainer
.
model
)
def
test
():
util
.
CreateDirIfNeed
(
os
.
path
.
dirname
(
args
.
testOutPatt
))
train_set
=
UpsamplingDataset
(
'.'
,
'out_view_%04d.png'
,
None
,
gray
=
True
)
train_set
=
UpsamplingDataset
(
'.'
,
'input/out_view_%04d.png'
,
None
,
color
=
args
.
color
)
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
batch_size
=
args
.
testBatchSize
,
shuffle
=
False
,
drop_last
=
False
)
trainer
=
Solver
(
args
,
training_data_loader
,
training_data_loader
,
SummaryWriter
(
run_dir
))
trainer
.
build_model
()
netio
.
LoadNet
(
args
.
test
,
trainer
.
netG
)
trainer
.
build_model
(
3
if
args
.
color
==
color_mode
.
RGB
else
1
)
netio
.
LoadNet
(
args
.
test
,
trainer
.
model
)
for
idx
,
input
,
_
in
training_data_loader
:
output
=
trainer
.
netG
(
input
)
if
args
.
color
==
color_mode
.
YCbCr
:
output_y
=
trainer
.
model
(
input
[:,
-
1
:])
output_cbcr
=
nn_f
.
upsample
(
input
[:,
0
:
2
],
scale_factor
=
2
)
output
=
util
.
ycbcr2rgb
(
torch
.
cat
([
output_cbcr
,
output_y
],
-
3
))
else
:
output
=
trainer
.
model
(
input
)
util
.
WriteImageTensor
(
output
,
args
.
testOutPatt
%
idx
)
...
...
FSRCNN/README.md
→
upsampling/
FSRCNN/README.md
View file @
3554ba52
File moved
FSRCNN/model.py
→
upsampling/
FSRCNN/model.py
View file @
3554ba52
File moved
FSRCNN/solver.py
→
upsampling/
FSRCNN/solver.py
View file @
3554ba52
File moved
SRCNN/README.md
→
upsampling/
SRCNN/README.md
View file @
3554ba52
File moved
SRCNN/model.py
→
upsampling/
SRCNN/model.py
View file @
3554ba52
File moved
SRCNN/solver.py
→
upsampling/
SRCNN/solver.py
View file @
3554ba52
...
...
@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import
torchvision
from
.model
import
Net
from
..
my.progress_bar
import
progress_bar
from
my.progress_bar
import
progress_bar
class
SRCNNTrainer
(
object
):
...
...
@@ -28,8 +28,8 @@ class SRCNNTrainer(object):
self
.
testing_loader
=
testing_loader
self
.
writer
=
writer
def
build_model
(
self
):
self
.
model
=
Net
(
num_channels
=
1
,
base_filter
=
64
,
upscale_factor
=
self
.
upscale_factor
).
to
(
self
.
device
)
def
build_model
(
self
,
num_channels
):
self
.
model
=
Net
(
num_channels
=
num_channels
,
base_filter
=
64
,
upscale_factor
=
self
.
upscale_factor
).
to
(
self
.
device
)
self
.
model
.
weight_init
(
mean
=
0.0
,
std
=
0.01
)
self
.
criterion
=
torch
.
nn
.
MSELoss
()
torch
.
manual_seed
(
self
.
seed
)
...
...
@@ -47,11 +47,15 @@ class SRCNNTrainer(object):
torch
.
save
(
self
.
model
,
model_out_path
)
print
(
"Checkpoint saved to {}"
.
format
(
model_out_path
))
def
train
(
self
,
epoch
,
iters
):
def
train
(
self
,
epoch
,
iters
,
channels
=
None
):
self
.
model
.
train
()
train_loss
=
0
for
batch_num
,
(
_
,
data
,
target
)
in
enumerate
(
self
.
training_loader
):
data
,
target
=
data
.
to
(
self
.
device
),
target
.
to
(
self
.
device
)
if
channels
:
data
=
data
[...,
channels
,
:,
:]
target
=
target
[...,
channels
,
:,
:]
data
=
data
.
to
(
self
.
device
)
target
=
target
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
out
=
self
.
model
(
data
)
loss
=
self
.
criterion
(
out
,
target
)
...
...
SRGAN/README.md
→
upsampling/
SRGAN/README.md
View file @
3554ba52
File moved
SRGAN/model.py
→
upsampling/
SRGAN/model.py
View file @
3554ba52
File moved
SRGAN/solver.py
→
upsampling/
SRGAN/solver.py
View file @
3554ba52
File moved
SubPixelCNN/model.py
→
upsampling/
SubPixelCNN/model.py
View file @
3554ba52
File moved
SubPixelCNN/solver.py
→
upsampling/
SubPixelCNN/solver.py
View file @
3554ba52
...
...
@@ -8,7 +8,7 @@ import torch.backends.cudnn as cudnn
import
torchvision
from
.model
import
Net
from
..
my.progress_bar
import
progress_bar
from
my.progress_bar
import
progress_bar
class
SubPixelTrainer
(
object
):
...
...
@@ -28,7 +28,9 @@ class SubPixelTrainer(object):
self
.
testing_loader
=
testing_loader
self
.
writer
=
writer
def
build_model
(
self
):
def
build_model
(
self
,
num_channels
):
if
num_channels
!=
1
:
raise
ValueError
(
'num_channels must be 1'
)
self
.
model
=
Net
(
upscale_factor
=
self
.
upscale_factor
).
to
(
self
.
device
)
self
.
criterion
=
torch
.
nn
.
MSELoss
()
torch
.
manual_seed
(
self
.
seed
)
...
...
@@ -39,17 +41,21 @@ class SubPixelTrainer(object):
self
.
criterion
.
cuda
()
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
self
.
lr
)
self
.
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
self
.
optimizer
,
milestones
=
[
50
,
75
,
100
],
gamma
=
0.5
)
# lr decay
self
.
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
self
.
optimizer
,
milestones
=
[
50
,
75
,
100
],
gamma
=
0.5
)
# lr decay
def
save
(
self
):
model_out_path
=
"model_path.pth"
torch
.
save
(
self
.
model
,
model_out_path
)
print
(
"Checkpoint saved to {}"
.
format
(
model_out_path
))
def
train
(
self
,
epoch
,
iters
):
def
train
(
self
,
epoch
,
iters
,
channels
=
None
):
self
.
model
.
train
()
train_loss
=
0
for
batch_num
,
(
_
,
data
,
target
)
in
enumerate
(
self
.
training_loader
):
if
channels
:
data
=
data
[...,
channels
,
:,
:]
target
=
target
[...,
channels
,
:,
:]
data
,
target
=
data
.
to
(
self
.
device
),
target
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
out
=
self
.
model
(
data
)
...
...
@@ -58,7 +64,8 @@ class SubPixelTrainer(object):
loss
.
backward
()
self
.
optimizer
.
step
()
sys
.
stdout
.
write
(
'Epoch %d: '
%
epoch
)
progress_bar
(
batch_num
,
len
(
self
.
training_loader
),
'Loss: %.4f'
%
(
train_loss
/
(
batch_num
+
1
)))
progress_bar
(
batch_num
,
len
(
self
.
training_loader
),
'Loss: %.4f'
%
(
train_loss
/
(
batch_num
+
1
)))
if
self
.
writer
:
self
.
writer
.
add_scalar
(
"loss"
,
loss
,
iters
)
if
iters
%
100
==
0
:
...
...
@@ -66,11 +73,13 @@ class SubPixelTrainer(object):
.
flatten
(
0
,
1
).
detach
()
self
.
writer
.
add_image
(
"Output_vs_gt"
,
torchvision
.
utils
.
make_grid
(
output_vs_gt
,
nrow
=
2
).
cpu
().
numpy
(),
torchvision
.
utils
.
make_grid
(
output_vs_gt
,
nrow
=
2
).
cpu
().
numpy
(),
iters
)
iters
+=
1
print
(
" Average Loss: {:.4f}"
.
format
(
train_loss
/
len
(
self
.
training_loader
)))
print
(
" Average Loss: {:.4f}"
.
format
(
train_loss
/
len
(
self
.
training_loader
)))
return
iters
def
test
(
self
):
...
...
@@ -84,9 +93,11 @@ class SubPixelTrainer(object):
mse
=
self
.
criterion
(
prediction
,
target
)
psnr
=
10
*
log10
(
1
/
mse
.
item
())
avg_psnr
+=
psnr
progress_bar
(
batch_num
,
len
(
self
.
testing_loader
),
'PSNR: %.4f'
%
(
avg_psnr
/
(
batch_num
+
1
)))
progress_bar
(
batch_num
,
len
(
self
.
testing_loader
),
'PSNR: %.4f'
%
(
avg_psnr
/
(
batch_num
+
1
)))
print
(
" Average PSNR: {:.4f} dB"
.
format
(
avg_psnr
/
len
(
self
.
testing_loader
)))
print
(
" Average PSNR: {:.4f} dB"
.
format
(
avg_psnr
/
len
(
self
.
testing_loader
)))
def
run
(
self
):
self
.
build_model
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment