In [None]:
import sys
import os

rootdir = os.path.abspath('../')
sys.path.append(rootdir)

import torch
import matplotlib.pyplot as plt
import numpy as np
from utils import img
from utils import sphere
from utils.constants import *
from nets.msl_net import *

# Select device
torch.cuda.set_device(0)
print("Set CUDA:%d as current device." % torch.cuda.current_device())


# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion

In [None]:
def PlotSphere(ax, r):
 u, v = np.mgrid[0:2 * PI:50j, 0:PI:20j]
 x = np.cos(u) * np.sin(v) * r
 y = np.sin(u) * np.sin(v) * r
 z = np.cos(v) * r
 ax.plot_surface(x, y, z, rstride=1, cstride=1,
 color='b', linewidth=0.5, alpha=0.3)


def PlotPlane(ax, r):
 # 二元函数定义域平面
 x = np.linspace(-r, r, 3)
 y = np.linspace(-r, r, 3)
 X, Y = np.meshgrid(x, y)
 ax.plot_wireframe(X, Y, X * 0, color='g', linewidth=1)


p = torch.tensor([[0.0, 0.0, 0.0]])
v = torch.tensor([[0.0, -1.0, 1.0]])
r = torch.tensor([[0.5]])
v = v / torch.norm(v) * r * 2
p_on_sphere_ = sphere.ray_sphere_intersect(p, v, r)[0][0]
print(p_on_sphere_)
print(p_on_sphere_.norm())
spher_coord = sphere.cartesian2spherical(p_on_sphere_)
print(spher_coord[..., 1:3].rad2deg())
p_on_sphere = sphere.spherical2cartesian(spher_coord)
print(p_on_sphere_.size())

fig = plt.figure(figsize=(8, 8))
ax = fig.gca(projection='3d')
plt.xlabel('x')
plt.ylabel('z')

PlotPlane(ax, r.item())
PlotSphere(ax, r[0, 0].item())

ax.scatter([0], [0], [0], color="g", s=10) # Center
ax.scatter([p_on_sphere[0, 0].item()],
 [p_on_sphere[0, 2].item()],
 [p_on_sphere[0, 1].item()],
 color="r", s=10) # Ray position
ax.scatter([p_on_sphere_[0, 0].item()],
 [p_on_sphere_[0, 2].item()],
 [p_on_sphere_[0, 1].item()],
 color="y", s=10) # Ray position

p_ = p + v
ax.plot([p[0, 0].item(), p_[0, 0].item()],
 [p[0, 2].item(), p_[0, 2].item()],
 [p[0, 1].item(), p_[0, 1].item()],
 color="r")

ax.plot([p_on_sphere_[0, 0].item(), p_on_sphere_[0, 0].item()],
 [p_on_sphere_[0, 2].item(), p_on_sphere_[0, 2].item()],
 [0, p_on_sphere_[0, 1].item()], color="k", linestyle='--', linewidth=0.5)

ax.plot([p_on_sphere_[0, 0].item(), 0],
 [p_on_sphere_[0, 2].item(), 0],
 [0, 0],
 linewidth=0.5, linestyle="--", color="k")

ax.plot([p_on_sphere_[0, 0].item(), 0],
 [p_on_sphere_[0, 2].item(), 0],
 [p_on_sphere_[0, 1], 0],
 linewidth=0.5, linestyle="--", color="k")

ax.set_xlim(-r.item(), r.item())
ax.set_ylim(-r.item(), r.item())
ax.set_zlim(-r.item(), r.item())

plt.show()


# Test Dataset Loader

In [None]:
from data.spherical_view_syn import SphericalViewSynDataset
from data.loader import FastDataLoader

DATA_DESC_FILE = f'{rootdir}/data/__new/street_fovea_r360x80_t1.0/train1.json'

dataset = SphericalViewSynDataset(DATA_DESC_FILE)
data_loader = FastDataLoader(dataset, 4, shuffle=False)

fig = plt.figure(figsize=(12, 6.5))
i = 0
for indices, patches, rays_o, rays_d in data_loader:
 print(i, patches.size(), rays_o.size(), rays_d.size())
 for idx in range(len(indices)):
 plt.subplot(4, 7, i + 1)
 img.plot(patches[idx])
 i += 1
 if i == 4:
 break


# Validate Dataset

In [None]:
from data.spherical_view_syn import SphericalViewSynDataset
from data.loader import FastDataLoader


#DATA_DESC_FILE = f'{rootdir}/data/pabellon_fovea_r40x40_t0.3/train.json'
#DATA_DESC_FILE = f'{rootdir}/data/gas_fovea_r80x60_t0.3_2021.01.26/train.json'
#DATA_DESC_FILE = f'{rootdir}/data/nerf_fern/train.json'
#DATA_DESC_FILE = f'{rootdir}/data/lobby_fovea_2021.01.18/train.json'
#DATA_DESC_FILE = f'{rootdir}/data/__new/street_fovea_r360x80_t1.0/train1.json'
#DATA_DESC_FILE = f'{rootdir}/data/__new/stones_fovea_r360x80_t1.0/train1.json'
#DATA_DESC_FILE = f'{rootdir}/data/__new/lobby_periph_r360x180_t1.0/train1.json'
DATA_DESC_FILE = f'{rootdir}/data/__new/classroom_all/nerf_cvt.json'


dataset = SphericalViewSynDataset(DATA_DESC_FILE, load_views=range(12))
dataset.set_patch_size(1)
res = dataset.view_res
data_loader = FastDataLoader(dataset, res[0] * res[1], shuffle=False)

selector = torch.arange(res[0] * res[1]).reshape(res[0], res[1])[::5, ::5].flatten()

for ri in range(0, 4):
 r = ri * 2 + 1
 p = None
 centers = None
 pixels = None
 idx_range = list(range(12)) #+ list(range(24, 30)) + list(range(42, 48))
 idx = 0
 for indices, patches, rays_o, rays_d in data_loader:
 if idx not in idx_range:
 idx += 1
 continue
 patches = patches[selector]
 rays_o = rays_o[selector]
 rays_d = rays_d[selector]
 r = torch.tensor([[r]], device=device.default())
 p_ = misc.torch2np(sphere.ray_sphere_intersect(rays_o, rays_d, r)[0].view(-1, 3))
 p = p_ if p is None else np.concatenate((p, p_), axis=0)
 pixels_ = misc.torch2np(patches)
 pixels = pixels_ if pixels is None else np.concatenate((pixels, pixels_), axis=0)
 idx += 1

 plt.figure(facecolor='white', figsize=(20, 20))
 ax = plt.axes(projection='3d')
 #ax = plt.subplot(1, 2, ri % 2 + 1, projection='3d')
 plt.xlabel('x')
 plt.ylabel('z')
 plt.title('r = %f' % r)
 ax.scatter([0], [0], [0], color="k", s=10)
 ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)
 ax.view_init(elev=0, azim=-90)
