In [1]:
import sys
sys.path.append('/e/dengnc')

from typing import List
import torch
from torch import nn
import matplotlib.pyplot as plt
from deeplightfield.data.lf_syn import LightFieldSynDataset
from deeplightfield.my import util
from deeplightfield.trans_unet import LatentSpaceTransformer

device = torch.device("cuda:2")


# Test data loader

In [None]:
DATA_DIR = '../data/lf_syn_2020.12.23'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'

train_dataset = LightFieldSynDataset(TRAIN_DATA_DESC_FILE)
train_data_loader = torch.utils.data.DataLoader(
 dataset=train_dataset,
 batch_size=3,
 num_workers=8,
 pin_memory=True,
 shuffle=True,
 drop_last=False)
print(len(train_data_loader))

print(train_dataset.cam_params)
print(train_dataset.sparse_view_positions)
print(train_dataset.diopter_of_layers)
plt.figure()
util.PlotImageTensor(train_dataset.sparse_view_images[0])
plt.figure()
util.PlotImageTensor(train_dataset.sparse_view_depths[0] / 255 * 10)

# Test disparity wrapper

In [3]:



transformer = LatentSpaceTransformer(train_dataset.sparse_view_images.size()[2],
 train_dataset.cam_params,
 train_dataset.diopter_of_layers,
 train_dataset.sparse_view_positions)
novel_views = torch.stack([
 train_dataset.view_positions[13],
 train_dataset.view_positions[30],
 train_dataset.view_positions[57],
], dim=0)
trans_images = transformer(train_dataset.sparse_view_images.to(device),
 train_dataset.sparse_view_depths.to(device),
 novel_views)


In [None]:

mask = (torch.sum(trans_images[0], 1) > 1e-5).to(dtype=torch.float)
blended = torch.sum(trans_images[0], 0)
weight = torch.sum(mask, 0)
blended = blended / weight.unsqueeze(0)

plt.figure(figsize=(6, 6))
util.PlotImageTensor(train_dataset.view_images[13])
plt.figure(figsize=(6, 6))
util.PlotImageTensor(blended)
plt.figure(figsize=(12, 6))
plt.subplot(2, 4, 1)
util.PlotImageTensor(train_dataset.sparse_view_images[0])
plt.subplot(2, 4, 2)
util.PlotImageTensor(train_dataset.sparse_view_images[1])
plt.subplot(2, 4, 3)
util.PlotImageTensor(train_dataset.sparse_view_images[2])
plt.subplot(2, 4, 4)
util.PlotImageTensor(train_dataset.sparse_view_images[3])

plt.subplot(2, 4, 5)
util.PlotImageTensor(trans_images[0, 0])
plt.subplot(2, 4, 6)
util.PlotImageTensor(trans_images[0, 1])
plt.subplot(2, 4, 7)
util.PlotImageTensor(trans_images[0, 2])
plt.subplot(2, 4, 8)
util.PlotImageTensor(trans_images[0, 3])
