{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n", "\n", "import torch\n", "import math\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from deep_view_syn.my import util\n", "from deep_view_syn.msl_net import *\n", "\n", "# Select device\n", "torch.cuda.set_device(2)\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def PlotSphere(ax, r):\n", " u, v = np.mgrid[0:2 * np.pi:50j, 0:np.pi:20j]\n", " x = np.cos(u) * np.sin(v) * r\n", " y = np.sin(u) * np.sin(v) * r\n", " z = np.cos(v) * r\n", " ax.plot_surface(x, y, z, rstride=1, cstride=1,\n", " color='b', linewidth=0.5, alpha=0.3)\n", "\n", "\n", "def PlotPlane(ax, r):\n", " # 二元函数定义域平面\n", " x = np.linspace(-r, r, 3)\n", " y = np.linspace(-r, r, 3)\n", " X, Y = np.meshgrid(x, y)\n", " ax.plot_wireframe(X, Y, X * 0, color='g', linewidth=1)\n", "\n", "\n", "p = torch.tensor([[0.0, 0.0, 0.0]])\n", "v = torch.tensor([[0.0, -1.0, 1.0]])\n", "r = torch.tensor([[2.5]])\n", "v = v / torch.norm(v) * r * 2\n", "p_on_sphere_ = RaySphereIntersect(p, v, r)[0]\n", "print(p_on_sphere_)\n", "print(p_on_sphere_.norm())\n", "spher_coord = RayToSpherical(p, v, r)\n", "print(spher_coord[..., 1:3].rad2deg())\n", "p_on_sphere = util.SphericalToCartesian(spher_coord)[0]\n", "\n", "fig = plt.figure(figsize=(6, 6))\n", "ax = fig.gca(projection='3d')\n", "plt.xlabel('x')\n", "plt.ylabel('z')\n", "\n", "PlotPlane(ax, r.item())\n", "PlotSphere(ax, r[0, 0].item())\n", "\n", "ax.scatter([0], [0], [0], color=\"g\", s=10) # Center\n", "ax.scatter([p_on_sphere[0, 0].item()],\n", " [p_on_sphere[0, 2].item()],\n", " [p_on_sphere[0, 1].item()],\n", " color=\"r\", s=10) # Ray position\n", "ax.scatter([p_on_sphere_[0, 0].item()],\n", " [p_on_sphere_[0, 2].item()],\n", " [p_on_sphere_[0, 1].item()],\n", " color=\"y\", s=10) # Ray position\n", "\n", "p_ = p + v\n", "ax.plot([p[0, 0].item(), p_[0, 0].item()],\n", " [p[0, 2].item(), p_[0, 2].item()],\n", " [p[0, 1].item(), p_[0, 1].item()],\n", " color=\"r\")\n", "\n", "ax.plot([p_on_sphere_[0, 0].item(), p_on_sphere_[0, 0].item()],\n", " [p_on_sphere_[0, 2].item(), p_on_sphere_[0, 2].item()],\n", " [0, p_on_sphere_[0, 1].item()], color=\"k\", linestyle='--', linewidth=0.5)\n", "\n", "ax.plot([p_on_sphere_[0, 0].item(), 0],\n", " [p_on_sphere_[0, 2].item(), 0],\n", " [0, 0],\n", " linewidth=0.5, linestyle=\"--\", color=\"k\")\n", "\n", "ax.plot([p_on_sphere_[0, 0].item(), 0],\n", " [p_on_sphere_[0, 2].item(), 0],\n", " [p_on_sphere_[0, 1], 0],\n", " linewidth=0.5, linestyle=\"--\", color=\"k\")\n", "\n", "ax.set_xlim(-r.item(), r.item())\n", "ax.set_ylim(-r.item(), r.item())\n", "ax.set_zlim(-r.item(), r.item())\n", "\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Dataset Loader & View-Spherical Transform" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset\n", "from deep_view_syn.data.spherical_view_syn import FastDataLoader\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.28'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "\n", "dataset = FastSphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", "dataset.set_patch_size((64, 64))\n", "data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n", "print(len(dataset))\n", "plt.figure()\n", "i = 0\n", "for indices, patches, rays_o, rays_d in data_loader:\n", " print(i, patches.size(), rays_o.size(), rays_d.size())\n", " for idx in range(len(indices)):\n", " plt.subplot(4, 4, i + 1)\n", " util.PlotImageTensor(patches[idx])\n", " i += 1\n", " if i == 16:\n", " break\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.26'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "DEPTH_RANGE = (1, 10)\n", "N_DEPTH_LAYERS = 10\n", "\n", "def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n", " diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n", " step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n", " depths = [1e5]\n", " depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n", " return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n", "\n", "train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", "train_data_loader = torch.utils.data.DataLoader(\n", " dataset=train_dataset,\n", " batch_size=4,\n", " num_workers=8,\n", " pin_memory=True,\n", " shuffle=True,\n", " drop_last=False)\n", "print(len(train_data_loader))\n", "\n", "print(\"view_res\", train_dataset.view_res)\n", "print(\"cam_params\", train_dataset.cam_params)\n", "\n", "msl_net = MslNet(train_dataset.cam_params,\n", " _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n", " train_dataset.view_res).to(device.GetDevice())\n", "print(\"sphere layers\", msl_net.rendering.sphere_layers)\n", "\n", "p = None\n", "v = None\n", "centers = None\n", "plt.figure(figsize=(6, 6))\n", "for _, view_images, ray_positions, ray_directions in train_data_loader:\n", " p = ray_positions\n", " v = ray_directions\n", " plt.subplot(2, 2, 1)\n", " util.PlotImageTensor(view_images[0])\n", " plt.subplot(2, 2, 2)\n", " util.PlotImageTensor(view_images[1])\n", " plt.subplot(2, 2, 3)\n", " util.PlotImageTensor(view_images[2])\n", " plt.subplot(2, 2, 4)\n", " util.PlotImageTensor(view_images[3])\n", " break\n", "p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),\n", " torch.tensor([[1]], device=device.GetDevice()))) \\\n", " .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)\n", "v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n", "p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n", "\n", "fig = plt.figure(figsize=(6, 6))\n", "ax = fig.gca(projection='3d')\n", "plt.xlabel('x')\n", "plt.ylabel('z')\n", "\n", "PlotSphere(ax, 1)\n", "\n", "ax.scatter([0], [0], [0], color=\"k\", s=10) # Center\n", "\n", "colors = [ 'r', 'g', 'b', 'y' ]\n", "for i in range(4):\n", " ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)\n", " for j in range(p_.shape[1]):\n", " ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],\n", " [centers[i, 2], centers[i, 2] + v[i, j, 2]],\n", " [centers[i, 1], centers[i, 1] + v[i, j, 1]],\n", " color=colors[i], linewidth=0.5, alpha=0.6)\n", "\n", "ax.set_xlim(-1, 1)\n", "ax.set_ylim(-1, 1)\n", "ax.set_zlim(-1, 1)\n", "\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Sampler" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "SAMPLE_PARAMS = {\n", " 'depth_range': (1, 5),\n", " 'n_samples': 5,\n", " 'perturb_sample': False\n", "}\n", "\n", "train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", "train_data_loader = torch.utils.data.DataLoader(\n", " dataset=train_dataset,\n", " batch_size=1,\n", " num_workers=8,\n", " pin_memory=True,\n", " shuffle=True,\n", " drop_last=False)\n", "print(len(train_data_loader))\n", "\n", "print(\"view_res\", train_dataset.view_res)\n", "print(\"cam_params\", train_dataset.cam_params)\n", "\n", "sampler = Sampler(**SAMPLE_PARAMS)\n", "\n", "fig = plt.figure(figsize=(12, 12))\n", "ax = fig.gca(projection='3d')\n", "plt.xlabel('x')\n", "plt.ylabel('z')\n", "\n", "i = 0\n", "selector: np.ndarray = np.array([j for j in range(65536)])\n", "selector = selector.reshape(256, 256)[::30, ::30]\n", "selector = selector.flatten()\n", "for _, pixels, p, v in train_data_loader:\n", " p = p.to(device.GetDevice())\n", " v = v.to(device.GetDevice())\n", " p_ = sampler(p, v)[0].squeeze().cpu().numpy()[selector]\n", " pixels_ = pixels.squeeze().permute(1, 2, 0).flatten(0, 1).cpu().numpy()[selector]\n", " for j in range(p_.shape[0]):\n", " #ax.plot(p_[j, :, 0], p_[j, :, 2], p_[j, :, 1], color=pixels_[j], linewidth=0.2)#, s=0.3)\n", " ax.scatter(p_[j, :, 0], p_[j, :, 2], p_[j, :, 1], color=pixels_[j], s=0.7)\n", " i += 1\n", " if i >= 20:\n", " break\n", "\n", "\n", "ax.scatter([0], [0], [0], color=\"k\", s=10) # Center\n", "\n", "ax.set_xlim(-5, 5)\n", "ax.set_ylim(-5, 5)\n", "ax.set_zlim(-5, 5)\n", "#ax.view_init(elev=90,azim=-90)\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "DEPTH_RANGE = (1, 10)\n", "N_DEPTH_LAYERS = 10\n", "\n", "def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n", " diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n", " step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n", " depths = [1e5]\n", " depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n", " return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n", "\n", "train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE, ray_as_item=True)\n", "train_data_loader = torch.utils.data.DataLoader(\n", " dataset=train_dataset,\n", " batch_size=4096,\n", " num_workers=8,\n", " pin_memory=True,\n", " shuffle=True,\n", " drop_last=False)\n", "print(len(train_data_loader))\n", "\n", "print(\"view_res\", train_dataset.view_res)\n", "print(\"cam_params\", train_dataset.cam_params)\n", "\n", "#msl_net = MslNet(train_dataset.cam_params,\n", "# _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n", "# train_dataset.view_res).to(device.GetDevice())\n", "#print(\"sphere layers\", msl_net.rendering.sphere_layers)\n", "\n", "fig = plt.figure(figsize=(12, 12))\n", "ax = fig.gca(projection='3d')\n", "plt.xlabel('x')\n", "plt.ylabel('z')\n", "\n", "i = 0\n", "selector: np.ndarray = np.array([j for j in range(65536)])\n", "selector = selector.reshape(256, 256)[::3, ::3]\n", "selector = selector.flatten()\n", "for _, pixels, ray_positions, ray_directions in train_data_loader:\n", " p = ray_positions\n", " v = ray_directions / ray_directions.norm(dim=1, keepdim=True)\n", " v = v.numpy()\n", " #ax.scatter(v[selector, 0], v[selector, 2], v[selector, 1], color=pixels.numpy()[selector, :], s=0.1)\n", " ax.scatter(v[:, 0], v[:, 2], v[:, 1], color=pixels.numpy(), s=0.1)\n", " i += 1\n", " if i >= 20:\n", " break\n", "\n", "\n", "ax.scatter([0], [0], [0], color=\"k\", s=10) # Center\n", "\n", "ax.set_xlim(-1, 1)\n", "ax.set_ylim(-1, 1)\n", "ax.set_zlim(-1, 1)\n", "ax.view_init(elev=0,azim=-90)\n", "\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Spherical View Synthesis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import ipywidgets as widgets # 控件库\n", "from IPython.display import display # 显示控件的方法\n", "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "from deep_view_syn.spher_net import SpherNet\n", "from deep_view_syn.my import netio\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n", "DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "NET_FILE = DATA_DIR + '/rgb_ray_b2048_encode10_fc256x8/model-epoch_300.pth'\n", "N_ENCODE_DIM = 10\n", "FC_PARAMS = {\n", " 'nf': 256,\n", " 'n_layers': 8,\n", " 'skips': []\n", "}\n", "GRAY = False\n", "ROT_ONLY = False\n", "FOV = 20\n", "\n", "out_res = (256, 256)\n", "cam_params = {\n", " 'fx': out_res[0] / util.Fov2Length(FOV),\n", " 'fy': -out_res[0] / util.Fov2Length(FOV),\n", " 'cx': out_res[0] / 2,\n", " 'cy': out_res[1] / 2\n", "}\n", "local_rays = util.GetLocalViewRays(cam_params, out_res, flatten=True).to(device.GetDevice())\n", "\n", "model = SpherNet(cam_params=cam_params,\n", " fc_params=FC_PARAMS,\n", " out_res=out_res,\n", " gray=GRAY,\n", " translation=not ROT_ONLY,\n", " encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())\n", "netio.LoadNet(NET_FILE, model)\n", "\n", "slider_x = widgets.FloatSlider(description='X', value=0,\n", " min=-0.05, max=0.05, step=0.002,\n", " continuous_update=True,\n", " readout=True, readout_format='.3f')\n", "slider_y = widgets.FloatSlider(description='Y', value=0,\n", " min=-0.025, max=0.025, step=0.002,\n", " continuous_update=True,\n", " readout=True, readout_format='.3f')\n", "slider_z = widgets.FloatSlider(description='Z', value=0,\n", " min=-0.05, max=0.05, step=0.002,\n", " continuous_update=True,\n", " readout=True, readout_format='.3f')\n", "slider_theta = widgets.IntSlider(description='θ', value=90,\n", " min=10, max=170, step=2,\n", " continuous_update=True,\n", " readout=True, readout_format='.1f')\n", "slider_phi = widgets.IntSlider(description='φ', value=90,\n", " min=-70, max=110, step=2,\n", " continuous_update=True,\n", " readout=True, readout_format='.1f')\n", "\n", "plt.figure()\n", "\n", "def f(x, y, z, theta, phi):\n", " print((x, y, z, theta, phi))\n", " # p: 1 x M x 3\n", " p = torch.tensor([[[x, y, z]]], device=device.GetDevice()).expand(-1, local_rays.size(0), -1)\n", " r = util.GetRotMatrix(math.radians(theta), math.radians(phi)).to(device.GetDevice())\n", " # v: 1 x M x 3\n", " v = torch.mm(local_rays, r).unsqueeze(0)\n", " print(local_rays, r)\n", " image = model(p, v)\n", " util.PlotImageTensor(image)\n", "\n", "out = widgets.interactive_output(f, {\n", " 'x': slider_x, 'y': slider_y, 'z': slider_z,\n", " 'theta': slider_theta, 'phi': slider_phi\n", "})\n", "display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6-final" } }, "nbformat": 4, "nbformat_minor": 2 }