In [None]:
import sys
import os
import torch
import matplotlib.pyplot as plt

rootdir = os.path.abspath(sys.path[0] + '/../')
sys.path.append(rootdir)
torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False)

from data.spherical_view_syn import *
from configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio
from utils import misc
from utils import img
from utils import device
from utils import view
from components.gen_final import GenFinal


def load_net(path):
    config = SphericalViewSynConfig()
    config.from_id(path[:-4])
    config.sa['perturb_sample'] = False
    #config.print()
    net = config.create_net().to(device.default())
    netio.load(path, net)
    return net


def find_file(prefix):
    for path in os.listdir():
        if path.startswith(prefix):
            return path
    return None


def load_views(data_desc_file) -> tuple[view.Trans, torch.Tensor]:
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        lines = file.readlines()
        n = len(lines) // 7
        gazes = torch.empty(n * 2, 3)
        views = torch.empty(n * 2, 4, 4)
        view_idx = 0
        for i in range(0, len(lines), 7):
            gazes[view_idx * 2] = torch.tensor([
                float(str) for str in lines[i + 1].split(',')
            ])
            gazes[view_idx * 2 + 1] = torch.tensor([
                float(str) for str in lines[i + 2].split(',')
            ])
            views[view_idx * 2] = torch.tensor([
                float(str) for str in lines[i + 3].split(',')
            ]).view(4, 4)
            views[view_idx * 2 + 1] = torch.tensor([
                float(str) for str in lines[i + 4].split(',')
            ]).view(4, 4)
            view_idx += 1
        gazes = gazes.to(device.default())
        views = views.to(device.default())
    return view.Trans(views[:, :3, 3], views[:, :3, :3]), gazes

fov_list = [20, 45, 110]
res_list = [(128, 128), (256, 256), (256, 230)]  # (192,256)]
res_full = (1600, 1440)


In [None]:
os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_gas_all_in_one'))
#os.chdir(os.path.join(rootdir, 'data/__0_user_study/us_mc_all_in_one'))
#os.chdir(os.path.join(rootdir, 'data/bedroom_all_in_one'))
#os.chdir(os.path.join(rootdir, 'data/lobby_all_in_one'))
#os.chdir(os.path.join(rootdir, 'data/gallery_all_in_one'))
print('Change working directory to ', os.getcwd())

fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph'))

# Load Dataset
views, gazes = load_views('hmd.csv')
print('Dataset loaded.')
print('views:', views.size())

gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,
               device=device.default())
gaze_centers = gen.full_cam.proj(gazes, center_as_origin=True)


In [None]:
for view_idx in range(gaze_centers.size(0) / 2):
    left_center = (gaze_centers[view_idx * 2][0].item(),
                   gaze_centers[view_idx * 2][1].item())
    right_center = (gaze_centers[view_idx * 2 + 1][0].item(),
                    gaze_centers[view_idx * 2 + 1][1].item())
    left_view = views.get(view_idx * 2)
    right_view = views.get(view_idx * 2 + 1)
    mono_trans = view.Trans((left_view.t + right_view.t) / 2, left_view.r)
    left_images = gen.gen(left_center, left_view, mono_trans=mono_trans)
    right_images = gen.gen(right_center, right_view, mono_trans=mono_trans)
    
    os.makedirs('output/video_frames/hmd2', exist_ok=True)
    img.save(torch.cat([left_images['blended'], right_images['blended']], -1),
                          'output/video_frames/hmd2/view%04d.png' % view_idx)
    print('Frame %d saved' % view_idx)
