Commit c10f614f authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent dcba5844
...@@ -27,7 +27,7 @@ class SphericalViewSynDataset(object): ...@@ -27,7 +27,7 @@ class SphericalViewSynDataset(object):
def __init__(self, dataset_desc_path: str, load_images: bool = True, def __init__(self, dataset_desc_path: str, load_images: bool = True,
load_depths: bool = False, load_bins: bool = False, c: int = color.RGB, load_depths: bool = False, load_bins: bool = False, c: int = color.RGB,
calculate_rays: bool = True, res: Tuple[int, int] = None): calculate_rays: bool = True, res: Tuple[int, int] = None, load_views=None):
""" """
Initialize data loader for spherical view synthesis task Initialize data loader for spherical view synthesis task
...@@ -52,7 +52,7 @@ class SphericalViewSynDataset(object): ...@@ -52,7 +52,7 @@ class SphericalViewSynDataset(object):
self.load_bins = load_bins self.load_bins = load_bins
# Load dataset description file # Load dataset description file
self._load_desc(dataset_desc_path, res) self._load_desc(dataset_desc_path, res, load_views)
# Load view images # Load view images
if self.load_images: if self.load_images:
...@@ -98,7 +98,7 @@ class SphericalViewSynDataset(object): ...@@ -98,7 +98,7 @@ class SphericalViewSynDataset(object):
disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0] disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
return torch.reciprocal(disp_val) return torch.reciprocal(disp_val)
def _load_desc(self, path, res=None): def _load_desc(self, path, res=None, load_views=None):
with open(path, 'r', encoding='utf-8') as file: with open(path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read()) data_desc = json.loads(file.read())
if not data_desc.get('view_file_pattern'): if not data_desc.get('view_file_pattern'):
...@@ -127,11 +127,17 @@ class SphericalViewSynDataset(object): ...@@ -127,11 +127,17 @@ class SphericalViewSynDataset(object):
[view.euler_to_matrix([rot[1], rot[0], 0]) for rot in data_desc['view_rots']] [view.euler_to_matrix([rot[1], rot[0], 0]) for rot in data_desc['view_rots']]
if len(data_desc['view_rots'][0]) == 2 else data_desc['view_rots'], if len(data_desc['view_rots'][0]) == 2 else data_desc['view_rots'],
device=device.default()).view(-1, 3, 3) # (N, 3, 3) device=device.default()).view(-1, 3, 3) # (N, 3, 3)
#self.view_centers = self.view_centers[:6] self.view_idxs = torch.tensor(
#self.view_rots = self.view_rots[:6] data_desc['views'] if 'views' in data_desc else list(range(self.view_centers.size(0))),
device=device.default())
if load_views is not None:
self.view_centers = self.view_centers[load_views]
self.view_rots = self.view_rots[load_views]
self.view_idxs = self.view_idxs[load_views]
self.n_views = self.view_centers.size(0) self.n_views = self.view_centers.size(0)
self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1] self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
self.view_idxs = data_desc['views'][:self.n_views] if 'views' in data_desc else range(self.n_views)
if 'gl_coord' in data_desc and data_desc['gl_coord'] == True: if 'gl_coord' in data_desc and data_desc['gl_coord'] == True:
print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)') print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
......
...@@ -70,6 +70,10 @@ ...@@ -70,6 +70,10 @@
" plt.subplot(133)\n", " plt.subplot(133)\n",
" img.plot(images['layers_img'][2])\n", " img.plot(images['layers_img'][2])\n",
" plt.figure(figsize=(12, 12))\n", " plt.figure(figsize=(12, 12))\n",
" img.plot(images['overlaid'])\n",
" plt.figure(figsize=(12, 12))\n",
" img.plot(images['blended_raw'])\n",
" plt.figure(figsize=(12, 12))\n",
" img.plot(images['blended'])\n", " img.plot(images['blended'])\n",
"\n", "\n",
"\n", "\n",
...@@ -87,7 +91,7 @@ ...@@ -87,7 +91,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -108,12 +112,12 @@ ...@@ -108,12 +112,12 @@
"fovea_net = load_net(find_file('fovea'))\n", "fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n", "periph_net = load_net(find_file('periph'))\n",
"renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]),\n", "renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]),\n",
" res_full, using_mask=False, device=device.default())\n" " res_full, device=device.default())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -129,13 +133,14 @@ ...@@ -129,13 +133,14 @@
" ],\n", " ],\n",
" 'barbershop': [\n", " 'barbershop': [\n",
" [0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0],\n",
" #[0, 0, 0, 20, 0, -300, 50],\n", " [0, 0, 0, 20, 0, -300, 50],\n",
" #[0, 0, 0, -140, -30, 150, -250],\n", " [0, 0, 0, -140, -30, 150, -250],\n",
" #[0, 0, 0, -60, -30, 75, -125],\n", " [0, 0, 0, -60, -30, 75, -125],\n",
" [0, 0, 0, -10, -5, 0, 0]\n",
" ],\n", " ],\n",
" 'lobby': [\n", " 'lobby': [\n",
" #[0, 0, 0, 0, 0, 75, 0],\n", " [0, 0, 0, 0, 0, 75, 0],\n",
" #[0, 0, 0, 0, 0, 5, 150],\n", " [0, 0, 0, 0, 0, 5, 150],\n",
" [0, 0, 0, -120, 0, 75, 50],\n", " [0, 0, 0, -120, 0, 75, 50],\n",
" ]\n", " ]\n",
"}\n", "}\n",
...@@ -143,14 +148,17 @@ ...@@ -143,14 +148,17 @@
"for i, param in enumerate(params[scene]):\n", "for i, param in enumerate(params[scene]):\n",
" view = Trans(torch.tensor(param[:3], device=device.default()),\n", " view = Trans(torch.tensor(param[:3], device=device.default()),\n",
" torch.tensor(euler_to_matrix([-param[4], param[3], 0]), device=device.default()).view(3, 3))\n", " torch.tensor(euler_to_matrix([-param[4], param[3], 0]), device=device.default()).view(3, 3))\n",
" images = renderer(view, param[-2:])\n", " images = renderer(view, param[-2:], using_mask=False, ret_raw=True)\n",
" if False:\n", " images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n",
" outputdir = '../__demo/mono/'\n", " outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n", " misc.create_dir(outputdir)\n",
" img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n", " img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
" img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n", " img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
" img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n", " img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
" img.save(images['blended'], f'{outputdir}{scene}_{i}_blended.png')\n", " img.save(images['blended'], f'{outputdir}{scene}_{i}_blended.png')\n",
" img.save(images['overlaid'], f'{outputdir}{scene}_{i}_overlaid.png')\n",
" img.save(images['blended_raw'], f'{outputdir}{scene}_{i}_blended_raw.png')\n",
" else:\n", " else:\n",
" images = plot_images(images)\n" " images = plot_images(images)\n"
] ]
...@@ -212,8 +220,9 @@ ...@@ -212,8 +220,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3.8.5 64-bit ('base': conda)", "display_name": "Python 3",
"name": "python385jvsc74a57bd082066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" "language": "python",
"name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
...@@ -231,9 +240,8 @@ ...@@ -231,9 +240,8 @@
"interpreter": { "interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" "hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
} }
}, }
"orig_nbformat": 2
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }
\ No newline at end of file
...@@ -2,37 +2,44 @@ ...@@ -2,37 +2,44 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:0 as current device.\n"
]
}
],
"source": [ "source": [
"import sys\n", "import sys\n",
"import os\n", "import os\n",
"import torch\n", "import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n", "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n", "sys.path.append(rootdir)\n",
"\n", "\n",
"torch.cuda.set_device(2)\n", "torch.cuda.set_device(0)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n", "torch.autograd.set_grad_enabled(False)\n",
"\n", "\n",
"from data.spherical_view_syn import *\n", "from data.spherical_view_syn import *\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n", "from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n", "from utils import netio\n",
"from utils import misc\n",
"from utils import img\n", "from utils import img\n",
"from utils import device\n", "from utils import device\n",
"from utils import view\n", "from utils.view import *\n",
"from components.fnr import FoveatedNeuralRenderer\n", "from components.fnr import FoveatedNeuralRenderer\n",
"\n", "\n",
"\n", "\n",
"def load_net(path):\n", "def load_net(path):\n",
" config = SphericalViewSynConfig()\n", " config = SphericalViewSynConfig()\n",
" config.from_id(path[:-4])\n", " config.from_id(os.path.splitext(path)[0])\n",
" config.SAMPLE_PARAMS['perturb_sample'] = False\n", " config.SAMPLE_PARAMS['perturb_sample'] = False\n",
" # config.print()\n",
" net = config.create_net().to(device.default())\n", " net = config.create_net().to(device.default())\n",
" netio.load(path, net)\n", " netio.load(path, net)\n",
" return net\n", " return net\n",
...@@ -45,14 +52,14 @@ ...@@ -45,14 +52,14 @@
" return None\n", " return None\n",
"\n", "\n",
"\n", "\n",
"def load_views(data_desc_file) -> view.Trans:\n", "def load_views(data_desc_file) -> Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n", " with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n", " data_desc = json.loads(file.read())\n",
" view_centers = torch.tensor(\n", " view_centers = torch.tensor(\n",
" data_desc['view_centers'], device=device.default()).view(-1, 3)\n", " data_desc['view_centers'], device=device.default()).view(-1, 3)\n",
" view_rots = torch.tensor(\n", " view_rots = torch.tensor(\n",
" data_desc['view_rots'], device=device.default()).view(-1, 3, 3)\n", " data_desc['view_rots'], device=device.default()).view(-1, 3, 3)\n",
" return view.Trans(view_centers, view_rots)\n", " return Trans(view_centers, view_rots)\n",
"\n", "\n",
"\n", "\n",
"def plot_cross(center, res):\n", "def plot_cross(center, res):\n",
...@@ -78,115 +85,120 @@ ...@@ -78,115 +85,120 @@
" color=[0, 1, 0])\n", " color=[0, 1, 0])\n",
"\n", "\n",
"\n", "\n",
"def plot_fovea(left_images, right_images, left_center, right_center):\n", "def plot_figures(left_images, right_images, left_center, right_center):\n",
" plt.figure(figsize=(8, 4))\n", " # Plot Fovea\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n", " plt.subplot(121)\n",
" img.plot(left_images['fovea'])\n", " img.plot(left_images['layers_img'][0])\n",
" fovea_res = left_images['fovea'].size()[-2:]\n", " fovea_res = left_images['layers_img'][0].size()[-2:]\n",
" plot_cross((0, 0), fovea_res)\n", " plot_cross((0, 0), fovea_res)\n",
" plt.subplot(122)\n", " plt.subplot(122)\n",
" img.plot(right_images['fovea'])\n", " img.plot(right_images['layers_img'][0])\n",
" plot_cross((0, 0), fovea_res)\n", " plot_cross((0, 0), fovea_res)\n",
"\n", "\n",
" # Plot Mid\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" img.plot(left_images['layers_img'][1])\n",
" plt.subplot(122)\n",
" img.plot(right_images['layers_img'][1])\n",
"\n",
" # Plot Periph\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" img.plot(left_images['layers_img'][2])\n",
" plt.subplot(122)\n",
" img.plot(right_images['layers_img'][2])\n",
"\n",
" # Plot Blended\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" img.plot(left_images['blended'])\n",
" full_res = left_images['blended'].size()[-2:]\n",
" plot_cross(left_center, full_res)\n",
" plt.subplot(122)\n",
" img.plot(right_images['blended'])\n",
" plot_cross(right_center, full_res)\n",
"\n",
"\n", "\n",
"scenes = {\n", "scenes = {\n",
" 'gas': '__0_user_study/us_gas_all_in_one',\n", " 'classroom': 'classroom_all',\n",
" 'mc': '__0_user_study/us_mc_all_in_one',\n", " 'stones': 'stones_all',\n",
" 'bedroom': 'bedroom_all_in_one',\n", " 'barbershop': 'barbershop_all',\n",
" 'gallery': 'gallery_all_in_one',\n", " 'lobby': 'lobby_all'\n",
" 'lobby': 'lobby_all_in_one'\n",
"}\n", "}\n",
"\n", "\n",
"\n",
"fov_list = [20, 45, 110]\n", "fov_list = [20, 45, 110]\n",
"res_list = [(128, 128), (256, 256), (256, 230)]\n", "res_list = [(256, 256), (256, 256), (256, 230)]\n",
"res_full = (1600, 1440)\n" "res_full = (1600, 1440)\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 26,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Change working directory to /home/dengnc/dvs/data/__new/classroom_all\n",
"Load net from fovea@snerffast4-rgb_e6_fc256x8_d1.00-7.00_s64_~p.pth ...\n",
"Load net from periph@snerffast4-rgb_e6_fc128x4_d1.00-7.00_s64_~p.pth ...\n"
]
}
],
"source": [ "source": [
"centers = {\n", "params = {\n",
" 'gas': [\n", " 'classroom': [\n",
" [(3.5, 0), (-3.5, 0)],\n", " [(0, 0, 0, 0, 0), (1, -83), (-5, -83)],\n",
" [(1.5, 0), (-1.5, 0)]\n", " [(0, 0, 0, 0, 0), (-171, 55), (-178, 55)],\n",
" ],\n", " [(0, 0, 0, 0, 0), (60, 55), (55, 55)],\n",
" 'mc': [\n", " [(0, 0, 0, 0, 0), (138, 160), (130, 160)]\n",
" [(2, 0), (-2, 0)],\n",
" [(2, 0), (-2, 0)]\n",
" ],\n", " ],\n",
" 'bedroom': [\n",
" [(5, 0), (-5, 0)],\n",
" [(6, 0), (-6, 0)],\n",
" [(5, 0), (-5, 0)]\n",
" ],\n",
" 'gallery': [\n",
" [(2.5, 0), (-2.5, 0)],\n",
" [(11.5, 0), (-11.5, 0)]\n",
" ]\n",
"}\n", "}\n",
"scene = 'bedroom'\n", "scene = 'classroom'\n",
"os.chdir(os.path.join(rootdir, f'data/{scenes[scene]}'))\n", "os.chdir(f'{rootdir}/data/__new/{scenes[scene]}')\n",
"print('Change working directory to ', os.getcwd())\n", "print('Change working directory to ', os.getcwd())\n",
"\n", "\n",
"fovea_net = load_net(find_file('fovea'))\n", "fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n", "periph_net = load_net(find_file('periph'))\n",
"\n", "renderer = FoveatedNeuralRenderer(fov_list, res_list,\n",
"# Load Dataset\n", " nn.ModuleList([fovea_net, periph_net, periph_net]),\n",
"views = load_views('demo.json')\n", " res_full, device=device.default())\n",
"print('Dataset loaded.')\n", "\n",
"print('views:', views.size())\n", "for i, param in enumerate(params[scene]):\n",
"gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n", " view = Trans(torch.tensor(param[0][:3], device=device.default()),\n",
" device=device.default())\n", " torch.tensor(euler_to_matrix([-param[0][4], param[0][3], 0]),\n",
"\n", " device=device.default()).view(3, 3))\n",
"for view_idx in range(views.size()[0]):\n", " eye_offset = torch.tensor([0.03, 0, 0], device=device.default())\n",
" test_view = views.get(view_idx)\n", " left_view = Trans(view.trans_point(-eye_offset), view.r)\n",
" left_images = gen(centers[scene][view_idx][0], view.Trans(\n", " right_view = Trans(view.trans_point(eye_offset), view.r)\n",
" test_view.trans_point(\n", " left_images, right_images = renderer(view, param[1], param[2],\n",
" torch.tensor([-0.03, 0, 0], device=device.default())\n", " stereo_disparity=0.06, using_mask=False, ret_raw=False)\n",
" ), test_view.r), mono_trans=test_view)\n", " if True:\n",
" right_images = gen(centers[scene][view_idx][1], view.Trans(\n", " outputdir = '../__demo/stereo/'\n",
" test_view.trans_point(\n", " misc.create_dir(outputdir)\n",
" torch.tensor([0.03, 0, 0], device=device.default())\n", " img.save(left_images['blended'], '%s%s_%d_l.png' % (outputdir, scene, i))\n",
" ), test_view.r), mono_trans=test_view)\n", " img.save(right_images['blended'], '%s%s_%d_r.png' % (outputdir, scene, i))\n",
" #plot_fovea(left_images, right_images, centers[scene][view_idx][0],\n", " stereo_overlap = torch.cat([\n",
" # centers[scene][view_idx][1])\n", " left_images['blended'][:, 0:1],\n",
" outputdir = '../__2_demo/mono_periph/stereo/'\n", " right_images['blended'][:, 1:3]\n",
" misc.create_dir(outputdir)\n", " ], dim=1)\n",
" # for key in images:\n", " img.save(stereo_overlap, '%s%s_%d_stereo.png' % (outputdir, scene, i))\n",
" key = 'blended'\n", " else:\n",
" img.save(left_images[key], '%s%s_view%04d_%s_l.png' % (outputdir, scene, view_idx, key))\n", " plot_figures(left_images, right_images, param[1], param[2])\n"
" img.save(right_images[key], '%s%s_view%04d_%s_r.png' % (outputdir, scene, view_idx, key))\n",
" stereo_overlap = torch.cat([left_images['blended'][:, 0:1], right_images['blended'][:, 1:3]], dim=1)\n",
" img.save(stereo_overlap, '%s%s_view%04d_%s_stereo.png' % (outputdir, scene, view_idx, key))\n",
"\n",
" left_images = gen(centers[scene][view_idx][0], view.Trans(\n",
" test_view.trans_point(\n",
" torch.tensor([-0.03, 0, 0], device=device.default())\n",
" ), test_view.r))\n",
" right_images = gen(centers[scene][view_idx][1], view.Trans(\n",
" test_view.trans_point(\n",
" torch.tensor([0.03, 0, 0], device=device.default())\n",
" ), test_view.r))\n",
" #plot_fovea(left_images, right_images, centers[scene][view_idx][0],\n",
" # centers[scene][view_idx][1])\n",
" outputdir = '../__2_demo/stereo/'\n",
" misc.create_dir(outputdir)\n",
" # for key in images:\n",
" key = 'blended'\n",
" img.save(left_images[key], '%s%s_view%04d_%s_l.png' % (outputdir, scene, view_idx, key))\n",
" img.save(right_images[key], '%s%s_view%04d_%s_r.png' % (outputdir, scene, view_idx, key))\n",
" stereo_overlap = torch.cat([left_images['blended'][:, 0:1], right_images['blended'][:, 1:3]], dim=1)\n",
" img.save(stereo_overlap, '%s%s_view%04d_%s_stereo.png' % (outputdir, scene, view_idx, key))\n"
] ]
} }
], ],
"metadata": { "metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
},
"kernelspec": { "kernelspec": {
"display_name": "Python 3.7.9 64-bit ('pytorch': conda)", "display_name": "Python 3.8.5 64-bit ('base': conda)",
"name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
...@@ -198,7 +210,7 @@ ...@@ -198,7 +210,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.9" "version": "3.8.5"
}, },
"orig_nbformat": 2 "orig_nbformat": 2
}, },
......
This diff is collapsed.
This diff is collapsed.
#/usr/bin/bash
datadir='data/__new/lobby_fovea_r360x80_t1.0'
trainset='data/__new/lobby_fovea_r360x80_t1.0/train1.json'
testset='data/__new/lobby_fovea_r360x80_t1.0/test1.json'
epochs=50
n_nets=$1
nf=$2
n_layers=$3
configid="eval@snerffast${n_nets}-rgb_e6_fc${nf}x${n_layers}_d1.20-6.00_s64_~p"
if [ ! -f "$datadir/$configid/model-epoch_$epochs.pth" ]; then
cont_epoch=0
for ((i=$epochs-1;i>0;i--)) do
if [ -f "$datadir/$configid/model-epoch_$i.pth" ]; then
cont_epoch=$i
break
fi
done
if [ ${cont_epoch} -gt 0 ]; then
python run_spherical_view_syn.py $trainset -e $epochs -m $configid/model-epoch_${cont_epoch}.pth
else
python run_spherical_view_syn.py $trainset -i $configid -e $epochs
fi
fi
if ! ls $datadir/$configid/output_$epochs/perf_* >/dev/null 2>&1; then
python run_spherical_view_syn.py $trainset -t -m $configid/model-epoch_$epochs.pth -o perf
python run_spherical_view_syn.py $testset -t -m $configid/model-epoch_$epochs.pth -o perf
fi
\ No newline at end of file
...@@ -23,7 +23,7 @@ with open(data_desc_path, 'r') as fp: ...@@ -23,7 +23,7 @@ with open(data_desc_path, 'r') as fp:
dataset_desc['cam_params'] = view.CameraParam.convert_camera_params( dataset_desc['cam_params'] = view.CameraParam.convert_camera_params(
dataset_desc['cam_params'], dataset_desc['cam_params'],
(dataset_desc['view_res']['x'], dataset_desc['view_res']['x'])) (dataset_desc['view_res']['y'], dataset_desc['view_res']['x']))
dataset_desc['view_rots'] = [ dataset_desc['view_rots'] = [
view.euler_to_matrix([rot[1], rot[0], 0]) view.euler_to_matrix([rot[1], rot[0], 0])
......
...@@ -9,12 +9,10 @@ from typing import Mapping, List ...@@ -9,12 +9,10 @@ from typing import Mapping, List
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../'))
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0, parser.add_argument('-b', '--batch-size', type=str, help='Resolution')
help='Which CUDA device to use.') parser.add_argument('-o', '--output', type=str)
parser.add_argument('--batch-size', type=str, parser.add_argument('--device', type=int, default=0, help='Which CUDA device to use.')
help='Resolution') parser.add_argument('model', type=str, help='Path of model to export')
parser.add_argument('model', type=str,
help='Path of model to export')
opt = parser.parse_args() opt = parser.parse_args()
# Select device # Select device
...@@ -28,30 +26,34 @@ from utils import device ...@@ -28,30 +26,34 @@ from utils import device
from configs.spherical_view_syn import SphericalViewSynConfig from configs.spherical_view_syn import SphericalViewSynConfig
dir_path, model_file = os.path.split(opt.model) dir_path, model_file = os.path.split(opt.model)
config_id = os.path.split(dir_path)[-1]
batch_size = eval(opt.batch_size) batch_size = eval(opt.batch_size)
batch_size_str = opt.batch_size.replace('*', 'x') batch_size_str = opt.batch_size.replace('*', 'x')
outdir = f"output_{int(os.path.splitext(model_file)[0][12:])}"
os.chdir(dir_path) if not opt.output:
misc.create_dir(outdir) epochs = os.path.splitext(model_file)[0][12:]
outdir = f"{dir_path}/output_{epochs}"
output = os.path.join(outdir, f"net@{batch_size_str}.onnx")
misc.create_dir(outdir)
else:
output = opt.output
config = SphericalViewSynConfig()
def load_net(path): def load_net():
id=os.path.split(dir_path)[-1]#os.path.splitext(os.path.basename(path))[0] config = SphericalViewSynConfig()
config.from_id(id) config.from_id(config_id)
config.SAMPLE_PARAMS['perturb_sample'] = False config.SAMPLE_PARAMS['perturb_sample'] = False
config.name += batch_size_str config.name += batch_size_str
config.print() config.print()
net = config.create_net().to(device.default()) net = config.create_net().to(device.default())
netio.load(path, net) netio.load(opt.model, net)
return net, id return net
def export_net(net: torch.nn.Module, name: str, def export_net(net: torch.nn.Module, path: str, input: Mapping[str, List[int]],
input: Mapping[str, List[int]], output_names: List[str]): output_names: List[str]):
outpath = os.path.join(outdir, f"{name}@{batch_size_str}.onnx")
input_tensors = tuple([ input_tensors = tuple([
torch.empty(size, device=device.default()) torch.empty(size, device=device.default())
for size in input.values() for size in input.values()
...@@ -59,21 +61,25 @@ def export_net(net: torch.nn.Module, name: str, ...@@ -59,21 +61,25 @@ def export_net(net: torch.nn.Module, name: str,
onnx.export( onnx.export(
net, net,
input_tensors, input_tensors,
outpath, path,
export_params=True, # store the trained parameter weights inside the model file export_params=True, # store the trained parameter weights inside the model file
verbose=True, verbose=True,
opset_version=9, # the ONNX version to export the model to opset_version=9, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding do_constant_folding=True, # whether to execute constant folding
input_names=input.keys(), # the model's input names input_names=list(input.keys()), # the model's input names
output_names=output_names # the model's output names output_names=output_names # the model's output names
) )
print('Model exported to ' + outpath) print('Model exported to ' + path)
if __name__ == "__main__": if __name__ == "__main__":
with torch.no_grad(): with torch.no_grad():
net: SnerfFast = load_net(model_file)[0] net: SnerfFast = load_net()
export_net(SnerfFastExport(net), 'net', { export_net(
'Encoded': [batch_size, net.n_samples, net.coord_encoder.out_dim], SnerfFastExport(net),
'Depths': [batch_size, net.n_samples] output,
}, ['Colors']) {
\ No newline at end of file 'Encoded': [batch_size, net.n_samples, net.coord_encoder.out_dim],
'Depths': [batch_size, net.n_samples]
},
['Colors'])
import sys
import os
import json
rootdir = os.path.abspath(sys.path[0] + '/../')
datadir = f"{rootdir}/data/__new/classroom_fovea_r360x80_t0.6"
n_nets_arr = [ 1, 2, 4, 8 ]
nf_arr = [ 64, 128, 256, 512, 1024 ]
n_layers_arr = [ 2, 4, 8 ]
head = "Nets,Layers," + ",".join([f"{val}" for val in nf_arr])
perf_train_table = []
perf_test_table = []
perf_time_table = []
for n_nets in n_nets_arr:
for n_layers in n_layers_arr:
perf_train_row = []
perf_test_row = []
perf_time_row = []
for nf in nf_arr:
configid = f"eval@snerffast{n_nets}-rgb_e6_fc{nf}x{n_layers}_d1.00-7.00_s64_~p"
outputdir = f"{datadir}/{configid}/output_50"
if not os.path.exists(outputdir):
perf_train_row.append("-")
perf_test_row.append("-")
perf_time_row.append("-")
continue
perf_test_found=False
perf_train_found=False
for file in os.listdir(outputdir):
if file.startswith("perf_r120x80_test"):
if perf_test_found:
os.remove(f"{outputdir}/{file}")
else:
perf_test_row.append(os.path.splitext(file)[0].split("_")[-1])
perf_test_found=True
elif file.startswith("perf_r120x80"):
if perf_train_found:
os.remove(f"{outputdir}/{file}")
else:
perf_train_row.append(os.path.splitext(file)[0].split("_")[-1])
perf_train_found=True
if perf_train_found == False:
perf_train_row.append("-")
if perf_test_found == False:
perf_test_row.append("-")
# Collect time values
time_file = f"{datadir}/eval_trt/time/eval_{n_nets}x{nf}x{n_layers}.json"
if not os.path.exists(time_file):
perf_time_row.append("-")
else:
with open(time_file) as fp:
time_data = json.load(fp)
time = 0
for item in time_data:
time += item['computeMs']
time /= len(time_data)
perf_time_row.append(f"{time:.1f}")
perf_train_table.append(perf_train_row)
perf_test_table.append(perf_test_row)
perf_time_table.append(perf_time_row)
perf_train_content = head + "\n"
for i, row in enumerate(perf_train_table):
if i % len(n_layers_arr) == 0:
perf_train_content += f"{n_nets_arr[i // len(n_layers_arr)]}"
perf_train_content += f",{n_layers_arr[i % len(n_layers_arr)]},"
perf_train_content += ",".join(row) + "\n"
perf_test_content = head + "\n"
for i, row in enumerate(perf_test_table):
if i % len(n_layers_arr) == 0:
perf_test_content += f"{n_nets_arr[i // len(n_layers_arr)]}"
perf_test_content += f",{n_layers_arr[i % len(n_layers_arr)]},"
perf_test_content += ",".join(row) + "\n"
perf_time_content = head + "\n"
for i, row in enumerate(perf_time_table):
if i % len(n_layers_arr) == 0:
perf_time_content += f"{n_nets_arr[i // len(n_layers_arr)]}"
perf_time_content += f",{n_layers_arr[i % len(n_layers_arr)]},"
perf_time_content += ",".join(row) + "\n"
with open(f"{datadir}/eval_perf.csv", "w") as fp:
fp.write("Train:\n")
fp.write(perf_train_content)
fp.write("Test:\n")
fp.write(perf_test_content)
fp.write("Time:\n")
fp.write(perf_time_content)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment