Commit 3554ba52 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f7038e26
{
}
\ No newline at end of file
# deeplightfield
# deep_view_syn
Requirement:
......
#/usr/bin/bash
testcase=$1
dataset='data/sp_view_syn_2020.12.31_fovea'
epochs=100
n_layers_arr=(4 8 4 8)
n_samples_arr=(16 16 32 32)
for nf in 64 128 256; do
n_layers=${n_layers_arr[$testcase]}
n_samples=${n_samples_arr[$testcase]}
configid="infer_test@msl-rgb_e10_fc${nf}x${n_layers}_d1-50_s${n_samples}"
python run_spherical_view_syn.py --dataset $dataset/train.json --config-id $configid --device $testcase --epochs $epochs
python run_spherical_view_syn.py --dataset $dataset/train.json --test $dataset/$configid/model-epoch_$epochs.pth --perf --device $testcase
python run_spherical_view_syn.py --dataset $dataset/test.json --test $dataset/$configid/model-epoch_$epochs.pth --perf --device $testcase
done
from ..my import color_mode
def update_config(config):
# Dataset settings
config.GRAY = False
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'msl'
......@@ -12,4 +14,6 @@ def update_config(config):
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
})
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 8
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 20
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12,
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 6
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
from ..my import color_mode
def update_config(config):
# Dataset settings
config.GRAY = False
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'msl'
......
import os
import importlib
from os.path import join
from ..my import color_mode
class SphericalViewSynConfig(object):
......@@ -8,7 +9,7 @@ class SphericalViewSynConfig(object):
def __init__(self):
self.name = 'default'
self.GRAY = False
self.COLOR = color_mode.RGB
# Net parameters
self.NET_TYPE = 'msl'
......@@ -30,18 +31,18 @@ class SphericalViewSynConfig(object):
def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.')
config_module = importlib.import_module(
'deeplightfield.' + module_name)
'deep_view_syn.' + module_name)
config_module.update_config(self)
self.name = module_name.split('.')[-1]
def load_by_name(self, name):
config_module = importlib.import_module(
'deeplightfield.configs.' + name)
'deep_view_syn.configs.' + name)
config_module.update_config(self)
self.name = name
def to_id(self):
net_type_id = "%s-%s" % (self.NET_TYPE, "gray" if self.GRAY else "rgb")
net_type_id = "%s-%s" % (self.NET_TYPE, color_mode.to_str(self.COLOR))
encode_id = "_e%d" % self.N_ENCODE_DIM
fc_id = "_fc%dx%d" % (self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'])
skip_id = "_skip%s" % ','.join([
......@@ -49,7 +50,7 @@ class SphericalViewSynConfig(object):
for val in self.FC_PARAMS['skips']
]) if len(self.FC_PARAMS['skips']) > 0 else ""
depth_id = "_d%d-%d" % (self.SAMPLE_PARAMS['depth_range'][0],
self.SAMPLE_PARAMS['depth_range'][1])
self.SAMPLE_PARAMS['depth_range'][1])
samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples']
neg_flags = '%s%s%s' % (
'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
......@@ -60,33 +61,38 @@ class SphericalViewSynConfig(object):
return "%s@%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, neg_flags)
def from_id(self, id: str):
self.name, config_str = id.split('@')
segs = config_str.split('_')
id_splited = id.split('@')
if len(id_splited) == 2:
self.name = id_splited[0]
segs = id_splited[-1].split('_')
for i, seg in enumerate(segs):
if i == 0: # NetType
self.NET_TYPE, color_mode = seg.split('-')
self.GRAY = (color_mode == 'gray')
continue
if seg.startswith('e'): # Encode
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:])
continue
if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (int(str) for str in seg[2:].split('x'))
if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
int(str) for str in seg[2:].split('x'))
continue
if seg.startswith('skip'): # Skip connection
self.FC_PARAMS['skips'] = [int(str) for str in seg[4:].split(',')]
if seg.startswith('skip'): # Skip connection
self.FC_PARAMS['skips'] = [int(str)
for str in seg[4:].split(',')]
continue
if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple(float(str) for str in seg[1:].split('-'))
if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple(
float(str) for str in seg[1:].split('-'))
continue
if seg.startswith('s'): # Number of samples
if seg.startswith('s'): # Number of samples
self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
continue
if seg.startswith('~'): # Negative flags
if seg.startswith('~'): # Negative flags
self.SAMPLE_PARAMS['perturb_sample'] = (seg.find('p') < 0)
self.SAMPLE_PARAMS['lindisp'] = (seg.find('l') < 0)
self.SAMPLE_PARAMS['inverse_r'] = (seg.find('i') < 0)
continue
if i == 0: # NetType
self.NET_TYPE, color_str = seg.split('-')
self.COLOR = color_mode.from_str(color_str)
continue
def print(self):
print('==== Config %s ====' % self.name)
......
......@@ -2,10 +2,12 @@ import os
import json
import torch
import torchvision.transforms.functional as trans_f
import torch.nn.functional as nn_f
from typing import Tuple, Union
from ..my import util
from ..my import device
from ..my import view
from ..my import color_mode
class SphericalViewSynDataset(object):
......@@ -23,7 +25,8 @@ class SphericalViewSynDataset(object):
"""
def __init__(self, dataset_desc_path: str, load_images: bool = True,
load_depths: bool = False, gray: bool = False, calculate_rays: bool = True):
load_depths: bool = False, color: int = color_mode.RGB,
calculate_rays: bool = True, res: Tuple[int, int] = None):
"""
Initialize data loader for spherical view synthesis task
......@@ -38,7 +41,7 @@ class SphericalViewSynDataset(object):
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param load_depths ```bool```: whether load depth images and return in __getitem__()
:param gray ```bool```: whether convert view images to grayscale
:param color ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
super().__init__()
......@@ -47,7 +50,7 @@ class SphericalViewSynDataset(object):
self.load_depths = load_depths
# Load dataset description file
self._load_desc(dataset_desc_path)
self._load_desc(dataset_desc_path, res)
# Load view images
if self.load_images:
......@@ -55,8 +58,12 @@ class SphericalViewSynDataset(object):
[self.view_file_pattern % i
for i in range(self.view_centers.size(0))]
).to(device.GetDevice())
if gray:
if color == color_mode.GRAY:
self.view_images = trans_f.rgb_to_grayscale(self.view_images)
elif color == color_mode.YCbCr:
self.view_images = util.rgb2ycbcr(self.view_images)
if res:
self.view_images = nn_f.interpolate(self.view_images, res)
else:
self.view_images = None
......@@ -68,6 +75,8 @@ class SphericalViewSynDataset(object):
for i in range(self.view_centers.size(0))]
).to(device.GetDevice()),
self.cam_params.get_local_rays())
if res:
self.view_depths = nn_f.interpolate(self.view_depths, res)
else:
self.view_depths = None
......@@ -85,7 +94,7 @@ class SphericalViewSynDataset(object):
output /= local_rays[..., 2]
return output
def _load_desc(self, path):
def _load_desc(self, path, res = None):
with open(path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
if data_desc['view_file_pattern'] == '':
......@@ -103,6 +112,9 @@ class SphericalViewSynDataset(object):
self.cam_params = view.CameraParam(data_desc['cam_params'],
self.view_res,
device=device.GetDevice())
if res:
self.view_res = res
self.cam_params.resize(res)
self.depth_range = [
data_desc['depth_range']['min'],
data_desc['depth_range']['max']
......
import os
from numpy.core.fromnumeric import trace
from numpy.lib.arraysetops import isin
import torch
import torchvision.transforms.functional as trans_f
from ..my import util
from ..my import device
from ..my import color_mode
class UpsamplingDataset(torch.utils.data.dataset.Dataset):
......@@ -13,7 +15,7 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
"""
def __init__(self, data_dir: str, input_patt: str, gt_patt: str,
gray: bool = False, load_once: bool = True):
color: int, load_once: bool = True):
"""
Initialize dataset for upsampling task
......@@ -33,15 +35,18 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
)))
self.load_once = load_once
self.load_gt = self.gt_patt != None
self.gray = gray
self.color = color
self.input = util.ReadImageTensor([self.input_patt % i for i in range(self.n)]) \
.to(device.GetDevice()) if self.load_once else None
self.gt = util.ReadImageTensor([self.gt_patt % i for i in range(self.n)]) \
.to(device.GetDevice()) if self.load_once and self.load_gt else None
if self.gray:
if self.color == color_mode.GRAY:
self.input = trans_f.rgb_to_grayscale(self.input)
self.gt = trans_f.rgb_to_grayscale(self.gt) \
if self.gt != None else None
elif self.color == color_mode.YCbCr:
self.input = util.rgb2ycbcr(self.input)
self.gt = util.rgb2ycbcr(self.gt) if self.gt != None else None
def __len__(self):
return self.n
......@@ -50,13 +55,16 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
if self.load_once:
return idx, self.input[idx], self.gt[idx] if self.load_gt else False
if isinstance(idx, torch.Tensor):
return idx, \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
[self.input_patt % i for i in idx])), \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
[self.gt_patt % i for i in idx])) if self.load_gt else False
return idx, \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
self.input_patt % idx)), \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
self.gt_patt % idx)) if self.load_gt else False
input = util.ReadImageTensor([self.input_patt % i for i in idx])
gt = util.ReadImageTensor([self.gt_patt % i for i in idx]) if self.load_gt else False
else:
input = util.ReadImageTensor([self.input_patt % idx])
gt = util.ReadImageTensor([self.gt_patt % idx]) if self.load_gt else False
if self.color == color_mode.GRAY:
input = trans_f.rgb_to_grayscale(input)
gt = trans_f.rgb_to_grayscale(gt) if isinstance(gt, torch.Tensor) else False
return idx, input, gt
elif self.color == color_mode.YCbCr:
input = util.rgb2ycbcr(input)
gt = util.rgb2ycbcr(gt) if isinstance(gt, torch.Tensor) else False
return idx, input, gt
\ No newline at end of file
......@@ -6,17 +6,17 @@ import torch.optim
from torch import onnx
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
__package__ = "deep_view_syn"
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
help='Which CUDA device to use.')
parser.add_argument('--model', type=str,
help='Path of model to export')
parser.add_argument('--batch-size', type=int,
parser.add_argument('--batch-size', type=str,
help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
help='Output directory')
parser.add_argument('model', type=str,
help='Path of model to export')
opt = parser.parse_args()
# Select device
......@@ -29,13 +29,18 @@ from .my import device
from .my import netio
from .my import util
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)
config = SphericalViewSynConfig()
def load_net(path):
name = os.path.splitext(os.path.basename(path))[0]
config = SphericalViewSynConfig()
config.load_by_name(name.split('@')[1])
config.from_id(name)
config.SAMPLE_PARAMS['spherical'] = True
config.SAMPLE_PARAMS['perturb_sample'] = False
config.SAMPLE_PARAMS['n_samples'] = 4
config.print()
net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, config.GRAY,
config.N_ENCODE_DIM, export_mode=True).to(device.GetDevice())
......@@ -46,16 +51,16 @@ def load_net(path):
if __name__ == "__main__":
with torch.no_grad():
# Load model
net, name = load_net(opt.model)
net, name = load_net(model_file)
# Input to the model
rays_o = torch.empty(opt.batch_size, 3, device=device.GetDevice())
rays_d = torch.empty(opt.batch_size, 3, device=device.GetDevice())
rays_o = torch.empty(batch_size, 3, device=device.GetDevice())
rays_d = torch.empty(batch_size, 3, device=device.GetDevice())
util.CreateDirIfNeed(opt.outdir)
# Export the model
outpath = os.path.join(opt.outdir, name + ".onnx")
outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
onnx.export(
net, # model being run
(rays_o, rays_d), # model input (or a tuple for multiple inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment