Commit c570c3b1 authored by BobYeah's avatar BobYeah
Browse files

checkpoint

parent 172b5205
......@@ -2,7 +2,6 @@ import torch
import torchvision.transforms.functional as trans_f
import json
from ..my import util
from ..my import imgio
class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
......@@ -44,8 +43,11 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Load dataset description file
with open(dataset_desc_path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
self.view_file_pattern: str = self.data_dir + \
data_desc['view_file_pattern']
if data_desc['view_file_pattern'] == '':
self.load_images = False
else:
self.view_file_pattern: str = self.data_dir + \
data_desc['view_file_pattern']
self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x'])
self.cam_params = data_desc['cam_params']
......@@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
.view(-1, 3, 3) # (N, 3, 3)
# Load view images
if load_images:
if self.load_images:
self.view_images = util.ReadImageTensor(
[self.view_file_pattern % i for i in range(self.view_centers.size(0))])
if gray:
......@@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Flatten rays if ray_as_item = True
if ray_as_item:
self.view_pixels = self.view_images.permute(
0, 2, 3, 1).flatten(0, 2)
self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten(
0, 2) if self.view_images != None else None
self.ray_positions = self.ray_positions.flatten(0, 1)
self.ray_directions = self.ray_directions.flatten(0, 1)
......@@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
if self.ray_as_item:
return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, self.ray_positions[idx], self.ray_directions[idx]
return idx, False, self.ray_positions[idx], self.ray_directions[idx]
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
import argparse
from PIL import Image
from .my import util
def batch_scale(src, target, size):
util.CreateDirIfNeed(target)
for file_name in os.listdir(src):
postfix = os.path.splitext(file_name)[1]
if postfix == '.jpg' or postfix == '.png':
im = Image.open(os.path.join(src, file_name))
im = im.resize(size)
im.save(os.path.join(target, file_name))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('src', type=str,
help='Source directory.')
parser.add_argument('target', type=str,
help='Target directory.')
parser.add_argument('--width', type=int,
help='Width of output images (pixel)')
parser.add_argument('--height', type=int,
help='Height of output images (pixel)')
opt = parser.parse_args()
batch_scale(opt.src, opt.target, (opt.width, opt.height))
from typing import List, Tuple
from math import pi
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import net_modules
from .my import util
from .my import device
def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Cartesian to Spherical
:param cart: ... x 3, coordinates in Cartesian
:return: ... x 3, coordinates in Spherical (r, theta, phi)
"""
rho = torch.norm(cart, p=2, dim=-1)
theta = torch.atan2(cart[..., 2], cart[..., 0])
theta = theta + (theta < 0).type_as(theta) * (2 * pi)
phi = torch.acos(cart[..., 1] / rho)
return torch.stack([rho, theta, phi], dim=-1)
def SphericalToCartesian(spher: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Spherical to Cartesian
:param spher: ... x 3, coordinates in Spherical
:return: ... x 3, coordinates in Cartesian (r, theta, phi)
"""
rho = spher[..., 0]
sin_theta_phi = torch.sin(spher[..., 1:3])
cos_theta_phi = torch.cos(spher[..., 1:3])
x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
y = rho * cos_theta_phi[..., 1]
z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
return torch.stack([x, y, z], dim=-1)
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
......@@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres = RaySphereIntersect(p, v, r)
return CartesianToSpherical(p_on_spheres)
class FcNet(nn.Module):
def __init__(self, in_chns: int, out_chns: int, nf: int, n_layers: int):
super().__init__()
self.layers = list()
self.layers += [
nn.Linear(in_chns, nf),
#nn.LayerNorm([nf]),
nn.ReLU()
]
for _ in range(1, n_layers):
self.layers += [
nn.Linear(nf, nf),
#nn.LayerNorm([nf]),
nn.ReLU()
]
self.layers.append(nn.Linear(nf, out_chns))
self.net = nn.Sequential(*self.layers)
self.net.apply(self.init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
return util.CartesianToSpherical(p_on_spheres)
class Rendering(nn.Module):
def __init__(self, sphere_layers: List[float]):
def __init__(self):
"""
Initialize a Rendering module
:param sphere_layers: L x 1, radius of sphere layers
"""
super().__init__()
self.sphere_layers = torch.tensor(
sphere_layers, device=device.GetDevice())
def forward(self, net: FcNet, p: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
def forward(self, color_alpha: torch.Tensor) -> torch.Tensor:
"""
[summary]
Blend layers to get final color
:param net: the full-connected net
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:return B x 1/3, view images by blended layers
:param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel
:return ```Tensor(B, C-1)``` blended pixels
"""
L = self.sphere_layers.size()[0]
sp = RayToSpherical(p, v, self.sphere_layers) # B x L x 3
sp[..., 0] = 1 / sp[..., 0] # Radius to diopter
color_alpha: torch.Tensor = net(
sp.flatten(0, 1)).view(p.size()[0], L, -1)
if (color_alpha.size(-1) == 2): # Grayscale
c = color_alpha[..., 0:1]
a = color_alpha[..., 1:2]
else: # RGB
c = color_alpha[..., 0:3]
a = color_alpha[..., 3:4]
c = color_alpha[..., :-1]
a = color_alpha[..., -1:]
blended = c[:, 0, :] * a[:, 0, :]
for l in range(1, L):
for l in range(1, color_alpha.size(1)):
blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :]
return blended
class MslNet(nn.Module):
def __init__(self, cam_params, sphere_layers: List[float], out_res: Tuple[int, int], gray=False):
def __init__(self, cam_params, fc_params, sphere_layers: List[float],
out_res: Tuple[int, int], gray=False, encode_to_dim: int = 0):
"""
Initialize a multi-sphere-layer net
:param cam_params: intrinsic parameters of camera
:param sphere_layers: L x 1, radius of sphere layers
:param fc_params: parameters of full-connection network
:param sphere_layers: list(L), radius of sphere layers
:param out_res: resolution of output view image
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
"""
super().__init__()
self.cam_params = cam_params
self.sphere_layers = torch.tensor(sphere_layers,
dtype=torch.float,
device=device.GetDevice())
self.in_chns = 3
self.out_res = out_res
self.v_local = util.GetLocalViewRays(self.cam_params, out_res, flatten=True) \
.to(device.GetDevice()) # N x 3
#self.net = FCBlock(hidden_ch=64,
# num_hidden_layers=4,
# in_features=3,
# out_features=2 if gray else 4,
# outermost_linear=True)
self.net = FcNet(in_chns=3, out_chns=2 if gray else 4, nf=256, n_layers=8)
self.rendering = Rendering(sphere_layers)
def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor) -> torch.Tensor:
self.input_encoder = net_modules.InputEncoder.Get(
encode_to_dim, self.in_chns)
fc_params['in_chns'] = self.input_encoder.out_dim
fc_params['out_chns'] = 2 if gray else 4
self.net = net_modules.FcNet(**fc_params)
self.rendering = Rendering()
def forward(self, ray_positions: torch.Tensor, ray_directions: torch.Tensor) -> torch.Tensor:
"""
T_view -> image
rays -> colors
:param view_centers: B x 3, centers of views
:param view_rots: B x 3 x 3, rotation matrices of views
:return: B x 1/3 x H_out x W_out, inferred images of views
:param ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray positions
:param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions
:return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels
"""
# Transpose matrix so we can perform vec x mat
view_rots_t = view_rots.permute(0, 2, 1)
# p and v are B x N x 3 tensor
p = view_centers.unsqueeze(1).expand(-1, self.v_local.size(0), -1)
v = torch.matmul(self.v_local, view_rots_t)
c: torch.Tensor = self.rendering(
self.net, p.flatten(0, 1), v.flatten(0, 1)) # (BN) x 3
p = ray_positions.view(-1, 3)
v = ray_directions.view(-1, 3)
spher = RayToSpherical(p, v, self.sphere_layers).flatten(0, 1)
color_alpha = self.net(self.input_encoder(spher)).view(
p.size(0), self.sphere_layers.size(0), -1)
c: torch.Tensor = self.rendering(color_alpha)
# unflatten
return c.view(view_centers.size(0), self.out_res[0],
self.out_res[1], -1).permute(0, 3, 1, 2)
return c.view(ray_directions.size(0), self.out_res[0],
self.out_res[1], -1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c
import sys
sys.path.append('/e/dengnc')
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
import argparse
import torch
import torch.optim
import torchvision
from typing import List, Tuple
from tensorboardX import SummaryWriter
from torch import nn
from .my import netio
from .my import util
from .my import device
from .my.simple_perf import SimplePerf
from .loss.loss import PerceptionReconstructionLoss
from .data.spherical_view_syn import SphericalViewSynDataset
from .msl_net import MslNet
from .spher_net import SpherNet
......@@ -36,9 +35,9 @@ TRAIN_MODE = True
EVAL_TIME_PERFORMANCE = False
RAY_AS_ITEM = True
# ========
#GRAY = True
ROT_ONLY = True
TRAIN_MODE = False
GRAY = True
#ROT_ONLY = True
#TRAIN_MODE = False
#EVAL_TIME_PERFORMANCE = True
#RAY_AS_ITEM = False
......@@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10
N_ENCODE_DIM = 10
FC_PARAMS = {
'nf': 128,
'n_layers': 6,
'n_layers': 8,
'skips': [4]
}
# Train
TRAIN_DATA_DESC_FILE = 'train.json'
BATCH_SIZE = 2048 if RAY_AS_ITEM else 4
EPOCH_RANGE = range(0, 500)
SAVE_INTERVAL = 20
# Test
TEST_NET_NAME = 'model-epoch_500'
TEST_DATA_DESC_FILE = 'test_fovea.json'
TEST_BATCH_SIZE = 5
# Paths
DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.26_rotonly/'
DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.28/'
RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb',
BATCH_SIZE,
N_ENCODE_DIM,
FC_PARAMS['nf'],
FC_PARAMS['n_layers'],
'_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '')
TRAIN_DATA_DESC_FILE = DATA_DIR + 'train.json'
RUN_DIR = DATA_DIR + RUN_ID + '/'
OUTPUT_DIR = RUN_DIR + 'output/'
LOG_DIR = RUN_DIR + 'log/'
# Test
TEST_NET_NAME = 'model-epoch_100'
TEST_BATCH_SIZE = 5
def train():
# 1. Initialize data loader
print("Load dataset: " + TRAIN_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset(
TRAIN_DATA_DESC_FILE, gray=GRAY, ray_as_item=RAY_AS_ITEM)
print("Load dataset: " + DATA_DIR + TRAIN_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset(DATA_DIR + TRAIN_DATA_DESC_FILE,
gray=GRAY, ray_as_item=RAY_AS_ITEM)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
......@@ -98,10 +97,12 @@ def train():
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res,
gray=GRAY).to(device.GetDevice())
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = nn.MSELoss()
......@@ -172,11 +173,11 @@ def train():
def test(net_file: str):
# 1. Load train dataset
print("Load dataset: " + TRAIN_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE,
load_images=True, gray=GRAY)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
print("Load dataset: " + DATA_DIR + TEST_DATA_DESC_FILE)
test_dataset = SphericalViewSynDataset(DATA_DIR + TEST_DATA_DESC_FILE,
load_images=True, gray=GRAY)
test_data_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=TEST_BATCH_SIZE,
pin_memory=True,
shuffle=False,
......@@ -184,37 +185,38 @@ def test(net_file: str):
# 2. Load trained model
if ROT_ONLY:
model = SpherNet(cam_params=train_dataset.cam_params,
model = SpherNet(cam_params=test_dataset.cam_params,
fc_params=FC_PARAMS,
out_res=train_dataset.view_res,
out_res=test_dataset.view_res,
gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else:
model = MslNet(cam_params=train_dataset.cam_params,
sphere_layers=_GetSphereLayers(
model = MslNet(cam_params=test_dataset.cam_params,
sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res,
out_res=test_dataset.view_res,
gray=GRAY).to(device.GetDevice())
netio.LoadNet(net_file, model)
# 3. Test on train dataset
print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE)
util.CreateDirIfNeed(OUTPUT_DIR)
util.CreateDirIfNeed(OUTPUT_DIR + TEST_NET_NAME)
output_dir = '%s%s/%s/' % (OUTPUT_DIR, TEST_NET_NAME, TEST_DATA_DESC_FILE)
util.CreateDirIfNeed(output_dir)
perf = SimplePerf(True, start=True)
i = 0
for view_idxs, view_images, ray_positions, ray_directions in train_data_loader:
for view_idxs, view_images, ray_positions, ray_directions in test_data_loader:
ray_positions = ray_positions.to(device.GetDevice())
ray_directions = ray_directions.to(device.GetDevice())
perf.Checkpoint("%d - Load" % i)
out_view_images = model(ray_positions, ray_directions)
perf.Checkpoint("%d - Infer" % i)
util.WriteImageTensor(
view_images,
['%s%s/gt_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs])
if test_dataset.load_images:
util.WriteImageTensor(
view_images,
['%sgt_view_%04d.png' % (output_dir, i) for i in view_idxs])
util.WriteImageTensor(
out_view_images,
['%s%s/out_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs])
['%sout_view_%04d.png' % (output_dir, i) for i in view_idxs])
perf.Checkpoint("%d - Write" % i)
i += 1
......
from typing import List, Tuple
from math import pi
from typing import Tuple
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import net_modules
from .my import util
from .my import device
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B'(1D), radius of spheres
:return: B x B' x 3, points of intersection
"""
# p, v: Expand to B x 1 x 3
p = p.unsqueeze(1)
v = v.unsqueeze(1)
# pp, vv, pv: B x 1
pp = (p * p).sum(dim=2)
vv = (v * v).sum(dim=2)
pv = (p * v).sum(dim=2)
# k: Expand to B x B' x 1
k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2)
return p + k * v
def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres = RaySphereIntersect(p, v, r)
return util.CartesianToSpherical(p_on_spheres)
class SpherNet(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment