Commit b7fae973 authored by BobYeah's avatar BobYeah
Browse files

implemented basic msl net

parent c570c3b1
[Net Parameters]
aaa = 1
bbb = 'abc'
ccc = (2,3)
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 32,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 12,
'skips': []
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
config.LOSS = 'mse_grad'
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 128,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 8,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 12,
'skips': []
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = False
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 12,
'skips': []
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 32,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 20),
'n_samples': 16,
'perturb_sample': True
}
\ No newline at end of file
def update_config(config):
# Dataset settings
config.GRAY = True
# Net parameters
config.NET_TYPE = 'msl'
config.N_ENCODE_DIM = 10
config.FC_PARAMS = {
'nf': 64,
'n_layers': 8,
'skips': [4]
}
config.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 4,
'perturb_sample': True
}
\ No newline at end of file
import math
from typing import Tuple
import torch import torch
import torchvision.transforms.functional as trans_f import torchvision.transforms.functional as trans_f
import json import json
from ..my import util from ..my import util
from ..my import device
def _convert_camera_params(input_camera_params, view_res):
"""
Check and convert camera parameters in config file to pixel-space
:param cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] },
the parameters of camera
:return: camera parameters
"""
input_is_normalized = bool(input_camera_params.get('normalized'))
camera_params = {}
if 'fov' in input_camera_params:
camera_params['fx'] = camera_params['fy'] = \
(1 if input_is_normalized else view_res[0]) / \
util.Fov2Length(input_camera_params['fov'])
camera_params['fy'] *= -1
else:
camera_params['fx'] = input_camera_params['fx']
camera_params['fy'] = input_camera_params['fy']
camera_params['cx'] = input_camera_params['cx']
camera_params['cy'] = input_camera_params['cy']
if input_is_normalized:
camera_params['fx'] *= view_res[1]
camera_params['fy'] *= view_res[0]
camera_params['cx'] *= view_res[1]
camera_params['cy'] *= view_res[0]
return camera_params
class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
...@@ -27,7 +58,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): ...@@ -27,7 +58,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
- view_file_pattern: string, the path pattern of view images - view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images - view_res: { "x", "y" }, the resolution of view images
- cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space) - cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] }, the parameters of camera
- view_centers: [ [ x, y, z ], ... ], centers of views - view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
...@@ -50,7 +81,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): ...@@ -50,7 +81,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
data_desc['view_file_pattern'] data_desc['view_file_pattern']
self.view_res = (data_desc['view_res']['y'], self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x']) data_desc['view_res']['x'])
self.cam_params = data_desc['cam_params'] self.cam_params = _convert_camera_params(
data_desc['cam_params'], self.view_res)
self.view_centers = torch.tensor(data_desc['view_centers']) # (N, 3) self.view_centers = torch.tensor(data_desc['view_centers']) # (N, 3)
self.view_rots = torch.tensor(data_desc['view_rots']) \ self.view_rots = torch.tensor(data_desc['view_rots']) \
.view(-1, 3, 3) # (N, 3, 3) .view(-1, 3, 3) # (N, 3, 3)
...@@ -70,24 +102,181 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): ...@@ -70,24 +102,181 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Transpose matrix so we can perform vec x mat # Transpose matrix so we can perform vec x mat
view_rots_t = self.view_rots.permute(0, 2, 1) view_rots_t = self.view_rots.permute(0, 2, 1)
# ray_positions & ray_directions are both (N, M, 3) # rays_o & rays_d are both (N, M, 3)
self.ray_positions = self.view_centers.unsqueeze(1) \ self.rays_o = self.view_centers.unsqueeze(1) \
.expand(-1, local_view_rays.size(0), -1) .expand(-1, local_view_rays.size(0), -1)
self.ray_directions = torch.matmul(local_view_rays, view_rots_t) self.rays_d = torch.matmul(local_view_rays, view_rots_t)
# Flatten rays if ray_as_item = True # Flatten rays if ray_as_item = True
if ray_as_item: if ray_as_item:
self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten( self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten(
0, 2) if self.view_images != None else None 0, 2) if self.view_images != None else None
self.ray_positions = self.ray_positions.flatten(0, 1) self.rays_o = self.rays_o.flatten(0, 1)
self.ray_directions = self.ray_directions.flatten(0, 1) self.rays_d = self.rays_d.flatten(0, 1)
def __len__(self): def __len__(self):
return self.ray_positions.size(0) return self.rays_o.size(0)
def __getitem__(self, idx): def __getitem__(self, idx):
if self.load_images: if self.load_images:
if self.ray_as_item: if self.ray_as_item:
return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx] return idx, self.view_pixels[idx], self.rays_o[idx], self.rays_d[idx]
return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx] return idx, self.view_images[idx], self.rays_o[idx], self.rays_d[idx]
return idx, False, self.ray_positions[idx], self.ray_directions[idx] return idx, False, self.rays_o[idx], self.rays_d[idx]
class FastSphericalViewSynDataset(torch.utils.data.dataset.Dataset):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam_params ```object```: camera intrinsic parameters\n
view_centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
view_images ```Tensor(N, 3, H, W)```: images of views\n
"""
def __init__(self, dataset_desc_path: str, load_images: bool = True, gray: bool = False):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images
- cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
- view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param gray ```bool```: whether convert view images to grayscale
"""
super().__init__()
self.data_dir = dataset_desc_path.rsplit('/', 1)[0] + '/'
self.load_images = load_images
# Load dataset description file
with open(dataset_desc_path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
if data_desc['view_file_pattern'] == '':
self.load_images = False
else:
self.view_file_pattern: str = self.data_dir + \
data_desc['view_file_pattern']
self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x'])
self.cam_params = _convert_camera_params(
data_desc['cam_params'], self.view_res)
self.view_centers = torch.tensor(
data_desc['view_centers'], device=device.GetDevice()) # (N, 3)
self.view_rots = torch.tensor(
data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3) # (N, 3, 3)
self.n_views = self.view_centers.size(0)
self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
# Load view images
if self.load_images:
self.view_images = util.ReadImageTensor(
[self.view_file_pattern % i
for i in range(self.view_centers.size(0))]
).to(device.GetDevice())
if gray:
self.view_images = trans_f.rgb_to_grayscale(self.view_images)
else:
self.view_images = None
local_view_rays = util.GetLocalViewRays(self.cam_params, self.view_res, True) \
.to(device.GetDevice()) # (HW, 3)
# Transpose matrix so we can perform vec x mat
view_rots_t = self.view_rots.permute(0, 2, 1)
# rays_o & rays_d are both (N, H, W, 3)
self.rays_o = self.view_centers[:, None, None, :] \
.expand(-1, self.view_res[0], self.view_res[1], -1)
self.rays_d = torch.matmul(local_view_rays, view_rots_t) \
.view_as(self.rays_o)
self.patched_images = self.view_images # (N, 1|3, H, W)
self.patched_rays_o = self.rays_o # (N, H, W, 3)
self.patched_rays_d = self.rays_d # (N, H, W, 3)
def set_patch_size(self, patch_size: Tuple[int, int], offset: Tuple[int, int] = (0, 0)):
"""
Set the size of patch and (optional) offset. If patch_size = (1, 1)
:param patch_size:
:param offset:
"""
patches = ((self.view_res[0] - offset[0]) // patch_size[0],
(self.view_res[1] - offset[1]) // patch_size[1])
slices = (..., slice(offset[0], offset[0] + patches[0] * patch_size[0]),
slice(offset[1], offset[1] + patches[1] * patch_size[1]))
if patch_size[0] == 1 and patch_size[1] == 1:
self.patched_images = self.view_images[slices] \
.permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None
self.patched_rays_o = self.rays_o[slices].flatten(0, 2)
self.patched_rays_d = self.rays_d[slices].flatten(0, 2)
elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]:
self.patched_images = self.view_images
self.patched_rays_o = self.rays_o
self.patched_rays_d = self.rays_d
else:
print(self.view_images.size(), self.rays_o.size())
print(self.view_images[slices].size(), self.rays_o[slices].size())
self.patched_images = self.view_images[slices] \
.view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \
.permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None
self.patched_rays_o = self.rays_o[slices] \
.view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
.permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
self.patched_rays_d = self.rays_d[slices] \
.view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
.permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
def __len__(self):
return self.patched_rays_o.size(0)
def __getitem__(self, idx):
if self.load_images:
return idx, self.patched_images[idx], self.patched_rays_o[idx], \
self.patched_rays_d[idx]
return idx, False, self.patched_rays_o[idx], self.patched_rays_d[idx]
class FastDataLoader(object):
class Iter(object):
def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
super().__init__()
self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \
if shuffle else torch.arange(len(dataset), device=device.GetDevice())
self.offset = 0
self.batch_size = batch_size
self.dataset = dataset
self.drop_last = drop_last
def __next__(self):
if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset):
raise StopIteration()
indices = self.indices[self.offset:self.offset + self.batch_size]
self.offset += self.batch_size
return self.dataset[indices]
def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None:
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
def __iter__(self):
return FastDataLoader.Iter(self.dataset, self.batch_size,
self.shuffle, self.drop_last)
def __len__(self):
return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \
else math.ceil(len(self.dataset) / self.batch_size)
import torch
from .ssim import *
from .perc_loss import *
device=torch.device("cuda:2")
l1loss = torch.nn.L1Loss()
perc_loss = VGGPerceptualLoss().to(device)
##### LOSS #####
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(
labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated, gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def loss_without_perc(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(
labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
##### LOSS #####
class ReconstructionLoss(torch.nn.Module):
def __init__(self):
super(ReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(
labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
return total_loss
class PerceptionReconstructionLoss(torch.nn.Module):
def __init__(self):
super(PerceptionReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x = torch.nn.functional.mse_loss(labels_dx, preds_dx)
rmse_grad_y = torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x = torch.log10(rmse_grad_x)
psnr_grad_y = torch.log10(rmse_grad_y)
p_loss = perc_loss(generated, gt)
total_loss = psnr_intensity + 0.5 * (psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
from typing import List, Tuple from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from .my import net_modules from .my import net_modules
from .my import util from .my import util
from .my import device from .my import device
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
""" """
Calculate intersections of each rays and each spheres Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays :param p ```Tensor(B, 3)```: positions of rays
:param v: B x 3, directions of rays :param v ```Tensor(B, 3)```: directions of rays
:param r: B'(1D), radius of spheres :param r ```Tensor(N)```: , radius of spheres
:return: B x B' x 3, points of intersection :return ```Tensor(B, N, 3)```: points of intersection
:return ```Tensor(B, N)```: depths of intersection along ray
""" """
# p, v: Expand to B x 1 x 3 # p, v: Expand to (B, 1, 3)
p = p.unsqueeze(1) p = p.unsqueeze(1)
v = v.unsqueeze(1) v = v.unsqueeze(1)
# pp, vv, pv: B x 1 # pp, vv, pv: (B, 1)
pp = (p * p).sum(dim=2) pp = (p * p).sum(dim=2)
vv = (v * v).sum(dim=2) vv = (v * v).sum(dim=2)
pv = (p * v).sum(dim=2) pv = (p * v).sum(dim=2)
# k: Expand to B x B' x 1 depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2) return p + depths[..., None] * v, depths
return p + k * v
def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres = RaySphereIntersect(p, v, r)
return util.CartesianToSpherical(p_on_spheres)
class Rendering(nn.Module): class Rendering(nn.Module):
def __init__(self): def __init__(self, *, raw_noise_std: float = 0.0, white_bg: bool = False):
""" """
Initialize a Rendering module Initialize a Rendering module
""" """
super().__init__() super().__init__()
self.raw_noise_std = raw_noise_std
self.white_bg = white_bg
def forward(self, raw, z_vals, ret_extra: bool = False):
"""Transforms model's predictions to semantically meaningful values.
Args:
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
"""
# Function for computing density from model prediction. This value is
# strictly between [0, 1].
def raw2alpha(raw, dists, act_fn=torch.relu):
return 1.0 - torch.exp(-act_fn(raw) * dists)
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N_samples)
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = util.broadcast_cat(dists, 1e10)
# Extract RGB of each sample position along each ray.
color = torch.sigmoid(raw[..., :-1]) # (N_rays, N_samples, 1|3)
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
noise = 0.
if self.raw_noise_std > 0.:
noise = torch.normal(0.0, self.raw_noise_std,
raw[..., 3].size(), rand_gen)
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point.
alpha = raw2alpha(raw[..., -1] + noise, dists) # (N_rays, N_samples)
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
one_minus_alpha = util.broadcast_cat(
torch.cumprod(1 - alpha[..., :-1] + 1e-10, dim=-1),
1.0, append=False)
weights = alpha * one_minus_alpha # (N_rays, N_samples)
# (N_rays, 1|3), computed weighted color of each sample along each ray.
color_map = torch.sum(weights[..., None] * color, dim=-2)
# To composite onto a white background, use the accumulated alpha map.
if self.white_bg or ret_extra:
# Sum of weights along each ray. This value is in [0, 1] up to numerical error.
acc_map = torch.sum(weights, -1)
if self.white_bg:
color_map = color_map + (1. - acc_map[..., None])
else:
acc_map = None
if not ret_extra:
return color_map
# Estimated depth map is expected distance.
depth_map = torch.sum(weights * z_vals, dim=-1)
# Disparity map is inverse depth.
disp_map = 1. / torch.max(1e-10, depth_map /
torch.sum(weights, dim=-1))
return color_map, disp_map, acc_map, weights, depth_map
class Sampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int,
perturb_sample: bool, spherical: bool):
"""
Initialize a Sampler module
def forward(self, color_alpha: torch.Tensor) -> torch.Tensor: :param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
"""
super().__init__()
self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1],
n_samples, device=device.GetDevice())
self.perturb_sample = perturb_sample
self.spherical = spherical
if perturb_sample:
mids = .5 * (self.r[1:] + self.r[:-1])
self.upper = torch.cat([mids, self.r[-1:]], -1)
self.lower = torch.cat([self.r[:1], mids], -1)
def forward(self, rays_o, rays_d):
""" """
Blend layers to get final color Sample points along rays. return Spherical or Cartesian coordinates,
specified by ```self.shperical```
:param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel :param rays_o ```Tensor(B, 3)```: rays' origin
:return ```Tensor(B, C-1)``` blended pixels :param rays_d ```Tensor(B, 3)```: rays' direction
:return ```Tensor(B, N, 3)```: sampled points
:return ```Tensor(B, N)```: corresponding depths along rays
""" """
c = color_alpha[..., :-1] if self.perturb_sample:
a = color_alpha[..., -1:] # stratified samples in those intervals
blended = c[:, 0, :] * a[:, 0, :] t_rand = torch.rand(self.r.size(),
for l in range(1, color_alpha.size(1)): generator=rand_gen,
blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :] device=device.GetDevice())
return blended r = self.lower + (self.upper - self.lower) * t_rand
else:
r = self.r
if self.spherical:
pts, depths = RaySphereIntersect(rays_o, rays_d, r)
sphers = util.CartesianToSpherical(pts)
sphers[..., 0] = 1 / sphers[..., 0]
return sphers, depths
else:
return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
class MslNet(nn.Module): class MslNet(nn.Module):
def __init__(self, cam_params, fc_params, sphere_layers: List[float], def __init__(self, fc_params, sampler_params,
out_res: Tuple[int, int], gray=False, encode_to_dim: int = 0): gray=False,
encode_to_dim: int = 0):
""" """
Initialize a multi-sphere-layer net Initialize a multi-sphere-layer net
:param cam_params: intrinsic parameters of camera :param fc_params: parameters for full-connection network
:param fc_params: parameters of full-connection network :param sampler_params: parameters for sampler
:param sphere_layers: list(L), radius of sphere layers
:param out_res: resolution of output view image
:param gray: is grayscale mode :param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions :param encode_to_dim: encode input to number of dimensions
""" """
super().__init__() super().__init__()
self.cam_params = cam_params
self.sphere_layers = torch.tensor(sphere_layers,
dtype=torch.float,
device=device.GetDevice())
self.in_chns = 3 self.in_chns = 3
self.out_res = out_res
self.input_encoder = net_modules.InputEncoder.Get( self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns) encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4 fc_params['out_chns'] = 2 if gray else 4
self.sampler = Sampler(**sampler_params)
self.net = net_modules.FcNet(**fc_params) self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering() self.rendering = Rendering()
def forward(self, ray_positions: torch.Tensor, ray_directions: torch.Tensor) -> torch.Tensor: def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
""" """
rays -> colors rays -> colors
:param ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray positions :param rays_o ```Tensor(B, ..., 3)```: rays' origin
:param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions :param rays_d ```Tensor(B, ..., 3)```: rays' direction
:return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels :return: Tensor(B, 1|3, ...), inferred images/pixels
""" """
p = ray_positions.view(-1, 3) p = rays_o.view(-1, 3)
v = ray_directions.view(-1, 3) v = rays_d.view(-1, 3)
spher = RayToSpherical(p, v, self.sphere_layers).flatten(0, 1) coords, depths = self.sampler(p, v)
color_alpha = self.net(self.input_encoder(spher)).view( encoded = self.input_encoder(coords)
p.size(0), self.sphere_layers.size(0), -1) color_map = self.rendering(self.net(encoded), depths)
c: torch.Tensor = self.rendering(color_alpha)
# unflatten # Unflatten according to input shape
return c.view(ray_directions.size(0), self.out_res[0], out_shape = list(rays_d.size())
self.out_res[1], -1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c out_shape[-1] = -1
return color_map.view(out_shape).movedim(-1, 1)
import torch
from typing import List
from torch import nn
class CombinedLoss(nn.Module):
def __init__(self, loss_modules: List[nn.Module], weights: List[float]):
super().__init__()
self.loss_modules = nn.ModuleList(loss_modules)
self.weights = weights
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return sum([self.weights[i] * self.loss_modules[i](input, target)
for i in range(len(self.loss_modules))])
class GradLoss(nn.Module):
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
input_dy, input_dx = self._cal_grad(input)
target_dy, target_dx = self._cal_grad(target)
return self.mse_loss(
torch.cat([
input_dy.flatten(1, -1),
input_dx.flatten(1, -1)
], 1),
torch.cat([
target_dy.flatten(1, -1),
target_dx.flatten(1, -1)
], 1))
def _cal_grad(self, images):
"""
Calculate gradient of images
:param image ```Tensor(..., C, H, W)```: input images
:return ```Tensor(..., 2C, H-2, W-2)```: gradient map of input images
"""
dy = images[..., 2:, :] - images[..., :-2, :]
dx = images[..., :, 2:] - images[..., :, :-2]
return dy, dx
from typing import Mapping
import torch import torch
import numpy as np import numpy as np
...@@ -9,15 +10,16 @@ def PrintNet(net): ...@@ -9,15 +10,16 @@ def PrintNet(net):
def LoadNet(path, model, solver=None, discriminator=None): def LoadNet(path, model, solver=None, discriminator=None):
print('Load net from %s ...' % path) print('Load net from %s ...' % path)
whole_dict = torch.load(path) whole_dict: Mapping = torch.load(path)
model.load_state_dict(whole_dict['model']) model.load_state_dict(whole_dict['model'])
if solver: if solver:
solver.load_state_dict(whole_dict['solver']) solver.load_state_dict(whole_dict['solver'])
if discriminator: if discriminator:
discriminator.load_state_dict(whole_dict['discriminator']) discriminator.load_state_dict(whole_dict['discriminator'])
return whole_dict['iters'] if 'iters' in whole_dict else 0
def SaveNet(path, model, solver=None, discriminator=None, iters=None):
def SaveNet(path, model, solver=None, discriminator=None):
print('Saving net to %s ...' % path) print('Saving net to %s ...' % path)
whole_dict = { whole_dict = {
'model': model.state_dict() 'model': model.state_dict()
...@@ -26,4 +28,6 @@ def SaveNet(path, model, solver=None, discriminator=None): ...@@ -26,4 +28,6 @@ def SaveNet(path, model, solver=None, discriminator=None):
whole_dict.update({'solver': solver.state_dict()}) whole_dict.update({'solver': solver.state_dict()})
if discriminator: if discriminator:
whole_dict.update({'discriminator': discriminator.state_dict()}) whole_dict.update({'discriminator': discriminator.state_dict()})
if iters:
whole_dict.update({'iters': iters})
torch.save(whole_dict, path) torch.save(whole_dict, path)
\ No newline at end of file
from typing import List, Tuple from typing import List, Tuple, Union
from math import pi import os
import numpy as np import math
import torch import torch
import torchvision import torchvision
import matplotlib.pyplot as plt import torchvision.transforms.functional as trans_func
import glm import glm
import os import numpy as np
import matplotlib.pyplot as plt
from torch.types import Number
from torchvision.utils import save_image from torchvision.utils import save_image
...@@ -16,7 +18,7 @@ gmat_type = [[glm.dmat2, glm.dmat2x3, glm.dmat2x4], ...@@ -16,7 +18,7 @@ gmat_type = [[glm.dmat2, glm.dmat2x3, glm.dmat2x4],
def Fov2Length(angle): def Fov2Length(angle):
return np.tan(angle * np.pi / 360) * 2 return math.tan(math.radians(angle) / 2) * 2
def SmoothStep(x0, x1, x): def SmoothStep(x0, x1, x):
...@@ -153,14 +155,13 @@ def CreateDirIfNeed(path): ...@@ -153,14 +155,13 @@ def CreateDirIfNeed(path):
os.makedirs(path) os.makedirs(path)
def GetLocalViewRays(cam_params, res: Tuple[int, int], flatten=False) -> torch.Tensor: def GetLocalViewRays(cam_params, res: Tuple[int, int], flatten=False, norm=True) -> torch.Tensor:
coords = MeshGrid(res) coords = MeshGrid(res)
c = torch.tensor([cam_params['cx'], cam_params['cy']]) c = torch.tensor([cam_params['cx'], cam_params['cy']])
f = torch.tensor([cam_params['fx'], cam_params['fy']]) f = torch.tensor([cam_params['fx'], cam_params['fy']])
rays = torch.cat([ rays = broadcast_cat((coords - c) / f, 1.0)
(coords - c) / f, if norm:
torch.ones(res[0], res[1], 1, ) rays = rays / rays.norm(dim=-1, keepdim=True)
], dim=2)
if flatten: if flatten:
rays = rays.flatten(0, 1) rays = rays.flatten(0, 1)
return rays return rays
...@@ -175,7 +176,7 @@ def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor: ...@@ -175,7 +176,7 @@ def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
""" """
rho = torch.norm(cart, p=2, dim=-1) rho = torch.norm(cart, p=2, dim=-1)
theta = torch.atan2(cart[..., 2], cart[..., 0]) theta = torch.atan2(cart[..., 2], cart[..., 0])
theta = theta + (theta < 0).type_as(theta) * (2 * pi) theta = theta + (theta < 0).type_as(theta) * (2 * math.pi)
phi = torch.acos(cart[..., 1] / rho) phi = torch.acos(cart[..., 1] / rho)
return torch.stack([rho, theta, phi], dim=-1) return torch.stack([rho, theta, phi], dim=-1)
...@@ -207,5 +208,78 @@ def GetDepthLayers(depth_range: Tuple[float, float], n_layers: int) -> List[floa ...@@ -207,5 +208,78 @@ def GetDepthLayers(depth_range: Tuple[float, float], n_layers: int) -> List[floa
""" """
diopter_range = (1 / depth_range[1], 1 / depth_range[0]) diopter_range = (1 / depth_range[1], 1 / depth_range[0])
depths = [1e5] # Background layer depths = [1e5] # Background layer
depths += list(1.0 / np.linspace(diopter_range[0], diopter_range[1], n_layers)) depths += list(1.0 /
np.linspace(diopter_range[0], diopter_range[1], n_layers))
return depths return depths
def GetRotMatrix(theta: Union[float, torch.Tensor], phi: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Get rotation matrix from angles in spherical space
:param theta ```Tensor(..., 1) | float```: rotation angles around y axis
:param phi ```Tensor(..., 1) | float```: rotation angles around x axis
:return: ```Tensor(..., 3, 3)``` rotation matrices
"""
if not isinstance(theta, torch.Tensor):
theta = torch.tensor([theta])
if not isinstance(phi, torch.Tensor):
phi = torch.tensor([phi])
spher = torch.cat([torch.ones_like(theta), theta, phi], dim=-1)
print(spher)
forward = SphericalToCartesian(spher) # (..., 3)
up = torch.tensor([0.0, 1.0, 0.0])
forward, up = torch.broadcast_tensors(forward, up)
print(forward, up)
right = torch.cross(forward, up, dim=-1) # (..., 3)
up = torch.cross(right, forward, dim=-1) # (..., 3)
print(right, up, forward)
return torch.stack([right, up, forward], dim=-2) # (..., 3, 3)
def broadcast_cat(input: torch.Tensor,
s: Union[Number, List[Number], torch.Tensor],
dim=-1,
append: bool = True) -> torch.Tensor:
"""
Concatenate a tensor with a scalar along last dimension
:param input ```Tensor(..., N)```: input tensor
:param s: scalar
:param append: append or prepend the scalar to input tensor
:return: ```Tensor(..., N+1)```
"""
if dim != -1:
raise NotImplementedError('currently only support the last dimension')
if isinstance(s, torch.Tensor):
x = s
elif isinstance(s, list):
x = torch.tensor(s, dtype=input.dtype, device=input.device)
else:
x = torch.tensor([s], dtype=input.dtype, device=input.device)
expand_shape = list(input.size())
expand_shape[dim] = -1
x = x.expand(expand_shape)
return torch.cat([input, x] if append else [x, input], dim)
def generate_video(frames: torch.Tensor, path: str, fps: float,
repeat: int = 1, pingpong: bool = False,
video_codec: str = 'libx264'):
"""
Generate video from a sequence of frames after converting type and
permuting channels to meet the requirement of ```torchvision.io.write_video()```
:param frames ```Tensor(B, C, H, W)```: a sequence of frames
:param path: video path
:param fps: frames per second
:param repeat: repeat times
:param pingpong: whether repeat sequence in pinpong form
:param video_codec: video codec
"""
frames = trans_func.convert_image_dtype(frames, torch.uint8)
frames = frames.detach().cpu().permute(0, 2, 3, 1)
if pingpong:
frames = torch.cat([frames, frames.flip(0)], 0)
frames = frames.expand(repeat, -1, -1, -1, 3).flatten(0, 1)
torchvision.io.write_video(path, frames, fps, video_codec)
This source diff could not be displayed because it is too large. You can view the blob instead.
import math
import sys import sys
import os import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
import argparse import argparse
import torch import torch
import torch.optim import torch.optim
import torchvision import torchvision
import importlib
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
from .my import netio
from .my import util
from .my import device
from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import SphericalViewSynDataset
from .msl_net import MslNet
from .spher_net import SpherNet
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=3, parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.') help='Which CUDA device to use.')
parser.add_argument('--config', type=str,
help='Net config files')
parser.add_argument('--dataset', type=str, required=True,
help='Dataset description file')
parser.add_argument('--test', type=str,
help='Test net file')
parser.add_argument('--test-samples', type=int,
help='Samples used for test')
parser.add_argument('--output-gt', action='store_true',
help='Output ground truth images if exist')
parser.add_argument('--output-alongside', action='store_true',
help='Output generated image alongside ground truth image')
parser.add_argument('--output-video', action='store_true',
help='Output test results as video')
opt = parser.parse_args() opt = parser.parse_args()
...@@ -28,201 +36,297 @@ opt = parser.parse_args() ...@@ -28,201 +36,297 @@ opt = parser.parse_args()
torch.cuda.set_device(opt.device) torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device()) print("Set CUDA:%d as current device." % torch.cuda.current_device())
from .my import netio
from .my import util
from .my import device
from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import *
from .msl_net import MslNet
from .spher_net import SpherNet
from .my import loss
class Config(object):
def __init__(self):
self.name = 'default'
self.GRAY = False
# Net parameters
self.NET_TYPE = 'msl'
self.N_ENCODE_DIM = 10
self.FC_PARAMS = {
'nf': 256,
'n_layers': 8,
'skips': [4]
}
self.SAMPLE_PARAMS = {
'depth_range': (1, 50),
'n_samples': 32,
'perturb_sample': True
}
self.LOSS = 'mse'
def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.')
config_module = importlib.import_module(
'deeplightfield.' + module_name)
config_module.update_config(config)
self.name = module_name.split('.')[-1]
def load_by_name(self, name):
config_module = importlib.import_module(
'deeplightfield.configs.' + name)
config_module.update_config(config)
self.name = name
def print(self):
print('==== Config %s ====' % self.name)
print('Net type: ', self.NET_TYPE)
print('Encode dim: ', self.N_ENCODE_DIM)
print('Full-connected network parameters:', self.FC_PARAMS)
print('Sample parameters', self.SAMPLE_PARAMS)
print('Loss', self.LOSS)
print('==========================')
config = Config()
# Toggles # Toggles
GRAY = False
ROT_ONLY = False ROT_ONLY = False
TRAIN_MODE = True
EVAL_TIME_PERFORMANCE = False EVAL_TIME_PERFORMANCE = False
RAY_AS_ITEM = True
# ======== # ========
GRAY = True
#ROT_ONLY = True #ROT_ONLY = True
#TRAIN_MODE = False
#EVAL_TIME_PERFORMANCE = True #EVAL_TIME_PERFORMANCE = True
#RAY_AS_ITEM = False
# Net parameters
DEPTH_RANGE = (1, 10)
N_DEPTH_LAYERS = 10
N_ENCODE_DIM = 10
FC_PARAMS = {
'nf': 128,
'n_layers': 8,
'skips': [4]
}
# Train # Train
TRAIN_DATA_DESC_FILE = 'train.json' PATCH_SIZE = 1
BATCH_SIZE = 2048 if RAY_AS_ITEM else 4 BATCH_SIZE = 4096 // (PATCH_SIZE * PATCH_SIZE)
EPOCH_RANGE = range(0, 500) EPOCH_RANGE = range(0, 500)
SAVE_INTERVAL = 20 SAVE_INTERVAL = 20
# Test # Test
TEST_NET_NAME = 'model-epoch_500' TEST_BATCH_SIZE = 1
TEST_DATA_DESC_FILE = 'test_fovea.json' TEST_CHUNKS = 1
TEST_BATCH_SIZE = 5
# Paths # Paths
DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.28/' data_desc_path = opt.dataset
RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb', data_desc_name = os.path.split(data_desc_path)[1]
BATCH_SIZE, if opt.test:
N_ENCODE_DIM, test_net_path = opt.test
FC_PARAMS['nf'], test_net_name = os.path.splitext(os.path.basename(test_net_path))[0]
FC_PARAMS['n_layers'], run_dir = os.path.dirname(test_net_path) + '/'
'_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '') run_id = os.path.basename(run_dir[:-1])
RUN_DIR = DATA_DIR + RUN_ID + '/' config_name = run_id.split('_b')[0]
OUTPUT_DIR = RUN_DIR + 'output/' output_dir = run_dir + 'output/%s/%s/' % (test_net_name, data_desc_name)
LOG_DIR = RUN_DIR + 'log/' config.load_by_name(config_name)
train_mode = False
if opt.test_samples:
config.SAMPLE_PARAMS['n_samples'] = opt.test_samples
output_dir = run_dir + 'output/%s/%s_s%d/' % \
(test_net_name, data_desc_name, opt.test_samples)
else:
if opt.config:
config.load(opt.config)
data_dir = os.path.dirname(data_desc_path) + '/'
run_id = '%s_b%d[%d]' % (config.name, BATCH_SIZE, PATCH_SIZE)
run_dir = data_dir + run_id + '/'
log_dir = run_dir + 'log/'
output_dir = None
train_mode = True
config.print()
print("dataset: ", data_desc_path)
print("train_mode: ", train_mode)
print("run_dir: ", run_dir)
if not train_mode:
print("output_dir", output_dir)
config.SAMPLE_PARAMS['perturb_sample'] = \
config.SAMPLE_PARAMS['perturb_sample'] and train_mode
NETS = {
'msl': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': True}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY,
encode_to_dim=config.N_ENCODE_DIM),
'nerf': lambda: MslNet(
fc_params=config.FC_PARAMS,
sampler_params=(config.SAMPLE_PARAMS.update(
{'spherical': False}), config.SAMPLE_PARAMS)[1],
gray=config.GRAY,
encode_to_dim=config.N_ENCODE_DIM),
'spher': lambda: SpherNet(
fc_params=config.FC_PARAMS,
gray=config.GRAY,
translation=not ROT_ONLY,
encode_to_dim=config.N_ENCODE_DIM)
}
LOSSES = {
'mse': lambda: nn.MSELoss(),
'mse_grad': lambda: loss.CombinedLoss(
[nn.MSELoss(), loss.GradLoss()], [1.0, 0.5])
}
# Initialize model
model = NETS[config.NET_TYPE]().to(device.GetDevice())
def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
sub_iters = 0
iters_in_epoch = len(data_loader)
for _, gt, rays_o, rays_d in data_loader:
gt = gt.to(device.GetDevice())
rays_o = rays_o.to(device.GetDevice())
rays_d = rays_d.to(device.GetDevice())
perf.Checkpoint("Load")
out = model(rays_o, rays_d)
perf.Checkpoint("Forward")
optimizer.zero_grad()
loss_value = loss(out, gt)
perf.Checkpoint("Compute loss")
loss_value.backward()
perf.Checkpoint("Backward")
optimizer.step()
perf.Checkpoint("Update")
print("Epoch: %d, Iter: %d(%d/%d), Loss: %f" %
(epoch, iters, sub_iters, iters_in_epoch, loss_value.item()))
# Write tensorboard logs.
writer.add_scalar("loss", loss_value, iters)
if len(gt.size()) == 4 and iters % 100 == 0:
output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt, nrow=4).cpu().numpy(), iters)
iters += 1
sub_iters += 1
return iters
def train(): def train():
# 1. Initialize data loader # 1. Initialize data loader
print("Load dataset: " + DATA_DIR + TRAIN_DATA_DESC_FILE) print("Load dataset: " + data_desc_path)
train_dataset = SphericalViewSynDataset(DATA_DIR + TRAIN_DATA_DESC_FILE, train_dataset = FastSphericalViewSynDataset(data_desc_path,
gray=GRAY, ray_as_item=RAY_AS_ITEM) gray=config.GRAY)
train_data_loader = torch.utils.data.DataLoader( train_dataset.set_patch_size((PATCH_SIZE, PATCH_SIZE))
train_data_loader = FastDataLoader(
dataset=train_dataset, dataset=train_dataset,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
pin_memory=True,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False,
print('Data loaded. %d iters per epoch.' % len(train_data_loader)) pin_memory=True)
# 2. Initialize components # 2. Initialize components
if ROT_ONLY:
model = SpherNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
out_res=train_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = nn.MSELoss() loss = LOSSES[config.LOSS]().to(device.GetDevice())
if EPOCH_RANGE.start > 0: if EPOCH_RANGE.start > 0:
netio.LoadNet('%smodel-epoch_%d.pth' % (RUN_DIR, EPOCH_RANGE.start), iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
model, solver=optimizer) model, solver=optimizer)
else:
iters = 0
epoch = None
# 3. Train # 3. Train
model.train() model.train()
epoch = None
iters = EPOCH_RANGE.start * len(train_data_loader)
util.CreateDirIfNeed(RUN_DIR) util.CreateDirIfNeed(run_dir)
util.CreateDirIfNeed(LOG_DIR) util.CreateDirIfNeed(log_dir)
perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True) perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
perf_epoch = SimplePerf(True, start=True) perf_epoch = SimplePerf(True, start=True)
writer = SummaryWriter(LOG_DIR) writer = SummaryWriter(log_dir)
print("Begin training...") print("Begin training...")
for epoch in EPOCH_RANGE: for epoch in EPOCH_RANGE:
for _, gt, ray_positions, ray_directions in train_data_loader:
gt = gt.to(device.GetDevice())
ray_positions = ray_positions.to(device.GetDevice())
ray_directions = ray_directions.to(device.GetDevice())
perf.Checkpoint("Load")
out = model(ray_positions, ray_directions)
perf.Checkpoint("Forward")
optimizer.zero_grad()
loss_value = loss(out, gt)
perf.Checkpoint("Compute loss")
loss_value.backward()
perf.Checkpoint("Backward")
optimizer.step()
perf.Checkpoint("Update")
print("Epoch: ", epoch, ", Iter: ", iters,
", Loss: ", loss_value.item())
# Write tensorboard logs.
writer.add_scalar("loss", loss_value, iters)
if not RAY_AS_ITEM and iters % 100 == 0:
output_vs_gt = torch.cat([out, gt], dim=0)
writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt, scale_each=True, normalize=False)
.cpu().detach().numpy(), iters)
iters += 1
perf_epoch.Checkpoint("Epoch") perf_epoch.Checkpoint("Epoch")
iters = train_loop(train_data_loader, optimizer, loss,
perf, writer, epoch, iters)
# Save checkpoint # Save checkpoint
if ((epoch + 1) % SAVE_INTERVAL == 0): if ((epoch + 1) % SAVE_INTERVAL == 0):
netio.SaveNet('%smodel-epoch_%d.pth' % (RUN_DIR, epoch + 1), model, netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
solver=optimizer) solver=optimizer, iters=iters)
print("Train finished") print("Train finished")
def test(net_file: str): def test():
torch.autograd.set_grad_enabled(False)
# 1. Load train dataset # 1. Load train dataset
print("Load dataset: " + DATA_DIR + TEST_DATA_DESC_FILE) print("Load dataset: " + data_desc_path)
test_dataset = SphericalViewSynDataset(DATA_DIR + TEST_DATA_DESC_FILE, test_dataset = SphericalViewSynDataset(data_desc_path,
load_images=True, gray=GRAY) load_images=opt.output_gt,
gray=config.GRAY)
test_data_loader = torch.utils.data.DataLoader( test_data_loader = torch.utils.data.DataLoader(
dataset=test_dataset, dataset=test_dataset,
batch_size=TEST_BATCH_SIZE, batch_size=1,
pin_memory=True,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False,
pin_memory=True)
# 2. Load trained model # 2. Load trained model
if ROT_ONLY: netio.LoadNet(test_net_path, model)
model = SpherNet(cam_params=test_dataset.cam_params,
fc_params=FC_PARAMS,
out_res=test_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=test_dataset.cam_params,
sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=test_dataset.view_res,
gray=GRAY).to(device.GetDevice())
netio.LoadNet(net_file, model)
# 3. Test on train dataset # 3. Test on train dataset
print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE) print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE)
output_dir = '%s%s/%s/' % (OUTPUT_DIR, TEST_NET_NAME, TEST_DATA_DESC_FILE)
util.CreateDirIfNeed(output_dir) util.CreateDirIfNeed(output_dir)
perf = SimplePerf(True, start=True) perf = SimplePerf(True, start=True)
i = 0 i = 0
for view_idxs, view_images, ray_positions, ray_directions in test_data_loader: n = test_dataset.view_rots.size(0)
ray_positions = ray_positions.to(device.GetDevice()) chns = 1 if config.GRAY else 3
ray_directions = ray_directions.to(device.GetDevice()) out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
test_dataset.view_res[1], device=device.GetDevice())
print(out_view_images.size())
for view_idxs, _, rays_o, rays_d in test_data_loader:
perf.Checkpoint("%d - Load" % i) perf.Checkpoint("%d - Load" % i)
out_view_images = model(ray_positions, ray_directions) rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
n_rays = rays_o.size(0)
chunk_size = n_rays // TEST_CHUNKS
out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
for offset in range(0, n_rays, chunk_size):
rays_o_ = rays_o[offset:offset + chunk_size]
rays_d_ = rays_d[offset:offset + chunk_size]
out_pixels[offset:offset + chunk_size] = \
model(rays_o_, rays_d_)
out_view_images[view_idxs] = out_pixels.view(
TEST_BATCH_SIZE, test_dataset.view_res[0],
test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
perf.Checkpoint("%d - Infer" % i) perf.Checkpoint("%d - Infer" % i)
if test_dataset.load_images:
util.WriteImageTensor(
view_images,
['%sgt_view_%04d.png' % (output_dir, i) for i in view_idxs])
util.WriteImageTensor(
out_view_images,
['%sout_view_%04d.png' % (output_dir, i) for i in view_idxs])
perf.Checkpoint("%d - Write" % i)
i += 1 i += 1
if opt.output_video:
util.generate_video(out_view_images, output_dir +
'out.mp4', 24, 3, True)
else:
gt_paths = ['%sgt_view_%04d.png' % (output_dir, i) for i in range(n)]
out_paths = ['%sout_view_%04d.png' % (output_dir, i) for i in range(n)]
if test_dataset.load_images:
if opt.output_alongside:
util.WriteImageTensor(
torch.cat([test_dataset.view_images,
out_view_images.cpu()], 3),
out_paths)
else:
util.WriteImageTensor(out_view_images, out_paths)
util.WriteImageTensor(test_dataset.view_images, gt_paths)
else:
util.WriteImageTensor(out_view_images, out_paths)
if __name__ == "__main__": if __name__ == "__main__":
if TRAIN_MODE: if train_mode:
train() train()
else: else:
test(RUN_DIR + TEST_NET_NAME + '.pth') test()
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from .my import net_modules from .my import net_modules
...@@ -7,45 +6,42 @@ from .my import util ...@@ -7,45 +6,42 @@ from .my import util
class SpherNet(nn.Module): class SpherNet(nn.Module):
def __init__(self, cam_params, # spher_min: Tuple[float, float], spher_max: Tuple[float, float], def __init__(self, fc_params,
fc_params,
out_res: Tuple[int, int] = None,
gray: bool = False, gray: bool = False,
translation: bool = False,
encode_to_dim: int = 0): encode_to_dim: int = 0):
""" """
Initialize a sphere net Initialize a sphere net
:param cam_params: intrinsic parameters of camera :param fc_params: parameters for full-connection network
:param fc_params: parameters of full-connection network :param gray: whether grayscale mode
:param out_res: resolution of output view image :param translation: whether support translation of view
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions :param encode_to_dim: encode input to number of dimensions
""" """
super().__init__() super().__init__()
self.cam_params = cam_params self.in_chns = 5 if translation else 2
self.in_chns = 2 self.input_encoder = net_modules.InputEncoder.Get(
self.out_res = out_res encode_to_dim, self.in_chns)
#self.spher_min = torch.tensor(spher_min, device=device.GetDevice()).view(1, 2)
#self.spher_max = torch.tensor(spher_max, device=device.GetDevice()).view(1, 2)
#self.spher_range = self.spher_max - self.spher_min
self.input_encoder = net_modules.InputEncoder.Get(encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 1 if gray else 3 fc_params['out_chns'] = 1 if gray else 3
self.net = net_modules.FcNet(**fc_params) self.net = net_modules.FcNet(**fc_params)
def forward(self, _, ray_directions: torch.Tensor) -> torch.Tensor: def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
""" """
rays -> colors rays -> colors
:param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions :param rays_o ```Tensor(B, ..., 3)```: rays' origin
:return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels :param rays_d ```Tensor(B, ..., 3)```: rays' direction
""" :return: Tensor(B, 1|3, ...), inferred images/pixels
v = ray_directions.view(-1, 3) # (*, 3) """
spher = util.CartesianToSpherical(v)[..., 1:3] # (*, 2) p = rays_o.view(-1, 3)
# (spher - self.spher_min) / self.spher_range * 2 - 0.5 v = rays_d.view(-1, 3)
spher_normed = spher spher = util.CartesianToSpherical(v)[..., 1:3] # (..., 2)
input = torch.cat([p, spher], dim=-1) if self.in_chns == 5 else spher
c: torch.Tensor = self.net(self.input_encoder(spher_normed))
# Unflatten to (B, 1|3, H, W) if take view as item c: torch.Tensor = self.net(self.input_encoder(input))
return c.view(ray_directions.size(0), self.out_res[0], self.out_res[1],
-1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c # Unflatten according to input shape
out_shape = list(rays_d.size())
out_shape[-1] = -1
return c.view(out_shape).movedim(-1, 1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment