Skip to content
Snippets Groups Projects
Commit 3554ba52 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f7038e26
Branches
No related merge requests found
Showing
with 107 additions and 206 deletions
{
}
\ No newline at end of file
# deeplightfield
# deep_view_syn
Requirement:
......
#/usr/bin/bash
testcase=$1
dataset='data/sp_view_syn_2020.12.31_fovea'
epochs=100
n_layers_arr=(4 8 4 8)
n_samples_arr=(16 16 32 32)
for nf in 64 128 256; do
n_layers=${n_layers_arr[$testcase]}
n_samples=${n_samples_arr[$testcase]}
configid="infer_test@msl-rgb_e10_fc${nf}x${n_layers}_d1-50_s${n_samples}"
python run_spherical_view_syn.py --dataset $dataset/train.json --config-id $configid --device $testcase --epochs $epochs
python run_spherical_view_syn.py --dataset $dataset/train.json --test $dataset/$configid/model-epoch_$epochs.pth --perf --device $testcase
python run_spherical_view_syn.py --dataset $dataset/test.json --test $dataset/$configid/model-epoch_$epochs.pth --perf --device $testcase
done
from ..my import color_mode
def update_config(config):
# Dataset settings
config.GRAY = False
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'msl'
......@@ -12,4 +14,6 @@ def update_config(config):
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
})
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 8
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 20
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12,
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 6
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
from ..my import color_mode
def update_config(config):
# Dataset settings
config.GRAY = False
config.COLOR = color_mode.RGB
# Net parameters
config.NET_TYPE = 'msl'
......
import os
import importlib
from os.path import join
from ..my import color_mode
class SphericalViewSynConfig(object):
......@@ -8,7 +9,7 @@ class SphericalViewSynConfig(object):
def __init__(self):
self.name = 'default'
self.GRAY = False
self.COLOR = color_mode.RGB
# Net parameters
self.NET_TYPE = 'msl'
......@@ -30,18 +31,18 @@ class SphericalViewSynConfig(object):
def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.')
config_module = importlib.import_module(
'deeplightfield.' + module_name)
'deep_view_syn.' + module_name)
config_module.update_config(self)
self.name = module_name.split('.')[-1]
def load_by_name(self, name):
config_module = importlib.import_module(
'deeplightfield.configs.' + name)
'deep_view_syn.configs.' + name)
config_module.update_config(self)
self.name = name
def to_id(self):
net_type_id = "%s-%s" % (self.NET_TYPE, "gray" if self.GRAY else "rgb")
net_type_id = "%s-%s" % (self.NET_TYPE, color_mode.to_str(self.COLOR))
encode_id = "_e%d" % self.N_ENCODE_DIM
fc_id = "_fc%dx%d" % (self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'])
skip_id = "_skip%s" % ','.join([
......@@ -49,7 +50,7 @@ class SphericalViewSynConfig(object):
for val in self.FC_PARAMS['skips']
]) if len(self.FC_PARAMS['skips']) > 0 else ""
depth_id = "_d%d-%d" % (self.SAMPLE_PARAMS['depth_range'][0],
self.SAMPLE_PARAMS['depth_range'][1])
self.SAMPLE_PARAMS['depth_range'][1])
samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples']
neg_flags = '%s%s%s' % (
'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
......@@ -60,33 +61,38 @@ class SphericalViewSynConfig(object):
return "%s@%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, neg_flags)
def from_id(self, id: str):
self.name, config_str = id.split('@')
segs = config_str.split('_')
id_splited = id.split('@')
if len(id_splited) == 2:
self.name = id_splited[0]
segs = id_splited[-1].split('_')
for i, seg in enumerate(segs):
if i == 0: # NetType
self.NET_TYPE, color_mode = seg.split('-')
self.GRAY = (color_mode == 'gray')
continue
if seg.startswith('e'): # Encode
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:])
continue
if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (int(str) for str in seg[2:].split('x'))
if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
int(str) for str in seg[2:].split('x'))
continue
if seg.startswith('skip'): # Skip connection
self.FC_PARAMS['skips'] = [int(str) for str in seg[4:].split(',')]
if seg.startswith('skip'): # Skip connection
self.FC_PARAMS['skips'] = [int(str)
for str in seg[4:].split(',')]
continue
if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple(float(str) for str in seg[1:].split('-'))
if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple(
float(str) for str in seg[1:].split('-'))
continue
if seg.startswith('s'): # Number of samples
if seg.startswith('s'): # Number of samples
self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
continue
if seg.startswith('~'): # Negative flags
if seg.startswith('~'): # Negative flags
self.SAMPLE_PARAMS['perturb_sample'] = (seg.find('p') < 0)
self.SAMPLE_PARAMS['lindisp'] = (seg.find('l') < 0)
self.SAMPLE_PARAMS['inverse_r'] = (seg.find('i') < 0)
continue
if i == 0: # NetType
self.NET_TYPE, color_str = seg.split('-')
self.COLOR = color_mode.from_str(color_str)
continue
def print(self):
print('==== Config %s ====' % self.name)
......
......@@ -2,10 +2,12 @@ import os
import json
import torch
import torchvision.transforms.functional as trans_f
import torch.nn.functional as nn_f
from typing import Tuple, Union
from ..my import util
from ..my import device
from ..my import view
from ..my import color_mode
class SphericalViewSynDataset(object):
......@@ -23,7 +25,8 @@ class SphericalViewSynDataset(object):
"""
def __init__(self, dataset_desc_path: str, load_images: bool = True,
load_depths: bool = False, gray: bool = False, calculate_rays: bool = True):
load_depths: bool = False, color: int = color_mode.RGB,
calculate_rays: bool = True, res: Tuple[int, int] = None):
"""
Initialize data loader for spherical view synthesis task
......@@ -38,7 +41,7 @@ class SphericalViewSynDataset(object):
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param load_depths ```bool```: whether load depth images and return in __getitem__()
:param gray ```bool```: whether convert view images to grayscale
:param color ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
super().__init__()
......@@ -47,7 +50,7 @@ class SphericalViewSynDataset(object):
self.load_depths = load_depths
# Load dataset description file
self._load_desc(dataset_desc_path)
self._load_desc(dataset_desc_path, res)
# Load view images
if self.load_images:
......@@ -55,8 +58,12 @@ class SphericalViewSynDataset(object):
[self.view_file_pattern % i
for i in range(self.view_centers.size(0))]
).to(device.GetDevice())
if gray:
if color == color_mode.GRAY:
self.view_images = trans_f.rgb_to_grayscale(self.view_images)
elif color == color_mode.YCbCr:
self.view_images = util.rgb2ycbcr(self.view_images)
if res:
self.view_images = nn_f.interpolate(self.view_images, res)
else:
self.view_images = None
......@@ -68,6 +75,8 @@ class SphericalViewSynDataset(object):
for i in range(self.view_centers.size(0))]
).to(device.GetDevice()),
self.cam_params.get_local_rays())
if res:
self.view_depths = nn_f.interpolate(self.view_depths, res)
else:
self.view_depths = None
......@@ -85,7 +94,7 @@ class SphericalViewSynDataset(object):
output /= local_rays[..., 2]
return output
def _load_desc(self, path):
def _load_desc(self, path, res = None):
with open(path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
if data_desc['view_file_pattern'] == '':
......@@ -103,6 +112,9 @@ class SphericalViewSynDataset(object):
self.cam_params = view.CameraParam(data_desc['cam_params'],
self.view_res,
device=device.GetDevice())
if res:
self.view_res = res
self.cam_params.resize(res)
self.depth_range = [
data_desc['depth_range']['min'],
data_desc['depth_range']['max']
......
import os
from numpy.core.fromnumeric import trace
from numpy.lib.arraysetops import isin
import torch
import torchvision.transforms.functional as trans_f
from ..my import util
from ..my import device
from ..my import color_mode
class UpsamplingDataset(torch.utils.data.dataset.Dataset):
......@@ -13,7 +15,7 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
"""
def __init__(self, data_dir: str, input_patt: str, gt_patt: str,
gray: bool = False, load_once: bool = True):
color: int, load_once: bool = True):
"""
Initialize dataset for upsampling task
......@@ -33,15 +35,18 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
)))
self.load_once = load_once
self.load_gt = self.gt_patt != None
self.gray = gray
self.color = color
self.input = util.ReadImageTensor([self.input_patt % i for i in range(self.n)]) \
.to(device.GetDevice()) if self.load_once else None
self.gt = util.ReadImageTensor([self.gt_patt % i for i in range(self.n)]) \
.to(device.GetDevice()) if self.load_once and self.load_gt else None
if self.gray:
if self.color == color_mode.GRAY:
self.input = trans_f.rgb_to_grayscale(self.input)
self.gt = trans_f.rgb_to_grayscale(self.gt) \
if self.gt != None else None
elif self.color == color_mode.YCbCr:
self.input = util.rgb2ycbcr(self.input)
self.gt = util.rgb2ycbcr(self.gt) if self.gt != None else None
def __len__(self):
return self.n
......@@ -50,13 +55,16 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
if self.load_once:
return idx, self.input[idx], self.gt[idx] if self.load_gt else False
if isinstance(idx, torch.Tensor):
return idx, \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
[self.input_patt % i for i in idx])), \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
[self.gt_patt % i for i in idx])) if self.load_gt else False
return idx, \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
self.input_patt % idx)), \
trans_f.rgb_to_grayscale(util.ReadImageTensor(
self.gt_patt % idx)) if self.load_gt else False
input = util.ReadImageTensor([self.input_patt % i for i in idx])
gt = util.ReadImageTensor([self.gt_patt % i for i in idx]) if self.load_gt else False
else:
input = util.ReadImageTensor([self.input_patt % idx])
gt = util.ReadImageTensor([self.gt_patt % idx]) if self.load_gt else False
if self.color == color_mode.GRAY:
input = trans_f.rgb_to_grayscale(input)
gt = trans_f.rgb_to_grayscale(gt) if isinstance(gt, torch.Tensor) else False
return idx, input, gt
elif self.color == color_mode.YCbCr:
input = util.rgb2ycbcr(input)
gt = util.rgb2ycbcr(gt) if isinstance(gt, torch.Tensor) else False
return idx, input, gt
\ No newline at end of file
......@@ -6,17 +6,17 @@ import torch.optim
from torch import onnx
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
__package__ = "deep_view_syn"
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
help='Which CUDA device to use.')
parser.add_argument('--model', type=str,
help='Path of model to export')
parser.add_argument('--batch-size', type=int,
parser.add_argument('--batch-size', type=str,
help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
help='Output directory')
parser.add_argument('model', type=str,
help='Path of model to export')
opt = parser.parse_args()
# Select device
......@@ -29,13 +29,18 @@ from .my import device
from .my import netio
from .my import util
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)
config = SphericalViewSynConfig()
def load_net(path):
name = os.path.splitext(os.path.basename(path))[0]
config = SphericalViewSynConfig()
config.load_by_name(name.split('@')[1])
config.from_id(name)
config.SAMPLE_PARAMS['spherical'] = True
config.SAMPLE_PARAMS['perturb_sample'] = False
config.SAMPLE_PARAMS['n_samples'] = 4
config.print()
net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, config.GRAY,
config.N_ENCODE_DIM, export_mode=True).to(device.GetDevice())
......@@ -46,16 +51,16 @@ def load_net(path):
if __name__ == "__main__":
with torch.no_grad():
# Load model
net, name = load_net(opt.model)
net, name = load_net(model_file)
# Input to the model
rays_o = torch.empty(opt.batch_size, 3, device=device.GetDevice())
rays_d = torch.empty(opt.batch_size, 3, device=device.GetDevice())
rays_o = torch.empty(batch_size, 3, device=device.GetDevice())
rays_d = torch.empty(batch_size, 3, device=device.GetDevice())
util.CreateDirIfNeed(opt.outdir)
# Export the model
outpath = os.path.join(opt.outdir, name + ".onnx")
outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
onnx.export(
net, # model being run
(rays_o, rays_d), # model input (or a tuple for multiple inputs)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment