{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "import torch\n", "import matplotlib.pyplot as plt\n", "\n", "os.chdir('../')\n", "sys.path.append(os.getcwd())\n", "torch.cuda.set_device(1)\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "\n", "from data.spherical_view_syn import *\n", "from configs.spherical_view_syn import SphericalViewSynConfig\n", "from utils import netio\n", "from utils import img\n", "from utils import device\n", "from utils import view\n", "from components import refine\n", "\n", "\n", "os.chdir('data/us_gas_all_in_one')\n", "print('Change working directory to ', os.getcwd())\n", "torch.autograd.set_grad_enabled(False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load Config\n", "model_path = 'fovea@nmsl-rgb_e10_fc128x4_d1-50_s32.pth'\n", "config = SphericalViewSynConfig()\n", "config.from_id(os.path.splitext(os.path.basename(model_path))[0])\n", "config.SAMPLE_PARAMS['perturb_sample'] = False\n", "config.print()\n", "\n", "# Load Dataset\n", "view_dataset = SphericalViewSynDataset(\n", " 'views.json', load_images=False, load_depths=False,\n", " color=config.COLOR, calculate_rays=False)\n", "ref_dataset = SphericalViewSynDataset(\n", " 'ref.json', load_images=False, load_depths=False,\n", " color=config.COLOR, calculate_rays=False)\n", "print('Dataset loaded.')\n", "\n", "def read_ref_images(idx):\n", " patt = 'ref/view_%04d.png'\n", " if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:\n", " return img.load([patt % i for i in idx]).to(device.default())\n", " else:\n", " return img.load(patt % idx).to(device.default())\n", "\n", "indices = torch.arange(view_dataset.n_views, device=device.default()).view(\n", " view_dataset.samples)\n", "ref_indices = torch.arange(ref_dataset.n_views, device=device.default()).view(\n", " ref_dataset.samples)\n", "cam_params = view.CameraParam({\n", " \"fov\": 20,\n", " \"cx\": 0.5,\n", " \"cy\": 0.5,\n", " \"normalized\": True\n", "}, (100, 100)).to(device.default())\n", "ref_cam_params = ref_dataset.cam_params\n", "\n", "# Load Spher net\n", "net = config.create_net().to(device.default())\n", "netio.load(model_path, net)\n", "print('Net loaded.')\n", "\n", "vr_cam = view.CameraParam({\n", " 'fov': 110,\n", " 'cx': 0.5,\n", " 'cy': 0.5,\n", " 'normalized': True\n", "}, (1600, 1440))\n", "\n", "def adjust_cam(cam, vr_cam, gaze_center):\n", " fovea_offset = (\n", " (gaze_center[0]) / vr_cam.f[0].item() * cam.f[0].item(),\n", " (gaze_center[1]) / vr_cam.f[1].item() * cam.f[1].item()\n", " )\n", " cam.c[0] = cam.res[1] / 2 - fovea_offset[0]\n", " cam.c[1] = cam.res[0] / 2 - fovea_offset[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "center = (-187, 64)\n", "test_view_coord = list(val // 2 for val in view_dataset.samples)\n", "#test_view_coord[3] -= 1\n", "test_view_coord = tuple(test_view_coord)\n", "test_view_coord_in_ref = (\n", " test_view_coord[0] + 1, test_view_coord[1] + 1, test_view_coord[2] + 1,\n", " test_view_coord[3] + 2, test_view_coord[4] - 1,\n", ")\n", "print('test_view_coord', test_view_coord)\n", "print('test_view_coord_in_ref', test_view_coord_in_ref)\n", "print('ref_dataset.samples', ref_dataset.samples)\n", "print('ref_indices.size()', ref_indices.size())\n", "print('indices.size()', indices.size())\n", "test_view_idx = indices[test_view_coord]\n", "a = ref_indices[test_view_coord_in_ref[0] - 2, test_view_coord_in_ref[1] + 2, test_view_coord_in_ref[2],\n", " test_view_coord_in_ref[3] - 1, test_view_coord_in_ref[4] + 1][None]\n", "b = ref_indices[test_view_coord_in_ref[0] + 2, test_view_coord_in_ref[1] + 2, test_view_coord_in_ref[2],\n", " test_view_coord_in_ref[3] + 1, test_view_coord_in_ref[4] + 1][None]\n", "c = ref_indices[test_view_coord_in_ref[0] - 2, test_view_coord_in_ref[1] - 2, test_view_coord_in_ref[2],\n", " test_view_coord_in_ref[3] - 1, test_view_coord_in_ref[4] - 1][None]\n", "d = ref_indices[test_view_coord_in_ref[0] + 2, test_view_coord_in_ref[1] - 2, test_view_coord_in_ref[2],\n", " test_view_coord_in_ref[3] + 1, test_view_coord_in_ref[4] - 1][None]\n", "bound_view_idxs = torch.cat([a, b, c, d])\n", "print(bound_view_idxs)\n", "# bound_view_idxs = [\n", "# indices[center_view_coord[0], center_view_coord[1], center_view_coord[2],\n", "# center_view_coord[3] - 1, center_view_coord[4]],\n", "# indices[center_view_coord[0], center_view_coord[1], center_view_coord[2],\n", "# center_view_coord[3] + 1, center_view_coord[4]],\n", "# ]\n", "\n", "def get_guides(view_coord):\n", " coord_offset = [val // 2 - view_dataset.samples[i] // 2 for i, val in enumerate(ref_dataset.samples)]\n", " guides_coord = [\n", " [\n", " view_coord[0] - 2, view_coord[1] + 2, view_coord[2],\n", " view_coord[3] - 1, view_coord[4] + 1\n", " ],\n", " [\n", " view_coord[0] + 2, view_coord[1] + 2, view_coord[2],\n", " view_coord[3] + 1, view_coord[4] + 1\n", " ],\n", " [\n", " view_coord[0] - 2, view_coord[1] - 2, view_coord[2],\n", " view_coord[3] - 1, view_coord[4] - 1\n", " ],\n", " [\n", " view_coord[0] + 2, view_coord[1] - 2, view_coord[2],\n", " view_coord[3] + 1, view_coord[4] - 1\n", " ]\n", " ]\n", " for coord in guides_coord:\n", " for i in range(len(coord_offset)):\n", " coord[i] += coord_offset[i]\n", " ref_indices = torch.arange(0, ref_dataset.n_views,\n", " device=device.default()).view(ref_dataset.samples)\n", " guides_idx = torch.stack([\n", " ref_indices[tuple(coord)] for coord in guides_coord\n", " ])\n", " print('guides_idx:', guides_idx)\n", " guides_image = read_ref_images(guides_idx).to(device.default())\n", " guides_trans = view.Trans(ref_dataset.view_centers[guides_idx],\n", " ref_dataset.view_rots[guides_idx])\n", " return refine.GuideRefinement(guides_image, guides_trans, ref_cam_params, net)\n", "\n", "guide_refine = get_guides(test_view_coord)\n", "\n", "def gen(fovea_center, trans):\n", " adjust_cam(cam_params, vr_cam, fovea_center)\n", "\n", " fovea_rays_o, fovea_rays_d = cam_params.get_global_rays(trans) # (H_fovea, W_fovea, 3)\n", "\n", " fovea_inferred, fovea_depthmap = net(\n", " fovea_rays_o.view(-1, 3), fovea_rays_d.view(-1, 3), ret_depth=True)\n", " fovea_inferred = fovea_inferred.view(\n", " cam_params.res[0], cam_params.res[1], -1).permute(2, 0, 1) # (C, H_fovea, W_fovea)\n", " fovea_depthmap = fovea_depthmap.view(cam_params.res[0], cam_params.res[1])\n", " \n", " fovea_refined = guide_refine.refine_by_guide(fovea_inferred, fovea_depthmap,\n", " fovea_rays_o, fovea_rays_d, False)\n", "\n", " return {\n", " 'fovea_raw': fovea_inferred,\n", " 'fovea': fovea_refined,\n", " 'fovea_depth': fovea_depthmap\n", " }\n", "\n", "#adjust_cam(cam_params, vr_cam, center)\n", "trans = view.Trans(view_dataset.view_centers[test_view_idx],\n", " view_dataset.view_rots[test_view_idx])\n", "#rays_o, rays_d = cam_params.get_global_rays(trans, flatten=True)\n", "\n", "#inferred, depthmap = net(rays_o.view(-1, 3),\n", "# rays_d.view(-1, 3), ret_depth=True)\n", "#inferred = inferred.view(\n", "# cam_params.res[0], cam_params.res[1], -1).permute(2, 0, 1)\n", "#inferred = nn_f.upsample_bilinear(inferred.unsqueeze(0), scale_factor=2)[0]\n", "#depthmap = depthmap.view(cam_params.res[0], cam_params.res[1])\n", "#depthmap = nn_f.upsample_bilinear(depthmap[None, None, :, :], scale_factor=2)[0, 0]\n", "#gt = view_dataset.view_images[test_view_idx]\n", "#bounds_img = read_ref_images(bound_view_idxs)\n", "#bounds_o = ref_dataset.view_centers[bound_view_idxs]\n", "#bounds_r = ref_dataset.view_rots[bound_view_idxs]\n", "images = gen(center, trans)\n", "inferred = images['fovea_raw']\n", "refined = images['fovea']\n", "\n", "\n", "#guide_refine = refine.GuideRefinement(bounds_img, view.Trans(bounds_o, bounds_r), ref_cam_params, net)\n", "#refined = guide_refine.refine_by_guide(inferred, depthmap, rays_o, rays_d, False)\n", "\n", "# warped = [nn_f.grid_sample(bounds_img[i], bounds_warp[i])\n", "# for i in range(len(bounds_warp))]\n", "# warped_inferred = [nn_f.grid_sample(\n", "# bounds_inferred[i], bounds_warp[i]) for i in range(len(bounds_warp))]\n", "\n", "fig = plt.figure(figsize=(12, 3))\n", "plt.set_cmap('Greys_r')\n", "plt.subplot(1, 3, 1)\n", "img.plot(inferred)\n", "plt.subplot(1, 3, 2)\n", "img.plot(refined)\n", "#plt.subplot(1, 3, 3)\n", "#img.plot(gt)\n", "plt.show()\n", "\n", "\n", "def plot_image_matrices(center_image, ref_images):\n", " if len(ref_images) == 2:\n", " plt.figure(figsize=(12, 4))\n", " plt.set_cmap('Greys_r')\n", " plt.subplot(1, 3, 1)\n", " img.plot(ref_images[0])\n", " plt.subplot(1, 3, 3)\n", " img.plot(ref_images[1])\n", " if center_image != None:\n", " plt.subplot(1, 3, 2)\n", " img.plot(center_image)\n", " elif len(ref_images) == 4:\n", " plt.figure(figsize=(12, 12))\n", " plt.set_cmap('Greys_r')\n", " plt.subplot(3, 3, 1)\n", " img.plot(ref_images[0])\n", " plt.subplot(3, 3, 3)\n", " img.plot(ref_images[1])\n", " plt.subplot(3, 3, 7)\n", " img.plot(ref_images[2])\n", " plt.subplot(3, 3, 9)\n", " img.plot(ref_images[3])\n", " if center_image != None:\n", " plt.subplot(3, 3, 5)\n", " img.plot(center_image)\n", " plt.show()\n", "\n", "\n", "#plot_image_matrices(input, warped_inferred)\n", "plot_image_matrices(None, bounds_img)\n", "# plot_image_matrices(torch.cat(warped[0:3], 1) if len(\n", "# warped) >= 3 else torch.cat(warped + [torch.zeros_like(warped[0])], 1), warped)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.9 64-bit ('pytorch': conda)", "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }