{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set CUDA:2 as current device.\n" ] } ], "source": [ "import sys\n", "import os\n", "import torch\n", "\n", "sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n", "torch.cuda.set_device(2)\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "torch.autograd.set_grad_enabled(False)\n", "\n", "from data.spherical_view_syn import *\n", "from configs.spherical_view_syn import SphericalViewSynConfig\n", "from utils import netio\n", "from utils import img\n", "from utils import misc\n", "from utils import device\n", "from utils import view\n", "from components.gen_final import GenFinal\n", "\n", "\n", "def load_net(path):\n", " config = SphericalViewSynConfig()\n", " config.from_id(path[:-4])\n", " config.SAMPLE_PARAMS['perturb_sample'] = False\n", " #config.print()\n", " net = config.create_net().to(device.default())\n", " netio.load(path, net)\n", " return net\n", "\n", "\n", "def find_file(prefix):\n", " for path in os.listdir():\n", " if path.startswith(prefix + '@'):\n", " return path\n", " return None\n", "\n", "\n", "def load_views(data_desc_file) -> view.Trans:\n", " with open(data_desc_file, 'r', encoding='utf-8') as file:\n", " data_desc = json.loads(file.read())\n", " view_centers = torch.tensor(\n", " data_desc['view_centers'], device=device.default()).view(-1, 3)\n", " view_rots = torch.tensor(\n", " data_desc['view_rots'], device=device.default()).view(-1, 3, 3)\n", " return view.Trans(view_centers, view_rots)\n", "\n", "scenes = {\n", " 'gas': '__0_user_study/us_gas_all_in_one',\n", " 'mc': '__0_user_study/us_mc_all_in_one',\n", " 'bedroom': 'bedroom_all_in_one',\n", " 'gallery': 'gallery_all_in_one',\n", " 'lobby': 'lobby_all_in_one'\n", "}\n", "\n", "fov_list = [20, 45, 110]\n", "res_list = [(128, 128), (256, 256), (256, 230)]\n", "res_full = (1600, 1440)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Change working directory to /home/dengnc/deep_view_syn/data/__0_user_study/us_mc_all_in_one\n", "Load net from fovea@nmsl-rgb_e10_fc128x4_d1-50_s32.pth ...\n", "Load net from periph@nnmsl-rgb_e10_fc64x4_d1-50_s16.pth ...\n", "Dataset loaded.\n", "views: [110]\n" ] } ], "source": [ "scene = 'mc'\n", "os.chdir(sys.path[0] + '/../data/' + scenes[scene])\n", "print('Change working directory to ', os.getcwd())\n", "\n", "fovea_net = load_net(find_file('fovea'))\n", "periph_net = load_net(find_file('periph'))\n", "\n", "# Load Dataset\n", "views = load_views('nerf_views.json')\n", "print('Dataset loaded.')\n", "print('views:', views.size())\n", "gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n", " device=device.default())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "for view_idx in range(20):\n", " center = (0, 0)\n", " test_view = views.get(view_idx)\n", " images = gen.gen(center, test_view,\n", " mono_trans=view.Trans(test_view.trans_point(\n", " torch.tensor([-0.03, 0, 0], device=device.default())\n", " ), test_view.r))\n", " #left_images = gen(center, view.Trans(\n", " # test_view.trans_point(\n", " # torch.tensor([-0.03, 0, 0], device=device.default())\n", " # ), test_view.r), mono_trans=test_view, ret_raw=True)\n", " #right_images = gen(center, view.Trans(\n", " # test_view.trans_point(\n", " # torch.tensor([0.03, 0, 0], device=device.default())\n", " # ), test_view.r), mono_trans=test_view, ret_raw=True)\n", " #plot_figures(images, center)\n", "\n", " outputdir = '/home/dengnc/deep_view_syn/data/__1_eval/output_mono_periph/ref_as_right_eye/%s/' % scene\n", " misc.create_dir(outputdir)\n", " #for key in images:\n", " key = 'blended'\n", " img.save(images[key], outputdir + 'view%04d_%s.png' % (view_idx, key))\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "gen() takes 3 positional arguments but 4 were given", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mcenter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mtest_view\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mviews\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mimages\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcenter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_view\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;31m#plot_figures(images, center)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: gen() takes 3 positional arguments but 4 were given" ] } ], "source": [ "import numpy as np\n", "gaze_idx = 0\n", "for y in np.linspace(-200, 200, 11):\n", " for x in np.linspace(-200, 200, 11):\n", " center = (int(x), int(y))\n", " test_view = views.get(0)\n", " images = gen.gen(center, test_view, True)\n", " #plot_figures(images, center)\n", "\n", " misc.create_dir('output/eval_gaze')\n", " out_path = 'output/eval_gaze/gaze%03d_%d,%d.png' % (gaze_idx, x, y)\n", " img.save(images['blended'], out_path)\n", " print('Output ' + out_path)\n", " gaze_idx += 1" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.9 64-bit ('pytorch': conda)", "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }