{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import sys\n", "import os\n", "import torch\n", "import matplotlib.pyplot as plt\n", "\n", "rootdir = os.path.abspath(sys.path[0] + '/../../')\n", "sys.path.append(rootdir)\n", "\n", "torch.autograd.set_grad_enabled(False)\n", "\n", "from model import Model\n", "from data import Dataset\n", "from utils import netio, img, device\n", "from utils.view import *\n", "from utils.types import *\n", "from components.render import render\n", "\n", "\n", "model: Model = None\n", "dataset: Dataset = None\n", "\n", "\n", "def load_model(path: PathLike):\n", " ckpt_path = netio.find_checkpoint(Path(path))\n", " ckpt = torch.load(ckpt_path)\n", " model = Model.create(ckpt[\"args\"][\"model\"], ckpt[\"args\"][\"model_args\"])\n", " model.load_state_dict(ckpt[\"states\"][\"model\"])\n", " model.to(device.default()).eval()\n", " return model\n", "\n", "\n", "def load_dataset(path: PathLike):\n", " return Dataset(path, color_mode=model.color, coord_sys=model.args.coord,\n", " device=device.default())\n", "\n", "\n", "def plot_images(images, rows, cols):\n", " plt.figure(figsize=(20, int(20 / cols * rows)))\n", " for r in range(rows):\n", " for c in range(cols):\n", " plt.subplot(rows, cols, r * cols + c + 1)\n", " img.plot(images[r * cols + c])\n", "\n", "\n", "def save_images(images, scene, i):\n", " outputdir = f'{rootdir}/data/__demo/layers/'\n", " os.makedirs(outputdir, exist_ok=True)\n", " for layer in range(len(images)):\n", " img.save(images[layer], f'{outputdir}{scene}_{i:04d}({layer}).png')\n", "\n", "scene = \"gas\"\n", "model_path = f\"{rootdir}/data/__thesis/{scene}/_nets/train/snerf_fast\"\n", "dataset_path = f\"{rootdir}/data/__thesis/{scene}/test.json\"\n", "\n", "\n", "model = load_model(model_path)\n", "dataset = load_dataset(dataset_path)\n", "\n", "\n", "i = 6\n", "cam = dataset.cam\n", "view = Trans(dataset.centers[i], dataset.rots[i])\n", "output = render(model, dataset.cam, view, \"colors\", \"weights\")\n", "output_colors = output.colors * output.weights\n", "\n", "samples_per_layer = 4#model.core.samples_per_field\n", "n_samples = model.args.n_samples\n", "output_layers = [\n", " output_colors[..., offset:offset+samples_per_layer, :].sum(-2)\n", " for offset in range(0, n_samples, samples_per_layer)\n", "]\n", " \n", "plot_images(output_layers, 8, 2)\n", "#save_images(output_layers, scene, i)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.0 ('dvs')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "metadata": { "interpreter": { "hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" } }, "vscode": { "interpreter": { "hash": "4469b029896260c1221afa6e0e6159922aafd2738570e75b7bc15e28db242604" } } }, "nbformat": 4, "nbformat_minor": 4 }