{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set CUDA:2 as current device.\n" ] } ], "source": [ "import sys\n", "import os\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import torchvision.transforms.functional as trans_f\n", "\n", "sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n", "torch.cuda.set_device(2)\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "\n", "from data.spherical_view_syn import *\n", "from configs.spherical_view_syn import SphericalViewSynConfig\n", "from utils import netio\n", "from utils import misc\n", "from utils import img\n", "from utils import device\n", "from utils import view\n", "from components.foveation import Foveation\n", "from components.gen_final import GenFinal\n", "\n", "\n", "def load_net(path):\n", " config = SphericalViewSynConfig()\n", " config.from_id(path[:-4])\n", " config.SAMPLE_PARAMS['perturb_sample'] = False\n", " config.print()\n", " net = config.create_net().to(device.default())\n", " netio.load(path, net)\n", " return net\n", "\n", "\n", "def find_file(prefix):\n", " for path in os.listdir():\n", " if path.startswith(prefix):\n", " return path\n", " return None\n", "\n", "\n", "def load_views(data_desc_file) -> view.Trans:\n", " with open(data_desc_file, 'r', encoding='utf-8') as file:\n", " data_desc = json.loads(file.read())\n", " samples = data_desc['samples'] if 'samples' in data_desc else [-1]\n", " view_centers = torch.tensor(\n", " data_desc['view_centers'], device=device.default()).view(samples + [3])\n", " view_rots = torch.tensor(\n", " data_desc['view_rots'], device=device.default()).view(samples + [3, 3])\n", " return view.Trans(view_centers, view_rots)\n", "\n", "\n", "def read_ref_images(idx):\n", " patt = 'ref/view_%04d.png'\n", " if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:\n", " return img.load([patt % i for i in idx])\n", " else:\n", " return img.load(patt % idx)\n", "\n", "\n", "def adjust_cam(cam, vr_cam, gaze_center):\n", " fovea_offset = (\n", " (gaze_center[0]) / vr_cam.f[0].item() * cam.f[0].item(),\n", " (gaze_center[1]) / vr_cam.f[1].item() * cam.f[1].item()\n", " )\n", " cam.c[0] = cam.res[1] / 2 - fovea_offset[0]\n", " cam.c[1] = cam.res[0] / 2 - fovea_offset[1]\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Change working directory to /home/dengnc/deep_view_syn/data/__0_user_study/us_gas_all_in_one\n", "==== Config fovea ====\n", "Net type: nmsl\n", "Encode dim: 10\n", "Optimizer decay: 0\n", "Normalize: False\n", "Direction as input: False\n", "Full-connected network parameters: {'nf': 128, 'n_layers': 4, 'skips': []}\n", "Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 32, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n", "==========================\n", "Load net from fovea@nmsl-rgb_e10_fc128x4_d1-50_s32.pth ...\n", "==== Config periph ====\n", "Net type: nnmsl\n", "Encode dim: 10\n", "Optimizer decay: 0\n", "Normalize: False\n", "Direction as input: False\n", "Full-connected network parameters: {'nf': 64, 'n_layers': 4, 'skips': []}\n", "Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n", "==========================\n", "Load net from periph@nnmsl-rgb_e10_fc64x4_d1-50_s16.pth ...\n", "Dataset loaded.\n", "views: [5, 5, 5, 5, 5]\n" ] } ], "source": [ "os.chdir(sys.path[0] + '/../data/__0_user_study/us_gas_all_in_one')\n", "#os.chdir(sys.path[0] + '/../data/__0_user_study/us_mc_all_in_one')\n", "#os.chdir(sys.path[0] + '/../data/bedroom_all_in_one')\n", "print('Change working directory to ', os.getcwd())\n", "torch.autograd.set_grad_enabled(False)\n", "\n", "fovea_net = load_net(find_file('fovea'))\n", "periph_net = load_net(find_file('periph'))\n", "\n", "# Load Dataset\n", "views = load_views('views.json')\n", "#ref_dataset = SphericalViewSynDataset('ref.json', load_images=False, calculate_rays=False)\n", "print('Dataset loaded.')\n", "\n", "print('views:', views.size())\n", "#print('ref views:', ref_dataset.samples)\n", "\n", "fov_list = [20, 45, 110]\n", "res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)]\n", "res_full = (1600, 1440)\n", "\n", "gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net, device.default())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ==gas==\n", "set_id = 0\n", "left_center = (-137, 64)\n", "right_center = (-142, 64)\n", "set_id = 1\n", "left_center = (133, -44)\n", "right_center = (130, -44)\n", "set_id = 2\n", "left_center = (-20, -5)\n", "right_center = (-25, -5)\n", "# ==mc==\n", "#set_id = 3\n", "#left_center = (-107, 80)\n", "#right_center = (-112, 80)\n", "#set_id = 4\n", "#left_center = (-17, -90)\n", "#right_center = (-22, -90)\n", "#set_id = 5\n", "#left_center = (95, 30)\n", "#right_center = (91, 30)\n", "\n", "view_coord = [0, 0, 0, 0, 0]\n", "for i, val in enumerate(views.size()):\n", " view_coord[i] += val // 2\n", "print('view_coord:', view_coord)\n", "test_view = views.get(*view_coord)\n", "\n", "cams = [\n", " view.CameraParam({\n", " \"fov\": fov_list[i],\n", " \"cx\": 0.5,\n", " \"cy\": 0.5,\n", " \"normalized\": True\n", " }, res_list[i]).to(device.default())\n", " for i in range(len(fov_list))\n", "]\n", "fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]\n", "#guide_cam = ref_dataset.cam_params\n", "vr_cam = view.CameraParam({\n", " 'fov': fov_list[-1],\n", " 'cx': 0.5,\n", " 'cy': 0.5,\n", " 'normalized': True\n", "}, res_full)\n", "foveation = Foveation(fov_list, res_full, device=device.default())\n", "\n", "\n", "def plot_figures(left_images, right_images, left_center, right_center):\n", " # Plot Fovea raw\n", " plt.figure(figsize=(8, 4))\n", " plt.subplot(121)\n", " img.plot(left_images['fovea_raw'])\n", " plt.subplot(122)\n", " img.plot(right_images['fovea_raw'])\n", "\n", " # Plot Fovea\n", " plt.figure(figsize=(8, 4))\n", " plt.subplot(121)\n", " img.plot(left_images['fovea'])\n", " plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],\n", " [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],\n", " color=[0, 1, 0])\n", " plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],\n", " [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],\n", " color=[0, 1, 0])\n", " plt.subplot(122)\n", " img.plot(right_images['fovea'])\n", " plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],\n", " [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],\n", " color=[0, 1, 0])\n", " plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],\n", " [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],\n", " color=[0, 1, 0])\n", "\n", " #plt.subplot(1, 4, 2)\n", " # img.plot(fovea_refined)\n", "\n", " # Plot Mid\n", " plt.figure(figsize=(8, 4))\n", " plt.subplot(121)\n", " img.plot(left_images['mid'])\n", " plt.subplot(122)\n", " img.plot(right_images['mid'])\n", "\n", " # Plot Periph\n", " plt.figure(figsize=(8, 4))\n", " plt.subplot(121)\n", " img.plot(left_images['periph'])\n", " plt.subplot(122)\n", " img.plot(right_images['periph'])\n", "\n", " # Plot Blended\n", " plt.figure(figsize=(12, 6))\n", " plt.subplot(121)\n", " img.plot(left_images['blended'])\n", " plt.plot([(res_full[1] - 1) / 2 + left_center[0] - 5, (res_full[1] - 1) / 2 + left_center[0] + 5],\n", " [(res_full[0] - 1) / 2 + left_center[1],\n", " (res_full[0] - 1) / 2 + left_center[1]],\n", " color=[0, 1, 0])\n", " plt.plot([(res_full[1] - 1) / 2 + left_center[0], (res_full[1] - 1) / 2 + left_center[0]],\n", " [(res_full[0] - 1) / 2 + left_center[1] - 5,\n", " (res_full[0] - 1) / 2 + left_center[1] + 5],\n", " color=[0, 1, 0])\n", " plt.subplot(122)\n", " img.plot(right_images['blended'])\n", " plt.plot([(res_full[1] - 1) / 2 + right_center[0] - 5, (res_full[1] - 1) / 2 + right_center[0] + 5],\n", " [(res_full[0] - 1) / 2 + right_center[1],\n", " (res_full[0] - 1) / 2 + right_center[1]],\n", " color=[0, 1, 0])\n", " plt.plot([(res_full[1] - 1) / 2 + right_center[0], (res_full[1] - 1) / 2 + right_center[0]],\n", " [(res_full[0] - 1) / 2 + right_center[1] - 5,\n", " (res_full[0] - 1) / 2 + right_center[1] + 5],\n", " color=[0, 1, 0])\n", "\n", "\n", "left_images = gen(\n", " left_center,\n", " view.Trans(\n", " test_view.trans_point(\n", " torch.tensor([-0.03, 0, 0], device=device.default())\n", " ),\n", " test_view.r\n", " ),\n", " ret_raw=True,\n", " mono_trans=test_view,\n", " shift=0)\n", "right_images = gen(\n", " right_center,\n", " view.Trans(\n", " test_view.trans_point(\n", " torch.tensor([0.03, 0, 0], device=device.default())\n", " ),\n", " test_view.r\n", " ),\n", " ret_raw=True,\n", " mono_trans=test_view,\n", " shift=0)\n", "\n", "plot_figures(left_images, right_images, left_center, right_center)\n", "\n", "misc.create_dir('output/mono_test')\n", "for key in left_images:\n", " img.save(\n", " left_images[key], 'output/mono_test/set%d_%s_l.png' % (set_id, key))\n", "for key in right_images:\n", " img.save(\n", " right_images[key], 'output/mono_test/set%d_%s_r.png' % (set_id, key))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.9 64-bit ('pytorch': conda)", "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }