{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "\n", "rootdir = os.path.abspath('../')\n", "sys.path.append(rootdir)\n", "\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from utils import img\n", "from utils import sphere\n", "from utils.constants import *\n", "from nets.msl_net import *\n", "\n", "# Select device\n", "torch.cuda.set_device(0)\n", "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def PlotSphere(ax, r):\n", " u, v = np.mgrid[0:2 * PI:50j, 0:PI:20j]\n", " x = np.cos(u) * np.sin(v) * r\n", " y = np.sin(u) * np.sin(v) * r\n", " z = np.cos(v) * r\n", " ax.plot_surface(x, y, z, rstride=1, cstride=1,\n", " color='b', linewidth=0.5, alpha=0.3)\n", "\n", "\n", "def PlotPlane(ax, r):\n", " # 二元函数定义域平面\n", " x = np.linspace(-r, r, 3)\n", " y = np.linspace(-r, r, 3)\n", " X, Y = np.meshgrid(x, y)\n", " ax.plot_wireframe(X, Y, X * 0, color='g', linewidth=1)\n", "\n", "\n", "p = torch.tensor([[0.0, 0.0, 0.0]])\n", "v = torch.tensor([[0.0, -1.0, 1.0]])\n", "r = torch.tensor([[0.5]])\n", "v = v / torch.norm(v) * r * 2\n", "p_on_sphere_ = sphere.ray_sphere_intersect(p, v, r)[0][0]\n", "print(p_on_sphere_)\n", "print(p_on_sphere_.norm())\n", "spher_coord = sphere.cartesian2spherical(p_on_sphere_)\n", "print(spher_coord[..., 1:3].rad2deg())\n", "p_on_sphere = sphere.spherical2cartesian(spher_coord)\n", "print(p_on_sphere_.size())\n", "\n", "fig = plt.figure(figsize=(8, 8))\n", "ax = fig.gca(projection='3d')\n", "plt.xlabel('x')\n", "plt.ylabel('z')\n", "\n", "PlotPlane(ax, r.item())\n", "PlotSphere(ax, r[0, 0].item())\n", "\n", "ax.scatter([0], [0], [0], color=\"g\", s=10) # Center\n", "ax.scatter([p_on_sphere[0, 0].item()],\n", " [p_on_sphere[0, 2].item()],\n", " [p_on_sphere[0, 1].item()],\n", " color=\"r\", s=10) # Ray position\n", "ax.scatter([p_on_sphere_[0, 0].item()],\n", " [p_on_sphere_[0, 2].item()],\n", " [p_on_sphere_[0, 1].item()],\n", " color=\"y\", s=10) # Ray position\n", "\n", "p_ = p + v\n", "ax.plot([p[0, 0].item(), p_[0, 0].item()],\n", " [p[0, 2].item(), p_[0, 2].item()],\n", " [p[0, 1].item(), p_[0, 1].item()],\n", " color=\"r\")\n", "\n", "ax.plot([p_on_sphere_[0, 0].item(), p_on_sphere_[0, 0].item()],\n", " [p_on_sphere_[0, 2].item(), p_on_sphere_[0, 2].item()],\n", " [0, p_on_sphere_[0, 1].item()], color=\"k\", linestyle='--', linewidth=0.5)\n", "\n", "ax.plot([p_on_sphere_[0, 0].item(), 0],\n", " [p_on_sphere_[0, 2].item(), 0],\n", " [0, 0],\n", " linewidth=0.5, linestyle=\"--\", color=\"k\")\n", "\n", "ax.plot([p_on_sphere_[0, 0].item(), 0],\n", " [p_on_sphere_[0, 2].item(), 0],\n", " [p_on_sphere_[0, 1], 0],\n", " linewidth=0.5, linestyle=\"--\", color=\"k\")\n", "\n", "ax.set_xlim(-r.item(), r.item())\n", "ax.set_ylim(-r.item(), r.item())\n", "ax.set_zlim(-r.item(), r.item())\n", "\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Dataset Loader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from data.spherical_view_syn import SphericalViewSynDataset\n", "from data.loader import FastDataLoader\n", "\n", "DATA_DESC_FILE = f'{rootdir}/data/__new/street_fovea_r360x80_t1.0/train1.json'\n", "\n", "dataset = SphericalViewSynDataset(DATA_DESC_FILE)\n", "data_loader = FastDataLoader(dataset, 4, shuffle=False)\n", "\n", "fig = plt.figure(figsize=(12, 6.5))\n", "i = 0\n", "for indices, patches, rays_o, rays_d in data_loader:\n", " print(i, patches.size(), rays_o.size(), rays_d.size())\n", " for idx in range(len(indices)):\n", " plt.subplot(4, 7, i + 1)\n", " img.plot(patches[idx])\n", " i += 1\n", " if i == 4:\n", " break\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Validate Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from data.spherical_view_syn import SphericalViewSynDataset\n", "from data.loader import FastDataLoader\n", "\n", "\n", "#DATA_DESC_FILE = f'{rootdir}/data/pabellon_fovea_r40x40_t0.3/train.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/gas_fovea_r80x60_t0.3_2021.01.26/train.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/nerf_fern/train.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/lobby_fovea_2021.01.18/train.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/__new/street_fovea_r360x80_t1.0/train1.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/__new/stones_fovea_r360x80_t1.0/train1.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/__new/lobby_periph_r360x180_t1.0/train1.json'\n", "DATA_DESC_FILE = f'{rootdir}/data/__new/classroom_all/nerf_cvt.json'\n", "\n", "\n", "dataset = SphericalViewSynDataset(DATA_DESC_FILE, load_views=range(12))\n", "dataset.set_patch_size(1)\n", "res = dataset.view_res\n", "data_loader = FastDataLoader(dataset, res[0] * res[1], shuffle=False)\n", "\n", "selector = torch.arange(res[0] * res[1]).reshape(res[0], res[1])[::5, ::5].flatten()\n", "\n", "for ri in range(0, 4):\n", " r = ri * 2 + 1\n", " p = None\n", " centers = None\n", " pixels = None\n", " idx_range = list(range(12)) #+ list(range(24, 30)) + list(range(42, 48))\n", " idx = 0\n", " for indices, patches, rays_o, rays_d in data_loader:\n", " if idx not in idx_range:\n", " idx += 1\n", " continue\n", " patches = patches[selector]\n", " rays_o = rays_o[selector]\n", " rays_d = rays_d[selector]\n", " r = torch.tensor([[r]], device=device.default())\n", " p_ = misc.torch2np(sphere.ray_sphere_intersect(rays_o, rays_d, r)[0].view(-1, 3))\n", " p = p_ if p is None else np.concatenate((p, p_), axis=0)\n", " pixels_ = misc.torch2np(patches)\n", " pixels = pixels_ if pixels is None else np.concatenate((pixels, pixels_), axis=0)\n", " idx += 1\n", "\n", " plt.figure(facecolor='white', figsize=(20, 20))\n", " ax = plt.axes(projection='3d')\n", " #ax = plt.subplot(1, 2, ri % 2 + 1, projection='3d')\n", " plt.xlabel('x')\n", " plt.ylabel('z')\n", " plt.title('r = %f' % r)\n", " ax.scatter([0], [0], [0], color=\"k\", s=10)\n", " ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n", " ax.view_init(elev=0, azim=-90)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }