In [1]:
import sys
import os
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as trans_f

sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
torch.cuda.set_device(0)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from deeplightfield.data.spherical_view_syn import *
from deeplightfield.msl_net import MslNet
from deeplightfield.configs.spherical_view_syn import SphericalViewSynConfig
from deeplightfield.my import netio
from deeplightfield.my import util
from deeplightfield.my import device
from deeplightfield.my import view
from deeplightfield.my.simple_perf import SimplePerf
from deeplightfield.my.foveation import Foveation


os.chdir(sys.path[0] + '/../data/sp_view_syn_2021.01.04_all_in_one')
print('Change working directory to ', os.getcwd())
torch.autograd.set_grad_enabled(False)
GRAY = False

Set CUDA:0 as current device.
Change working directory to /e/dengnc/deeplightfield/data/sp_view_syn_2021.01.04_all_in_one


In [2]:
def load_net(name):
 # Load Config
 config = SphericalViewSynConfig()
 config.load_by_name(name.split('@')[1])
 config.SAMPLE_PARAMS['spherical'] = True
 config.SAMPLE_PARAMS['perturb_sample'] = False
 config.print()
 net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, GRAY,
 config.N_ENCODE_DIM).to(device.GetDevice())
 netio.LoadNet(name + '.pth', net)
 return net


def read_ref_images(idx):
 patt = 'ref/view_%04d.png'
 if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:
 return trans_f.rgb_to_grayscale(util.ReadImageTensor([patt % i for i in idx]))
 else:
 return trans_f.rgb_to_grayscale(util.ReadImageTensor(patt % idx))


if GRAY:
 fovea_net = load_net('fovea@msl_coarse_gray1')
 periph_net = load_net('periph@msl_gray_periph')
else:
 fovea_net = load_net('fovea@msl_coarse_rgb1')
 periph_net = load_net('periph@msl_rgb_periph')

# Load Dataset
view_dataset = SphericalViewSynDataset(
 'train.json', load_images=False, load_depths=False,
 gray=GRAY, calculate_rays=False)
ref_dataset = SphericalViewSynDataset(
 'ref.json', load_images=False, load_depths=False,
 gray=GRAY, calculate_rays=False)
print('Dataset loaded.')

fov_list = [10, 60, 110]
res_list = [(64, 64), (256, 256), (256, 256)]
cams = [
 view.CameraParam({
 "fov": fov_list[i],
 "cx": 0.5,
 "cy": 0.5,
 "normalized": True
 }, res_list[i]).to(device.GetDevice())
 for i in range(len(fov_list))
]
fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]
ref_cam_params = ref_dataset.cam_params

indices = torch.arange(view_dataset.n_views,
 device=device.GetDevice()).view(view_dataset.samples)
ref_indices = torch.arange(
 ref_dataset.n_views, device=device.GetDevice()).view(ref_dataset.samples)


==== Config msl_coarse_rgb1 ====
Net type: msl
Encode dim: 10
Full-connected network parameters: {'nf': 64, 'n_layers': 12, 'skips': []}
Sample parameters {'spherical': True, 'depth_range': (1, 20), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}
Load net from fovea@msl_coarse_rgb1.pth ...
==== Config msl_rgb_periph ====
Net type: msl
Encode dim: 10
Full-connected network parameters: {'nf': 64, 'n_layers': 8, 'skips': []}
Sample parameters {'spherical': True, 'depth_range': (1, 50), 'n_samples': 4, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}
Load net from periph@msl_rgb_periph.pth ...
Dataset loaded.


In [None]:
view_coord = [val // 2 for val in view_dataset.samples]
view_coord[0] = view_coord[0] + 1
print(view_coord, indices.size())
view_idx = indices[tuple(view_coord)]
view_o = view_dataset.view_centers[view_idx] # (3)
view_r = view_dataset.view_rots[view_idx] # (3, 3)
foveation = Foveation(fov_list, (1440, 1440), device=device.GetDevice())

perf = SimplePerf(True, True)

fovea_rays_o, fovea_rays_d = fovea_cam.get_global_rays(view_o, view_r) # (H_fovea, W_fovea, 3)
mid_rays_o, mid_rays_d = mid_cam.get_global_rays(view_o, view_r) # (H_mid, W_mid, 3)
periph_rays_o, periph_rays_d = periph_cam.get_global_rays(view_o, view_r) # (H_periph, W_periph, 3)
mid_periph_rays_o = torch.stack([mid_rays_o, periph_rays_o], 0)
mid_periph_rays_d = torch.stack([mid_rays_d, periph_rays_d], 0)
perf.Checkpoint('Get rays')

perf1 = SimplePerf(True, True)

fovea_inferred = fovea_net(fovea_rays_o.view(-1, 3), fovea_rays_d.view(-1, 3)).view(
 fovea_cam.res[0], fovea_cam.res[1], -1).permute(2, 0, 1) # (C, H_fovea, W_fovea)
perf1.Checkpoint('Infer fovea')

#mid_inferred = periph_net(mid_rays_o, mid_rays_d) # (C, H_mid, W_mid)
#perf1.Checkpoint('Infer mid')

#periph_inferred = periph_net(periph_rays_o, periph_rays_d) # (C, H_periph, W_periph)
#perf1.Checkpoint('Infer periph')

periph_mid_inferred = periph_net(mid_periph_rays_o.view(-1, 3),
 mid_periph_rays_d.view(-1, 3)) # (C, H_periph, W_periph)
periph_mid_inferred = periph_mid_inferred.view(2, mid_cam.res[0], mid_cam.res[1], -1).permute(0, 3, 1, 2)
mid_inferred = periph_mid_inferred[0]
periph_inferred = periph_mid_inferred[1]
perf1.Checkpoint('Infer mid & periph')

perf.Checkpoint('Infer')

blended = foveation.synthesis([
 fovea_inferred[None, ...],
 mid_inferred[None, ...],
 periph_inferred[None, ...]
])

perf.Checkpoint('Blend')

plt.figure(figsize=(12, 4))
plt.set_cmap('Greys_r')
plt.subplot(1, 3, 1)
util.PlotImageTensor(fovea_inferred)
plt.subplot(1, 3, 2)
util.PlotImageTensor(mid_inferred)
plt.subplot(1, 3, 3)
util.PlotImageTensor(periph_inferred)

plt.figure(figsize=(12, 12))
util.PlotImageTensor(blended)

util.CreateDirIfNeed('output')
util.WriteImageTensor(blended, 'output/blended_%04d.png' % view_idx)