In [1]:
import sys
import os
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as trans_f

sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
__package__ = "deep_view_syn.notebook"
torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from ..data.spherical_view_syn import *
from ..configs.spherical_view_syn import SphericalViewSynConfig
from ..my import netio
from ..my import util
from ..my import device
from ..my import view
from ..my.foveation import Foveation
from ..my.gen_final import GenFinal


def load_net(path):
    config = SphericalViewSynConfig()
    config.from_id(path[:-4])
    config.SAMPLE_PARAMS['perturb_sample'] = False
    config.print()
    net = config.create_net().to(device.GetDevice())
    netio.LoadNet(path, net)
    return net


def find_file(prefix):
    for path in os.listdir():
        if path.startswith(prefix):
            return path
    return None


def load_views(data_desc_file) -> view.Trans:
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        data_desc = json.loads(file.read())
        samples = data_desc['samples'] if 'samples' in data_desc else [-1]
        view_centers = torch.tensor(
            data_desc['view_centers'], device=device.GetDevice()).view(samples + [3])
        view_rots = torch.tensor(
            data_desc['view_rots'], device=device.GetDevice()).view(samples + [3, 3])
        return view.Trans(view_centers, view_rots)


def read_ref_images(idx):
    patt = 'ref/view_%04d.png'
    if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:
        return util.ReadImageTensor([patt % i for i in idx])
    else:
        return util.ReadImageTensor(patt % idx)


def adjust_cam(cam, vr_cam, gaze_center):
    fovea_offset = (
        (gaze_center[0]) / vr_cam.f[0].item() * cam.f[0].item(),
        (gaze_center[1]) / vr_cam.f[1].item() * cam.f[1].item()
    )
    cam.c[0] = cam.res[1] / 2 - fovea_offset[0]
    cam.c[1] = cam.res[0] / 2 - fovea_offset[1]


Set CUDA:2 as current device.


In [2]:
os.chdir(sys.path[0] + '/../data/__0_user_study/us_gas_all_in_one')
#os.chdir(sys.path[0] + '/../data/__0_user_study/us_mc_all_in_one')
#os.chdir(sys.path[0] + '/../data/bedroom_all_in_one')
print('Change working directory to ', os.getcwd())
torch.autograd.set_grad_enabled(False)

fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph'))

# Load Dataset
views = load_views('views.json')
#ref_dataset = SphericalViewSynDataset('ref.json', load_images=False, calculate_rays=False)
print('Dataset loaded.')

print('views:', views.size())
#print('ref views:', ref_dataset.samples)

fov_list = [20, 45, 110]
res_list = [(128, 128), (256, 256), (256, 230)]  # (192,256)]
res_full = (1600, 1440)

gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net, device.GetDevice())

Change working directory to  /home/dengnc/deep_view_syn/data/__0_user_study/us_gas_all_in_one
==== Config fovea ====
Net type:  nmsl
Encode dim:  10
Optimizer decay:  0
Normalize:  False
Direction as input:  False
Full-connected network parameters: {'nf': 128, 'n_layers': 4, 'skips': []}
Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 32, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}
Load net from fovea@nmsl-rgb_e10_fc128x4_d1-50_s32.pth ...
==== Config periph ====
Net type:  nnmsl
Encode dim:  10
Optimizer decay:  0
Normalize:  False
Direction as input:  False
Full-connected network parameters: {'nf': 64, 'n_layers': 4, 'skips': []}
Sample parameters {'spherical': True, 'depth_range': (1.0, 50.0), 'n_samples': 16, 'perturb_sample': False, 'lindisp': True, 'inverse_r': True}
Load net from periph@nnmsl-rgb_e10_fc64x4_d1-50_s16.pth ...
Dataset loaded.
views: [5, 5, 5, 5, 5]


In [None]:
# ==gas==
set_id = 0
left_center = (-137, 64)
right_center = (-142, 64)
set_id = 1
left_center = (133, -44)
right_center = (130, -44)
set_id = 2
left_center = (-20, -5)
right_center = (-25, -5)
# ==mc==
#set_id = 3
#left_center = (-107, 80)
#right_center = (-112, 80)
#set_id = 4
#left_center = (-17, -90)
#right_center = (-22, -90)
#set_id = 5
#left_center = (95, 30)
#right_center = (91, 30)

view_coord = [0, 0, 0, 0, 0]
for i, val in enumerate(views.size()):
    view_coord[i] += val // 2
print('view_coord:', view_coord)
test_view = views.get(*view_coord)

cams = [
    view.CameraParam({
        "fov": fov_list[i],
        "cx": 0.5,
        "cy": 0.5,
        "normalized": True
    }, res_list[i]).to(device.GetDevice())
    for i in range(len(fov_list))
]
fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]
#guide_cam = ref_dataset.cam_params
vr_cam = view.CameraParam({
    'fov': fov_list[-1],
    'cx': 0.5,
    'cy': 0.5,
    'normalized': True
}, res_full)
foveation = Foveation(fov_list, res_full, device=device.GetDevice())


def plot_figures(left_images, right_images, left_center, right_center):
    # Plot Fovea raw
    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    util.PlotImageTensor(left_images['fovea_raw'])
    plt.subplot(122)
    util.PlotImageTensor(right_images['fovea_raw'])

    # Plot Fovea
    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    util.PlotImageTensor(left_images['fovea'])
    plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],
             [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],
             color=[0, 1, 0])
    plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],
             [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],
             color=[0, 1, 0])
    plt.subplot(122)
    util.PlotImageTensor(right_images['fovea'])
    plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],
             [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],
             color=[0, 1, 0])
    plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],
             [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],
             color=[0, 1, 0])

    #plt.subplot(1, 4, 2)
    # util.PlotImageTensor(fovea_refined)

    # Plot Mid
    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    util.PlotImageTensor(left_images['mid'])
    plt.subplot(122)
    util.PlotImageTensor(right_images['mid'])

    # Plot Periph
    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    util.PlotImageTensor(left_images['periph'])
    plt.subplot(122)
    util.PlotImageTensor(right_images['periph'])

    # Plot Blended
    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    util.PlotImageTensor(left_images['blended'])
    plt.plot([(res_full[1] - 1) / 2 + left_center[0] - 5, (res_full[1] - 1) / 2 + left_center[0] + 5],
             [(res_full[0] - 1) / 2 + left_center[1],
              (res_full[0] - 1) / 2 + left_center[1]],
             color=[0, 1, 0])
    plt.plot([(res_full[1] - 1) / 2 + left_center[0], (res_full[1] - 1) / 2 + left_center[0]],
             [(res_full[0] - 1) / 2 + left_center[1] - 5,
              (res_full[0] - 1) / 2 + left_center[1] + 5],
             color=[0, 1, 0])
    plt.subplot(122)
    util.PlotImageTensor(right_images['blended'])
    plt.plot([(res_full[1] - 1) / 2 + right_center[0] - 5, (res_full[1] - 1) / 2 + right_center[0] + 5],
             [(res_full[0] - 1) / 2 + right_center[1],
              (res_full[0] - 1) / 2 + right_center[1]],
             color=[0, 1, 0])
    plt.plot([(res_full[1] - 1) / 2 + right_center[0], (res_full[1] - 1) / 2 + right_center[0]],
             [(res_full[0] - 1) / 2 + right_center[1] - 5,
              (res_full[0] - 1) / 2 + right_center[1] + 5],
             color=[0, 1, 0])


left_images = gen(
    left_center,
    view.Trans(
        test_view.trans_point(
            torch.tensor([-0.03, 0, 0], device=device.GetDevice())
        ),
        test_view.r
    ),
    ret_raw=True,
    mono_trans=test_view,
    shift=0)
right_images = gen(
    right_center,
    view.Trans(
        test_view.trans_point(
            torch.tensor([0.03, 0, 0], device=device.GetDevice())
        ),
        test_view.r
    ),
    ret_raw=True,
    mono_trans=test_view,
    shift=0)

plot_figures(left_images, right_images, left_center, right_center)

util.CreateDirIfNeed('output/mono_test')
for key in left_images:
    util.WriteImageTensor(
        left_images[key], 'output/mono_test/set%d_%s_l.png' % (set_id, key))
for key in right_images:
    util.WriteImageTensor(
        right_images[key], 'output/mono_test/set%d_%s_r.png' % (set_id, key))
