In [None]:
import sys
import os
import torch
import matplotlib.pyplot as plt

rootdir = os.path.abspath(sys.path[0] + '/../')
sys.path.append(rootdir)
torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from data.spherical_view_syn import *
from configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio
from utils import misc
from utils import img
from utils import device
from utils import view
from components.gen_final import GenFinal


def load_net(path):
 config = SphericalViewSynConfig()
 config.from_id(path[:-4])
 config.sa['perturb_sample'] = False
 config.print()
 net = config.create_net().to(device.default())
 netio.load(path, net)
 return net


def find_file(prefix):
 for path in os.listdir():
 if path.startswith(prefix):
 return path
 return None


def load_views(data_desc_file) -> view.Trans:
 with open(data_desc_file, 'r', encoding='utf-8') as file:
 data_desc = json.loads(file.read())
 samples = data_desc['samples'] if 'samples' in data_desc else [-1]
 view_centers = torch.tensor(
 data_desc['view_centers'], device=device.default()).view(samples + [3])
 view_rots = torch.tensor(
 data_desc['view_rots'], device=device.default()).view(samples + [3, 3])
 return view.Trans(view_centers, view_rots)


def plot_cross(center, res):
 plt.plot(
 [
 (res[1] - 1) / 2 + center[0] - 5,
 (res[1] - 1) / 2 + center[0] + 5
 ],
 [
 (res[0] - 1) / 2 + center[1],
 (res[0] - 1) / 2 + center[1]
 ],
 color=[0, 1, 0])
 plt.plot(
 [
 (res[1] - 1) / 2 + center[0],
 (res[1] - 1) / 2 + center[0]
 ],
 [
 (res[0] - 1) / 2 + center[1] - 5,
 (res[0] - 1) / 2 + center[1] + 5
 ],
 color=[0, 1, 0])


def plot_figures(left_images, right_images, left_center, right_center):
 # Plot Fovea raw
 plt.figure(figsize=(8, 4))
 plt.subplot(121)
 img.plot(left_images['fovea_raw'])
 plt.subplot(122)
 img.plot(right_images['fovea_raw'])

 # Plot Fovea
 plt.figure(figsize=(8, 4))
 plt.subplot(121)
 img.plot(left_images['fovea'])
 fovea_res = left_images['fovea'].size()[-2:]
 plot_cross((0, 0), fovea_res)
 plt.subplot(122)
 img.plot(right_images['fovea'])
 plot_cross((0, 0), fovea_res)

 #plt.subplot(1, 4, 2)
 # img.plot(fovea_refined)

 # Plot Mid
 plt.figure(figsize=(8, 4))
 plt.subplot(121)
 img.plot(left_images['mid'])
 plt.subplot(122)
 img.plot(right_images['mid'])

 # Plot Periph
 plt.figure(figsize=(8, 4))
 plt.subplot(121)
 img.plot(left_images['periph'])
 plt.subplot(122)
 img.plot(right_images['periph'])

 # Plot Blended
 plt.figure(figsize=(12, 6))
 plt.subplot(121)
 img.plot(left_images['blended'])
 full_res = left_images['blended'].size()[-2:]
 plot_cross(left_center, full_res)
 plt.subplot(122)
 img.plot(right_images['blended'])
 plot_cross(right_center, full_res)


In [None]:
#os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_gas_all_in_one'))
os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_mc_all_in_one'))
#os.chdir(os.path.join(rootdir, 'data/__0_user_study/bedroom_all_in_one'))
print('Change working directory to ', os.getcwd())
torch.autograd.set_grad_enabled(False)

fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph'))

# Load Dataset
views = load_views('views.json')
#ref_dataset = SphericalViewSynDataset('ref.json', load_images=False, calculate_rays=False)
print('Dataset loaded.')

print('views:', views.size())
#print('ref views:', ref_dataset.samples)

fov_list = [20, 45, 110]
res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)]
res_full = (1600, 1440)

gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net, device.default())

In [None]:
centers = [
 # ==gas==
 [(-137, 64), (-142, 64)],
 [(133, -44), (130, -44)],
 [(-20, -5), (-25, -5)],
 # ==mc==
 [(-107, 80), (-112, 80)],
 [(-17, -90), (-22, -90)],
 [(95, 30), (91, 30)]
]
set_id = 5

view_coord = [0, 0, 0, 0, 0]
for i, val in enumerate(views.size()):
 view_coord[i] += val // 2
print('view_coord:', view_coord)
test_view = views.get(*view_coord)

left_images = gen(centers[set_id][0], view.Trans(
 test_view.trans_point(
 torch.tensor([-0.03, 0, 0], device=device.default())
 ), test_view.r), mono_trans=test_view, ret_raw=True)
right_images = gen(centers[set_id][1], view.Trans(
 test_view.trans_point(
 torch.tensor([0.03, 0, 0], device=device.default())
 ), test_view.r), mono_trans=test_view, ret_raw=True)

#plot_figures(left_images, right_images, centers[set_id][0], centers[set_id][1])

misc.create_dir('output')
for key in left_images:
 img.save(
 left_images[key], 'output/set%d_%s_l.png' % (set_id, key))
for key in right_images:
 img.save(
 right_images[key], 'output/set%d_%s_r.png' % (set_id, key))
