{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set CUDA:0 as current device.\n", "Change working directory to /e/dengnc/deeplightfield/data/sp_view_syn_2020.12.31_fovea\n" ] }, { "data": { "text/plain": "" }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sys\n", "import os\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import torchvision.transforms.functional as trans_f\n", "import torch.nn.functional as nn_f\n", "\n", "sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n", "torch.cuda.set_device(0)\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "\n", "from deeplightfield.data.spherical_view_syn import *\n", "from deeplightfield.msl_net import MslNet\n", "from deeplightfield.configs.spherical_view_syn import SphericalViewSynConfig\n", "from deeplightfield.my import netio\n", "from deeplightfield.my import util\n", "from deeplightfield.my import device\n", "from deeplightfield.my import view\n", "\n", "\n", "os.chdir(sys.path[0] + '/../data/sp_view_syn_2020.12.31_fovea')\n", "print('Change working directory to ', os.getcwd())\n", "torch.autograd.set_grad_enabled(False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==== Config msl_coarse_gray1 ====\n", "Net type: msl\n", "Encode dim: 10\n", "Full-connected network parameters: {'nf': 64, 'n_layers': 12, 'skips': []}\n", "Sample parameters {'depth_range': (1, 20), 'n_samples': 16, 'perturb_sample': False, 'spherical': True}\n", "Loss mse\n", "==========================\n", "View dataset loaded.\n", "Ref dataset loaded.\n", "Load net from msl_coarse_gray1_b4096/model-epoch_500.pth ...\n", "Net loaded.\n" ] } ], "source": [ "# Load Config\n", "config = SphericalViewSynConfig()\n", "config.load_by_name('msl_coarse_gray1')\n", "config.SAMPLE_PARAMS['spherical'] = True\n", "config.SAMPLE_PARAMS['perturb_sample'] = False\n", "config.print()\n", "\n", "# Load Dataset\n", "view_dataset = SphericalViewSynDataset('train.json', load_images=True, load_depths=False, gray=True)\n", "print('View dataset loaded.')\n", "def read_ref_images(idx):\n", " patt= 'ref/train/view_%04d.png'\n", " if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:\n", " return trans_f.rgb_to_grayscale(util.ReadImageTensor([patt % i for i in idx]))\n", " else:\n", " return trans_f.rgb_to_grayscale(util.ReadImageTensor(patt % idx))\n", "print('Ref dataset loaded.')\n", "\n", "indices = torch.arange(view_dataset.n_views, device=device.GetDevice()).view(view_dataset.samples)\n", "cam_params = view_dataset.cam_params\n", "lr_cam_params = view.CameraParam({\n", " \"fov\" : 10,\n", " \"cx\" : 25.0,\n", " \"cy\" : 25.0\n", "}, (50, 50)).to(device.GetDevice())\n", "ref_cam_params = view.CameraParam({\n", " \"fx\" : 519.615251596838,\n", " \"fy\" : -519.615251596838,\n", " \"cx\" : 300.0,\n", " \"cy\" : 300.0\n", "}, (600, 600)).to(device.GetDevice())\n", "gt_images = view_dataset.view_images\n", "gt_depths = view_dataset.view_depths\n", "rays_o = view_dataset.rays_o\n", "rays_d = view_dataset.rays_d\n", "views_o = view_dataset.view_centers\n", "views_r = view_dataset.view_rots\n", "\n", "# Load Spher net\n", "net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, gray=True, encode_to_dim=config.N_ENCODE_DIM).to(device.GetDevice())\n", "netio.LoadNet('msl_coarse_gray1_b4096/model-epoch_500.pth', net)\n", "print('Net loaded.')\n", "\n", "def plot_point_cloud(pcloud, colors, ax=None):\n", " if not ax:\n", " plt.figure(figsize=(12, 12))\n", " ax = plt.gca(projection='3d')\n", " points3 = pcloud.flatten(0, -2).cpu().numpy()\n", " colors = colors.permute(1, 2, 0).flatten(0, 1).expand(-1, 3).cpu().numpy()\n", " ax.scatter(points3[:, 0], points3[:, 2], points3[:, 1], color=colors, s=0.3)\n", " util.save_2d_tensor('points.csv', points3)\n", " util.save_2d_tensor('colors.csv', colors)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_warp (rays_o, rays_d, depthmap, tgt_o, tgt_r, tgt_cam):\n", " pcloud = rays_o + rays_d * depthmap[..., None]\n", " pcloud_in_tgt = view.trans_point(\n", " pcloud, tgt_o, tgt_r, inverse=True)[None, ...]\n", " pixel_positions = tgt_cam.proj(pcloud_in_tgt)\n", " pixel_positions[..., 0] /= ref_cam_params.res[1] * 0.5\n", " pixel_positions[..., 1] /= ref_cam_params.res[0] * 0.5\n", " pixel_positions -= 1\n", " return pixel_positions\n", "\n", "\n", "center_view_coord = tuple(val // 2 for val in view_dataset.samples)\n", "center_view_idx = indices[center_view_coord]\n", "bound_view_idxs = [\n", " indices[center_view_coord[0] - 1, center_view_coord[1] + 1, center_view_coord[2],\n", " center_view_coord[3] - 1, center_view_coord[4] + 1],\n", " indices[center_view_coord[0] + 1, center_view_coord[1] + 1, center_view_coord[2],\n", " center_view_coord[3] + 1, center_view_coord[4] + 1],\n", " indices[center_view_coord[0] - 1, center_view_coord[1] - 1, center_view_coord[2],\n", " center_view_coord[3] - 1, center_view_coord[4] - 1],\n", " indices[center_view_coord[0] + 1, center_view_coord[1] - 1, center_view_coord[2],\n", " center_view_coord[3] + 1, center_view_coord[4] - 1]\n", "]\n", "#bound_view_idxs = [\n", "# indices[center_view_coord[0], center_view_coord[1], center_view_coord[2],\n", "# center_view_coord[3] - 1, center_view_coord[4]],\n", "# indices[center_view_coord[0], center_view_coord[1], center_view_coord[2],\n", "# center_view_coord[3] + 1, center_view_coord[4]],\n", "#]\n", "o = views_o[center_view_idx]\n", "r = views_r[center_view_idx]\n", "center_rays_o = rays_o[center_view_idx]\n", "center_rays_d = rays_d[center_view_idx]\n", "lr_center_rays_o = o[None, None, :].expand(lr_cam_params.res[0], lr_cam_params.res[1], -1)\n", "lr_center_rays_d = view.trans_vector(lr_cam_params.get_local_rays(), r)\n", "input, depthmap_ = net(center_rays_o, center_rays_d, ret_depth=True)\n", "lr_input, lr_depthmap = net(lr_center_rays_o, lr_center_rays_d, ret_depth=True)\n", "print(lr_input.size(), lr_depthmap.size())\n", "lr_input = nn_f.upsample(lr_input[None, ...], scale_factor=2, mode='bicubic')[0]\n", "lr_depthmap = nn_f.upsample(lr_depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]\n", "gt = gt_images[center_view_idx]\n", "bounds_img = [read_ref_images(idx).to(device.GetDevice())\n", " for idx in bound_view_idxs]\n", "bounds_o = [views_o[idx] for idx in bound_view_idxs]\n", "bounds_r = [views_r[idx] for idx in bound_view_idxs]\n", "bounds_rays_o = [\n", " views_o[idx][None, None, :].expand(ref_cam_params.res[0], ref_cam_params.res[1], -1)\n", " for idx in bound_view_idxs\n", "]\n", "bounds_rays_d = [\n", " view.trans_vector(ref_cam_params.get_local_rays(), views_r[idx])\n", " for idx in bound_view_idxs\n", "]\n", "bounds_inferred = [\n", " net(bounds_rays_o[i], bounds_rays_d[i])[None, ...]\n", " for i in range(len(bounds_img))\n", "]\n", "bounds_diff = [\n", " (bounds_img[i] - bounds_inferred[i] + 1e-5) / (bounds_inferred[i] + 1e-5)\n", " for i in range(len(bounds_img))\n", "]\n", "bounds_warp = [\n", " get_warp(center_rays_o, center_rays_d, depthmap_, bounds_o[i], bounds_r[i], ref_cam_params)\n", " for i in range(len(bounds_img))\n", "]\n", "bounds_warp_lr = [\n", " get_warp(center_rays_o, center_rays_d, lr_depthmap, bounds_o[i], bounds_r[i], ref_cam_params)\n", " for i in range(len(bounds_img))\n", "]\n", "\n", "def refine(input, bounds_diff, warps):\n", " warped_diff = [nn_f.grid_sample(bounds_diff[i], warps[i]) for i in range(len(warps))]\n", " avg_diff = sum(warped_diff) / len(warps)\n", " return input * (1 + avg_diff)\n", "\n", "warped = [nn_f.grid_sample(bounds_img[i], bounds_warp[i]) for i in range(len(bounds_warp))]\n", "warped_inferred = [nn_f.grid_sample(bounds_inferred[i], bounds_warp[i]) for i in range(len(bounds_warp))]\n", "\n", "input_refined = refine(input, bounds_diff, bounds_warp)\n", "input_refined_lr = refine(lr_input, bounds_diff, bounds_warp_lr)\n", "\n", "fig = plt.figure(figsize=(12, 3))\n", "plt.set_cmap('Greys_r')\n", "plt.subplot(1, 4, 1)\n", "util.PlotImageTensor(input)\n", "plt.subplot(1, 4, 2)\n", "util.PlotImageTensor(input_refined)\n", "plt.subplot(1, 4, 3)\n", "util.PlotImageTensor(input_refined_lr)\n", "plt.subplot(1, 4, 4)\n", "util.PlotImageTensor(gt)\n", "plt.show()\n", "\n", "def plot_image_matrices(center_image, ref_images):\n", " if len(ref_images) == 2:\n", " plt.figure(figsize=(12, 4))\n", " plt.set_cmap('Greys_r')\n", " plt.subplot(1, 3, 1)\n", " util.PlotImageTensor(ref_images[0])\n", " plt.subplot(1, 3, 3)\n", " util.PlotImageTensor(ref_images[1])\n", " plt.subplot(1, 3, 2)\n", " util.PlotImageTensor(center_image)\n", " elif len(ref_images) == 4:\n", " plt.figure(figsize=(12, 12))\n", " plt.set_cmap('Greys_r')\n", " plt.subplot(3, 3, 1)\n", " util.PlotImageTensor(ref_images[0])\n", " plt.subplot(3, 3, 3)\n", " util.PlotImageTensor(ref_images[1])\n", " plt.subplot(3, 3, 7)\n", " util.PlotImageTensor(ref_images[2])\n", " plt.subplot(3, 3, 9)\n", " util.PlotImageTensor(ref_images[3])\n", " plt.subplot(3, 3, 5)\n", " util.PlotImageTensor(center_image)\n", " plt.show()\n", "\n", "plot_image_matrices(input, warped_inferred)\n", "plot_image_matrices(gt, bounds_img)\n", "plot_image_matrices(torch.cat(warped[0:3], 1) if len(warped) >= 3 else torch.cat(warped + [torch.zeros_like(warped[0])], 1), warped)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6-final" }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }