Commit b7fae973 authored by BobYeah's avatar BobYeah
Browse files

implemented basic msl net

parent c570c3b1
[Net Parameters]
aaa = 1
bbb = 'abc'
ccc = (2,3)
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 32,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 12,
'skips': []
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
config.LOSS = 'mse_grad'
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 128,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 8,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 12,
'skips': []
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 12,
'skips': []
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 32,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 4,
'perturb_sample': True
}
\ No newline at end of file
import math
from typing import Tuple
import torch
import torchvision.transforms.functional as trans_f
import json
from ..my import util
from ..my import device
def _convert_camera_params(input_camera_params, view_res):
"""
Check and convert camera parameters in config file to pixel-space
:param cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] },
the parameters of camera
:return: camera parameters
"""
input_is_normalized = bool(input_camera_params.get('normalized'))
camera_params = {}
if 'fov' in input_camera_params:
camera_params['fx'] = camera_params['fy'] = \
(1 if input_is_normalized else view_res[0]) / \
util.Fov2Length(input_camera_params['fov'])
camera_params['fy'] *= -1
else:
camera_params['fx'] = input_camera_params['fx']
camera_params['fy'] = input_camera_params['fy']
camera_params['cx'] = input_camera_params['cx']
camera_params['cy'] = input_camera_params['cy']
if input_is_normalized:
camera_params['fx'] *= view_res[1]
camera_params['fy'] *= view_res[0]
camera_params['cx'] *= view_res[1]
camera_params['cy'] *= view_res[0]
return camera_params
class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
......@@ -27,7 +58,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images
- cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
- cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] }, the parameters of camera
- view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
......@@ -50,7 +81,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
data_desc['view_file_pattern']
self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x'])
self.cam_params = data_desc['cam_params']
self.cam_params = _convert_camera_params(
data_desc['cam_params'], self.view_res)
self.view_centers = torch.tensor(data_desc['view_centers']) # (N, 3)
self.view_rots = torch.tensor(data_desc['view_rots']) \
.view(-1, 3, 3) # (N, 3, 3)
......@@ -70,24 +102,181 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Transpose matrix so we can perform vec x mat
view_rots_t = self.view_rots.permute(0, 2, 1)
# ray_positions & ray_directions are both (N, M, 3)
self.ray_positions = self.view_centers.unsqueeze(1) \
# rays_o & rays_d are both (N, M, 3)
self.rays_o = self.view_centers.unsqueeze(1) \
.expand(-1, local_view_rays.size(0), -1)
self.ray_directions = torch.matmul(local_view_rays, view_rots_t)
self.rays_d = torch.matmul(local_view_rays, view_rots_t)
# Flatten rays if ray_as_item = True
if ray_as_item:
self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten(
0, 2) if self.view_images != None else None
self.ray_positions = self.ray_positions.flatten(0, 1)
self.ray_directions = self.ray_directions.flatten(0, 1)
self.rays_o = self.rays_o.flatten(0, 1)
self.rays_d = self.rays_d.flatten(0, 1)
def __len__(self):
return self.ray_positions.size(0)
return self.rays_o.size(0)
def __getitem__(self, idx):
if self.load_images:
if self.ray_as_item:
return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, False, self.ray_positions[idx], self.ray_directions[idx]
return idx, self.view_pixels[idx], self.rays_o[idx], self.rays_d[idx]
return idx, self.view_images[idx], self.rays_o[idx], self.rays_d[idx]
return idx, False, self.rays_o[idx], self.rays_d[idx]
class FastSphericalViewSynDataset(torch.utils.data.dataset.Dataset):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam_params ```object```: camera intrinsic parameters\n
view_centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
view_images ```Tensor(N, 3, H, W)```: images of views\n
"""
def __init__(self, dataset_desc_path: str, load_images: bool = True, gray: bool = False):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images
- cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
- view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param gray ```bool```: whether convert view images to grayscale
"""
super().__init__()
self.data_dir = dataset_desc_path.rsplit('/', 1)[0] + '/'
self.load_images = load_images
# Load dataset description file
with open(dataset_desc_path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
if data_desc['view_file_pattern'] == '':
self.load_images = False
else:
self.view_file_pattern: str = self.data_dir + \
data_desc['view_file_pattern']
self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x'])
self.cam_params = _convert_camera_params(
data_desc['cam_params'], self.view_res)
self.view_centers = torch.tensor(
data_desc['view_centers'], device=device.GetDevice()) # (N, 3)
self.view_rots = torch.tensor(
data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3) # (N, 3, 3)
self.n_views = self.view_centers.size(0)
self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
# Load view images
if self.load_images:
self.view_images = util.ReadImageTensor(
[self.view_file_pattern % i
for i in range(self.view_centers.size(0))]
).to(device.GetDevice())
if gray:
self.view_images = trans_f.rgb_to_grayscale(self.view_images)
else:
self.view_images = None
local_view_rays = util.GetLocalViewRays(self.cam_params, self.view_res, True) \
.to(device.GetDevice()) # (HW, 3)
# Transpose matrix so we can perform vec x mat
view_rots_t = self.view_rots.permute(0, 2, 1)
# rays_o & rays_d are both (N, H, W, 3)
self.rays_o = self.view_centers[:, None, None, :] \
.expand(-1, self.view_res[0], self.view_res[1], -1)
self.rays_d = torch.matmul(local_view_rays, view_rots_t) \
.view_as(self.rays_o)
self.patched_images = self.view_images # (N, 1|3, H, W)
self.patched_rays_o = self.rays_o # (N, H, W, 3)
self.patched_rays_d = self.rays_d # (N, H, W, 3)
def set_patch_size(self, patch_size: Tuple[int, int], offset: Tuple[int, int] = (0, 0)):
"""
Set the size of patch and (optional) offset. If patch_size = (1, 1)
:param patch_size:
:param offset:
"""
patches = ((self.view_res[0] - offset[0]) // patch_size[0],
(self.view_res[1] - offset[1]) // patch_size[1])
slices = (..., slice(offset[0], offset[0] + patches[0] * patch_size[0]),
slice(offset[1], offset[1] + patches[1] * patch_size[1]))
if patch_size[0] == 1 and patch_size[1] == 1:
self.patched_images = self.view_images[slices] \
.permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None
self.patched_rays_o = self.rays_o[slices].flatten(0, 2)
self.patched_rays_d = self.rays_d[slices].flatten(0, 2)
elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]:
self.patched_images = self.view_images
self.patched_rays_o = self.rays_o
self.patched_rays_d = self.rays_d
else:
print(self.view_images.size(), self.rays_o.size())
print(self.view_images[slices].size(), self.rays_o[slices].size())
self.patched_images = self.view_images[slices] \
.view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \
.permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None
self.patched_rays_o = self.rays_o[slices] \
.view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
.permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
self.patched_rays_d = self.rays_d[slices] \
.view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
.permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
def __len__(self):
return self.patched_rays_o.size(0)
def __getitem__(self, idx):
if self.load_images:
return idx, self.patched_images[idx], self.patched_rays_o[idx], \
self.patched_rays_d[idx]
return idx, False, self.patched_rays_o[idx], self.patched_rays_d[idx]
class FastDataLoader(object):
class Iter(object):
def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
super().__init__()
self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \
if shuffle else torch.arange(len(dataset), device=device.GetDevice())
self.offset = 0
self.batch_size = batch_size
self.dataset = dataset
self.drop_last = drop_last
def __next__(self):
if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset):
raise StopIteration()
indices = self.indices[self.offset:self.offset + self.batch_size]
self.offset += self.batch_size
return self.dataset[indices]
def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None:
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
def __iter__(self):
return FastDataLoader.Iter(self.dataset, self.batch_size,
self.shuffle, self.drop_last)
def __len__(self):
return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \
else math.ceil(len(self.dataset) / self.batch_size)
import torch
from .ssim import *
from .perc_loss import *
device=torch.device("cuda:2")
l1loss = torch.nn.L1Loss()
perc_loss = VGGPerceptualLoss().to(device)
##### LOSS #####
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(
labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated, gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def loss_without_perc(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(
labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
##### LOSS #####
class ReconstructionLoss(torch.nn.Module):
def __init__(self):
super(ReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(
labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
return total_loss
class PerceptionReconstructionLoss(torch.nn.Module):
def __init__(self):
super(PerceptionReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x = torch.nn.functional.mse_loss(labels_dx, preds_dx)
rmse_grad_y = torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x = torch.log10(rmse_grad_x)
psnr_grad_y = torch.log10(rmse_grad_y)
p_loss = perc_loss(generated, gt)
total_loss = psnr_intensity + 0.5 * (psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
from typing import List, Tuple
from typing import Tuple
import torch
import torch.nn as nn
from .my import net_modules
from .my import util
from .my import device
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B'(1D), radius of spheres
:return: B x B' x 3, points of intersection
:param p ```Tensor(B, 3)```: positions of rays
:param v ```Tensor(B, 3)```: directions of rays
:param r ```Tensor(N)```: , radius of spheres
:return ```Tensor(B, N, 3)```: points of intersection
:return ```Tensor(B, N)```: depths of intersection along ray
"""
# p, v: Expand to B x 1 x 3
# p, v: Expand to (B, 1, 3)
p = p.unsqueeze(1)
v = v.unsqueeze(1)
# pp, vv, pv: B x 1
# pp, vv, pv: (B, 1)
pp = (p * p).sum(dim=2)
vv = (v * v).sum(dim=2)
pv = (p * v).sum(dim=2)
# k: Expand to B x B' x 1
k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2)
return p + k * v
def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres = RaySphereIntersect(p, v, r)
return util.CartesianToSpherical(p_on_spheres)
depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
return p + depths[..., None] * v, depths
class Rendering(nn.Module):
def __init__(self):
def __init__(self, *, raw_noise_std: float = 0.0, white_bg: bool = False):
"""
Initialize a Rendering module
"""
super().__init__()
self.raw_noise_std = raw_noise_std
self.white_bg = white_bg
def forward(self, raw, z_vals, ret_extra: bool = False):
"""Transforms model's predictions to semantically meaningful values.
Args:
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
"""
# Function for computing density from model prediction. This value is
# strictly between [0, 1].
def raw2alpha(raw, dists, act_fn=torch.relu):
return 1.0 - torch.exp(-act_fn(raw) * dists)
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N_samples)
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = util.broadcast_cat(dists, 1e10)
# Extract RGB of each sample position along each ray.
color = torch.sigmoid(raw[..., :-1]) # (N_rays, N_samples, 1|3)
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
noise = 0.
if self.raw_noise_std > 0.:
noise = torch.normal(0.0, self.raw_noise_std,
raw[..., 3].size(), rand_gen)
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point.
alpha = raw2alpha(raw[..., -1] + noise, dists) # (N_rays, N_samples)
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = util.broadcast_cat(
torch.cumprod(1 - alpha[..., :-1] + 1e-10, dim=-1),
1.0, append=False)
weights = alpha * one_minus_alpha # (N_rays, N_samples)
# (N_rays, 1|3), computed weighted color of each sample along each ray.
color_map = torch.sum(weights[..., None] * color, dim=-2)
# To composite onto a white background, use the accumulated alpha map.
if self.white_bg or ret_extra:
# Sum of weights along each ray. This value is in [0, 1] up to numerical error.
acc_map = torch.sum(weights, -1)
if self.white_bg:
color_map = color_map + (1. - acc_map[..., None])
else:
acc_map = None
if not ret_extra:
return color_map
# Estimated depth map is expected distance.
depth_map = torch.sum(weights * z_vals, dim=-1)
# Disparity map is inverse depth.
disp_map = 1. / torch.max(1e-10, depth_map /
torch.sum(weights, dim=-1))
return color_map, disp_map, acc_map, weights, depth_map
class Sampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, spherical: bool):
"""
Initialize a Sampler module
def forward(self, color_alpha: torch.Tensor) -> torch.Tensor:
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
"""
super().__init__()
self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1],
n_samples, device=device.GetDevice())
self.perturb_sample = perturb_sample
self.spherical = spherical
if perturb_sample:
mids = .5 * (self.r[1:] + self.r[:-1])
self.upper = torch.cat([mids, self.r[-1:]], -1)
self.lower = torch.cat([self.r[:1], mids], -1)
def forward(self, rays_o, rays_d):
"""
Blend layers to get final color
Sample points along rays. return Spherical or Cartesian coordinates,
specified by ```self.shperical```
:param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel
:return ```Tensor(B, C-1)``` blended pixels
:param rays_o ```Tensor(B, 3)```: rays' origin
:param rays_d ```Tensor(B, 3)```: rays' direction
:return ```Tensor(B, N, 3)```: sampled points
:return ```Tensor(B, N)```: corresponding depths along rays
"""
c = color_alpha[..., :-1]
a = color_alpha[..., -1:]
blended = c[:, 0, :] * a[:, 0, :]
for l in range(1, color_alpha.size(1)):
blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :]
return blended
if self.perturb_sample:
# stratified samples in those intervals
t_rand = torch.rand(self.r.size(),
generator=rand_gen,
device=device.GetDevice())
r = self.lower + (self.upper - self.lower) * t_rand
else:
r = self.r
if self.spherical:
pts, depths = RaySphereIntersect(rays_o, rays_d, r)
sphers = util.CartesianToSpherical(pts)
sphers[..., 0] = 1 / sphers[..., 0]
return sphers, depths
else:
return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
class MslNet(nn.Module):
def __init__(self, cam_params, fc_params, sphere_layers: List[float],
out_res: Tuple[int, int], gray=False, encode_to_dim: int = 0):
def __init__(self, fc_params, sampler_params,
gray=False,
encode_to_dim: int = 0):
"""
Initialize a multi-sphere-layer net
:param cam_params: intrinsic parameters of camera
:param fc_params: parameters of full-connection network
:param sphere_layers: list(L), radius of sphere layers
:param out_res: resolution of output view image
:param fc_params: parameters for full-connection network
:param sampler_params: parameters for sampler
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.cam_params = cam_params
self.sphere_layers = torch.tensor(sphere_layers,
dtype=torch.float,
device=device.GetDevice())
self.in_chns = 3
self.out_res = out_res
self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4
self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering()
def forward(self, ray_positions: torch.Tensor, ray_directions: torch.Tensor) -> torch.Tensor:
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
"""
rays -> colors
:param ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray positions
:param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions
:return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels
:param rays_o ```Tensor(B, ..., 3)```: rays' origin
:param rays_d ```Tensor(B, ..., 3)```: rays' direction
:return: Tensor(B, 1|3, ...), inferred images/pixels
"""
p = ray_positions.view(-1, 3)
v = ray_directions.view(-1, 3)
spher = RayToSpherical(p, v, self.sphere_layers).flatten(0, 1)
color_alpha = self.net(self.input_encoder(spher)).view(
p.size(0), self.sphere_layers.size(0), -1)
c: torch.Tensor = self.rendering(color_alpha)
# unflatten
return c.view(ray_directions.size(0), self.out_res[0],
self.out_res[1], -1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c
p = rays_o.view(-1, 3)
v = rays_d.view(-1, 3)
coords, depths = self.sampler(p, v)
encoded = self.input_encoder(coords)
color_map = self.rendering(self.net(encoded), depths)
# Unflatten according to input shape
out_shape = list(rays_d.size())
out_shape[-1] = -1
return color_map.view(out_shape).movedim(-1, 1)
import torch
from typing import List
from torch import nn
class CombinedLoss(nn.Module):
def __init__(self, loss_modules: List[nn.Module], weights: List[float]):
super().__init__()
self.loss_modules = nn.ModuleList(loss_modules)
self.weights = weights
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return sum([self.weights[i] * self.loss_modules[i](input, target)
for i in range(len(self.loss_modules))])
class GradLoss(nn.Module):
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
input_dy, input_dx = self._cal_grad(input)
target_dy, target_dx = self._cal_grad(target)
return self.mse_loss(
torch.cat([
input_dy.flatten(1, -1),
input_dx.flatten(1, -1)
], 1),
torch.cat([
target_dy.flatten(1, -1),
target_dx.flatten(1, -1)
], 1))
def _cal_grad(self, images):
"""
Calculate gradient of images
:param image ```Tensor(..., C, H, W)```: input images
:return ```Tensor(..., 2C, H-2, W-2)```: gradient map of input images
"""
dy = images[..., 2:, :] - images[..., :-2, :]
dx = images[..., :, 2:] - images[..., :, :-2]
return dy, dx
from typing import Mapping
import torch
import numpy as np
......@@ -9,15 +10,16 @@ def PrintNet(net):
def LoadNet(path, model, solver=None, discriminator=None):
print('Load net from %s ...' % path)
whole_dict = torch.load(path)
whole_dict: Mapping = torch.load(path)
model.load_state_dict(whole_dict['model'])
if solver:
solver.load_state_dict(whole_dict['solver'])
if discriminator:
discriminator.load_state_dict(whole_dict['discriminator'])
return whole_dict['iters'] if 'iters' in whole_dict else 0
def SaveNet(path, model, solver=None, discriminator=None):
def SaveNet(path, model, solver=None, discriminator=None, iters=None):
print('Saving net to %s ...' % path)
whole_dict = {
'model': model.state_dict()
......@@ -26,4 +28,6 @@ def SaveNet(path, model, solver=None, discriminator=None):
whole_dict.update({'solver': solver.state_dict()})
if discriminator:
whole_dict.update({'discriminator': discriminator.state_dict()})
if iters:
whole_dict.update({'iters': iters})
torch.save(whole_dict, path)
\ No newline at end of file
from typing import List, Tuple
from math import pi
import numpy as np
from typing import List, Tuple, Union
import os
import math
import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms.functional as trans_func
import glm
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.types import Number
from torchvision.utils import save_image
......@@ -16,7 +18,7 @@ gmat_type = [[glm.dmat2, glm.dmat2x3, glm.dmat2x4],
def Fov2Length(angle):
return np.tan(angle * np.pi / 360) * 2
return math.tan(math.radians(angle) / 2) * 2
def SmoothStep(x0, x1, x):
......@@ -153,14 +155,13 @@ def CreateDirIfNeed(path):
os.makedirs(path)
def GetLocalViewRays(cam_params, res: Tuple[int, int], flatten=False) -> torch.Tensor:
def GetLocalViewRays(cam_params, res: Tuple[int, int], flatten=False, norm=True) -> torch.Tensor:
coords = MeshGrid(res)
c = torch.tensor([cam_params['cx'], cam_params['cy']])
f = torch.tensor([cam_params['fx'], cam_params['fy']])
rays = torch.cat([
(coords - c) / f,
torch.ones(res[0], res[1], 1, )
], dim=2)
rays = broadcast_cat((coords - c) / f, 1.0)
if norm:
rays = rays / rays.norm(dim=-1, keepdim=True)
if flatten:
rays = rays.flatten(0, 1)
return rays
......@@ -175,7 +176,7 @@ def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
"""
rho = torch.norm(cart, p=2, dim=-1)
theta = torch.atan2(cart[..., 2], cart[..., 0])
theta = theta + (theta < 0).type_as(theta) * (2 * pi)
theta = theta + (theta < 0).type_as(theta) * (2 * math.pi)
phi = torch.acos(cart[..., 1] / rho)
return torch.stack([rho, theta, phi], dim=-1)
......@@ -207,5 +208,78 @@ def GetDepthLayers(depth_range: Tuple[float, float], n_layers: int) -> List[floa
"""
diopter_range = (1 / depth_range[1], 1 / depth_range[0])
depths = [1e5] # Background layer
depths += list(1.0 / np.linspace(diopter_range[0], diopter_range[1], n_layers))
depths += list(1.0 /
np.linspace(diopter_range[0], diopter_range[1], n_layers))
return depths
def GetRotMatrix(theta: Union[float, torch.Tensor], phi: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Get rotation matrix from angles in spherical space
:param theta ```Tensor(..., 1) | float```: rotation angles around y axis
:param phi ```Tensor(..., 1) | float```: rotation angles around x axis
:return: ```Tensor(..., 3, 3)``` rotation matrices
"""
if not isinstance(theta, torch.Tensor):
theta = torch.tensor([theta])
if not isinstance(phi, torch.Tensor):
phi = torch.tensor([phi])
spher = torch.cat([torch.ones_like(theta), theta, phi], dim=-1)
print(spher)
forward = SphericalToCartesian(spher) # (..., 3)
up = torch.tensor([0.0, 1.0, 0.0])
forward, up = torch.broadcast_tensors(forward, up)
print(forward, up)
right = torch.cross(forward, up, dim=-1) # (..., 3)
up = torch.cross(right, forward, dim=-1) # (..., 3)
print(right, up, forward)
return torch.stack([right, up, forward], dim=-2) # (..., 3, 3)
def broadcast_cat(input: torch.Tensor,
s: Union[Number, List[Number], torch.Tensor],
dim=-1,
append: bool = True) -> torch.Tensor:
"""
Concatenate a tensor with a scalar along last dimension
:param input ```Tensor(..., N)```: input tensor
:param s: scalar
:param append: append or prepend the scalar to input tensor
:return: ```Tensor(..., N+1)```
"""
if dim != -1:
raise NotImplementedError('currently only support the last dimension')
if isinstance(s, torch.Tensor):
x = s
elif isinstance(s, list):
x = torch.tensor(s, dtype=input.dtype, device=input.device)
else:
x = torch.tensor([s], dtype=input.dtype, device=input.device)
expand_shape = list(input.size())
expand_shape[dim] = -1
x = x.expand(expand_shape)
return torch.cat([input, x] if append else [x, input], dim)
def generate_video(frames: torch.Tensor, path: str, fps: float,
repeat: int = 1, pingpong: bool = False,
video_codec: str = 'libx264'):
"""
Generate video from a sequence of frames after converting type and
permuting channels to meet the requirement of ```torchvision.io.write_video()```
:param frames ```Tensor(B, C, H, W)```: a sequence of frames
:param path: video path
:param fps: frames per second
:param repeat: repeat times
:param pingpong: whether repeat sequence in pinpong form
:param video_codec: video codec
"""
frames = trans_func.convert_image_dtype(frames, torch.uint8)
frames = frames.detach().cpu().permute(0, 2, 3, 1)
if pingpong:
frames = torch.cat([frames, frames.flip(0)], 0)
frames = frames.expand(repeat, -1, -1, -1, 3).flatten(0, 1)
torchvision.io.write_video(path, frames, fps, video_codec)
This source diff could not be displayed because it is too large. You can view the blob instead.
import math
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
import argparse
import torch
import torch.optim
import torchvision
import importlib
from tensorboardX import SummaryWriter
from torch import nn
from .my import netio
from .my import util
from .my import device
from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import SphericalViewSynDataset
from .msl_net import MslNet
from .spher_net import SpherNet
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
parser.add_argument('--config', type=str,
help='Net config files')
parser.add_argument('--dataset', type=str, required=True,
help='Dataset description file')
parser.add_argument('--test', type=str,
help='Test net file')
parser.add_argument('--test-samples', type=int,
help='Samples used for test')
parser.add_argument('--output-gt', action='store_true',
help='Output ground truth images if exist')
parser.add_argument('--output-alongside', action='store_true',
help='Output generated image alongside ground truth image')
parser.add_argument('--output-video', action='store_true',
help='Output test results as video')
opt = parser.parse_args()
......@@ -28,201 +36,297 @@ opt = parser.parse_args()
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import netio
from .my import util
from .my import device
from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import *
from .msl_net import MslNet
from .spher_net import SpherNet
from .my import loss
class Config(object):
def __init__(self):
self.name = 'default'
self.GRAY = False
# Net parameters
self.NET_TYPE = 'msl'
self.N_ENCODE_DIM = 10
self.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
self.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 32,
'perturb_sample': True
}
self.LOSS = 'mse'
def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.')
config_module = importlib.import_module(
'deeplightfield.' + module_name)
config_module.update_config(config)
self.name = module_name.split('.')[-1]
def load_by_name(self, name):
config_module = importlib.import_module(
'deeplightfield.configs.' + name)
config_module.update_config(config)
self.name = name
def print(self):
print('==== Config %s ====' % self.name)
print('Net type: ', self.NET_TYPE)
print('Encode dim: ', self.N_ENCODE_DIM)
print('Full-connected network parameters:', self.FC_PARAMS)
print('Sample parameters', self.SAMPLE_PARAMS)
print('Loss', self.LOSS)
print('==========================')
config = Config()
# Toggles
GRAY = False
ROT_ONLY = False
TRAIN_MODE = True
EVAL_TIME_PERFORMANCE = False
RAY_AS_ITEM = True
# ========
GRAY = True
#ROT_ONLY = True
#TRAIN_MODE = False
#EVAL_TIME_PERFORMANCE = True
#RAY_AS_ITEM = False
# Net parameters
DEPTH_RANGE = (1, 10)
N_DEPTH_LAYERS = 10
N_ENCODE_DIM = 10
FC_PARAMS = {
'nf': 128,
'n_layers': 8,
'skips': [4]
}
# Train
TRAIN_DATA_DESC_FILE = 'train.json'
BATCH_SIZE = 2048 if RAY_AS_ITEM else 4
PATCH_SIZE = 1
BATCH_SIZE = 4096 // (PATCH_SIZE * PATCH_SIZE)
EPOCH_RANGE = range(0, 500)
SAVE_INTERVAL = 20
# Test
TEST_NET_NAME = 'model-epoch_500'
TEST_DATA_DESC_FILE = 'test_fovea.json'
TEST_BATCH_SIZE = 5
TEST_BATCH_SIZE = 1
TEST_CHUNKS = 1
# Paths
DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.28/'
RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb',
BATCH_SIZE,
N_ENCODE_DIM,
FC_PARAMS['nf'],
FC_PARAMS['n_layers'],
'_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '')
RUN_DIR = DATA_DIR + RUN_ID + '/'
OUTPUT_DIR = RUN_DIR + 'output/'
LOG_DIR = RUN_DIR + 'log/'
def train():
# 1. Initialize data loader
print("Load dataset: " + DATA_DIR + TRAIN_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset(DATA_DIR + TRAIN_DATA_DESC_FILE,
gray=GRAY, ray_as_item=RAY_AS_ITEM)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=False)
print('Data loaded. %d iters per epoch.' % len(train_data_loader))
# 2. Initialize components
if ROT_ONLY:
model = SpherNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
out_res=train_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = nn.MSELoss()
if EPOCH_RANGE.start > 0:
netio.LoadNet('%smodel-epoch_%d.pth' % (RUN_DIR, EPOCH_RANGE.start),
model, solver=optimizer)
# 3. Train
model.train()
epoch = None
iters = EPOCH_RANGE.start * len(train_data_loader)
data_desc_path = opt.dataset
data_desc_name = os.path.split(data_desc_path)[1]
if opt.test:
test_net_path = opt.test
test_net_name = os.path.splitext(os.path.basename(test_net_path))[0]
run_dir = os.path.dirname(test_net_path) + '/'
run_id = os.path.basename(run_dir[:-1])
config_name = run_id.split('_b')[0]
output_dir = run_dir + 'output/%s/%s/' % (test_net_name, data_desc_name)
config.load_by_name(config_name)
train_mode = False
if opt.test_samples:
config.SAMPLE_PARAMS['n_samples'] = opt.test_samples
output_dir = run_dir + 'output/%s/%s_s%d/' % \
(test_net_name, data_desc_name, opt.test_samples)
else:
if opt.config:
config.load(opt.config)
data_dir = os.path.dirname(data_desc_path) + '/'
run_id = '%s_b%d[%d]' % (config.name, BATCH_SIZE, PATCH_SIZE)
run_dir = data_dir + run_id + '/'
log_dir = run_dir + 'log/'
output_dir = None
train_mode = True
config.print()
print("dataset: ", data_desc_path)
print("train_mode: ", train_mode)
print("run_dir: ", run_dir)
if not train_mode:
print("output_dir", output_dir)
config.SAMPLE_PARAMS['perturb_sample'] = \
config.SAMPLE_PARAMS['perturb_sample'] and train_mode
NETS = {
'msl': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY,
encode_to_dim=config.N_ENCODE_DIM),
'nerf': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': False}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY,
encode_to_dim=config.N_ENCODE_DIM),
'spher': lambda: SpherNet(
fc_params=config.FC_PARAMS,
gray=config.GRAY,
translation=not ROT_ONLY,
encode_to_dim=config.N_ENCODE_DIM)
}
util.CreateDirIfNeed(RUN_DIR)
util.CreateDirIfNeed(LOG_DIR)
LOSSES = {
'mse': lambda: nn.MSELoss(),
'mse_grad': lambda: loss.CombinedLoss(
[nn.MSELoss(), loss.GradLoss()], [1.0, 0.5])
}
perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
perf_epoch = SimplePerf(True, start=True)
writer = SummaryWriter(LOG_DIR)
# Initialize model
model = NETS[config.NET_TYPE]().to(device.GetDevice())
print("Begin training...")
for epoch in EPOCH_RANGE:
for _, gt, ray_positions, ray_directions in train_data_loader:
def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
sub_iters = 0
iters_in_epoch = len(data_loader)
for _, gt, rays_o, rays_d in data_loader:
gt = gt.to(device.GetDevice())
ray_positions = ray_positions.to(device.GetDevice())
ray_directions = ray_directions.to(device.GetDevice())
rays_o = rays_o.to(device.GetDevice())
rays_d = rays_d.to(device.GetDevice())
perf.Checkpoint("Load")
out = model(ray_positions, ray_directions)
out = model(rays_o, rays_d)
perf.Checkpoint("Forward")
optimizer.zero_grad()
loss_value = loss(out, gt)
perf.Checkpoint("Compute loss")
loss_value.backward()
perf.Checkpoint("Backward")
optimizer.step()
perf.Checkpoint("Update")
print("Epoch: ", epoch, ", Iter: ", iters,
", Loss: ", loss_value.item())
print("Epoch: %d, Iter: %d(%d/%d), Loss: %f" %
(epoch, iters, sub_iters, iters_in_epoch, loss_value.item()))
# Write tensorboard logs.
writer.add_scalar("loss", loss_value, iters)
if not RAY_AS_ITEM and iters % 100 == 0:
output_vs_gt = torch.cat([out, gt], dim=0)
if len(gt.size()) == 4 and iters % 100 == 0:
output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt, scale_each=True, normalize=False)
.cpu().detach().numpy(), iters)
output_vs_gt, nrow=4).cpu().numpy(), iters)
iters += 1
sub_iters += 1
return iters
def train():
# 1. Initialize data loader
print("Load dataset: " + data_desc_path)
train_dataset = FastSphericalViewSynDataset(data_desc_path,
gray=config.GRAY)
train_dataset.set_patch_size((PATCH_SIZE, PATCH_SIZE))
train_data_loader = FastDataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=False,
pin_memory=True)
# 2. Initialize components
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = LOSSES[config.LOSS]().to(device.GetDevice())
if EPOCH_RANGE.start > 0:
iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
model, solver=optimizer)
else:
iters = 0
epoch = None
# 3. Train
model.train()
util.CreateDirIfNeed(run_dir)
util.CreateDirIfNeed(log_dir)
perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
perf_epoch = SimplePerf(True, start=True)
writer = SummaryWriter(log_dir)
print("Begin training...")
for epoch in EPOCH_RANGE:
perf_epoch.Checkpoint("Epoch")
iters = train_loop(train_data_loader, optimizer, loss,
perf, writer, epoch, iters)
# Save checkpoint
if ((epoch + 1) % SAVE_INTERVAL == 0):
netio.SaveNet('%smodel-epoch_%d.pth' % (RUN_DIR, epoch + 1), model,
solver=optimizer)
netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
solver=optimizer, iters=iters)
print("Train finished")
def test(net_file: str):
def test():
torch.autograd.set_grad_enabled(False)
# 1. Load train dataset
print("Load dataset: " + DATA_DIR + TEST_DATA_DESC_FILE)
test_dataset = SphericalViewSynDataset(DATA_DIR + TEST_DATA_DESC_FILE,
load_images=True, gray=GRAY)
print("Load dataset: " + data_desc_path)
test_dataset = SphericalViewSynDataset(data_desc_path,
load_images=opt.output_gt,
gray=config.GRAY)
test_data_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=TEST_BATCH_SIZE,
pin_memory=True,
batch_size=1,
shuffle=False,
drop_last=False)
drop_last=False,
pin_memory=True)
# 2. Load trained model
if ROT_ONLY:
model = SpherNet(cam_params=test_dataset.cam_params,
fc_params=FC_PARAMS,
out_res=test_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=test_dataset.cam_params,
sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=test_dataset.view_res,
gray=GRAY).to(device.GetDevice())
netio.LoadNet(net_file, model)
netio.LoadNet(test_net_path, model)
# 3. Test on train dataset
print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE)
output_dir = '%s%s/%s/' % (OUTPUT_DIR, TEST_NET_NAME, TEST_DATA_DESC_FILE)
util.CreateDirIfNeed(output_dir)
perf = SimplePerf(True, start=True)
i = 0
for view_idxs, view_images, ray_positions, ray_directions in test_data_loader:
ray_positions = ray_positions.to(device.GetDevice())
ray_directions = ray_directions.to(device.GetDevice())
n = test_dataset.view_rots.size(0)
chns = 1 if config.GRAY else 3
out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
test_dataset.view_res[1], device=device.GetDevice())
print(out_view_images.size())
for view_idxs, _, rays_o, rays_d in test_data_loader:
perf.Checkpoint("%d - Load" % i)
out_view_images = model(ray_positions, ray_directions)
rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
n_rays = rays_o.size(0)
chunk_size = n_rays // TEST_CHUNKS
out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
for offset in range(0, n_rays, chunk_size):
rays_o_ = rays_o[offset:offset + chunk_size]
rays_d_ = rays_d[offset:offset + chunk_size]
out_pixels[offset:offset + chunk_size] = \
model(rays_o_, rays_d_)
out_view_images[view_idxs] = out_pixels.view(
TEST_BATCH_SIZE, test_dataset.view_res[0],
test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
perf.Checkpoint("%d - Infer" % i)
i += 1
if opt.output_video:
util.generate_video(out_view_images, output_dir +
'out.mp4', 24, 3, True)
else:
gt_paths = ['%sgt_view_%04d.png' % (output_dir, i) for i in range(n)]
out_paths = ['%sout_view_%04d.png' % (output_dir, i) for i in range(n)]
if test_dataset.load_images:
if opt.output_alongside:
util.WriteImageTensor(
view_images,
['%sgt_view_%04d.png' % (output_dir, i) for i in view_idxs])
util.WriteImageTensor(
out_view_images,
['%sout_view_%04d.png' % (output_dir, i) for i in view_idxs])
perf.Checkpoint("%d - Write" % i)
i += 1
torch.cat([test_dataset.view_images,
out_view_images.cpu()], 3),
out_paths)
else:
util.WriteImageTensor(out_view_images, out_paths)
util.WriteImageTensor(test_dataset.view_images, gt_paths)
else:
util.WriteImageTensor(out_view_images, out_paths)
if __name__ == "__main__":
if TRAIN_MODE:
if train_mode:
train()
else:
test(RUN_DIR + TEST_NET_NAME + '.pth')
test()
from typing import Tuple
import torch
import torch.nn as nn
from .my import net_modules
......@@ -7,45 +6,42 @@ from .my import util
class SpherNet(nn.Module):
def __init__(self, cam_params, # spher_min: Tuple[float, float], spher_max: Tuple[float, float],
fc_params,
out_res: Tuple[int, int] = None,
def __init__(self, fc_params,
gray: bool = False,
translation: bool = False,
encode_to_dim: int = 0):
"""
Initialize a sphere net
:param cam_params: intrinsic parameters of camera
:param fc_params: parameters of full-connection network
:param out_res: resolution of output view image
:param gray: is grayscale mode
:param fc_params: parameters for full-connection network
:param gray: whether grayscale mode
:param translation: whether support translation of view
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.cam_params = cam_params
self.in_chns = 2
self.out_res = out_res
#self.spher_min = torch.tensor(spher_min, device=device.GetDevice()).view(1, 2)
#self.spher_max = torch.tensor(spher_max, device=device.GetDevice()).view(1, 2)
#self.spher_range = self.spher_max - self.spher_min
self.input_encoder = net_modules.InputEncoder.Get(encode_to_dim, self.in_chns)
self.in_chns = 5 if translation else 2
self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 1 if gray else 3
self.net = net_modules.FcNet(**fc_params)
def forward(self, _, ray_directions: torch.Tensor) -> torch.Tensor:
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
"""
rays -> colors
:param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions
:return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels
:param rays_o ```Tensor(B, ..., 3)```: rays' origin
:param rays_d ```Tensor(B, ..., 3)```: rays' direction
:return: Tensor(B, 1|3, ...), inferred images/pixels
"""
v = ray_directions.view(-1, 3) # (*, 3)
spher = util.CartesianToSpherical(v)[..., 1:3] # (*, 2)
# (spher - self.spher_min) / self.spher_range * 2 - 0.5
spher_normed = spher
c: torch.Tensor = self.net(self.input_encoder(spher_normed))
# Unflatten to (B, 1|3, H, W) if take view as item
return c.view(ray_directions.size(0), self.out_res[0], self.out_res[1],
-1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c
p = rays_o.view(-1, 3)
v = rays_d.view(-1, 3)
spher = util.CartesianToSpherical(v)[..., 1:3] # (..., 2)
input = torch.cat([p, spher], dim=-1) if self.in_chns == 5 else spher
c: torch.Tensor = self.net(self.input_encoder(input))
# Unflatten according to input shape
out_shape = list(rays_d.size())
out_shape[-1] = -1
return c.view(out_shape).movedim(-1, 1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment