Commit 172b5205 authored by BobYeah's avatar BobYeah
Browse files

spher_net (for fix view point) implemented and tested

parent 3e1a5b04
import torch
import torchvision.transforms.functional as trans_f
import json
from ..my import util
from ..my import imgio
class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam_params ```object```: camera intrinsic parameters\n
view_centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
view_images ```Tensor(N, 3, H, W)```: images of views\n
"""
def __init__(self, dataset_desc_path: str, load_images: bool = True, gray: bool = False,
ray_as_item=False):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images
- cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
- view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param gray ```bool```: whether convert view images to grayscale
:param ray_as_item ```bool```: whether to treat each ray in view as an item
"""
self.data_dir = dataset_desc_path.rsplit('/', 1)[0] + '/'
self.load_images = load_images
self.ray_as_item = ray_as_item
# Load dataset description file
with open(dataset_desc_path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
self.view_file_pattern: str = self.data_dir + \
data_desc['view_file_pattern']
self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x'])
self.cam_params = data_desc['cam_params']
self.view_centers = torch.tensor(data_desc['view_centers']) # (N, 3)
self.view_rots = torch.tensor(data_desc['view_rots']) \
.view(-1, 3, 3) # (N, 3, 3)
# Load view images
if load_images:
self.view_images = util.ReadImageTensor(
[self.view_file_pattern % i for i in range(self.view_centers.size(0))])
if gray:
self.view_images = trans_f.rgb_to_grayscale(self.view_images)
else:
self.view_images = None
local_view_rays = util.GetLocalViewRays(self.cam_params,
self.view_res,
flatten=True) # (M, 3)
# Transpose matrix so we can perform vec x mat
view_rots_t = self.view_rots.permute(0, 2, 1)
# ray_positions & ray_directions are both (N, M, 3)
self.ray_positions = self.view_centers.unsqueeze(1) \
.expand(-1, local_view_rays.size(0), -1)
self.ray_directions = torch.matmul(local_view_rays, view_rots_t)
# Flatten rays if ray_as_item = True
if ray_as_item:
self.view_pixels = self.view_images.permute(
0, 2, 3, 1).flatten(0, 2)
self.ray_positions = self.ray_positions.flatten(0, 1)
self.ray_directions = self.ray_directions.flatten(0, 1)
def __len__(self):
return self.ray_positions.size(0)
def __getitem__(self, idx):
if self.load_images:
if self.ray_as_item:
return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, self.ray_positions[idx], self.ray_directions[idx]
from typing import List from typing import List, Tuple
from math import pi
import torch import torch
import torch.nn as nn import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import * from .pytorch_prototyping.pytorch_prototyping import *
...@@ -6,40 +7,176 @@ from .my import util ...@@ -6,40 +7,176 @@ from .my import util
from .my import device from .my import device
def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Cartesian to Spherical
:param cart: ... x 3, coordinates in Cartesian
:return: ... x 3, coordinates in Spherical (r, theta, phi)
"""
rho = torch.norm(cart, p=2, dim=-1)
theta = torch.atan2(cart[..., 2], cart[..., 0])
theta = theta + (theta < 0).type_as(theta) * (2 * pi)
phi = torch.acos(cart[..., 1] / rho)
return torch.stack([rho, theta, phi], dim=-1)
def SphericalToCartesian(spher: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Spherical to Cartesian
:param spher: ... x 3, coordinates in Spherical
:return: ... x 3, coordinates in Cartesian (r, theta, phi)
"""
rho = spher[..., 0]
sin_theta_phi = torch.sin(spher[..., 1:3])
cos_theta_phi = torch.cos(spher[..., 1:3])
x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
y = rho * cos_theta_phi[..., 1]
z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
return torch.stack([x, y, z], dim=-1)
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B'(1D), radius of spheres
:return: B x B' x 3, points of intersection
"""
# p, v: Expand to B x 1 x 3
p = p.unsqueeze(1)
v = v.unsqueeze(1)
# pp, vv, pv: B x 1
pp = (p * p).sum(dim=2)
vv = (v * v).sum(dim=2)
pv = (p * v).sum(dim=2)
# k: Expand to B x B' x 1
k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2)
return p + k * v
def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres = RaySphereIntersect(p, v, r)
return CartesianToSpherical(p_on_spheres)
class FcNet(nn.Module): class FcNet(nn.Module):
def __init__(self, in_chns, out_chns, nf, n_layers): def __init__(self, in_chns: int, out_chns: int, nf: int, n_layers: int):
super().__init__() super().__init__()
self.layers = list() self.layers = list()
self.layers.append(nn.Linear(in_chns, nf)) self.layers += [
self.layers.append(nn.LeakyReLU()) nn.Linear(in_chns, nf),
#nn.LayerNorm([nf]),
nn.ReLU()
]
for _ in range(1, n_layers): for _ in range(1, n_layers):
self.layers.append(nn.Linear(nf, nf)) self.layers += [
self.layers.append(nn.LeakyReLU()) nn.Linear(nf, nf),
#nn.LayerNorm([nf]),
nn.ReLU()
]
self.layers.append(nn.Linear(nf, out_chns)) self.layers.append(nn.Linear(nf, out_chns))
self.net = nn.Sequential(*self.layers) self.net = nn.Sequential(*self.layers)
self.net.apply(self.init_weights)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x) return self.net(x)
def init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
class Rendering(nn.Module): class Rendering(nn.Module):
def __init__(self, n_sphere_layers): def __init__(self, sphere_layers: List[float]):
"""
Initialize a Rendering module
:param sphere_layers: L x 1, radius of sphere layers
"""
super().__init__() super().__init__()
self.n_sl = n_sphere_layers self.sphere_layers = torch.tensor(
sphere_layers, device=device.GetDevice())
def forward(self, net, pos, dir): def forward(self, net: FcNet, p: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
""" """
[summary] [summary]
:param pos: B x 3, position of a ray :param net: the full-connected net
:param dir: B x 3, direction of a ray :param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:return B x 1/3, view images by blended layers
""" """
L = self.sphere_layers.size()[0]
sp = RayToSpherical(p, v, self.sphere_layers) # B x L x 3
sp[..., 0] = 1 / sp[..., 0] # Radius to diopter
color_alpha: torch.Tensor = net(
sp.flatten(0, 1)).view(p.size()[0], L, -1)
if (color_alpha.size(-1) == 2): # Grayscale
c = color_alpha[..., 0:1]
a = color_alpha[..., 1:2]
else: # RGB
c = color_alpha[..., 0:3]
a = color_alpha[..., 3:4]
blended = c[:, 0, :] * a[:, 0, :]
for l in range(1, L):
blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :]
return blended
class MslNet(nn.Module): class MslNet(nn.Module):
def __init__(self): def __init__(self, cam_params, sphere_layers: List[float], out_res: Tuple[int, int], gray=False):
"""
Initialize a multi-sphere-layer net
:param cam_params: intrinsic parameters of camera
:param sphere_layers: L x 1, radius of sphere layers
:param out_res: resolution of output view image
"""
super().__init__() super().__init__()
self.cam_params = cam_params
self.out_res = out_res
self.v_local = util.GetLocalViewRays(self.cam_params, out_res, flatten=True) \
.to(device.GetDevice()) # N x 3
#self.net = FCBlock(hidden_ch=64,
# num_hidden_layers=4,
# in_features=3,
# out_features=2 if gray else 4,
# outermost_linear=True)
self.net = FcNet(in_chns=3, out_chns=2 if gray else 4, nf=256, n_layers=8)
self.rendering = Rendering(sphere_layers)
def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor) -> torch.Tensor:
"""
T_view -> image
:param view_centers: B x 3, centers of views
:param view_rots: B x 3 x 3, rotation matrices of views
:return: B x 1/3 x H_out x W_out, inferred images of views
"""
# Transpose matrix so we can perform vec x mat
view_rots_t = view_rots.permute(0, 2, 1)
def forward(self, x): # p and v are B x N x 3 tensor
p = view_centers.unsqueeze(1).expand(-1, self.v_local.size(0), -1)
v = torch.matmul(self.v_local, view_rots_t)
c: torch.Tensor = self.rendering(
self.net, p.flatten(0, 1), v.flatten(0, 1)) # (BN) x 3
# unflatten
return c.view(view_centers.size(0), self.out_res[0],
self.out_res[1], -1).permute(0, 3, 1, 2)
from typing import List, NoReturn
import torch
from torchvision.transforms.functional import convert_image_dtype
from torchvision.io.image import read_image
from torchvision.utils import save_image
def ReadImages(*args, paths: List[str] = None, dtype=torch.float) -> torch.Tensor:
raise NotImplementedError('The method has bug. Use util.ReadImageTensor instead')
if not paths:
paths = args
images = torch.stack([read_image(path) for path in paths], dim=0)
return convert_image_dtype(images, dtype)
def SaveImages(input, *args, paths: List[str] = None) -> NoReturn:
raise NotImplementedError('The method has bug. Use util.WriteImageTensor instead')
if not paths:
paths = args
if input.size(0) != len(paths):
raise ValueError('batch size of input images is not same as length of paths')
for i, path in enumerate(range(paths)):
save_image(input[i], path)
\ No newline at end of file
from typing import List
import torch
import torch.nn as nn
import numpy as np
class FcLayer(nn.Module):
def __init__(self, in_chns: int, out_chns: int, activate: nn.Module = None,
skip_chns: int = 0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_chns + skip_chns, out_chns),
activate
) if activate else nn.Linear(in_chns + skip_chns, out_chns)
self.skip = skip_chns != 0
def forward(self, x: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
return self.net(torch.cat([x0, x], dim=-1) if self.skip else x)
class FcNet(nn.Module):
def __init__(self, *, in_chns: int, out_chns: int,
nf: int, n_layers: int, skips: List[int] = []):
"""
Initialize a full-connection net
:kwarg in_chns: channels of input
:kwarg out_chns: channels of output
:kwarg nf: number of features in each hidden layer
:kwarg n_layers: number of layers
:kwarg skips: create skip connections from input to layers in this list
"""
super().__init__()
self.layers = list()
self.layers += [FcLayer(in_chns, nf, nn.ReLU())]
self.layers += [
FcLayer(nf, nf, nn.ReLU(),
skip_chns=in_chns if i in skips else 0)
for i in range(1, n_layers)
]
self.layers += [FcLayer(nf, out_chns)]
for i, layer in enumerate(self.layers):
self.add_module('layer%d' % i, layer)
self.apply(self.init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = x
for layer in self.layers:
x = layer(x, x0)
return x
def init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
class InputEncoder(nn.Module):
def Get(multires, input_dims):
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires - 1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
return InputEncoder(**embed_kwargs)
def __init__(self, **kwargs):
super().__init__()
self._CreateFunc(**kwargs)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Encode the given input to R^D space
:param input ```Tensor(B x C)```: input
:return ```Tensor(B x D): encoded
:rtype: torch.Tensor
"""
return torch.cat([fn(input) for fn in self.embed_fns], dim=-1)
def _CreateFunc(self, **kwargs):
embed_fns = []
d = kwargs['input_dims']
out_dim = 0
if kwargs['include_input'] or kwargs['num_freqs'] == 0:
embed_fns.append(lambda x: x)
out_dim += d
if kwargs['num_freqs'] != 0:
max_freq = kwargs['max_freq_log2']
N_freqs = kwargs['num_freqs']
if kwargs['log_sampling']:
freq_bands = 2. ** np.linspace(0., max_freq, N_freqs)
else:
freq_bands = np.linspace(2. ** 0., 2. ** max_freq, N_freqs)
for freq in freq_bands:
for p_fn in kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
...@@ -6,6 +6,7 @@ class SimplePerf(object): ...@@ -6,6 +6,7 @@ class SimplePerf(object):
def __init__(self, enable, start = False) -> None: def __init__(self, enable, start = False) -> None:
super().__init__() super().__init__()
self.enable = enable self.enable = enable
self.start_event = None
if start: if start:
self.Start() self.Start()
...@@ -23,6 +24,6 @@ class SimplePerf(object): ...@@ -23,6 +24,6 @@ class SimplePerf(object):
return return
self.end_event.record() self.end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(name, ': ', self.start_event.elapsed_time(self.end_event)) print('%s: %.1fms' % (name, self.start_event.elapsed_time(self.end_event)))
if not end: if not end:
self.start_event.record() self.start_event.record()
\ No newline at end of file
from typing import Tuple from typing import List, Tuple
from math import pi
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
...@@ -150,3 +151,61 @@ def MeshGrid(size: Tuple[int, int], normalize: bool = False, swap_dim: bool = Fa ...@@ -150,3 +151,61 @@ def MeshGrid(size: Tuple[int, int], normalize: bool = False, swap_dim: bool = Fa
def CreateDirIfNeed(path): def CreateDirIfNeed(path):
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
def GetLocalViewRays(cam_params, res: Tuple[int, int], flatten=False) -> torch.Tensor:
coords = MeshGrid(res)
c = torch.tensor([cam_params['cx'], cam_params['cy']])
f = torch.tensor([cam_params['fx'], cam_params['fy']])
rays = torch.cat([
(coords - c) / f,
torch.ones(res[0], res[1], 1, )
], dim=2)
if flatten:
rays = rays.flatten(0, 1)
return rays
def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Cartesian to Spherical
:param cart: ... x 3, coordinates in Cartesian
:return: ... x 3, coordinates in Spherical (r, theta, phi)
"""
rho = torch.norm(cart, p=2, dim=-1)
theta = torch.atan2(cart[..., 2], cart[..., 0])
theta = theta + (theta < 0).type_as(theta) * (2 * pi)
phi = torch.acos(cart[..., 1] / rho)
return torch.stack([rho, theta, phi], dim=-1)
def SphericalToCartesian(spher: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Spherical to Cartesian
:param spher: ... x 3, coordinates in Spherical
:return: ... x 3, coordinates in Cartesian (r, theta, phi)
"""
rho = spher[..., 0]
sin_theta_phi = torch.sin(spher[..., 1:3])
cos_theta_phi = torch.cos(spher[..., 1:3])
x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
y = rho * cos_theta_phi[..., 1]
z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
return torch.stack([x, y, z], dim=-1)
def GetDepthLayers(depth_range: Tuple[float, float], n_layers: int) -> List[float]:
"""
Get [n_layers] foreground layers whose diopters are distributed uniformly
in [depth_range] plus a background layer
:param depth_range: depth range of foreground layers
:param n_layers: number of foreground layers
:return: list of [n_layers+1] depths
"""
diopter_range = (1 / depth_range[1], 1 / depth_range[0])
depths = [1e5] # Background layer
depths += list(1.0 / np.linspace(diopter_range[0], diopter_range[1], n_layers))
return depths
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -144,5 +144,5 @@ def test(net_file: str): ...@@ -144,5 +144,5 @@ def test(net_file: str):
if __name__ == "__main__": if __name__ == "__main__":
train() #train()
#test(RUN_DIR + '/model-epoch_1000.pth') test(RUN_DIR + '/model-epoch_1000.pth')
import sys
sys.path.append('/e/dengnc')
__package__ = "deeplightfield"
import argparse
import torch
import torch.optim
import torchvision
from typing import List, Tuple
from tensorboardX import SummaryWriter
from torch import nn
from .my import netio
from .my import util
from .my import device
from .my.simple_perf import SimplePerf
from .loss.loss import PerceptionReconstructionLoss
from .data.spherical_view_syn import SphericalViewSynDataset
from .msl_net import MslNet
from .spher_net import SpherNet
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
opt = parser.parse_args()
# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
# Toggles
GRAY = False
ROT_ONLY = False
TRAIN_MODE = True
EVAL_TIME_PERFORMANCE = False
RAY_AS_ITEM = True
# ========
#GRAY = True
ROT_ONLY = True
TRAIN_MODE = False
#EVAL_TIME_PERFORMANCE = True
#RAY_AS_ITEM = False
# Net parameters
DEPTH_RANGE = (1, 10)
N_DEPTH_LAYERS = 10
N_ENCODE_DIM = 10
FC_PARAMS = {
'nf': 128,
'n_layers': 6,
'skips': [4]
}
# Train
BATCH_SIZE = 2048 if RAY_AS_ITEM else 4
EPOCH_RANGE = range(0, 500)
SAVE_INTERVAL = 20
# Paths
DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.26_rotonly/'
RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb',
BATCH_SIZE,
N_ENCODE_DIM,
FC_PARAMS['nf'],
FC_PARAMS['n_layers'],
'_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '')
TRAIN_DATA_DESC_FILE = DATA_DIR + 'train.json'
RUN_DIR = DATA_DIR + RUN_ID + '/'
OUTPUT_DIR = RUN_DIR + 'output/'
LOG_DIR = RUN_DIR + 'log/'
# Test
TEST_NET_NAME = 'model-epoch_100'
TEST_BATCH_SIZE = 5
def train():
# 1. Initialize data loader
print("Load dataset: " + TRAIN_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset(
TRAIN_DATA_DESC_FILE, gray=GRAY, ray_as_item=RAY_AS_ITEM)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=False)
print('Data loaded. %d iters per epoch.' % len(train_data_loader))
# 2. Initialize components
if ROT_ONLY:
model = SpherNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
out_res=train_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=train_dataset.cam_params,
sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res,
gray=GRAY).to(device.GetDevice())
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = nn.MSELoss()
if EPOCH_RANGE.start > 0:
netio.LoadNet('%smodel-epoch_%d.pth' % (RUN_DIR, EPOCH_RANGE.start),
model, solver=optimizer)
# 3. Train
model.train()
epoch = None
iters = EPOCH_RANGE.start * len(train_data_loader)
util.CreateDirIfNeed(RUN_DIR)
util.CreateDirIfNeed(LOG_DIR)
perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
perf_epoch = SimplePerf(True, start=True)
writer = SummaryWriter(LOG_DIR)
print("Begin training...")
for epoch in EPOCH_RANGE:
for _, gt, ray_positions, ray_directions in train_data_loader:
gt = gt.to(device.GetDevice())
ray_positions = ray_positions.to(device.GetDevice())
ray_directions = ray_directions.to(device.GetDevice())
perf.Checkpoint("Load")
out = model(ray_positions, ray_directions)
perf.Checkpoint("Forward")
optimizer.zero_grad()
loss_value = loss(out, gt)
perf.Checkpoint("Compute loss")
loss_value.backward()
perf.Checkpoint("Backward")
optimizer.step()
perf.Checkpoint("Update")
print("Epoch: ", epoch, ", Iter: ", iters,
", Loss: ", loss_value.item())
# Write tensorboard logs.
writer.add_scalar("loss", loss_value, iters)
if not RAY_AS_ITEM and iters % 100 == 0:
output_vs_gt = torch.cat([out, gt], dim=0)
writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt, scale_each=True, normalize=False)
.cpu().detach().numpy(), iters)
iters += 1
perf_epoch.Checkpoint("Epoch")
# Save checkpoint
if ((epoch + 1) % SAVE_INTERVAL == 0):
netio.SaveNet('%smodel-epoch_%d.pth' % (RUN_DIR, epoch + 1), model,
solver=optimizer)
print("Train finished")
def test(net_file: str):
# 1. Load train dataset
print("Load dataset: " + TRAIN_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE,
load_images=True, gray=GRAY)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=TEST_BATCH_SIZE,
pin_memory=True,
shuffle=False,
drop_last=False)
# 2. Load trained model
if ROT_ONLY:
model = SpherNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
out_res=train_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=train_dataset.cam_params,
sphere_layers=_GetSphereLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res,
gray=GRAY).to(device.GetDevice())
netio.LoadNet(net_file, model)
# 3. Test on train dataset
print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE)
util.CreateDirIfNeed(OUTPUT_DIR)
util.CreateDirIfNeed(OUTPUT_DIR + TEST_NET_NAME)
perf = SimplePerf(True, start=True)
i = 0
for view_idxs, view_images, ray_positions, ray_directions in train_data_loader:
ray_positions = ray_positions.to(device.GetDevice())
ray_directions = ray_directions.to(device.GetDevice())
perf.Checkpoint("%d - Load" % i)
out_view_images = model(ray_positions, ray_directions)
perf.Checkpoint("%d - Infer" % i)
util.WriteImageTensor(
view_images,
['%s%s/gt_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs])
util.WriteImageTensor(
out_view_images,
['%s%s/out_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs])
perf.Checkpoint("%d - Write" % i)
i += 1
if __name__ == "__main__":
if TRAIN_MODE:
train()
else:
test(RUN_DIR + TEST_NET_NAME + '.pth')
from typing import List, Tuple
from math import pi
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import net_modules
from .my import util
from .my import device
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B'(1D), radius of spheres
:return: B x B' x 3, points of intersection
"""
# p, v: Expand to B x 1 x 3
p = p.unsqueeze(1)
v = v.unsqueeze(1)
# pp, vv, pv: B x 1
pp = (p * p).sum(dim=2)
vv = (v * v).sum(dim=2)
pv = (p * v).sum(dim=2)
# k: Expand to B x B' x 1
k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2)
return p + k * v
def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres = RaySphereIntersect(p, v, r)
return util.CartesianToSpherical(p_on_spheres)
class SpherNet(nn.Module):
def __init__(self, cam_params, # spher_min: Tuple[float, float], spher_max: Tuple[float, float],
fc_params,
out_res: Tuple[int, int] = None,
gray: bool = False,
encode_to_dim: int = 0):
"""
Initialize a sphere net
:param cam_params: intrinsic parameters of camera
:param fc_params: parameters of full-connection network
:param out_res: resolution of output view image
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.cam_params = cam_params
self.in_chns = 2
self.out_res = out_res
#self.spher_min = torch.tensor(spher_min, device=device.GetDevice()).view(1, 2)
#self.spher_max = torch.tensor(spher_max, device=device.GetDevice()).view(1, 2)
#self.spher_range = self.spher_max - self.spher_min
self.input_encoder = net_modules.InputEncoder.Get(encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 1 if gray else 3
self.net = net_modules.FcNet(**fc_params)
def forward(self, _, ray_directions: torch.Tensor) -> torch.Tensor:
"""
rays -> colors
:param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions
:return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels
"""
v = ray_directions.view(-1, 3) # (*, 3)
spher = util.CartesianToSpherical(v)[..., 1:3] # (*, 2)
# (spher - self.spher_min) / self.spher_range * 2 - 0.5
spher_normed = spher
c: torch.Tensor = self.net(self.input_encoder(spher_normed))
# Unflatten to (B, 1|3, H, W) if take view as item
return c.view(ray_directions.size(0), self.out_res[0], self.out_res[1],
-1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment