Commit c10f614f authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent dcba5844
......@@ -27,7 +27,7 @@ class SphericalViewSynDataset(object):
def __init__(self, dataset_desc_path: str, load_images: bool = True,
load_depths: bool = False, load_bins: bool = False, c: int = color.RGB,
calculate_rays: bool = True, res: Tuple[int, int] = None):
calculate_rays: bool = True, res: Tuple[int, int] = None, load_views=None):
"""
Initialize data loader for spherical view synthesis task
......@@ -52,7 +52,7 @@ class SphericalViewSynDataset(object):
self.load_bins = load_bins
# Load dataset description file
self._load_desc(dataset_desc_path, res)
self._load_desc(dataset_desc_path, res, load_views)
# Load view images
if self.load_images:
......@@ -98,7 +98,7 @@ class SphericalViewSynDataset(object):
disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
return torch.reciprocal(disp_val)
def _load_desc(self, path, res=None):
def _load_desc(self, path, res=None, load_views=None):
with open(path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
if not data_desc.get('view_file_pattern'):
......@@ -127,11 +127,17 @@ class SphericalViewSynDataset(object):
[view.euler_to_matrix([rot[1], rot[0], 0]) for rot in data_desc['view_rots']]
if len(data_desc['view_rots'][0]) == 2 else data_desc['view_rots'],
device=device.default()).view(-1, 3, 3) # (N, 3, 3)
#self.view_centers = self.view_centers[:6]
#self.view_rots = self.view_rots[:6]
self.view_idxs = torch.tensor(
data_desc['views'] if 'views' in data_desc else list(range(self.view_centers.size(0))),
device=device.default())
if load_views is not None:
self.view_centers = self.view_centers[load_views]
self.view_rots = self.view_rots[load_views]
self.view_idxs = self.view_idxs[load_views]
self.n_views = self.view_centers.size(0)
self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
self.view_idxs = data_desc['views'][:self.n_views] if 'views' in data_desc else range(self.n_views)
if 'gl_coord' in data_desc and data_desc['gl_coord'] == True:
print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
......
......@@ -70,6 +70,10 @@
" plt.subplot(133)\n",
" img.plot(images['layers_img'][2])\n",
" plt.figure(figsize=(12, 12))\n",
" img.plot(images['overlaid'])\n",
" plt.figure(figsize=(12, 12))\n",
" img.plot(images['blended_raw'])\n",
" plt.figure(figsize=(12, 12))\n",
" img.plot(images['blended'])\n",
"\n",
"\n",
......@@ -87,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"metadata": {},
"outputs": [
{
......@@ -108,12 +112,12 @@
"fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n",
"renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]),\n",
" res_full, using_mask=False, device=device.default())\n"
" res_full, device=device.default())"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
......@@ -129,13 +133,14 @@
" ],\n",
" 'barbershop': [\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" #[0, 0, 0, 20, 0, -300, 50],\n",
" #[0, 0, 0, -140, -30, 150, -250],\n",
" #[0, 0, 0, -60, -30, 75, -125],\n",
" [0, 0, 0, 20, 0, -300, 50],\n",
" [0, 0, 0, -140, -30, 150, -250],\n",
" [0, 0, 0, -60, -30, 75, -125],\n",
" [0, 0, 0, -10, -5, 0, 0]\n",
" ],\n",
" 'lobby': [\n",
" #[0, 0, 0, 0, 0, 75, 0],\n",
" #[0, 0, 0, 0, 0, 5, 150],\n",
" [0, 0, 0, 0, 0, 75, 0],\n",
" [0, 0, 0, 0, 0, 5, 150],\n",
" [0, 0, 0, -120, 0, 75, 50],\n",
" ]\n",
"}\n",
......@@ -143,14 +148,17 @@
"for i, param in enumerate(params[scene]):\n",
" view = Trans(torch.tensor(param[:3], device=device.default()),\n",
" torch.tensor(euler_to_matrix([-param[4], param[3], 0]), device=device.default()).view(3, 3))\n",
" images = renderer(view, param[-2:])\n",
" if False:\n",
" images = renderer(view, param[-2:], using_mask=False, ret_raw=True)\n",
" images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
" if True:\n",
" outputdir = '../__demo/mono/'\n",
" misc.create_dir(outputdir)\n",
" img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
" img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
" img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
" img.save(images['blended'], f'{outputdir}{scene}_{i}_blended.png')\n",
" img.save(images['overlaid'], f'{outputdir}{scene}_{i}_overlaid.png')\n",
" img.save(images['blended_raw'], f'{outputdir}{scene}_{i}_blended_raw.png')\n",
" else:\n",
" images = plot_images(images)\n"
]
......@@ -212,8 +220,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('base': conda)",
"name": "python385jvsc74a57bd082066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
......@@ -231,9 +240,8 @@
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
}
},
"orig_nbformat": 2
}
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
}
......@@ -2,37 +2,44 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:0 as current device.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n",
"\n",
"torch.cuda.set_device(2)\n",
"torch.cuda.set_device(0)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"from data.spherical_view_syn import *\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n",
"from utils import misc\n",
"from utils import img\n",
"from utils import device\n",
"from utils import view\n",
"from utils.view import *\n",
"from components.fnr import FoveatedNeuralRenderer\n",
"\n",
"\n",
"def load_net(path):\n",
" config = SphericalViewSynConfig()\n",
" config.from_id(path[:-4])\n",
" config.from_id(os.path.splitext(path)[0])\n",
" config.SAMPLE_PARAMS['perturb_sample'] = False\n",
" # config.print()\n",
" net = config.create_net().to(device.default())\n",
" netio.load(path, net)\n",
" return net\n",
......@@ -45,14 +52,14 @@
" return None\n",
"\n",
"\n",
"def load_views(data_desc_file) -> view.Trans:\n",
"def load_views(data_desc_file) -> Trans:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" data_desc = json.loads(file.read())\n",
" view_centers = torch.tensor(\n",
" data_desc['view_centers'], device=device.default()).view(-1, 3)\n",
" view_rots = torch.tensor(\n",
" data_desc['view_rots'], device=device.default()).view(-1, 3, 3)\n",
" return view.Trans(view_centers, view_rots)\n",
" return Trans(view_centers, view_rots)\n",
"\n",
"\n",
"def plot_cross(center, res):\n",
......@@ -78,115 +85,120 @@
" color=[0, 1, 0])\n",
"\n",
"\n",
"def plot_fovea(left_images, right_images, left_center, right_center):\n",
" plt.figure(figsize=(8, 4))\n",
"def plot_figures(left_images, right_images, left_center, right_center):\n",
" # Plot Fovea\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" img.plot(left_images['fovea'])\n",
" fovea_res = left_images['fovea'].size()[-2:]\n",
" img.plot(left_images['layers_img'][0])\n",
" fovea_res = left_images['layers_img'][0].size()[-2:]\n",
" plot_cross((0, 0), fovea_res)\n",
" plt.subplot(122)\n",
" img.plot(right_images['fovea'])\n",
" img.plot(right_images['layers_img'][0])\n",
" plot_cross((0, 0), fovea_res)\n",
"\n",
" # Plot Mid\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" img.plot(left_images['layers_img'][1])\n",
" plt.subplot(122)\n",
" img.plot(right_images['layers_img'][1])\n",
"\n",
" # Plot Periph\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" img.plot(left_images['layers_img'][2])\n",
" plt.subplot(122)\n",
" img.plot(right_images['layers_img'][2])\n",
"\n",
" # Plot Blended\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(121)\n",
" img.plot(left_images['blended'])\n",
" full_res = left_images['blended'].size()[-2:]\n",
" plot_cross(left_center, full_res)\n",
" plt.subplot(122)\n",
" img.plot(right_images['blended'])\n",
" plot_cross(right_center, full_res)\n",
"\n",
"\n",
"scenes = {\n",
" 'gas': '__0_user_study/us_gas_all_in_one',\n",
" 'mc': '__0_user_study/us_mc_all_in_one',\n",
" 'bedroom': 'bedroom_all_in_one',\n",
" 'gallery': 'gallery_all_in_one',\n",
" 'lobby': 'lobby_all_in_one'\n",
" 'classroom': 'classroom_all',\n",
" 'stones': 'stones_all',\n",
" 'barbershop': 'barbershop_all',\n",
" 'lobby': 'lobby_all'\n",
"}\n",
"\n",
"\n",
"fov_list = [20, 45, 110]\n",
"res_list = [(128, 128), (256, 256), (256, 230)]\n",
"res_list = [(256, 256), (256, 256), (256, 230)]\n",
"res_full = (1600, 1440)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Change working directory to /home/dengnc/dvs/data/__new/classroom_all\n",
"Load net from fovea@snerffast4-rgb_e6_fc256x8_d1.00-7.00_s64_~p.pth ...\n",
"Load net from periph@snerffast4-rgb_e6_fc128x4_d1.00-7.00_s64_~p.pth ...\n"
]
}
],
"source": [
"centers = {\n",
" 'gas': [\n",
" [(3.5, 0), (-3.5, 0)],\n",
" [(1.5, 0), (-1.5, 0)]\n",
" ],\n",
" 'mc': [\n",
" [(2, 0), (-2, 0)],\n",
" [(2, 0), (-2, 0)]\n",
"params = {\n",
" 'classroom': [\n",
" [(0, 0, 0, 0, 0), (1, -83), (-5, -83)],\n",
" [(0, 0, 0, 0, 0), (-171, 55), (-178, 55)],\n",
" [(0, 0, 0, 0, 0), (60, 55), (55, 55)],\n",
" [(0, 0, 0, 0, 0), (138, 160), (130, 160)]\n",
" ],\n",
" 'bedroom': [\n",
" [(5, 0), (-5, 0)],\n",
" [(6, 0), (-6, 0)],\n",
" [(5, 0), (-5, 0)]\n",
" ],\n",
" 'gallery': [\n",
" [(2.5, 0), (-2.5, 0)],\n",
" [(11.5, 0), (-11.5, 0)]\n",
" ]\n",
"}\n",
"scene = 'bedroom'\n",
"os.chdir(os.path.join(rootdir, f'data/{scenes[scene]}'))\n",
"scene = 'classroom'\n",
"os.chdir(f'{rootdir}/data/__new/{scenes[scene]}')\n",
"print('Change working directory to ', os.getcwd())\n",
"\n",
"fovea_net = load_net(find_file('fovea'))\n",
"periph_net = load_net(find_file('periph'))\n",
"\n",
"# Load Dataset\n",
"views = load_views('demo.json')\n",
"print('Dataset loaded.')\n",
"print('views:', views.size())\n",
"gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n",
" device=device.default())\n",
"\n",
"for view_idx in range(views.size()[0]):\n",
" test_view = views.get(view_idx)\n",
" left_images = gen(centers[scene][view_idx][0], view.Trans(\n",
" test_view.trans_point(\n",
" torch.tensor([-0.03, 0, 0], device=device.default())\n",
" ), test_view.r), mono_trans=test_view)\n",
" right_images = gen(centers[scene][view_idx][1], view.Trans(\n",
" test_view.trans_point(\n",
" torch.tensor([0.03, 0, 0], device=device.default())\n",
" ), test_view.r), mono_trans=test_view)\n",
" #plot_fovea(left_images, right_images, centers[scene][view_idx][0],\n",
" # centers[scene][view_idx][1])\n",
" outputdir = '../__2_demo/mono_periph/stereo/'\n",
" misc.create_dir(outputdir)\n",
" # for key in images:\n",
" key = 'blended'\n",
" img.save(left_images[key], '%s%s_view%04d_%s_l.png' % (outputdir, scene, view_idx, key))\n",
" img.save(right_images[key], '%s%s_view%04d_%s_r.png' % (outputdir, scene, view_idx, key))\n",
" stereo_overlap = torch.cat([left_images['blended'][:, 0:1], right_images['blended'][:, 1:3]], dim=1)\n",
" img.save(stereo_overlap, '%s%s_view%04d_%s_stereo.png' % (outputdir, scene, view_idx, key))\n",
"\n",
" left_images = gen(centers[scene][view_idx][0], view.Trans(\n",
" test_view.trans_point(\n",
" torch.tensor([-0.03, 0, 0], device=device.default())\n",
" ), test_view.r))\n",
" right_images = gen(centers[scene][view_idx][1], view.Trans(\n",
" test_view.trans_point(\n",
" torch.tensor([0.03, 0, 0], device=device.default())\n",
" ), test_view.r))\n",
" #plot_fovea(left_images, right_images, centers[scene][view_idx][0],\n",
" # centers[scene][view_idx][1])\n",
" outputdir = '../__2_demo/stereo/'\n",
" misc.create_dir(outputdir)\n",
" # for key in images:\n",
" key = 'blended'\n",
" img.save(left_images[key], '%s%s_view%04d_%s_l.png' % (outputdir, scene, view_idx, key))\n",
" img.save(right_images[key], '%s%s_view%04d_%s_r.png' % (outputdir, scene, view_idx, key))\n",
" stereo_overlap = torch.cat([left_images['blended'][:, 0:1], right_images['blended'][:, 1:3]], dim=1)\n",
" img.save(stereo_overlap, '%s%s_view%04d_%s_stereo.png' % (outputdir, scene, view_idx, key))\n"
"renderer = FoveatedNeuralRenderer(fov_list, res_list,\n",
" nn.ModuleList([fovea_net, periph_net, periph_net]),\n",
" res_full, device=device.default())\n",
"\n",
"for i, param in enumerate(params[scene]):\n",
" view = Trans(torch.tensor(param[0][:3], device=device.default()),\n",
" torch.tensor(euler_to_matrix([-param[0][4], param[0][3], 0]),\n",
" device=device.default()).view(3, 3))\n",
" eye_offset = torch.tensor([0.03, 0, 0], device=device.default())\n",
" left_view = Trans(view.trans_point(-eye_offset), view.r)\n",
" right_view = Trans(view.trans_point(eye_offset), view.r)\n",
" left_images, right_images = renderer(view, param[1], param[2],\n",
" stereo_disparity=0.06, using_mask=False, ret_raw=False)\n",
" if True:\n",
" outputdir = '../__demo/stereo/'\n",
" misc.create_dir(outputdir)\n",
" img.save(left_images['blended'], '%s%s_%d_l.png' % (outputdir, scene, i))\n",
" img.save(right_images['blended'], '%s%s_%d_r.png' % (outputdir, scene, i))\n",
" stereo_overlap = torch.cat([\n",
" left_images['blended'][:, 0:1],\n",
" right_images['blended'][:, 1:3]\n",
" ], dim=1)\n",
" img.save(stereo_overlap, '%s%s_%d_stereo.png' % (outputdir, scene, i))\n",
" else:\n",
" plot_figures(left_images, right_images, param[1], param[2])\n"
]
}
],
"metadata": {
"interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
},
"kernelspec": {
"display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
"name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
"display_name": "Python 3.8.5 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
......@@ -198,7 +210,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
"version": "3.8.5"
},
"orig_nbformat": 2
},
......
This diff is collapsed.
This diff is collapsed.
#/usr/bin/bash
datadir='data/__new/lobby_fovea_r360x80_t1.0'
trainset='data/__new/lobby_fovea_r360x80_t1.0/train1.json'
testset='data/__new/lobby_fovea_r360x80_t1.0/test1.json'
epochs=50
n_nets=$1
nf=$2
n_layers=$3
configid="eval@snerffast${n_nets}-rgb_e6_fc${nf}x${n_layers}_d1.20-6.00_s64_~p"
if [ ! -f "$datadir/$configid/model-epoch_$epochs.pth" ]; then
cont_epoch=0
for ((i=$epochs-1;i>0;i--)) do
if [ -f "$datadir/$configid/model-epoch_$i.pth" ]; then
cont_epoch=$i
break
fi
done
if [ ${cont_epoch} -gt 0 ]; then
python run_spherical_view_syn.py $trainset -e $epochs -m $configid/model-epoch_${cont_epoch}.pth
else
python run_spherical_view_syn.py $trainset -i $configid -e $epochs
fi
fi
if ! ls $datadir/$configid/output_$epochs/perf_* >/dev/null 2>&1; then
python run_spherical_view_syn.py $trainset -t -m $configid/model-epoch_$epochs.pth -o perf
python run_spherical_view_syn.py $testset -t -m $configid/model-epoch_$epochs.pth -o perf
fi
\ No newline at end of file
......@@ -23,7 +23,7 @@ with open(data_desc_path, 'r') as fp:
dataset_desc['cam_params'] = view.CameraParam.convert_camera_params(
dataset_desc['cam_params'],
(dataset_desc['view_res']['x'], dataset_desc['view_res']['x']))
(dataset_desc['view_res']['y'], dataset_desc['view_res']['x']))
dataset_desc['view_rots'] = [
view.euler_to_matrix([rot[1], rot[0], 0])
......
......@@ -9,12 +9,10 @@ from typing import Mapping, List
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
help='Which CUDA device to use.')
parser.add_argument('--batch-size', type=str,
help='Resolution')
parser.add_argument('model', type=str,
help='Path of model to export')
parser.add_argument('-b', '--batch-size', type=str, help='Resolution')
parser.add_argument('-o', '--output', type=str)
parser.add_argument('--device', type=int, default=0, help='Which CUDA device to use.')
parser.add_argument('model', type=str, help='Path of model to export')
opt = parser.parse_args()
# Select device
......@@ -28,30 +26,34 @@ from utils import device
from configs.spherical_view_syn import SphericalViewSynConfig
dir_path, model_file = os.path.split(opt.model)
config_id = os.path.split(dir_path)[-1]
batch_size = eval(opt.batch_size)
batch_size_str = opt.batch_size.replace('*', 'x')
outdir = f"output_{int(os.path.splitext(model_file)[0][12:])}"
os.chdir(dir_path)
misc.create_dir(outdir)
if not opt.output:
epochs = os.path.splitext(model_file)[0][12:]
outdir = f"{dir_path}/output_{epochs}"
output = os.path.join(outdir, f"net@{batch_size_str}.onnx")
misc.create_dir(outdir)
else:
output = opt.output
config = SphericalViewSynConfig()
def load_net(path):
id=os.path.split(dir_path)[-1]#os.path.splitext(os.path.basename(path))[0]
config.from_id(id)
def load_net():
config = SphericalViewSynConfig()
config.from_id(config_id)
config.SAMPLE_PARAMS['perturb_sample'] = False
config.name += batch_size_str
config.print()
net = config.create_net().to(device.default())
netio.load(path, net)
return net, id
netio.load(opt.model, net)
return net
def export_net(net: torch.nn.Module, name: str,
input: Mapping[str, List[int]], output_names: List[str]):
outpath = os.path.join(outdir, f"{name}@{batch_size_str}.onnx")
def export_net(net: torch.nn.Module, path: str, input: Mapping[str, List[int]],
output_names: List[str]):
input_tensors = tuple([
torch.empty(size, device=device.default())
for size in input.values()
......@@ -59,21 +61,25 @@ def export_net(net: torch.nn.Module, name: str,
onnx.export(
net,
input_tensors,
outpath,
path,
export_params=True, # store the trained parameter weights inside the model file
verbose=True,
opset_version=9, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding
input_names=input.keys(), # the model's input names
output_names=output_names # the model's output names
do_constant_folding=True, # whether to execute constant folding
input_names=list(input.keys()), # the model's input names
output_names=output_names # the model's output names
)
print('Model exported to ' + outpath)
print('Model exported to ' + path)
if __name__ == "__main__":
with torch.no_grad():
net: SnerfFast = load_net(model_file)[0]
export_net(SnerfFastExport(net), 'net', {
'Encoded': [batch_size, net.n_samples, net.coord_encoder.out_dim],
'Depths': [batch_size, net.n_samples]
}, ['Colors'])
\ No newline at end of file
net: SnerfFast = load_net()
export_net(
SnerfFastExport(net),
output,
{
'Encoded': [batch_size, net.n_samples, net.coord_encoder.out_dim],
'Depths': [batch_size, net.n_samples]
},
['Colors'])
import sys
import os
import json
rootdir = os.path.abspath(sys.path[0] + '/../')
datadir = f"{rootdir}/data/__new/classroom_fovea_r360x80_t0.6"
n_nets_arr = [ 1, 2, 4, 8 ]
nf_arr = [ 64, 128, 256, 512, 1024 ]
n_layers_arr = [ 2, 4, 8 ]
head = "Nets,Layers," + ",".join([f"{val}" for val in nf_arr])
perf_train_table = []
perf_test_table = []
perf_time_table = []
for n_nets in n_nets_arr:
for n_layers in n_layers_arr:
perf_train_row = []
perf_test_row = []
perf_time_row = []
for nf in nf_arr:
configid = f"eval@snerffast{n_nets}-rgb_e6_fc{nf}x{n_layers}_d1.00-7.00_s64_~p"
outputdir = f"{datadir}/{configid}/output_50"
if not os.path.exists(outputdir):
perf_train_row.append("-")
perf_test_row.append("-")
perf_time_row.append("-")
continue
perf_test_found=False
perf_train_found=False
for file in os.listdir(outputdir):
if file.startswith("perf_r120x80_test"):
if perf_test_found:
os.remove(f"{outputdir}/{file}")
else:
perf_test_row.append(os.path.splitext(file)[0].split("_")[-1])
perf_test_found=True
elif file.startswith("perf_r120x80"):
if perf_train_found:
os.remove(f"{outputdir}/{file}")
else:
perf_train_row.append(os.path.splitext(file)[0].split("_")[-1])
perf_train_found=True
if perf_train_found == False:
perf_train_row.append("-")
if perf_test_found == False:
perf_test_row.append("-")
# Collect time values
time_file = f"{datadir}/eval_trt/time/eval_{n_nets}x{nf}x{n_layers}.json"
if not os.path.exists(time_file):
perf_time_row.append("-")
else:
with open(time_file) as fp:
time_data = json.load(fp)
time = 0
for item in time_data:
time += item['computeMs']
time /= len(time_data)
perf_time_row.append(f"{time:.1f}")
perf_train_table.append(perf_train_row)
perf_test_table.append(perf_test_row)
perf_time_table.append(perf_time_row)
perf_train_content = head + "\n"
for i, row in enumerate(perf_train_table):
if i % len(n_layers_arr) == 0:
perf_train_content += f"{n_nets_arr[i // len(n_layers_arr)]}"
perf_train_content += f",{n_layers_arr[i % len(n_layers_arr)]},"
perf_train_content += ",".join(row) + "\n"
perf_test_content = head + "\n"
for i, row in enumerate(perf_test_table):
if i % len(n_layers_arr) == 0:
perf_test_content += f"{n_nets_arr[i // len(n_layers_arr)]}"
perf_test_content += f",{n_layers_arr[i % len(n_layers_arr)]},"
perf_test_content += ",".join(row) + "\n"
perf_time_content = head + "\n"
for i, row in enumerate(perf_time_table):
if i % len(n_layers_arr) == 0:
perf_time_content += f"{n_nets_arr[i // len(n_layers_arr)]}"
perf_time_content += f",{n_layers_arr[i % len(n_layers_arr)]},"
perf_time_content += ",".join(row) + "\n"
with open(f"{datadir}/eval_perf.csv", "w") as fp:
fp.write("Train:\n")
fp.write(perf_train_content)
fp.write("Test:\n")
fp.write(perf_test_content)
fp.write("Time:\n")
fp.write(perf_time_content)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment