In [None]:
%matplotlib inline
import sys
import os
import torch
import matplotlib.pyplot as plt

rootdir = os.path.abspath(sys.path[0] + '/../../')
sys.path.append(rootdir)

torch.autograd.set_grad_enabled(False)

from model import Model
from data import Dataset
from utils import netio, img, device
from utils.view import *
from utils.types import *
from components.render import render


model: Model = None
dataset: Dataset = None


def load_model(path: PathLike):
 ckpt_path = netio.find_checkpoint(Path(path))
 ckpt = torch.load(ckpt_path)
 model = Model.create(ckpt["args"]["model"], ckpt["args"]["model_args"])
 model.load_state_dict(ckpt["states"]["model"])
 model.to(device.default()).eval()
 return model


def load_dataset(path: PathLike):
 return Dataset(path, color_mode=model.color, coord_sys=model.args.coord,
 device=device.default())


def plot_images(images, rows, cols):
 plt.figure(figsize=(20, int(20 / cols * rows)))
 for r in range(rows):
 for c in range(cols):
 plt.subplot(rows, cols, r * cols + c + 1)
 img.plot(images[r * cols + c])


def save_images(images, scene, i):
 outputdir = f'{rootdir}/data/__demo/layers/'
 os.makedirs(outputdir, exist_ok=True)
 for layer in range(len(images)):
 img.save(images[layer], f'{outputdir}{scene}_{i:04d}({layer}).png')

scene = "gas"
model_path = f"{rootdir}/data/__thesis/{scene}/_nets/train/snerf_fast"
dataset_path = f"{rootdir}/data/__thesis/{scene}/test.json"


model = load_model(model_path)
dataset = load_dataset(dataset_path)


i = 6
cam = dataset.cam
view = Trans(dataset.centers[i], dataset.rots[i])
output = render(model, dataset.cam, view, "colors", "weights")
output_colors = output.colors * output.weights

samples_per_layer = 4#model.core.samples_per_field
n_samples = model.args.n_samples
output_layers = [
 output_colors[..., offset:offset+samples_per_layer, :].sum(-2)
 for offset in range(0, n_samples, samples_per_layer)
]
 
plot_images(output_layers, 8, 2)
#save_images(output_layers, scene, i)