In [None]:
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))

import torch
import math
import matplotlib.pyplot as plt
import numpy as np
from deep_view_syn.my import util
from deep_view_syn.msl_net import *

# Select device
torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion

In [None]:
def PlotSphere(ax, r):
 u, v = np.mgrid[0:2 * np.pi:50j, 0:np.pi:20j]
 x = np.cos(u) * np.sin(v) * r
 y = np.sin(u) * np.sin(v) * r
 z = np.cos(v) * r
 ax.plot_surface(x, y, z, rstride=1, cstride=1,
 color='b', linewidth=0.5, alpha=0.3)


def PlotPlane(ax, r):
 # 二元函数定义域平面
 x = np.linspace(-r, r, 3)
 y = np.linspace(-r, r, 3)
 X, Y = np.meshgrid(x, y)
 ax.plot_wireframe(X, Y, X * 0, color='g', linewidth=1)


p = torch.tensor([[0.0, 0.0, 0.0]])
v = torch.tensor([[0.0, -1.0, 1.0]])
r = torch.tensor([[2.5]])
v = v / torch.norm(v) * r * 2
p_on_sphere_ = RaySphereIntersect(p, v, r)[0]
print(p_on_sphere_)
print(p_on_sphere_.norm())
spher_coord = RayToSpherical(p, v, r)
print(spher_coord[..., 1:3].rad2deg())
p_on_sphere = util.SphericalToCartesian(spher_coord)[0]

fig = plt.figure(figsize=(6, 6))
ax = fig.gca(projection='3d')
plt.xlabel('x')
plt.ylabel('z')

PlotPlane(ax, r.item())
PlotSphere(ax, r[0, 0].item())

ax.scatter([0], [0], [0], color="g", s=10) # Center
ax.scatter([p_on_sphere[0, 0].item()],
 [p_on_sphere[0, 2].item()],
 [p_on_sphere[0, 1].item()],
 color="r", s=10) # Ray position
ax.scatter([p_on_sphere_[0, 0].item()],
 [p_on_sphere_[0, 2].item()],
 [p_on_sphere_[0, 1].item()],
 color="y", s=10) # Ray position

p_ = p + v
ax.plot([p[0, 0].item(), p_[0, 0].item()],
 [p[0, 2].item(), p_[0, 2].item()],
 [p[0, 1].item(), p_[0, 1].item()],
 color="r")

ax.plot([p_on_sphere_[0, 0].item(), p_on_sphere_[0, 0].item()],
 [p_on_sphere_[0, 2].item(), p_on_sphere_[0, 2].item()],
 [0, p_on_sphere_[0, 1].item()], color="k", linestyle='--', linewidth=0.5)

ax.plot([p_on_sphere_[0, 0].item(), 0],
 [p_on_sphere_[0, 2].item(), 0],
 [0, 0],
 linewidth=0.5, linestyle="--", color="k")

ax.plot([p_on_sphere_[0, 0].item(), 0],
 [p_on_sphere_[0, 2].item(), 0],
 [p_on_sphere_[0, 1], 0],
 linewidth=0.5, linestyle="--", color="k")

ax.set_xlim(-r.item(), r.item())
ax.set_ylim(-r.item(), r.item())
ax.set_zlim(-r.item(), r.item())

plt.show()


# Test Dataset Loader & View-Spherical Transform

In [None]:
from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset
from deep_view_syn.data.spherical_view_syn import FastDataLoader

DATA_DIR = '../data/sp_view_syn_2020.12.28'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'

dataset = FastSphericalViewSynDataset(TRAIN_DATA_DESC_FILE)
dataset.set_patch_size((64, 64))
data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)
print(len(dataset))
plt.figure()
i = 0
for indices, patches, rays_o, rays_d in data_loader:
 print(i, patches.size(), rays_o.size(), rays_d.size())
 for idx in range(len(indices)):
 plt.subplot(4, 4, i + 1)
 util.PlotImageTensor(patches[idx])
 i += 1
 if i == 16:
 break


In [None]:
from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset

DATA_DIR = '../data/sp_view_syn_2020.12.26'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'
DEPTH_RANGE = (1, 10)
N_DEPTH_LAYERS = 10

def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:
 diopter_range = (1 / depth_range[1], 1 / depth_range[0])
 step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)
 depths = [1e5]
 depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]
 return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)

train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)
train_data_loader = torch.utils.data.DataLoader(
 dataset=train_dataset,
 batch_size=4,
 num_workers=8,
 pin_memory=True,
 shuffle=True,
 drop_last=False)
print(len(train_data_loader))

print("view_res", train_dataset.view_res)
print("cam_params", train_dataset.cam_params)

msl_net = MslNet(train_dataset.cam_params,
 _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),
 train_dataset.view_res).to(device.GetDevice())
print("sphere layers", msl_net.rendering.sphere_layers)

p = None
v = None
centers = None
plt.figure(figsize=(6, 6))
for _, view_images, ray_positions, ray_directions in train_data_loader:
 p = ray_positions
 v = ray_directions
 plt.subplot(2, 2, 1)
 util.PlotImageTensor(view_images[0])
 plt.subplot(2, 2, 2)
 util.PlotImageTensor(view_images[1])
 plt.subplot(2, 2, 3)
 util.PlotImageTensor(view_images[2])
 plt.subplot(2, 2, 4)
 util.PlotImageTensor(view_images[3])
 break
p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),
 torch.tensor([[1]], device=device.GetDevice()))) \
 .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)
v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()
p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()

fig = plt.figure(figsize=(6, 6))
ax = fig.gca(projection='3d')
plt.xlabel('x')
plt.ylabel('z')

PlotSphere(ax, 1)

ax.scatter([0], [0], [0], color="k", s=10) # Center

colors = [ 'r', 'g', 'b', 'y' ]
for i in range(4):
 ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)
 for j in range(p_.shape[1]):
 ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],
 [centers[i, 2], centers[i, 2] + v[i, j, 2]],
 [centers[i, 1], centers[i, 1] + v[i, j, 1]],
 color=colors[i], linewidth=0.5, alpha=0.6)

ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)

plt.show()


# Test Sampler

In [None]:
from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset

DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'
SAMPLE_PARAMS = {
 'depth_range': (1, 5),
 'n_samples': 5,
 'perturb_sample': False
}

train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)
train_data_loader = torch.utils.data.DataLoader(
 dataset=train_dataset,
 batch_size=1,
 num_workers=8,
 pin_memory=True,
 shuffle=True,
 drop_last=False)
print(len(train_data_loader))

print("view_res", train_dataset.view_res)
print("cam_params", train_dataset.cam_params)

sampler = Sampler(**SAMPLE_PARAMS)

fig = plt.figure(figsize=(12, 12))
ax = fig.gca(projection='3d')
plt.xlabel('x')
plt.ylabel('z')

i = 0
selector: np.ndarray = np.array([j for j in range(65536)])
selector = selector.reshape(256, 256)[::30, ::30]
selector = selector.flatten()
for _, pixels, p, v in train_data_loader:
 p = p.to(device.GetDevice())
 v = v.to(device.GetDevice())
 p_ = sampler(p, v)[0].squeeze().cpu().numpy()[selector]
 pixels_ = pixels.squeeze().permute(1, 2, 0).flatten(0, 1).cpu().numpy()[selector]
 for j in range(p_.shape[0]):
 #ax.plot(p_[j, :, 0], p_[j, :, 2], p_[j, :, 1], color=pixels_[j], linewidth=0.2)#, s=0.3)
 ax.scatter(p_[j, :, 0], p_[j, :, 2], p_[j, :, 1], color=pixels_[j], s=0.7)
 i += 1
 if i >= 20:
 break


ax.scatter([0], [0], [0], color="k", s=10) # Center

ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
ax.set_zlim(-5, 5)
#ax.view_init(elev=90,azim=-90)

plt.show()

In [None]:
from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset

DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'
DEPTH_RANGE = (1, 10)
N_DEPTH_LAYERS = 10

def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:
 diopter_range = (1 / depth_range[1], 1 / depth_range[0])
 step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)
 depths = [1e5]
 depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]
 return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)

train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE, ray_as_item=True)
train_data_loader = torch.utils.data.DataLoader(
 dataset=train_dataset,
 batch_size=4096,
 num_workers=8,
 pin_memory=True,
 shuffle=True,
 drop_last=False)
print(len(train_data_loader))

print("view_res", train_dataset.view_res)
print("cam_params", train_dataset.cam_params)

#msl_net = MslNet(train_dataset.cam_params,
# _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),
# train_dataset.view_res).to(device.GetDevice())
#print("sphere layers", msl_net.rendering.sphere_layers)

fig = plt.figure(figsize=(12, 12))
ax = fig.gca(projection='3d')
plt.xlabel('x')
plt.ylabel('z')

i = 0
selector: np.ndarray = np.array([j for j in range(65536)])
selector = selector.reshape(256, 256)[::3, ::3]
selector = selector.flatten()
for _, pixels, ray_positions, ray_directions in train_data_loader:
 p = ray_positions
 v = ray_directions / ray_directions.norm(dim=1, keepdim=True)
 v = v.numpy()
 #ax.scatter(v[selector, 0], v[selector, 2], v[selector, 1], color=pixels.numpy()[selector, :], s=0.1)
 ax.scatter(v[:, 0], v[:, 2], v[:, 1], color=pixels.numpy(), s=0.1)
 i += 1
 if i >= 20:
 break


ax.scatter([0], [0], [0], color="k", s=10) # Center

ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
ax.view_init(elev=0,azim=-90)

plt.show()


# Test Spherical View Synthesis

In [None]:
import ipywidgets as widgets # 控件库
from IPython.display import display # 显示控件的方法
from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset
from deep_view_syn.spher_net import SpherNet
from deep_view_syn.my import netio

DATA_DIR = '../data/sp_view_syn_2020.12.28_small'
DATA_DESC_FILE = DATA_DIR + '/train.json'
NET_FILE = DATA_DIR + '/rgb_ray_b2048_encode10_fc256x8/model-epoch_300.pth'
N_ENCODE_DIM = 10
FC_PARAMS = {
 'nf': 256,
 'n_layers': 8,
 'skips': []
}
GRAY = False
ROT_ONLY = False
FOV = 20

out_res = (256, 256)
cam_params = {
 'fx': out_res[0] / util.Fov2Length(FOV),
 'fy': -out_res[0] / util.Fov2Length(FOV),
 'cx': out_res[0] / 2,
 'cy': out_res[1] / 2
}
local_rays = util.GetLocalViewRays(cam_params, out_res, flatten=True).to(device.GetDevice())

model = SpherNet(cam_params=cam_params,
 fc_params=FC_PARAMS,
 out_res=out_res,
 gray=GRAY,
 translation=not ROT_ONLY,
 encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
netio.LoadNet(NET_FILE, model)

slider_x = widgets.FloatSlider(description='X', value=0,
 min=-0.05, max=0.05, step=0.002,
 continuous_update=True,
 readout=True, readout_format='.3f')
slider_y = widgets.FloatSlider(description='Y', value=0,
 min=-0.025, max=0.025, step=0.002,
 continuous_update=True,
 readout=True, readout_format='.3f')
slider_z = widgets.FloatSlider(description='Z', value=0,
 min=-0.05, max=0.05, step=0.002,
 continuous_update=True,
 readout=True, readout_format='.3f')
slider_theta = widgets.IntSlider(description='θ', value=90,
 min=10, max=170, step=2,
 continuous_update=True,
 readout=True, readout_format='.1f')
slider_phi = widgets.IntSlider(description='φ', value=90,
 min=-70, max=110, step=2,
 continuous_update=True,
 readout=True, readout_format='.1f')

plt.figure()

def f(x, y, z, theta, phi):
 print((x, y, z, theta, phi))
 # p: 1 x M x 3
 p = torch.tensor([[[x, y, z]]], device=device.GetDevice()).expand(-1, local_rays.size(0), -1)
 r = util.GetRotMatrix(math.radians(theta), math.radians(phi)).to(device.GetDevice())
 # v: 1 x M x 3
 v = torch.mm(local_rays, r).unsqueeze(0)
 print(local_rays, r)
 image = model(p, v)
 util.PlotImageTensor(image)

out = widgets.interactive_output(f, {
 'x': slider_x, 'y': slider_y, 'z': slider_z,
 'theta': slider_theta, 'phi': slider_phi
})
display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)
