test_refinement.ipynb 11.6 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
{
 "cells": [
  {
   "cell_type": "code",
5
   "execution_count": null,
BobYeah's avatar
sync    
BobYeah committed
6
   "metadata": {},
7
   "outputs": [],
BobYeah's avatar
sync    
BobYeah committed
8
9
10
11
12
13
   "source": [
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
14
15
    "os.chdir('../')\n",
    "sys.path.append(os.getcwd())\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
16
    "torch.cuda.set_device(1)\n",
BobYeah's avatar
sync    
BobYeah committed
17
18
    "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
19
20
21
22
23
24
25
    "from data.spherical_view_syn import *\n",
    "from configs.spherical_view_syn import SphericalViewSynConfig\n",
    "from utils import netio\n",
    "from utils import img\n",
    "from utils import device\n",
    "from utils import view\n",
    "from components import refine\n",
BobYeah's avatar
sync    
BobYeah committed
26
27
    "\n",
    "\n",
28
    "os.chdir('data/us_gas_all_in_one')\n",
BobYeah's avatar
sync    
BobYeah committed
29
30
31
32
33
34
    "print('Change working directory to ', os.getcwd())\n",
    "torch.autograd.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
35
   "execution_count": null,
BobYeah's avatar
sync    
BobYeah committed
36
   "metadata": {},
37
   "outputs": [],
BobYeah's avatar
sync    
BobYeah committed
38
39
   "source": [
    "# Load Config\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
    "model_path = 'fovea@nmsl-rgb_e10_fc128x4_d1-50_s32.pth'\n",
BobYeah's avatar
sync    
BobYeah committed
41
    "config = SphericalViewSynConfig()\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
    "config.from_id(os.path.splitext(os.path.basename(model_path))[0])\n",
Nianchen Deng's avatar
Nianchen Deng committed
43
    "config.sa['perturb_sample'] = False\n",
BobYeah's avatar
sync    
BobYeah committed
44
45
46
    "config.print()\n",
    "\n",
    "# Load Dataset\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
48
    "view_dataset = SphericalViewSynDataset(\n",
    "    'views.json', load_images=False, load_depths=False,\n",
Nianchen Deng's avatar
Nianchen Deng committed
49
    "    color=config.c, calculate_rays=False)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
50
51
    "ref_dataset = SphericalViewSynDataset(\n",
    "    'ref.json', load_images=False, load_depths=False,\n",
Nianchen Deng's avatar
Nianchen Deng committed
52
    "    color=config.c, calculate_rays=False)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
54
    "print('Dataset loaded.')\n",
    "\n",
BobYeah's avatar
sync    
BobYeah committed
55
    "def read_ref_images(idx):\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
56
    "    patt = 'ref/view_%04d.png'\n",
BobYeah's avatar
sync    
BobYeah committed
57
    "    if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
    "        return img.load([patt % i for i in idx]).to(device.default())\n",
BobYeah's avatar
sync    
BobYeah committed
59
    "    else:\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
60
    "        return img.load(patt % idx).to(device.default())\n",
BobYeah's avatar
sync    
BobYeah committed
61
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
62
    "indices = torch.arange(view_dataset.n_views, device=device.default()).view(\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
    "    view_dataset.samples)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
    "ref_indices = torch.arange(ref_dataset.n_views, device=device.default()).view(\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
65
    "    ref_dataset.samples)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
    "cam_params = view.Camera({\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
67
68
69
70
    "    \"fov\": 20,\n",
    "    \"cx\": 0.5,\n",
    "    \"cy\": 0.5,\n",
    "    \"normalized\": True\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71
    "}, (100, 100)).to(device.default())\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
72
    "ref_cam_params = ref_dataset.cam_params\n",
BobYeah's avatar
sync    
BobYeah committed
73
74
    "\n",
    "# Load Spher net\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
75
76
    "net = config.create_net().to(device.default())\n",
    "netio.load(model_path, net)\n",
BobYeah's avatar
sync    
BobYeah committed
77
78
    "print('Net loaded.')\n",
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
79
    "vr_cam = view.Camera({\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
80
81
82
83
84
85
86
87
88
89
90
91
92
    "    'fov': 110,\n",
    "    'cx': 0.5,\n",
    "    'cy': 0.5,\n",
    "    'normalized': True\n",
    "}, (1600, 1440))\n",
    "\n",
    "def adjust_cam(cam, vr_cam, gaze_center):\n",
    "    fovea_offset = (\n",
    "        (gaze_center[0]) / vr_cam.f[0].item() * cam.f[0].item(),\n",
    "        (gaze_center[1]) / vr_cam.f[1].item() * cam.f[1].item()\n",
    "    )\n",
    "    cam.c[0] = cam.res[1] / 2 - fovea_offset[0]\n",
    "    cam.c[1] = cam.res[0] / 2 - fovea_offset[1]"
BobYeah's avatar
sync    
BobYeah committed
93
94
95
96
   ]
  },
  {
   "cell_type": "code",
97
   "execution_count": null,
BobYeah's avatar
sync    
BobYeah committed
98
   "metadata": {},
99
   "outputs": [],
BobYeah's avatar
sync    
BobYeah committed
100
   "source": [
Nianchen Deng's avatar
sync    
Nianchen Deng committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    "center = (-187, 64)\n",
    "test_view_coord = list(val // 2 for val in view_dataset.samples)\n",
    "#test_view_coord[3] -= 1\n",
    "test_view_coord = tuple(test_view_coord)\n",
    "test_view_coord_in_ref = (\n",
    "    test_view_coord[0] + 1, test_view_coord[1] + 1, test_view_coord[2] + 1,\n",
    "    test_view_coord[3] + 2, test_view_coord[4] - 1,\n",
    ")\n",
    "print('test_view_coord', test_view_coord)\n",
    "print('test_view_coord_in_ref', test_view_coord_in_ref)\n",
    "print('ref_dataset.samples', ref_dataset.samples)\n",
    "print('ref_indices.size()', ref_indices.size())\n",
    "print('indices.size()', indices.size())\n",
    "test_view_idx = indices[test_view_coord]\n",
    "a = ref_indices[test_view_coord_in_ref[0] - 2, test_view_coord_in_ref[1] + 2, test_view_coord_in_ref[2],\n",
    "                test_view_coord_in_ref[3] - 1, test_view_coord_in_ref[4] + 1][None]\n",
    "b = ref_indices[test_view_coord_in_ref[0] + 2, test_view_coord_in_ref[1] + 2, test_view_coord_in_ref[2],\n",
    "                test_view_coord_in_ref[3] + 1, test_view_coord_in_ref[4] + 1][None]\n",
    "c = ref_indices[test_view_coord_in_ref[0] - 2, test_view_coord_in_ref[1] - 2, test_view_coord_in_ref[2],\n",
    "                test_view_coord_in_ref[3] - 1, test_view_coord_in_ref[4] - 1][None]\n",
    "d = ref_indices[test_view_coord_in_ref[0] + 2, test_view_coord_in_ref[1] - 2, test_view_coord_in_ref[2],\n",
    "                test_view_coord_in_ref[3] + 1, test_view_coord_in_ref[4] - 1][None]\n",
    "bound_view_idxs = torch.cat([a, b, c, d])\n",
    "print(bound_view_idxs)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
125
    "# bound_view_idxs = [\n",
BobYeah's avatar
sync    
BobYeah committed
126
127
128
129
    "#    indices[center_view_coord[0], center_view_coord[1], center_view_coord[2],\n",
    "#            center_view_coord[3] - 1, center_view_coord[4]],\n",
    "#    indices[center_view_coord[0], center_view_coord[1], center_view_coord[2],\n",
    "#            center_view_coord[3] + 1, center_view_coord[4]],\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
130
131
    "# ]\n",
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    "def get_guides(view_coord):\n",
    "    coord_offset = [val // 2 - view_dataset.samples[i] // 2 for i, val in enumerate(ref_dataset.samples)]\n",
    "    guides_coord = [\n",
    "        [\n",
    "            view_coord[0] - 2, view_coord[1] + 2, view_coord[2],\n",
    "            view_coord[3] - 1, view_coord[4] + 1\n",
    "        ],\n",
    "        [\n",
    "            view_coord[0] + 2, view_coord[1] + 2, view_coord[2],\n",
    "            view_coord[3] + 1, view_coord[4] + 1\n",
    "        ],\n",
    "        [\n",
    "            view_coord[0] - 2, view_coord[1] - 2, view_coord[2],\n",
    "            view_coord[3] - 1, view_coord[4] - 1\n",
    "        ],\n",
    "        [\n",
    "            view_coord[0] + 2, view_coord[1] - 2, view_coord[2],\n",
    "            view_coord[3] + 1, view_coord[4] - 1\n",
    "        ]\n",
    "    ]\n",
    "    for coord in guides_coord:\n",
    "        for i in range(len(coord_offset)):\n",
    "            coord[i] += coord_offset[i]\n",
    "    ref_indices = torch.arange(0, ref_dataset.n_views,\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
156
    "                               device=device.default()).view(ref_dataset.samples)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
157
158
159
160
    "    guides_idx = torch.stack([\n",
    "        ref_indices[tuple(coord)] for coord in guides_coord\n",
    "    ])\n",
    "    print('guides_idx:', guides_idx)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
161
    "    guides_image = read_ref_images(guides_idx).to(device.default())\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    "    guides_trans = view.Trans(ref_dataset.view_centers[guides_idx],\n",
    "                              ref_dataset.view_rots[guides_idx])\n",
    "    return refine.GuideRefinement(guides_image, guides_trans, ref_cam_params, net)\n",
    "\n",
    "guide_refine = get_guides(test_view_coord)\n",
    "\n",
    "def gen(fovea_center, trans):\n",
    "    adjust_cam(cam_params, vr_cam, fovea_center)\n",
    "\n",
    "    fovea_rays_o, fovea_rays_d = cam_params.get_global_rays(trans)  # (H_fovea, W_fovea, 3)\n",
    "\n",
    "    fovea_inferred, fovea_depthmap = net(\n",
    "        fovea_rays_o.view(-1, 3), fovea_rays_d.view(-1, 3), ret_depth=True)\n",
    "    fovea_inferred = fovea_inferred.view(\n",
    "        cam_params.res[0], cam_params.res[1], -1).permute(2, 0, 1)  # (C, H_fovea, W_fovea)\n",
    "    fovea_depthmap = fovea_depthmap.view(cam_params.res[0], cam_params.res[1])\n",
    "    \n",
    "    fovea_refined = guide_refine.refine_by_guide(fovea_inferred, fovea_depthmap,\n",
    "                                     fovea_rays_o, fovea_rays_d, False)\n",
    "\n",
    "    return {\n",
    "        'fovea_raw': fovea_inferred,\n",
    "        'fovea': fovea_refined,\n",
    "        'fovea_depth': fovea_depthmap\n",
    "    }\n",
    "\n",
    "#adjust_cam(cam_params, vr_cam, center)\n",
    "trans = view.Trans(view_dataset.view_centers[test_view_idx],\n",
    "                   view_dataset.view_rots[test_view_idx])\n",
    "#rays_o, rays_d = cam_params.get_global_rays(trans, flatten=True)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
192
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    "#inferred, depthmap = net(rays_o.view(-1, 3),\n",
    "#                         rays_d.view(-1, 3), ret_depth=True)\n",
    "#inferred = inferred.view(\n",
    "#    cam_params.res[0], cam_params.res[1], -1).permute(2, 0, 1)\n",
    "#inferred = nn_f.upsample_bilinear(inferred.unsqueeze(0), scale_factor=2)[0]\n",
    "#depthmap = depthmap.view(cam_params.res[0], cam_params.res[1])\n",
    "#depthmap = nn_f.upsample_bilinear(depthmap[None, None, :, :], scale_factor=2)[0, 0]\n",
    "#gt = view_dataset.view_images[test_view_idx]\n",
    "#bounds_img = read_ref_images(bound_view_idxs)\n",
    "#bounds_o = ref_dataset.view_centers[bound_view_idxs]\n",
    "#bounds_r = ref_dataset.view_rots[bound_view_idxs]\n",
    "images = gen(center, trans)\n",
    "inferred = images['fovea_raw']\n",
    "refined = images['fovea']\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
207
    "\n",
BobYeah's avatar
sync    
BobYeah committed
208
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
209
210
    "#guide_refine = refine.GuideRefinement(bounds_img, view.Trans(bounds_o, bounds_r), ref_cam_params, net)\n",
    "#refined = guide_refine.refine_by_guide(inferred, depthmap, rays_o, rays_d, False)\n",
BobYeah's avatar
sync    
BobYeah committed
211
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
212
    "# warped = [nn_f.grid_sample(bounds_img[i], bounds_warp[i])\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
213
    "#          for i in range(len(bounds_warp))]\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
214
    "# warped_inferred = [nn_f.grid_sample(\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
215
    "#    bounds_inferred[i], bounds_warp[i]) for i in range(len(bounds_warp))]\n",
BobYeah's avatar
sync    
BobYeah committed
216
217
218
    "\n",
    "fig = plt.figure(figsize=(12, 3))\n",
    "plt.set_cmap('Greys_r')\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
219
    "plt.subplot(1, 3, 1)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
220
    "img.plot(inferred)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
221
    "plt.subplot(1, 3, 2)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
222
    "img.plot(refined)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
223
    "#plt.subplot(1, 3, 3)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
224
    "#img.plot(gt)\n",
BobYeah's avatar
sync    
BobYeah committed
225
226
    "plt.show()\n",
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
227
    "\n",
BobYeah's avatar
sync    
BobYeah committed
228
229
230
231
232
    "def plot_image_matrices(center_image, ref_images):\n",
    "    if len(ref_images) == 2:\n",
    "        plt.figure(figsize=(12, 4))\n",
    "        plt.set_cmap('Greys_r')\n",
    "        plt.subplot(1, 3, 1)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
233
    "        img.plot(ref_images[0])\n",
BobYeah's avatar
sync    
BobYeah committed
234
    "        plt.subplot(1, 3, 3)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
235
    "        img.plot(ref_images[1])\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
236
237
    "        if center_image != None:\n",
    "            plt.subplot(1, 3, 2)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
238
    "            img.plot(center_image)\n",
BobYeah's avatar
sync    
BobYeah committed
239
240
241
242
    "    elif len(ref_images) == 4:\n",
    "        plt.figure(figsize=(12, 12))\n",
    "        plt.set_cmap('Greys_r')\n",
    "        plt.subplot(3, 3, 1)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
243
    "        img.plot(ref_images[0])\n",
BobYeah's avatar
sync    
BobYeah committed
244
    "        plt.subplot(3, 3, 3)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
245
    "        img.plot(ref_images[1])\n",
BobYeah's avatar
sync    
BobYeah committed
246
    "        plt.subplot(3, 3, 7)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
247
    "        img.plot(ref_images[2])\n",
BobYeah's avatar
sync    
BobYeah committed
248
    "        plt.subplot(3, 3, 9)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
249
    "        img.plot(ref_images[3])\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
250
251
    "        if center_image != None:\n",
    "            plt.subplot(3, 3, 5)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
252
    "            img.plot(center_image)\n",
BobYeah's avatar
sync    
BobYeah committed
253
254
    "    plt.show()\n",
    "\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
255
256
    "\n",
    "#plot_image_matrices(input, warped_inferred)\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
257
258
    "plot_image_matrices(None, bounds_img)\n",
    "# plot_image_matrices(torch.cat(warped[0:3], 1) if len(\n",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
259
    "#    warped) >= 3 else torch.cat(warped + [torch.zeros_like(warped[0])], 1), warped)\n"
BobYeah's avatar
sync    
BobYeah committed
260
261
262
263
264
265
266
267
268
269
270
271
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
Nianchen Deng's avatar
sync    
Nianchen Deng committed
272
273
   "display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
   "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
BobYeah's avatar
sync    
BobYeah committed
274
275
276
277
278
279
280
281
282
283
284
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
Nianchen Deng's avatar
sync    
Nianchen Deng committed
285
   "version": "3.7.9"
BobYeah's avatar
sync    
BobYeah committed
286
287
288
289
290
291
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}