{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set CUDA:0 as current device.\n", "Change working directory to /e/dengnc/deeplightfield/data/sp_view_syn_2021.01.04_all_in_one\n" ] } ], "source": [ "import sys\n", "import os\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import torchvision.transforms.functional as trans_f\n", "\n", "sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n", "torch.cuda.set_device(0)\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "\n", "from deeplightfield.data.spherical_view_syn import *\n", "from deeplightfield.msl_net import MslNet\n", "from deeplightfield.configs.spherical_view_syn import SphericalViewSynConfig\n", "from deeplightfield.my import netio\n", "from deeplightfield.my import util\n", "from deeplightfield.my import device\n", "from deeplightfield.my import view\n", "from deeplightfield.my.simple_perf import SimplePerf\n", "from deeplightfield.my.foveation import Foveation\n", "\n", "\n", "os.chdir(sys.path[0] + '/../data/sp_view_syn_2021.01.04_all_in_one')\n", "print('Change working directory to ', os.getcwd())\n", "torch.autograd.set_grad_enabled(False)\n", "GRAY = False" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==== Config msl_coarse_rgb1 ====\n", "Net type: msl\n", "Encode dim: 10\n", "Full-connected network parameters: {'nf': 64, 'n_layers': 12, 'skips': []}\n", "Sample parameters {'spherical': True, 'depth_range': (1, 20), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n", "==========================\n", "Load net from fovea@msl_coarse_rgb1.pth ...\n", "==== Config msl_rgb_periph ====\n", "Net type: msl\n", "Encode dim: 10\n", "Full-connected network parameters: {'nf': 64, 'n_layers': 8, 'skips': []}\n", "Sample parameters {'spherical': True, 'depth_range': (1, 50), 'n_samples': 4, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}\n", "==========================\n", "Load net from periph@msl_rgb_periph.pth ...\n", "Dataset loaded.\n" ] } ], "source": [ "def load_net(name):\n", " # Load Config\n", " config = SphericalViewSynConfig()\n", " config.load_by_name(name.split('@')[1])\n", " config.SAMPLE_PARAMS['spherical'] = True\n", " config.SAMPLE_PARAMS['perturb_sample'] = False\n", " config.print()\n", " net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, GRAY,\n", " config.N_ENCODE_DIM).to(device.GetDevice())\n", " netio.LoadNet(name + '.pth', net)\n", " return net\n", "\n", "\n", "def read_ref_images(idx):\n", " patt = 'ref/view_%04d.png'\n", " if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:\n", " return trans_f.rgb_to_grayscale(util.ReadImageTensor([patt % i for i in idx]))\n", " else:\n", " return trans_f.rgb_to_grayscale(util.ReadImageTensor(patt % idx))\n", "\n", "\n", "if GRAY:\n", " fovea_net = load_net('fovea@msl_coarse_gray1')\n", " periph_net = load_net('periph@msl_gray_periph')\n", "else:\n", " fovea_net = load_net('fovea@msl_coarse_rgb1')\n", " periph_net = load_net('periph@msl_rgb_periph')\n", "\n", "# Load Dataset\n", "view_dataset = SphericalViewSynDataset(\n", " 'train.json', load_images=False, load_depths=False,\n", " gray=GRAY, calculate_rays=False)\n", "ref_dataset = SphericalViewSynDataset(\n", " 'ref.json', load_images=False, load_depths=False,\n", " gray=GRAY, calculate_rays=False)\n", "print('Dataset loaded.')\n", "\n", "fov_list = [10, 60, 110]\n", "res_list = [(64, 64), (256, 256), (256, 256)]\n", "cams = [\n", " view.CameraParam({\n", " \"fov\": fov_list[i],\n", " \"cx\": 0.5,\n", " \"cy\": 0.5,\n", " \"normalized\": True\n", " }, res_list[i]).to(device.GetDevice())\n", " for i in range(len(fov_list))\n", "]\n", "fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]\n", "ref_cam_params = ref_dataset.cam_params\n", "\n", "indices = torch.arange(view_dataset.n_views,\n", " device=device.GetDevice()).view(view_dataset.samples)\n", "ref_indices = torch.arange(\n", " ref_dataset.n_views, device=device.GetDevice()).view(ref_dataset.samples)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "view_coord = [val // 2 for val in view_dataset.samples]\n", "view_coord[0] = view_coord[0] + 1\n", "print(view_coord, indices.size())\n", "view_idx = indices[tuple(view_coord)]\n", "view_o = view_dataset.view_centers[view_idx] # (3)\n", "view_r = view_dataset.view_rots[view_idx] # (3, 3)\n", "foveation = Foveation(fov_list, (1440, 1440), device=device.GetDevice())\n", "\n", "perf = SimplePerf(True, True)\n", "\n", "fovea_rays_o, fovea_rays_d = fovea_cam.get_global_rays(view_o, view_r) # (H_fovea, W_fovea, 3)\n", "mid_rays_o, mid_rays_d = mid_cam.get_global_rays(view_o, view_r) # (H_mid, W_mid, 3)\n", "periph_rays_o, periph_rays_d = periph_cam.get_global_rays(view_o, view_r) # (H_periph, W_periph, 3)\n", "mid_periph_rays_o = torch.stack([mid_rays_o, periph_rays_o], 0)\n", "mid_periph_rays_d = torch.stack([mid_rays_d, periph_rays_d], 0)\n", "perf.Checkpoint('Get rays')\n", "\n", "perf1 = SimplePerf(True, True)\n", "\n", "fovea_inferred = fovea_net(fovea_rays_o.view(-1, 3), fovea_rays_d.view(-1, 3)).view(\n", " fovea_cam.res[0], fovea_cam.res[1], -1).permute(2, 0, 1) # (C, H_fovea, W_fovea)\n", "perf1.Checkpoint('Infer fovea')\n", "\n", "#mid_inferred = periph_net(mid_rays_o, mid_rays_d) # (C, H_mid, W_mid)\n", "#perf1.Checkpoint('Infer mid')\n", "\n", "#periph_inferred = periph_net(periph_rays_o, periph_rays_d) # (C, H_periph, W_periph)\n", "#perf1.Checkpoint('Infer periph')\n", "\n", "periph_mid_inferred = periph_net(mid_periph_rays_o.view(-1, 3),\n", " mid_periph_rays_d.view(-1, 3)) # (C, H_periph, W_periph)\n", "periph_mid_inferred = periph_mid_inferred.view(2, mid_cam.res[0], mid_cam.res[1], -1).permute(0, 3, 1, 2)\n", "mid_inferred = periph_mid_inferred[0]\n", "periph_inferred = periph_mid_inferred[1]\n", "perf1.Checkpoint('Infer mid & periph')\n", "\n", "perf.Checkpoint('Infer')\n", "\n", "blended = foveation.synthesis([\n", " fovea_inferred[None, ...],\n", " mid_inferred[None, ...],\n", " periph_inferred[None, ...]\n", "])\n", "\n", "perf.Checkpoint('Blend')\n", "\n", "plt.figure(figsize=(12, 4))\n", "plt.set_cmap('Greys_r')\n", "plt.subplot(1, 3, 1)\n", "util.PlotImageTensor(fovea_inferred)\n", "plt.subplot(1, 3, 2)\n", "util.PlotImageTensor(mid_inferred)\n", "plt.subplot(1, 3, 3)\n", "util.PlotImageTensor(periph_inferred)\n", "\n", "plt.figure(figsize=(12, 12))\n", "util.PlotImageTensor(blended)\n", "\n", "util.CreateDirIfNeed('output')\n", "util.WriteImageTensor(blended, 'output/blended_%04d.png' % view_idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }