Commit 3554ba52 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent f7038e26
{
}
\ No newline at end of file
# deeplightfield # deep_view_syn
Requirement: Requirement:
......
#/usr/bin/bash
testcase=$1
dataset='data/sp_view_syn_2020.12.31_fovea'
epochs=100
n_layers_arr=(4 8 4 8)
n_samples_arr=(16 16 32 32)
for nf in 64 128 256; do
n_layers=${n_layers_arr[$testcase]}
n_samples=${n_samples_arr[$testcase]}
configid="infer_test@msl-rgb_e10_fc${nf}x${n_layers}_d1-50_s${n_samples}"
python run_spherical_view_syn.py --dataset $dataset/train.json --config-id $configid --device $testcase --epochs $epochs
python run_spherical_view_syn.py --dataset $dataset/train.json --test $dataset/$configid/model-epoch_$epochs.pth --perf --device $testcase
python run_spherical_view_syn.py --dataset $dataset/test.json --test $dataset/$configid/model-epoch_$epochs.pth --perf --device $testcase
done
from ..my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
config.GRAY = False config.COLOR = color_mode.RGB
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
...@@ -12,4 +14,6 @@ def update_config(config): ...@@ -12,4 +14,6 @@ def update_config(config):
config.SAMPLE_PARAMS.update({ config.SAMPLE_PARAMS.update({
'depth_range': (1, 50), 'depth_range': (1, 50),
'n_samples': 16 'n_samples': 16
}) })
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 8
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 32
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 256,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 20
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 12,
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 20),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8,
'skips': [4]
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 128,
'n_layers': 6
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 16
})
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS.update({
'nf': 64,
'n_layers': 8
})
config.SAMPLE_PARAMS.update({
'depth_range': (1, 50),
'n_samples': 4
})
\ No newline at end of file
from ..my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
config.GRAY = False config.COLOR = color_mode.RGB
# Net parameters # Net parameters
config.NET_TYPE = 'msl' config.NET_TYPE = 'msl'
......
import os import os
import importlib import importlib
from os.path import join from os.path import join
from ..my import color_mode
class SphericalViewSynConfig(object): class SphericalViewSynConfig(object):
...@@ -8,7 +9,7 @@ class SphericalViewSynConfig(object): ...@@ -8,7 +9,7 @@ class SphericalViewSynConfig(object):
def __init__(self): def __init__(self):
self.name = 'default' self.name = 'default'
self.GRAY = False self.COLOR = color_mode.RGB
# Net parameters # Net parameters
self.NET_TYPE = 'msl' self.NET_TYPE = 'msl'
...@@ -30,18 +31,18 @@ class SphericalViewSynConfig(object): ...@@ -30,18 +31,18 @@ class SphericalViewSynConfig(object):
def load(self, path): def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.') module_name = os.path.splitext(path)[0].replace('/', '.')
config_module = importlib.import_module( config_module = importlib.import_module(
'deeplightfield.' + module_name) 'deep_view_syn.' + module_name)
config_module.update_config(self) config_module.update_config(self)
self.name = module_name.split('.')[-1] self.name = module_name.split('.')[-1]
def load_by_name(self, name): def load_by_name(self, name):
config_module = importlib.import_module( config_module = importlib.import_module(
'deeplightfield.configs.' + name) 'deep_view_syn.configs.' + name)
config_module.update_config(self) config_module.update_config(self)
self.name = name self.name = name
def to_id(self): def to_id(self):
net_type_id = "%s-%s" % (self.NET_TYPE, "gray" if self.GRAY else "rgb") net_type_id = "%s-%s" % (self.NET_TYPE, color_mode.to_str(self.COLOR))
encode_id = "_e%d" % self.N_ENCODE_DIM encode_id = "_e%d" % self.N_ENCODE_DIM
fc_id = "_fc%dx%d" % (self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers']) fc_id = "_fc%dx%d" % (self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'])
skip_id = "_skip%s" % ','.join([ skip_id = "_skip%s" % ','.join([
...@@ -49,7 +50,7 @@ class SphericalViewSynConfig(object): ...@@ -49,7 +50,7 @@ class SphericalViewSynConfig(object):
for val in self.FC_PARAMS['skips'] for val in self.FC_PARAMS['skips']
]) if len(self.FC_PARAMS['skips']) > 0 else "" ]) if len(self.FC_PARAMS['skips']) > 0 else ""
depth_id = "_d%d-%d" % (self.SAMPLE_PARAMS['depth_range'][0], depth_id = "_d%d-%d" % (self.SAMPLE_PARAMS['depth_range'][0],
self.SAMPLE_PARAMS['depth_range'][1]) self.SAMPLE_PARAMS['depth_range'][1])
samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples'] samples_id = '_s%d' % self.SAMPLE_PARAMS['n_samples']
neg_flags = '%s%s%s' % ( neg_flags = '%s%s%s' % (
'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '', 'p' if not self.SAMPLE_PARAMS['perturb_sample'] else '',
...@@ -60,33 +61,38 @@ class SphericalViewSynConfig(object): ...@@ -60,33 +61,38 @@ class SphericalViewSynConfig(object):
return "%s@%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, neg_flags) return "%s@%s%s%s%s%s%s%s" % (self.name, net_type_id, encode_id, fc_id, skip_id, depth_id, samples_id, neg_flags)
def from_id(self, id: str): def from_id(self, id: str):
self.name, config_str = id.split('@') id_splited = id.split('@')
segs = config_str.split('_') if len(id_splited) == 2:
self.name = id_splited[0]
segs = id_splited[-1].split('_')
for i, seg in enumerate(segs): for i, seg in enumerate(segs):
if i == 0: # NetType if seg.startswith('e'): # Encode
self.NET_TYPE, color_mode = seg.split('-')
self.GRAY = (color_mode == 'gray')
continue
if seg.startswith('e'): # Encode
self.N_ENCODE_DIM = int(seg[1:]) self.N_ENCODE_DIM = int(seg[1:])
continue continue
if seg.startswith('fc'): # Full-connected network parameters if seg.startswith('fc'): # Full-connected network parameters
self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (int(str) for str in seg[2:].split('x')) self.FC_PARAMS['nf'], self.FC_PARAMS['n_layers'] = (
int(str) for str in seg[2:].split('x'))
continue continue
if seg.startswith('skip'): # Skip connection if seg.startswith('skip'): # Skip connection
self.FC_PARAMS['skips'] = [int(str) for str in seg[4:].split(',')] self.FC_PARAMS['skips'] = [int(str)
for str in seg[4:].split(',')]
continue continue
if seg.startswith('d'): # Depth range if seg.startswith('d'): # Depth range
self.SAMPLE_PARAMS['depth_range'] = tuple(float(str) for str in seg[1:].split('-')) self.SAMPLE_PARAMS['depth_range'] = tuple(
float(str) for str in seg[1:].split('-'))
continue continue
if seg.startswith('s'): # Number of samples if seg.startswith('s'): # Number of samples
self.SAMPLE_PARAMS['n_samples'] = int(seg[1:]) self.SAMPLE_PARAMS['n_samples'] = int(seg[1:])
continue continue
if seg.startswith('~'): # Negative flags if seg.startswith('~'): # Negative flags
self.SAMPLE_PARAMS['perturb_sample'] = (seg.find('p') < 0) self.SAMPLE_PARAMS['perturb_sample'] = (seg.find('p') < 0)
self.SAMPLE_PARAMS['lindisp'] = (seg.find('l') < 0) self.SAMPLE_PARAMS['lindisp'] = (seg.find('l') < 0)
self.SAMPLE_PARAMS['inverse_r'] = (seg.find('i') < 0) self.SAMPLE_PARAMS['inverse_r'] = (seg.find('i') < 0)
continue continue
if i == 0: # NetType
self.NET_TYPE, color_str = seg.split('-')
self.COLOR = color_mode.from_str(color_str)
continue
def print(self): def print(self):
print('==== Config %s ====' % self.name) print('==== Config %s ====' % self.name)
......
...@@ -2,10 +2,12 @@ import os ...@@ -2,10 +2,12 @@ import os
import json import json
import torch import torch
import torchvision.transforms.functional as trans_f import torchvision.transforms.functional as trans_f
import torch.nn.functional as nn_f
from typing import Tuple, Union from typing import Tuple, Union
from ..my import util from ..my import util
from ..my import device from ..my import device
from ..my import view from ..my import view
from ..my import color_mode
class SphericalViewSynDataset(object): class SphericalViewSynDataset(object):
...@@ -23,7 +25,8 @@ class SphericalViewSynDataset(object): ...@@ -23,7 +25,8 @@ class SphericalViewSynDataset(object):
""" """
def __init__(self, dataset_desc_path: str, load_images: bool = True, def __init__(self, dataset_desc_path: str, load_images: bool = True,
load_depths: bool = False, gray: bool = False, calculate_rays: bool = True): load_depths: bool = False, color: int = color_mode.RGB,
calculate_rays: bool = True, res: Tuple[int, int] = None):
""" """
Initialize data loader for spherical view synthesis task Initialize data loader for spherical view synthesis task
...@@ -38,7 +41,7 @@ class SphericalViewSynDataset(object): ...@@ -38,7 +41,7 @@ class SphericalViewSynDataset(object):
:param dataset_desc_path ```str```: path to the data description file :param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__() :param load_images ```bool```: whether load view images and return in __getitem__()
:param load_depths ```bool```: whether load depth images and return in __getitem__() :param load_depths ```bool```: whether load depth images and return in __getitem__()
:param gray ```bool```: whether convert view images to grayscale :param color ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays :param calculate_rays ```bool```: whether calculate rays
""" """
super().__init__() super().__init__()
...@@ -47,7 +50,7 @@ class SphericalViewSynDataset(object): ...@@ -47,7 +50,7 @@ class SphericalViewSynDataset(object):
self.load_depths = load_depths self.load_depths = load_depths
# Load dataset description file # Load dataset description file
self._load_desc(dataset_desc_path) self._load_desc(dataset_desc_path, res)
# Load view images # Load view images
if self.load_images: if self.load_images:
...@@ -55,8 +58,12 @@ class SphericalViewSynDataset(object): ...@@ -55,8 +58,12 @@ class SphericalViewSynDataset(object):
[self.view_file_pattern % i [self.view_file_pattern % i
for i in range(self.view_centers.size(0))] for i in range(self.view_centers.size(0))]
).to(device.GetDevice()) ).to(device.GetDevice())
if gray: if color == color_mode.GRAY:
self.view_images = trans_f.rgb_to_grayscale(self.view_images) self.view_images = trans_f.rgb_to_grayscale(self.view_images)
elif color == color_mode.YCbCr:
self.view_images = util.rgb2ycbcr(self.view_images)
if res:
self.view_images = nn_f.interpolate(self.view_images, res)
else: else:
self.view_images = None self.view_images = None
...@@ -68,6 +75,8 @@ class SphericalViewSynDataset(object): ...@@ -68,6 +75,8 @@ class SphericalViewSynDataset(object):
for i in range(self.view_centers.size(0))] for i in range(self.view_centers.size(0))]
).to(device.GetDevice()), ).to(device.GetDevice()),
self.cam_params.get_local_rays()) self.cam_params.get_local_rays())
if res:
self.view_depths = nn_f.interpolate(self.view_depths, res)
else: else:
self.view_depths = None self.view_depths = None
...@@ -85,7 +94,7 @@ class SphericalViewSynDataset(object): ...@@ -85,7 +94,7 @@ class SphericalViewSynDataset(object):
output /= local_rays[..., 2] output /= local_rays[..., 2]
return output return output
def _load_desc(self, path): def _load_desc(self, path, res = None):
with open(path, 'r', encoding='utf-8') as file: with open(path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read()) data_desc = json.loads(file.read())
if data_desc['view_file_pattern'] == '': if data_desc['view_file_pattern'] == '':
...@@ -103,6 +112,9 @@ class SphericalViewSynDataset(object): ...@@ -103,6 +112,9 @@ class SphericalViewSynDataset(object):
self.cam_params = view.CameraParam(data_desc['cam_params'], self.cam_params = view.CameraParam(data_desc['cam_params'],
self.view_res, self.view_res,
device=device.GetDevice()) device=device.GetDevice())
if res:
self.view_res = res
self.cam_params.resize(res)
self.depth_range = [ self.depth_range = [
data_desc['depth_range']['min'], data_desc['depth_range']['min'],
data_desc['depth_range']['max'] data_desc['depth_range']['max']
......
import os import os
from numpy.core.fromnumeric import trace from numpy.core.fromnumeric import trace
from numpy.lib.arraysetops import isin
import torch import torch
import torchvision.transforms.functional as trans_f import torchvision.transforms.functional as trans_f
from ..my import util from ..my import util
from ..my import device from ..my import device
from ..my import color_mode
class UpsamplingDataset(torch.utils.data.dataset.Dataset): class UpsamplingDataset(torch.utils.data.dataset.Dataset):
...@@ -13,7 +15,7 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset): ...@@ -13,7 +15,7 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
""" """
def __init__(self, data_dir: str, input_patt: str, gt_patt: str, def __init__(self, data_dir: str, input_patt: str, gt_patt: str,
gray: bool = False, load_once: bool = True): color: int, load_once: bool = True):
""" """
Initialize dataset for upsampling task Initialize dataset for upsampling task
...@@ -33,15 +35,18 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset): ...@@ -33,15 +35,18 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
))) )))
self.load_once = load_once self.load_once = load_once
self.load_gt = self.gt_patt != None self.load_gt = self.gt_patt != None
self.gray = gray self.color = color
self.input = util.ReadImageTensor([self.input_patt % i for i in range(self.n)]) \ self.input = util.ReadImageTensor([self.input_patt % i for i in range(self.n)]) \
.to(device.GetDevice()) if self.load_once else None .to(device.GetDevice()) if self.load_once else None
self.gt = util.ReadImageTensor([self.gt_patt % i for i in range(self.n)]) \ self.gt = util.ReadImageTensor([self.gt_patt % i for i in range(self.n)]) \
.to(device.GetDevice()) if self.load_once and self.load_gt else None .to(device.GetDevice()) if self.load_once and self.load_gt else None
if self.gray: if self.color == color_mode.GRAY:
self.input = trans_f.rgb_to_grayscale(self.input) self.input = trans_f.rgb_to_grayscale(self.input)
self.gt = trans_f.rgb_to_grayscale(self.gt) \ self.gt = trans_f.rgb_to_grayscale(self.gt) \
if self.gt != None else None if self.gt != None else None
elif self.color == color_mode.YCbCr:
self.input = util.rgb2ycbcr(self.input)
self.gt = util.rgb2ycbcr(self.gt) if self.gt != None else None
def __len__(self): def __len__(self):
return self.n return self.n
...@@ -50,13 +55,16 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset): ...@@ -50,13 +55,16 @@ class UpsamplingDataset(torch.utils.data.dataset.Dataset):
if self.load_once: if self.load_once:
return idx, self.input[idx], self.gt[idx] if self.load_gt else False return idx, self.input[idx], self.gt[idx] if self.load_gt else False
if isinstance(idx, torch.Tensor): if isinstance(idx, torch.Tensor):
return idx, \ input = util.ReadImageTensor([self.input_patt % i for i in idx])
trans_f.rgb_to_grayscale(util.ReadImageTensor( gt = util.ReadImageTensor([self.gt_patt % i for i in idx]) if self.load_gt else False
[self.input_patt % i for i in idx])), \ else:
trans_f.rgb_to_grayscale(util.ReadImageTensor( input = util.ReadImageTensor([self.input_patt % idx])
[self.gt_patt % i for i in idx])) if self.load_gt else False gt = util.ReadImageTensor([self.gt_patt % idx]) if self.load_gt else False
return idx, \ if self.color == color_mode.GRAY:
trans_f.rgb_to_grayscale(util.ReadImageTensor( input = trans_f.rgb_to_grayscale(input)
self.input_patt % idx)), \ gt = trans_f.rgb_to_grayscale(gt) if isinstance(gt, torch.Tensor) else False
trans_f.rgb_to_grayscale(util.ReadImageTensor( return idx, input, gt
self.gt_patt % idx)) if self.load_gt else False elif self.color == color_mode.YCbCr:
input = util.rgb2ycbcr(input)
gt = util.rgb2ycbcr(gt) if isinstance(gt, torch.Tensor) else False
return idx, input, gt
\ No newline at end of file
...@@ -6,17 +6,17 @@ import torch.optim ...@@ -6,17 +6,17 @@ import torch.optim
from torch import onnx from torch import onnx
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield" __package__ = "deep_view_syn"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0, parser.add_argument('--device', type=int, default=0,
help='Which CUDA device to use.') help='Which CUDA device to use.')
parser.add_argument('--model', type=str, parser.add_argument('--batch-size', type=str,
help='Path of model to export')
parser.add_argument('--batch-size', type=int,
help='Resolution') help='Resolution')
parser.add_argument('--outdir', type=str, default='./', parser.add_argument('--outdir', type=str, default='./',
help='Output directory') help='Output directory')
parser.add_argument('model', type=str,
help='Path of model to export')
opt = parser.parse_args() opt = parser.parse_args()
# Select device # Select device
...@@ -29,13 +29,18 @@ from .my import device ...@@ -29,13 +29,18 @@ from .my import device
from .my import netio from .my import netio
from .my import util from .my import util
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)
config = SphericalViewSynConfig()
def load_net(path): def load_net(path):
name = os.path.splitext(os.path.basename(path))[0] name = os.path.splitext(os.path.basename(path))[0]
config = SphericalViewSynConfig() config.from_id(name)
config.load_by_name(name.split('@')[1])
config.SAMPLE_PARAMS['spherical'] = True config.SAMPLE_PARAMS['spherical'] = True
config.SAMPLE_PARAMS['perturb_sample'] = False config.SAMPLE_PARAMS['perturb_sample'] = False
config.SAMPLE_PARAMS['n_samples'] = 4
config.print() config.print()
net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, config.GRAY, net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, config.GRAY,
config.N_ENCODE_DIM, export_mode=True).to(device.GetDevice()) config.N_ENCODE_DIM, export_mode=True).to(device.GetDevice())
...@@ -46,16 +51,16 @@ def load_net(path): ...@@ -46,16 +51,16 @@ def load_net(path):
if __name__ == "__main__": if __name__ == "__main__":
with torch.no_grad(): with torch.no_grad():
# Load model # Load model
net, name = load_net(opt.model) net, name = load_net(model_file)
# Input to the model # Input to the model
rays_o = torch.empty(opt.batch_size, 3, device=device.GetDevice()) rays_o = torch.empty(batch_size, 3, device=device.GetDevice())
rays_d = torch.empty(opt.batch_size, 3, device=device.GetDevice()) rays_d = torch.empty(batch_size, 3, device=device.GetDevice())
util.CreateDirIfNeed(opt.outdir) util.CreateDirIfNeed(opt.outdir)
# Export the model # Export the model
outpath = os.path.join(opt.outdir, name + ".onnx") outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
onnx.export( onnx.export(
net, # model being run net, # model being run
(rays_o, rays_d), # model input (or a tuple for multiple inputs) (rays_o, rays_d), # model input (or a tuple for multiple inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment