In [1]:
import sys
import os
import torch
import matplotlib.pyplot as plt

rootdir = os.path.abspath(sys.path[0] + '/../')
sys.path.append(rootdir)
torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from ..data.spherical_view_syn import *
from ..configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio
from utils import img
from utils import device
from utils import view
from components.gen_final import GenFinal
from utils.perf import Perf


def load_net(path):
    config = SphericalViewSynConfig()
    config.from_id(path[:-4])
    config.sa['perturb_sample'] = False
    config.print()
    net = config.create_net().to(device.default())
    netio.load(path, net)
    return net


def find_file(prefix):
    for path in os.listdir():
        if path.startswith(prefix):
            return path
    return None


def load_views(data_desc_file) -> view.Trans:
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        data_desc = json.loads(file.read())
        view_centers = torch.tensor(
            data_desc['view_centers'], device=device.default()).view(-1, 3)
        view_rots = torch.tensor(
            data_desc['view_rots'], device=device.default()).view(-1, 3, 3)
        return view.Trans(view_centers, view_rots)


def plot_figures(images, center):
    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    img.plot(images['fovea_raw'])
    plt.subplot(122)
    img.plot(images['fovea'])

    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    img.plot(images['mid_raw'])
    plt.subplot(122)
    img.plot(images['mid'])

    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    img.plot(images['periph_raw'])
    plt.subplot(122)
    img.plot(images['periph'])

    # Plot Blended
    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    img.plot(images['blended_raw'])
    plt.subplot(122)
    img.plot(images['blended'])
    plt.plot([(res_full[1] - 1) / 2 + center[0] - 5, (res_full[1] - 1) / 2 + center[0] + 5],
                [(res_full[0] - 1) / 2 + center[1],
                (res_full[0] - 1) / 2 + center[1]],
                color=[0, 1, 0])
    plt.plot([(res_full[1] - 1) / 2 + center[0], (res_full[1] - 1) / 2 + center[0]],
                [(res_full[0] - 1) / 2 + center[1] - 5,
                (res_full[0] - 1) / 2 + center[1] + 5],
                color=[0, 1, 0])

Set CUDA:2 as current device.


In [None]:
os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_gas_all_in_one'))
#os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_mc_all_in_one'))
#os.chdir(os.path.join(rootdir, 'data/__0_user_study/lobby_all_in_one'))
print('Change working directory to ', os.getcwd())
torch.autograd.set_grad_enabled(False)

fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph'))

# Load Dataset
views = load_views('nerf_views.json')
print('Dataset loaded.')

print('views:', views.size())
#print('ref views:', ref_dataset.samples)

fov_list = [20, 45, 110]
res_list = [(128, 128), (256, 256), (256, 230)]  # (192,256)]
res_full = (1600, 1440)
gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,
               device=device.default())


In [None]:
test_view = view.Trans(
    torch.tensor([[0.0, 0.0, 0.0]], device=device.default()),
    torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device.default())
)
perf = Perf(True, True)
rays_o, rays_d = gen.layer_cams[0].get_global_rays(test_view, True)
perf.checkpoint("GetRays")
rays_o = rays_o.view(-1, 3)
rays_d = rays_d.view(-1, 3)
coords, pts, depths = fovea_net.sampler(rays_o, rays_d)
perf.checkpoint("Sample")
encoded = fovea_net.input_encoder(coords)
perf.checkpoint("Encode")
print("Rays:", rays_d)
print("Spherical coords:", coords)
print("Depths:", depths)
print("Encoded:", encoded)
#plot_figures(images, center)

#os.makedirs('output/teasers', exist_ok=True)
#for key in images:
#    img.save(
#        images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))
