{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.transforms.functional as trans_f\n",
    "\n",
    "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
    "sys.path.append(rootdir)\n",
    "torch.cuda.set_device(2)\n",
    "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
    "\n",
    "from data.spherical_view_syn import *\n",
    "from configs.spherical_view_syn import SphericalViewSynConfig\n",
    "from utils import netio\n",
    "from utils import misc\n",
    "from utils import img\n",
    "from utils import device\n",
    "from utils import view\n",
    "from components.foveation import Foveation\n",
    "from components.gen_final import GenFinal\n",
    "\n",
    "\n",
    "def load_net(path):\n",
    "    config = SphericalViewSynConfig()\n",
    "    config.from_id(path[:-4])\n",
    "    config.sa['perturb_sample'] = False\n",
    "    config.print()\n",
    "    net = config.create_net().to(device.default())\n",
    "    netio.load(path, net)\n",
    "    return net\n",
    "\n",
    "\n",
    "def find_file(prefix):\n",
    "    for path in os.listdir():\n",
    "        if path.startswith(prefix):\n",
    "            return path\n",
    "    return None\n",
    "\n",
    "\n",
    "def load_views(data_desc_file) -> view.Trans:\n",
    "    with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
    "        data_desc = json.loads(file.read())\n",
    "        samples = data_desc['samples'] if 'samples' in data_desc else [-1]\n",
    "        view_centers = torch.tensor(\n",
    "            data_desc['view_centers'], device=device.default()).view(samples + [3])\n",
    "        view_rots = torch.tensor(\n",
    "            data_desc['view_rots'], device=device.default()).view(samples + [3, 3])\n",
    "        return view.Trans(view_centers, view_rots)\n",
    "\n",
    "\n",
    "def read_ref_images(idx):\n",
    "    patt = 'ref/view_%04d.png'\n",
    "    if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:\n",
    "        return img.load([patt % i for i in idx])\n",
    "    else:\n",
    "        return img.load(patt % idx)\n",
    "\n",
    "\n",
    "def adjust_cam(cam, vr_cam, gaze_center):\n",
    "    fovea_offset = (\n",
    "        (gaze_center[0]) / vr_cam.f[0].item() * cam.f[0].item(),\n",
    "        (gaze_center[1]) / vr_cam.f[1].item() * cam.f[1].item()\n",
    "    )\n",
    "    cam.c[0] = cam.res[1] / 2 - fovea_offset[0]\n",
    "    cam.c[1] = cam.res[0] / 2 - fovea_offset[1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir(os.path.join('data/__0_user_study/us_gas_all_in_one'))\n",
    "#os.chdir(os.path.join('data/__0_user_study/us_mc_all_in_one'))\n",
    "#os.chdir(os.path.join('data/bedroom_all_in_one'))\n",
    "print('Change working directory to ', os.getcwd())\n",
    "torch.autograd.set_grad_enabled(False)\n",
    "\n",
    "fovea_net = load_net(find_file('fovea'))\n",
    "periph_net = load_net(find_file('periph'))\n",
    "\n",
    "# Load Dataset\n",
    "views = load_views('views.json')\n",
    "#ref_dataset = SphericalViewSynDataset('ref.json', load_images=False, calculate_rays=False)\n",
    "print('Dataset loaded.')\n",
    "\n",
    "print('views:', views.size())\n",
    "#print('ref views:', ref_dataset.samples)\n",
    "\n",
    "fov_list = [20, 45, 110]\n",
    "res_list = [(128, 128), (256, 256), (256, 230)]  # (192,256)]\n",
    "res_full = (1600, 1440)\n",
    "\n",
    "gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net, device.default())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==gas==\n",
    "set_id = 0\n",
    "left_center = (-137, 64)\n",
    "right_center = (-142, 64)\n",
    "set_id = 1\n",
    "left_center = (133, -44)\n",
    "right_center = (130, -44)\n",
    "set_id = 2\n",
    "left_center = (-20, -5)\n",
    "right_center = (-25, -5)\n",
    "# ==mc==\n",
    "#set_id = 3\n",
    "#left_center = (-107, 80)\n",
    "#right_center = (-112, 80)\n",
    "#set_id = 4\n",
    "#left_center = (-17, -90)\n",
    "#right_center = (-22, -90)\n",
    "#set_id = 5\n",
    "#left_center = (95, 30)\n",
    "#right_center = (91, 30)\n",
    "\n",
    "view_coord = [0, 0, 0, 0, 0]\n",
    "for i, val in enumerate(views.size()):\n",
    "    view_coord[i] += val // 2\n",
    "print('view_coord:', view_coord)\n",
    "test_view = views.get(*view_coord)\n",
    "\n",
    "cams = [\n",
    "    view.Camera({\n",
    "        \"fov\": fov_list[i],\n",
    "        \"cx\": 0.5,\n",
    "        \"cy\": 0.5,\n",
    "        \"normalized\": True\n",
    "    }, res_list[i]).to(device.default())\n",
    "    for i in range(len(fov_list))\n",
    "]\n",
    "fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]\n",
    "#guide_cam = ref_dataset.cam_params\n",
    "vr_cam = view.Camera({\n",
    "    'fov': fov_list[-1],\n",
    "    'cx': 0.5,\n",
    "    'cy': 0.5,\n",
    "    'normalized': True\n",
    "}, res_full)\n",
    "foveation = Foveation(fov_list, res_full, device=device.default())\n",
    "\n",
    "\n",
    "def plot_figures(left_images, right_images, left_center, right_center):\n",
    "    # Plot Fovea raw\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    plt.subplot(121)\n",
    "    img.plot(left_images['fovea_raw'])\n",
    "    plt.subplot(122)\n",
    "    img.plot(right_images['fovea_raw'])\n",
    "\n",
    "    # Plot Fovea\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    plt.subplot(121)\n",
    "    img.plot(left_images['fovea'])\n",
    "    plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],\n",
    "             [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],\n",
    "             color=[0, 1, 0])\n",
    "    plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],\n",
    "             [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],\n",
    "             color=[0, 1, 0])\n",
    "    plt.subplot(122)\n",
    "    img.plot(right_images['fovea'])\n",
    "    plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],\n",
    "             [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],\n",
    "             color=[0, 1, 0])\n",
    "    plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],\n",
    "             [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],\n",
    "             color=[0, 1, 0])\n",
    "\n",
    "    #plt.subplot(1, 4, 2)\n",
    "    # img.plot(fovea_refined)\n",
    "\n",
    "    # Plot Mid\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    plt.subplot(121)\n",
    "    img.plot(left_images['mid'])\n",
    "    plt.subplot(122)\n",
    "    img.plot(right_images['mid'])\n",
    "\n",
    "    # Plot Periph\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    plt.subplot(121)\n",
    "    img.plot(left_images['periph'])\n",
    "    plt.subplot(122)\n",
    "    img.plot(right_images['periph'])\n",
    "\n",
    "    # Plot Blended\n",
    "    plt.figure(figsize=(12, 6))\n",
    "    plt.subplot(121)\n",
    "    img.plot(left_images['blended'])\n",
    "    plt.plot([(res_full[1] - 1) / 2 + left_center[0] - 5, (res_full[1] - 1) / 2 + left_center[0] + 5],\n",
    "             [(res_full[0] - 1) / 2 + left_center[1],\n",
    "              (res_full[0] - 1) / 2 + left_center[1]],\n",
    "             color=[0, 1, 0])\n",
    "    plt.plot([(res_full[1] - 1) / 2 + left_center[0], (res_full[1] - 1) / 2 + left_center[0]],\n",
    "             [(res_full[0] - 1) / 2 + left_center[1] - 5,\n",
    "              (res_full[0] - 1) / 2 + left_center[1] + 5],\n",
    "             color=[0, 1, 0])\n",
    "    plt.subplot(122)\n",
    "    img.plot(right_images['blended'])\n",
    "    plt.plot([(res_full[1] - 1) / 2 + right_center[0] - 5, (res_full[1] - 1) / 2 + right_center[0] + 5],\n",
    "             [(res_full[0] - 1) / 2 + right_center[1],\n",
    "              (res_full[0] - 1) / 2 + right_center[1]],\n",
    "             color=[0, 1, 0])\n",
    "    plt.plot([(res_full[1] - 1) / 2 + right_center[0], (res_full[1] - 1) / 2 + right_center[0]],\n",
    "             [(res_full[0] - 1) / 2 + right_center[1] - 5,\n",
    "              (res_full[0] - 1) / 2 + right_center[1] + 5],\n",
    "             color=[0, 1, 0])\n",
    "\n",
    "\n",
    "left_images = gen(\n",
    "    left_center,\n",
    "    view.Trans(\n",
    "        test_view.trans_point(\n",
    "            torch.tensor([-0.03, 0, 0], device=device.default())\n",
    "        ),\n",
    "        test_view.r\n",
    "    ),\n",
    "    ret_raw=True,\n",
    "    mono_trans=test_view,\n",
    "    shift=0)\n",
    "right_images = gen(\n",
    "    right_center,\n",
    "    view.Trans(\n",
    "        test_view.trans_point(\n",
    "            torch.tensor([0.03, 0, 0], device=device.default())\n",
    "        ),\n",
    "        test_view.r\n",
    "    ),\n",
    "    ret_raw=True,\n",
    "    mono_trans=test_view,\n",
    "    shift=0)\n",
    "\n",
    "plot_figures(left_images, right_images, left_center, right_center)\n",
    "\n",
    "os.makedirs('output/mono_test', exist_ok=True)\n",
    "for key in left_images:\n",
    "    img.save(\n",
    "        left_images[key], 'output/mono_test/set%d_%s_l.png' % (set_id, key))\n",
    "for key in right_images:\n",
    "    img.save(\n",
    "        right_images[key], 'output/mono_test/set%d_%s_r.png' % (set_id, key))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
   "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}