From c570c3b15b9766385408ea93ef06b26b0e909757 Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Mon, 28 Dec 2020 21:04:08 +0800 Subject: [PATCH] checkpoint --- data/spherical_view_syn.py | 16 ++-- image_scale.py | 32 ++++++++ msl_net.py | 152 ++++++++++--------------------------- run_spherical_view_syn.py | 72 +++++++++--------- spher_net.py | 39 +--------- 5 files changed, 119 insertions(+), 192 deletions(-) create mode 100644 image_scale.py diff --git a/data/spherical_view_syn.py b/data/spherical_view_syn.py index b37602a..e8c88af 100644 --- a/data/spherical_view_syn.py +++ b/data/spherical_view_syn.py @@ -2,7 +2,6 @@ import torch import torchvision.transforms.functional as trans_f import json from ..my import util -from ..my import imgio class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): @@ -44,8 +43,11 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): # Load dataset description file with open(dataset_desc_path, 'r', encoding='utf-8') as file: data_desc = json.loads(file.read()) - self.view_file_pattern: str = self.data_dir + \ - data_desc['view_file_pattern'] + if data_desc['view_file_pattern'] == '': + self.load_images = False + else: + self.view_file_pattern: str = self.data_dir + \ + data_desc['view_file_pattern'] self.view_res = (data_desc['view_res']['y'], data_desc['view_res']['x']) self.cam_params = data_desc['cam_params'] @@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): .view(-1, 3, 3) # (N, 3, 3) # Load view images - if load_images: + if self.load_images: self.view_images = util.ReadImageTensor( [self.view_file_pattern % i for i in range(self.view_centers.size(0))]) if gray: @@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): # Flatten rays if ray_as_item = True if ray_as_item: - self.view_pixels = self.view_images.permute( - 0, 2, 3, 1).flatten(0, 2) + self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten( + 0, 2) if self.view_images != None else None self.ray_positions = self.ray_positions.flatten(0, 1) self.ray_directions = self.ray_directions.flatten(0, 1) @@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): if self.ray_as_item: return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx] return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx] - return idx, self.ray_positions[idx], self.ray_directions[idx] + return idx, False, self.ray_positions[idx], self.ray_directions[idx] diff --git a/image_scale.py b/image_scale.py new file mode 100644 index 0000000..22630e6 --- /dev/null +++ b/image_scale.py @@ -0,0 +1,32 @@ +import sys +import os +sys.path.append(os.path.abspath(sys.path[0] + '/../')) +__package__ = "deeplightfield" + +import argparse +from PIL import Image +from .my import util + + +def batch_scale(src, target, size): + util.CreateDirIfNeed(target) + for file_name in os.listdir(src): + postfix = os.path.splitext(file_name)[1] + if postfix == '.jpg' or postfix == '.png': + im = Image.open(os.path.join(src, file_name)) + im = im.resize(size) + im.save(os.path.join(target, file_name)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('src', type=str, + help='Source directory.') + parser.add_argument('target', type=str, + help='Target directory.') + parser.add_argument('--width', type=int, + help='Width of output images (pixel)') + parser.add_argument('--height', type=int, + help='Height of output images (pixel)') + opt = parser.parse_args() + batch_scale(opt.src, opt.target, (opt.width, opt.height)) diff --git a/msl_net.py b/msl_net.py index c576743..247472d 100644 --- a/msl_net.py +++ b/msl_net.py @@ -1,42 +1,11 @@ from typing import List, Tuple -from math import pi import torch import torch.nn as nn -from .pytorch_prototyping.pytorch_prototyping import * +from .my import net_modules from .my import util from .my import device -def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor: - """ - Convert coordinates from Cartesian to Spherical - - :param cart: ... x 3, coordinates in Cartesian - :return: ... x 3, coordinates in Spherical (r, theta, phi) - """ - rho = torch.norm(cart, p=2, dim=-1) - theta = torch.atan2(cart[..., 2], cart[..., 0]) - theta = theta + (theta < 0).type_as(theta) * (2 * pi) - phi = torch.acos(cart[..., 1] / rho) - return torch.stack([rho, theta, phi], dim=-1) - - -def SphericalToCartesian(spher: torch.Tensor) -> torch.Tensor: - """ - Convert coordinates from Spherical to Cartesian - - :param spher: ... x 3, coordinates in Spherical - :return: ... x 3, coordinates in Cartesian (r, theta, phi) - """ - rho = spher[..., 0] - sin_theta_phi = torch.sin(spher[..., 1:3]) - cos_theta_phi = torch.cos(spher[..., 1:3]) - x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1] - y = rho * cos_theta_phi[..., 1] - z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1] - return torch.stack([x, y, z], dim=-1) - - def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: """ Calculate intersections of each rays and each spheres @@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T :return: B x B' x 3, spherical coordinates """ p_on_spheres = RaySphereIntersect(p, v, r) - return CartesianToSpherical(p_on_spheres) - - -class FcNet(nn.Module): - - def __init__(self, in_chns: int, out_chns: int, nf: int, n_layers: int): - super().__init__() - self.layers = list() - self.layers += [ - nn.Linear(in_chns, nf), - #nn.LayerNorm([nf]), - nn.ReLU() - ] - for _ in range(1, n_layers): - self.layers += [ - nn.Linear(nf, nf), - #nn.LayerNorm([nf]), - nn.ReLU() - ] - self.layers.append(nn.Linear(nf, out_chns)) - self.net = nn.Sequential(*self.layers) - self.net.apply(self.init_weights) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) - - def init_weights(self, m): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight) - nn.init.constant_(m.bias, 0.0) + return util.CartesianToSpherical(p_on_spheres) class Rendering(nn.Module): - def __init__(self, sphere_layers: List[float]): + def __init__(self): """ Initialize a Rendering module - - :param sphere_layers: L x 1, radius of sphere layers """ super().__init__() - self.sphere_layers = torch.tensor( - sphere_layers, device=device.GetDevice()) - def forward(self, net: FcNet, p: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + def forward(self, color_alpha: torch.Tensor) -> torch.Tensor: """ - [summary] + Blend layers to get final color - :param net: the full-connected net - :param p: B x 3, positions of rays - :param v: B x 3, directions of rays - :return B x 1/3, view images by blended layers + :param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel + :return ```Tensor(B, C-1)``` blended pixels """ - L = self.sphere_layers.size()[0] - sp = RayToSpherical(p, v, self.sphere_layers) # B x L x 3 - sp[..., 0] = 1 / sp[..., 0] # Radius to diopter - color_alpha: torch.Tensor = net( - sp.flatten(0, 1)).view(p.size()[0], L, -1) - if (color_alpha.size(-1) == 2): # Grayscale - c = color_alpha[..., 0:1] - a = color_alpha[..., 1:2] - else: # RGB - c = color_alpha[..., 0:3] - a = color_alpha[..., 3:4] + c = color_alpha[..., :-1] + a = color_alpha[..., -1:] blended = c[:, 0, :] * a[:, 0, :] - for l in range(1, L): + for l in range(1, color_alpha.size(1)): blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :] return blended class MslNet(nn.Module): - def __init__(self, cam_params, sphere_layers: List[float], out_res: Tuple[int, int], gray=False): + def __init__(self, cam_params, fc_params, sphere_layers: List[float], + out_res: Tuple[int, int], gray=False, encode_to_dim: int = 0): """ Initialize a multi-sphere-layer net :param cam_params: intrinsic parameters of camera - :param sphere_layers: L x 1, radius of sphere layers + :param fc_params: parameters of full-connection network + :param sphere_layers: list(L), radius of sphere layers :param out_res: resolution of output view image + :param gray: is grayscale mode + :param encode_to_dim: encode input to number of dimensions """ super().__init__() self.cam_params = cam_params + self.sphere_layers = torch.tensor(sphere_layers, + dtype=torch.float, + device=device.GetDevice()) + self.in_chns = 3 self.out_res = out_res - self.v_local = util.GetLocalViewRays(self.cam_params, out_res, flatten=True) \ - .to(device.GetDevice()) # N x 3 - #self.net = FCBlock(hidden_ch=64, - # num_hidden_layers=4, - # in_features=3, - # out_features=2 if gray else 4, - # outermost_linear=True) - self.net = FcNet(in_chns=3, out_chns=2 if gray else 4, nf=256, n_layers=8) - self.rendering = Rendering(sphere_layers) - - def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor) -> torch.Tensor: + self.input_encoder = net_modules.InputEncoder.Get( + encode_to_dim, self.in_chns) + fc_params['in_chns'] = self.input_encoder.out_dim + fc_params['out_chns'] = 2 if gray else 4 + self.net = net_modules.FcNet(**fc_params) + self.rendering = Rendering() + + def forward(self, ray_positions: torch.Tensor, ray_directions: torch.Tensor) -> torch.Tensor: """ - T_view -> image + rays -> colors - :param view_centers: B x 3, centers of views - :param view_rots: B x 3 x 3, rotation matrices of views - :return: B x 1/3 x H_out x W_out, inferred images of views + :param ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray positions + :param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions + :return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels """ - # Transpose matrix so we can perform vec x mat - view_rots_t = view_rots.permute(0, 2, 1) - - # p and v are B x N x 3 tensor - p = view_centers.unsqueeze(1).expand(-1, self.v_local.size(0), -1) - v = torch.matmul(self.v_local, view_rots_t) - c: torch.Tensor = self.rendering( - self.net, p.flatten(0, 1), v.flatten(0, 1)) # (BN) x 3 + p = ray_positions.view(-1, 3) + v = ray_directions.view(-1, 3) + spher = RayToSpherical(p, v, self.sphere_layers).flatten(0, 1) + color_alpha = self.net(self.input_encoder(spher)).view( + p.size(0), self.sphere_layers.size(0), -1) + c: torch.Tensor = self.rendering(color_alpha) # unflatten - return c.view(view_centers.size(0), self.out_res[0], - self.out_res[1], -1).permute(0, 3, 1, 2) + return c.view(ray_directions.size(0), self.out_res[0], + self.out_res[1], -1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c diff --git a/run_spherical_view_syn.py b/run_spherical_view_syn.py index c724ff9..8474bec 100644 --- a/run_spherical_view_syn.py +++ b/run_spherical_view_syn.py @@ -1,19 +1,18 @@ import sys -sys.path.append('/e/dengnc') +import os +sys.path.append(os.path.abspath(sys.path[0] + '/../')) __package__ = "deeplightfield" import argparse import torch import torch.optim import torchvision -from typing import List, Tuple from tensorboardX import SummaryWriter from torch import nn from .my import netio from .my import util from .my import device from .my.simple_perf import SimplePerf -from .loss.loss import PerceptionReconstructionLoss from .data.spherical_view_syn import SphericalViewSynDataset from .msl_net import MslNet from .spher_net import SpherNet @@ -36,9 +35,9 @@ TRAIN_MODE = True EVAL_TIME_PERFORMANCE = False RAY_AS_ITEM = True # ======== -#GRAY = True -ROT_ONLY = True -TRAIN_MODE = False +GRAY = True +#ROT_ONLY = True +#TRAIN_MODE = False #EVAL_TIME_PERFORMANCE = True #RAY_AS_ITEM = False @@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10 N_ENCODE_DIM = 10 FC_PARAMS = { 'nf': 128, - 'n_layers': 6, + 'n_layers': 8, 'skips': [4] } # Train +TRAIN_DATA_DESC_FILE = 'train.json' BATCH_SIZE = 2048 if RAY_AS_ITEM else 4 EPOCH_RANGE = range(0, 500) SAVE_INTERVAL = 20 +# Test +TEST_NET_NAME = 'model-epoch_500' +TEST_DATA_DESC_FILE = 'test_fovea.json' +TEST_BATCH_SIZE = 5 + # Paths -DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.26_rotonly/' +DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.28/' RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb', BATCH_SIZE, N_ENCODE_DIM, FC_PARAMS['nf'], FC_PARAMS['n_layers'], '_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '') -TRAIN_DATA_DESC_FILE = DATA_DIR + 'train.json' RUN_DIR = DATA_DIR + RUN_ID + '/' OUTPUT_DIR = RUN_DIR + 'output/' LOG_DIR = RUN_DIR + 'log/' -# Test -TEST_NET_NAME = 'model-epoch_100' -TEST_BATCH_SIZE = 5 - - def train(): # 1. Initialize data loader - print("Load dataset: " + TRAIN_DATA_DESC_FILE) - train_dataset = SphericalViewSynDataset( - TRAIN_DATA_DESC_FILE, gray=GRAY, ray_as_item=RAY_AS_ITEM) + print("Load dataset: " + DATA_DIR + TRAIN_DATA_DESC_FILE) + train_dataset = SphericalViewSynDataset(DATA_DIR + TRAIN_DATA_DESC_FILE, + gray=GRAY, ray_as_item=RAY_AS_ITEM) train_data_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=BATCH_SIZE, @@ -98,10 +97,12 @@ def train(): encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) else: model = MslNet(cam_params=train_dataset.cam_params, + fc_params=FC_PARAMS, sphere_layers=util.GetDepthLayers( DEPTH_RANGE, N_DEPTH_LAYERS), out_res=train_dataset.view_res, - gray=GRAY).to(device.GetDevice()) + gray=GRAY, + encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) loss = nn.MSELoss() @@ -172,11 +173,11 @@ def train(): def test(net_file: str): # 1. Load train dataset - print("Load dataset: " + TRAIN_DATA_DESC_FILE) - train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE, - load_images=True, gray=GRAY) - train_data_loader = torch.utils.data.DataLoader( - dataset=train_dataset, + print("Load dataset: " + DATA_DIR + TEST_DATA_DESC_FILE) + test_dataset = SphericalViewSynDataset(DATA_DIR + TEST_DATA_DESC_FILE, + load_images=True, gray=GRAY) + test_data_loader = torch.utils.data.DataLoader( + dataset=test_dataset, batch_size=TEST_BATCH_SIZE, pin_memory=True, shuffle=False, @@ -184,37 +185,38 @@ def test(net_file: str): # 2. Load trained model if ROT_ONLY: - model = SpherNet(cam_params=train_dataset.cam_params, + model = SpherNet(cam_params=test_dataset.cam_params, fc_params=FC_PARAMS, - out_res=train_dataset.view_res, + out_res=test_dataset.view_res, gray=GRAY, encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) else: - model = MslNet(cam_params=train_dataset.cam_params, - sphere_layers=_GetSphereLayers( + model = MslNet(cam_params=test_dataset.cam_params, + sphere_layers=util.GetDepthLayers( DEPTH_RANGE, N_DEPTH_LAYERS), - out_res=train_dataset.view_res, + out_res=test_dataset.view_res, gray=GRAY).to(device.GetDevice()) netio.LoadNet(net_file, model) # 3. Test on train dataset print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE) - util.CreateDirIfNeed(OUTPUT_DIR) - util.CreateDirIfNeed(OUTPUT_DIR + TEST_NET_NAME) + output_dir = '%s%s/%s/' % (OUTPUT_DIR, TEST_NET_NAME, TEST_DATA_DESC_FILE) + util.CreateDirIfNeed(output_dir) perf = SimplePerf(True, start=True) i = 0 - for view_idxs, view_images, ray_positions, ray_directions in train_data_loader: + for view_idxs, view_images, ray_positions, ray_directions in test_data_loader: ray_positions = ray_positions.to(device.GetDevice()) ray_directions = ray_directions.to(device.GetDevice()) perf.Checkpoint("%d - Load" % i) out_view_images = model(ray_positions, ray_directions) perf.Checkpoint("%d - Infer" % i) - util.WriteImageTensor( - view_images, - ['%s%s/gt_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs]) + if test_dataset.load_images: + util.WriteImageTensor( + view_images, + ['%sgt_view_%04d.png' % (output_dir, i) for i in view_idxs]) util.WriteImageTensor( out_view_images, - ['%s%s/out_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs]) + ['%sout_view_%04d.png' % (output_dir, i) for i in view_idxs]) perf.Checkpoint("%d - Write" % i) i += 1 diff --git a/spher_net.py b/spher_net.py index 4ab69a0..e1bc6e9 100644 --- a/spher_net.py +++ b/spher_net.py @@ -1,45 +1,8 @@ -from typing import List, Tuple -from math import pi +from typing import Tuple import torch import torch.nn as nn -from .pytorch_prototyping.pytorch_prototyping import * from .my import net_modules from .my import util -from .my import device - - -def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: - """ - Calculate intersections of each rays and each spheres - - :param p: B x 3, positions of rays - :param v: B x 3, directions of rays - :param r: B'(1D), radius of spheres - :return: B x B' x 3, points of intersection - """ - # p, v: Expand to B x 1 x 3 - p = p.unsqueeze(1) - v = v.unsqueeze(1) - # pp, vv, pv: B x 1 - pp = (p * p).sum(dim=2) - vv = (v * v).sum(dim=2) - pv = (p * v).sum(dim=2) - # k: Expand to B x B' x 1 - k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2) - return p + k * v - - -def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: - """ - Calculate intersections of each rays and each spheres - - :param p: B x 3, positions of rays - :param v: B x 3, directions of rays - :param r: B' x 1, radius of spheres - :return: B x B' x 3, spherical coordinates - """ - p_on_spheres = RaySphereIntersect(p, v, r) - return util.CartesianToSpherical(p_on_spheres) class SpherNet(nn.Module): -- GitLab