{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "import os\n",
    "\n",
    "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
    "sys.path.append(rootdir)\n",
    "\n",
    "from utils.voxels import *\n",
    "\n",
    "bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
    "voxel_size = (bbox[1] - bbox[0]) / steps\n",
    "voxels = init_voxels(bbox, steps)\n",
    "corners, corner_indices = get_corners(voxels, bbox, steps)\n",
    "voxel_indices_in_grid = torch.arange(-1, voxels.shape[0])\n",
    "emb = torch.nn.Embedding(corners.shape[0], 3, _weight=corners)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([11, 3]) tensor([ 0, -1, -1,  1, -1, -1,  2,  3,  4, -1,  5,  6, -1,  7,  8, -1,  9, 10])\n"
     ]
    }
   ],
   "source": [
    "keeps = torch.tensor([True]*18)\n",
    "keeps[torch.tensor([1,2,4,5,9,12,15])] = False\n",
    "voxels = voxels[keeps]\n",
    "corner_indices = corner_indices[keeps]\n",
    "grid_indices = to_grid_indices(voxels, bbox, steps)\n",
    "voxel_indices_in_grid = grid_indices.new_full([steps.prod().item() + 1], -1)\n",
    "voxel_indices_in_grid[grid_indices + 1] = torch.arange(voxels.shape[0])\n",
    "print(voxels.shape, voxel_indices_in_grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([88, 3]) torch.Size([185, 3]) torch.Size([88, 8])\n"
     ]
    }
   ],
   "source": [
    "new_voxels = split_voxels(voxels, (bbox[1] - bbox[0]) / steps, 2, align_border=False).reshape(-1, 3)\n",
    "new_corners, new_corner_indices = get_corners(new_voxels, bbox, steps * 2)\n",
    "print(new_voxels.shape, new_corners.shape, new_corner_indices.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0,  0, -1,  0,  0, -1,  1,  1, -1,  1,  1, -1,  2,  2,  3,  3,  4,  4,\n",
      "         4,  2,  2,  3,  3,  4,  4,  4,  2,  2,  3,  3,  4,  4,  4,  0,  0, -1,\n",
      "         0,  0, -1,  1,  1, -1,  1,  1, -1,  2,  2,  3,  3,  4,  4,  4,  2,  2,\n",
      "         3,  3,  4,  4,  4,  2,  2,  3,  3,  4,  4,  4, -1, -1,  5,  5,  6,  6,\n",
      "         6, -1, -1,  5,  5,  6,  6,  6, -1, -1,  7,  7,  8,  8,  8, -1, -1,  7,\n",
      "         7,  8,  8,  8, -1, -1,  9,  9, 10, 10, 10, -1, -1,  9,  9, 10, 10, 10,\n",
      "        -1, -1,  9,  9, 10, 10, 10,  5,  5,  6,  6,  6,  5,  5,  6,  6,  6,  7,\n",
      "         7,  8,  8,  8,  7,  7,  8,  8,  8,  9,  9, 10, 10, 10,  9,  9, 10, 10,\n",
      "        10,  9,  9, 10, 10, 10,  5,  5,  6,  6,  6,  5,  5,  6,  6,  6,  7,  7,\n",
      "         8,  8,  8,  7,  7,  8,  8,  8,  9,  9, 10, 10, 10,  9,  9, 10, 10, 10,\n",
      "         9,  9, 10, 10, 10])\n",
      "tensor(0)\n"
     ]
    }
   ],
   "source": [
    "voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps).min(steps - 1), steps) + 1]\n",
    "print(voxel_indices_of_new_corner)\n",
    "p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n",
    "print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "08b118544df3cb8970a671e5837a88fd458f4d4c799ef1fb2709465a22a45b92"
  },
  "kernelspec": {
   "display_name": "Python 3.9.5 64-bit ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}