gen_test.ipynb 6.07 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Set CUDA:2 as current device.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
22
23
    "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
    "sys.path.append(rootdir)\n",
Nianchen Deng's avatar
Nianchen Deng committed
24
25
26
27
28
    "torch.cuda.set_device(2)\n",
    "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
    "\n",
    "from ..data.spherical_view_syn import *\n",
    "from ..configs.spherical_view_syn import SphericalViewSynConfig\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
30
31
32
33
34
    "from utils import netio\n",
    "from utils import img\n",
    "from utils import device\n",
    "from utils import view\n",
    "from components.gen_final import GenFinal\n",
    "from utils.perf import Perf\n",
Nianchen Deng's avatar
Nianchen Deng committed
35
36
37
38
39
40
41
    "\n",
    "\n",
    "def load_net(path):\n",
    "    config = SphericalViewSynConfig()\n",
    "    config.from_id(path[:-4])\n",
    "    config.SAMPLE_PARAMS['perturb_sample'] = False\n",
    "    config.print()\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
43
    "    net = config.create_net().to(device.default())\n",
    "    netio.load(path, net)\n",
Nianchen Deng's avatar
Nianchen Deng committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    "    return net\n",
    "\n",
    "\n",
    "def find_file(prefix):\n",
    "    for path in os.listdir():\n",
    "        if path.startswith(prefix):\n",
    "            return path\n",
    "    return None\n",
    "\n",
    "\n",
    "def load_views(data_desc_file) -> view.Trans:\n",
    "    with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
    "        data_desc = json.loads(file.read())\n",
    "        view_centers = torch.tensor(\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
    "            data_desc['view_centers'], device=device.default()).view(-1, 3)\n",
Nianchen Deng's avatar
Nianchen Deng committed
59
    "        view_rots = torch.tensor(\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
60
    "            data_desc['view_rots'], device=device.default()).view(-1, 3, 3)\n",
Nianchen Deng's avatar
Nianchen Deng committed
61
62
63
64
65
66
    "        return view.Trans(view_centers, view_rots)\n",
    "\n",
    "\n",
    "def plot_figures(images, center):\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    plt.subplot(121)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
67
    "    img.plot(images['fovea_raw'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
68
    "    plt.subplot(122)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
69
    "    img.plot(images['fovea'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
70
71
72
    "\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    plt.subplot(121)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
73
    "    img.plot(images['mid_raw'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
74
    "    plt.subplot(122)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
75
    "    img.plot(images['mid'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
76
77
78
    "\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    plt.subplot(121)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
79
    "    img.plot(images['periph_raw'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
80
    "    plt.subplot(122)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
81
    "    img.plot(images['periph'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
82
83
84
85
    "\n",
    "    # Plot Blended\n",
    "    plt.figure(figsize=(12, 6))\n",
    "    plt.subplot(121)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
86
    "    img.plot(images['blended_raw'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
87
    "    plt.subplot(122)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
88
    "    img.plot(images['blended'])\n",
Nianchen Deng's avatar
Nianchen Deng committed
89
90
91
92
93
94
95
96
97
98
99
100
    "    plt.plot([(res_full[1] - 1) / 2 + center[0] - 5, (res_full[1] - 1) / 2 + center[0] + 5],\n",
    "                [(res_full[0] - 1) / 2 + center[1],\n",
    "                (res_full[0] - 1) / 2 + center[1]],\n",
    "                color=[0, 1, 0])\n",
    "    plt.plot([(res_full[1] - 1) / 2 + center[0], (res_full[1] - 1) / 2 + center[0]],\n",
    "                [(res_full[0] - 1) / 2 + center[1] - 5,\n",
    "                (res_full[0] - 1) / 2 + center[1] + 5],\n",
    "                color=[0, 1, 0])"
   ]
  },
  {
   "cell_type": "code",
101
   "execution_count": null,
Nianchen Deng's avatar
Nianchen Deng committed
102
   "metadata": {},
103
   "outputs": [],
Nianchen Deng's avatar
Nianchen Deng committed
104
   "source": [
105
106
107
    "os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_gas_all_in_one'))\n",
    "#os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_mc_all_in_one'))\n",
    "#os.chdir(os.path.join(rootdir, 'data/__0_user_study/lobby_all_in_one'))\n",
Nianchen Deng's avatar
Nianchen Deng committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    "print('Change working directory to ', os.getcwd())\n",
    "torch.autograd.set_grad_enabled(False)\n",
    "\n",
    "fovea_net = load_net(find_file('fovea'))\n",
    "periph_net = load_net(find_file('periph'))\n",
    "\n",
    "# Load Dataset\n",
    "views = load_views('nerf_views.json')\n",
    "print('Dataset loaded.')\n",
    "\n",
    "print('views:', views.size())\n",
    "#print('ref views:', ref_dataset.samples)\n",
    "\n",
    "fov_list = [20, 45, 110]\n",
    "res_list = [(128, 128), (256, 256), (256, 230)]  # (192,256)]\n",
    "res_full = (1600, 1440)\n",
    "gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
125
    "               device=device.default())\n"
Nianchen Deng's avatar
Nianchen Deng committed
126
127
128
129
   ]
  },
  {
   "cell_type": "code",
130
   "execution_count": null,
Nianchen Deng's avatar
Nianchen Deng committed
131
   "metadata": {},
132
   "outputs": [],
Nianchen Deng's avatar
Nianchen Deng committed
133
134
   "source": [
    "test_view = view.Trans(\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
136
    "    torch.tensor([[0.0, 0.0, 0.0]], device=device.default()),\n",
    "    torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device.default())\n",
Nianchen Deng's avatar
Nianchen Deng committed
137
    ")\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
138
    "perf = Perf(True, True)\n",
Nianchen Deng's avatar
Nianchen Deng committed
139
    "rays_o, rays_d = gen.layer_cams[0].get_global_rays(test_view, True)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
140
    "perf.checkpoint(\"GetRays\")\n",
Nianchen Deng's avatar
Nianchen Deng committed
141
142
143
    "rays_o = rays_o.view(-1, 3)\n",
    "rays_d = rays_d.view(-1, 3)\n",
    "coords, pts, depths = fovea_net.sampler(rays_o, rays_d)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
144
    "perf.checkpoint(\"Sample\")\n",
Nianchen Deng's avatar
Nianchen Deng committed
145
    "encoded = fovea_net.input_encoder(coords)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
    "perf.checkpoint(\"Encode\")\n",
Nianchen Deng's avatar
Nianchen Deng committed
147
148
149
150
151
152
    "print(\"Rays:\", rays_d)\n",
    "print(\"Spherical coords:\", coords)\n",
    "print(\"Depths:\", depths)\n",
    "print(\"Encoded:\", encoded)\n",
    "#plot_figures(images, center)\n",
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
153
    "#misc.create_dir('output/teasers')\n",
Nianchen Deng's avatar
Nianchen Deng committed
154
    "#for key in images:\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
155
    "#    img.save(\n",
Nianchen Deng's avatar
Nianchen Deng committed
156
157
158
159
160
161
162
163
164
165
166
167
168
    "#        images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
Nianchen Deng's avatar
sync    
Nianchen Deng committed
169
170
   "display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
   "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
Nianchen Deng's avatar
Nianchen Deng committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}