Commit c570c3b1 authored by BobYeah's avatar BobYeah
Browse files

checkpoint

parent 172b5205
...@@ -2,7 +2,6 @@ import torch ...@@ -2,7 +2,6 @@ import torch
import torchvision.transforms.functional as trans_f import torchvision.transforms.functional as trans_f
import json import json
from ..my import util from ..my import util
from ..my import imgio
class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
...@@ -44,8 +43,11 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): ...@@ -44,8 +43,11 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Load dataset description file # Load dataset description file
with open(dataset_desc_path, 'r', encoding='utf-8') as file: with open(dataset_desc_path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read()) data_desc = json.loads(file.read())
self.view_file_pattern: str = self.data_dir + \ if data_desc['view_file_pattern'] == '':
data_desc['view_file_pattern'] self.load_images = False
else:
self.view_file_pattern: str = self.data_dir + \
data_desc['view_file_pattern']
self.view_res = (data_desc['view_res']['y'], self.view_res = (data_desc['view_res']['y'],
data_desc['view_res']['x']) data_desc['view_res']['x'])
self.cam_params = data_desc['cam_params'] self.cam_params = data_desc['cam_params']
...@@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): ...@@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
.view(-1, 3, 3) # (N, 3, 3) .view(-1, 3, 3) # (N, 3, 3)
# Load view images # Load view images
if load_images: if self.load_images:
self.view_images = util.ReadImageTensor( self.view_images = util.ReadImageTensor(
[self.view_file_pattern % i for i in range(self.view_centers.size(0))]) [self.view_file_pattern % i for i in range(self.view_centers.size(0))])
if gray: if gray:
...@@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): ...@@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Flatten rays if ray_as_item = True # Flatten rays if ray_as_item = True
if ray_as_item: if ray_as_item:
self.view_pixels = self.view_images.permute( self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten(
0, 2, 3, 1).flatten(0, 2) 0, 2) if self.view_images != None else None
self.ray_positions = self.ray_positions.flatten(0, 1) self.ray_positions = self.ray_positions.flatten(0, 1)
self.ray_directions = self.ray_directions.flatten(0, 1) self.ray_directions = self.ray_directions.flatten(0, 1)
...@@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): ...@@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
if self.ray_as_item: if self.ray_as_item:
return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx] return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx] return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx]
return idx, self.ray_positions[idx], self.ray_directions[idx] return idx, False, self.ray_positions[idx], self.ray_directions[idx]
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield"
import argparse
from PIL import Image
from .my import util
def batch_scale(src, target, size):
util.CreateDirIfNeed(target)
for file_name in os.listdir(src):
postfix = os.path.splitext(file_name)[1]
if postfix == '.jpg' or postfix == '.png':
im = Image.open(os.path.join(src, file_name))
im = im.resize(size)
im.save(os.path.join(target, file_name))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('src', type=str,
help='Source directory.')
parser.add_argument('target', type=str,
help='Target directory.')
parser.add_argument('--width', type=int,
help='Width of output images (pixel)')
parser.add_argument('--height', type=int,
help='Height of output images (pixel)')
opt = parser.parse_args()
batch_scale(opt.src, opt.target, (opt.width, opt.height))
from typing import List, Tuple from typing import List, Tuple
from math import pi
import torch import torch
import torch.nn as nn import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import * from .my import net_modules
from .my import util from .my import util
from .my import device from .my import device
def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Cartesian to Spherical
:param cart: ... x 3, coordinates in Cartesian
:return: ... x 3, coordinates in Spherical (r, theta, phi)
"""
rho = torch.norm(cart, p=2, dim=-1)
theta = torch.atan2(cart[..., 2], cart[..., 0])
theta = theta + (theta < 0).type_as(theta) * (2 * pi)
phi = torch.acos(cart[..., 1] / rho)
return torch.stack([rho, theta, phi], dim=-1)
def SphericalToCartesian(spher: torch.Tensor) -> torch.Tensor:
"""
Convert coordinates from Spherical to Cartesian
:param spher: ... x 3, coordinates in Spherical
:return: ... x 3, coordinates in Cartesian (r, theta, phi)
"""
rho = spher[..., 0]
sin_theta_phi = torch.sin(spher[..., 1:3])
cos_theta_phi = torch.cos(spher[..., 1:3])
x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
y = rho * cos_theta_phi[..., 1]
z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
return torch.stack([x, y, z], dim=-1)
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
""" """
Calculate intersections of each rays and each spheres Calculate intersections of each rays and each spheres
...@@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T ...@@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T
:return: B x B' x 3, spherical coordinates :return: B x B' x 3, spherical coordinates
""" """
p_on_spheres = RaySphereIntersect(p, v, r) p_on_spheres = RaySphereIntersect(p, v, r)
return CartesianToSpherical(p_on_spheres) return util.CartesianToSpherical(p_on_spheres)
class FcNet(nn.Module):
def __init__(self, in_chns: int, out_chns: int, nf: int, n_layers: int):
super().__init__()
self.layers = list()
self.layers += [
nn.Linear(in_chns, nf),
#nn.LayerNorm([nf]),
nn.ReLU()
]
for _ in range(1, n_layers):
self.layers += [
nn.Linear(nf, nf),
#nn.LayerNorm([nf]),
nn.ReLU()
]
self.layers.append(nn.Linear(nf, out_chns))
self.net = nn.Sequential(*self.layers)
self.net.apply(self.init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
class Rendering(nn.Module): class Rendering(nn.Module):
def __init__(self, sphere_layers: List[float]): def __init__(self):
""" """
Initialize a Rendering module Initialize a Rendering module
:param sphere_layers: L x 1, radius of sphere layers
""" """
super().__init__() super().__init__()
self.sphere_layers = torch.tensor(
sphere_layers, device=device.GetDevice())
def forward(self, net: FcNet, p: torch.Tensor, v: torch.Tensor) -> torch.Tensor: def forward(self, color_alpha: torch.Tensor) -> torch.Tensor:
""" """
[summary] Blend layers to get final color
:param net: the full-connected net :param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel
:param p: B x 3, positions of rays :return ```Tensor(B, C-1)``` blended pixels
:param v: B x 3, directions of rays
:return B x 1/3, view images by blended layers
""" """
L = self.sphere_layers.size()[0] c = color_alpha[..., :-1]
sp = RayToSpherical(p, v, self.sphere_layers) # B x L x 3 a = color_alpha[..., -1:]
sp[..., 0] = 1 / sp[..., 0] # Radius to diopter
color_alpha: torch.Tensor = net(
sp.flatten(0, 1)).view(p.size()[0], L, -1)
if (color_alpha.size(-1) == 2): # Grayscale
c = color_alpha[..., 0:1]
a = color_alpha[..., 1:2]
else: # RGB
c = color_alpha[..., 0:3]
a = color_alpha[..., 3:4]
blended = c[:, 0, :] * a[:, 0, :] blended = c[:, 0, :] * a[:, 0, :]
for l in range(1, L): for l in range(1, color_alpha.size(1)):
blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :] blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :]
return blended return blended
class MslNet(nn.Module): class MslNet(nn.Module):
def __init__(self, cam_params, sphere_layers: List[float], out_res: Tuple[int, int], gray=False): def __init__(self, cam_params, fc_params, sphere_layers: List[float],
out_res: Tuple[int, int], gray=False, encode_to_dim: int = 0):
""" """
Initialize a multi-sphere-layer net Initialize a multi-sphere-layer net
:param cam_params: intrinsic parameters of camera :param cam_params: intrinsic parameters of camera
:param sphere_layers: L x 1, radius of sphere layers :param fc_params: parameters of full-connection network
:param sphere_layers: list(L), radius of sphere layers
:param out_res: resolution of output view image :param out_res: resolution of output view image
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
""" """
super().__init__() super().__init__()
self.cam_params = cam_params self.cam_params = cam_params
self.sphere_layers = torch.tensor(sphere_layers,
dtype=torch.float,
device=device.GetDevice())
self.in_chns = 3
self.out_res = out_res self.out_res = out_res
self.v_local = util.GetLocalViewRays(self.cam_params, out_res, flatten=True) \ self.input_encoder = net_modules.InputEncoder.Get(
.to(device.GetDevice()) # N x 3 encode_to_dim, self.in_chns)
#self.net = FCBlock(hidden_ch=64, fc_params['in_chns'] = self.input_encoder.out_dim
# num_hidden_layers=4, fc_params['out_chns'] = 2 if gray else 4
# in_features=3, self.net = net_modules.FcNet(**fc_params)
# out_features=2 if gray else 4, self.rendering = Rendering()
# outermost_linear=True)
self.net = FcNet(in_chns=3, out_chns=2 if gray else 4, nf=256, n_layers=8) def forward(self, ray_positions: torch.Tensor, ray_directions: torch.Tensor) -> torch.Tensor:
self.rendering = Rendering(sphere_layers)
def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor) -> torch.Tensor:
""" """
T_view -> image rays -> colors
:param view_centers: B x 3, centers of views :param ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray positions
:param view_rots: B x 3 x 3, rotation matrices of views :param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions
:return: B x 1/3 x H_out x W_out, inferred images of views :return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels
""" """
# Transpose matrix so we can perform vec x mat p = ray_positions.view(-1, 3)
view_rots_t = view_rots.permute(0, 2, 1) v = ray_directions.view(-1, 3)
spher = RayToSpherical(p, v, self.sphere_layers).flatten(0, 1)
# p and v are B x N x 3 tensor color_alpha = self.net(self.input_encoder(spher)).view(
p = view_centers.unsqueeze(1).expand(-1, self.v_local.size(0), -1) p.size(0), self.sphere_layers.size(0), -1)
v = torch.matmul(self.v_local, view_rots_t) c: torch.Tensor = self.rendering(color_alpha)
c: torch.Tensor = self.rendering(
self.net, p.flatten(0, 1), v.flatten(0, 1)) # (BN) x 3
# unflatten # unflatten
return c.view(view_centers.size(0), self.out_res[0], return c.view(ray_directions.size(0), self.out_res[0],
self.out_res[1], -1).permute(0, 3, 1, 2) self.out_res[1], -1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c
import sys import sys
sys.path.append('/e/dengnc') import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deeplightfield" __package__ = "deeplightfield"
import argparse import argparse
import torch import torch
import torch.optim import torch.optim
import torchvision import torchvision
from typing import List, Tuple
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
from .my import netio from .my import netio
from .my import util from .my import util
from .my import device from .my import device
from .my.simple_perf import SimplePerf from .my.simple_perf import SimplePerf
from .loss.loss import PerceptionReconstructionLoss
from .data.spherical_view_syn import SphericalViewSynDataset from .data.spherical_view_syn import SphericalViewSynDataset
from .msl_net import MslNet from .msl_net import MslNet
from .spher_net import SpherNet from .spher_net import SpherNet
...@@ -36,9 +35,9 @@ TRAIN_MODE = True ...@@ -36,9 +35,9 @@ TRAIN_MODE = True
EVAL_TIME_PERFORMANCE = False EVAL_TIME_PERFORMANCE = False
RAY_AS_ITEM = True RAY_AS_ITEM = True
# ======== # ========
#GRAY = True GRAY = True
ROT_ONLY = True #ROT_ONLY = True
TRAIN_MODE = False #TRAIN_MODE = False
#EVAL_TIME_PERFORMANCE = True #EVAL_TIME_PERFORMANCE = True
#RAY_AS_ITEM = False #RAY_AS_ITEM = False
...@@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10 ...@@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10
N_ENCODE_DIM = 10 N_ENCODE_DIM = 10
FC_PARAMS = { FC_PARAMS = {
'nf': 128, 'nf': 128,
'n_layers': 6, 'n_layers': 8,
'skips': [4] 'skips': [4]
} }
# Train # Train
TRAIN_DATA_DESC_FILE = 'train.json'
BATCH_SIZE = 2048 if RAY_AS_ITEM else 4 BATCH_SIZE = 2048 if RAY_AS_ITEM else 4
EPOCH_RANGE = range(0, 500) EPOCH_RANGE = range(0, 500)
SAVE_INTERVAL = 20 SAVE_INTERVAL = 20
# Test
TEST_NET_NAME = 'model-epoch_500'
TEST_DATA_DESC_FILE = 'test_fovea.json'
TEST_BATCH_SIZE = 5
# Paths # Paths
DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.26_rotonly/' DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.28/'
RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb', RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb',
BATCH_SIZE, BATCH_SIZE,
N_ENCODE_DIM, N_ENCODE_DIM,
FC_PARAMS['nf'], FC_PARAMS['nf'],
FC_PARAMS['n_layers'], FC_PARAMS['n_layers'],
'_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '') '_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '')
TRAIN_DATA_DESC_FILE = DATA_DIR + 'train.json'
RUN_DIR = DATA_DIR + RUN_ID + '/' RUN_DIR = DATA_DIR + RUN_ID + '/'
OUTPUT_DIR = RUN_DIR + 'output/' OUTPUT_DIR = RUN_DIR + 'output/'
LOG_DIR = RUN_DIR + 'log/' LOG_DIR = RUN_DIR + 'log/'
# Test
TEST_NET_NAME = 'model-epoch_100'
TEST_BATCH_SIZE = 5
def train(): def train():
# 1. Initialize data loader # 1. Initialize data loader
print("Load dataset: " + TRAIN_DATA_DESC_FILE) print("Load dataset: " + DATA_DIR + TRAIN_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset( train_dataset = SphericalViewSynDataset(DATA_DIR + TRAIN_DATA_DESC_FILE,
TRAIN_DATA_DESC_FILE, gray=GRAY, ray_as_item=RAY_AS_ITEM) gray=GRAY, ray_as_item=RAY_AS_ITEM)
train_data_loader = torch.utils.data.DataLoader( train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset, dataset=train_dataset,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
...@@ -98,10 +97,12 @@ def train(): ...@@ -98,10 +97,12 @@ def train():
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else: else:
model = MslNet(cam_params=train_dataset.cam_params, model = MslNet(cam_params=train_dataset.cam_params,
fc_params=FC_PARAMS,
sphere_layers=util.GetDepthLayers( sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS), DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res, out_res=train_dataset.view_res,
gray=GRAY).to(device.GetDevice()) gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = nn.MSELoss() loss = nn.MSELoss()
...@@ -172,11 +173,11 @@ def train(): ...@@ -172,11 +173,11 @@ def train():
def test(net_file: str): def test(net_file: str):
# 1. Load train dataset # 1. Load train dataset
print("Load dataset: " + TRAIN_DATA_DESC_FILE) print("Load dataset: " + DATA_DIR + TEST_DATA_DESC_FILE)
train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE, test_dataset = SphericalViewSynDataset(DATA_DIR + TEST_DATA_DESC_FILE,
load_images=True, gray=GRAY) load_images=True, gray=GRAY)
train_data_loader = torch.utils.data.DataLoader( test_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset, dataset=test_dataset,
batch_size=TEST_BATCH_SIZE, batch_size=TEST_BATCH_SIZE,
pin_memory=True, pin_memory=True,
shuffle=False, shuffle=False,
...@@ -184,37 +185,38 @@ def test(net_file: str): ...@@ -184,37 +185,38 @@ def test(net_file: str):
# 2. Load trained model # 2. Load trained model
if ROT_ONLY: if ROT_ONLY:
model = SpherNet(cam_params=train_dataset.cam_params, model = SpherNet(cam_params=test_dataset.cam_params,
fc_params=FC_PARAMS, fc_params=FC_PARAMS,
out_res=train_dataset.view_res, out_res=test_dataset.view_res,
gray=GRAY, gray=GRAY,
encode_to_dim=N_ENCODE_DIM).to(device.GetDevice()) encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
else: else:
model = MslNet(cam_params=train_dataset.cam_params, model = MslNet(cam_params=test_dataset.cam_params,
sphere_layers=_GetSphereLayers( sphere_layers=util.GetDepthLayers(
DEPTH_RANGE, N_DEPTH_LAYERS), DEPTH_RANGE, N_DEPTH_LAYERS),
out_res=train_dataset.view_res, out_res=test_dataset.view_res,
gray=GRAY).to(device.GetDevice()) gray=GRAY).to(device.GetDevice())
netio.LoadNet(net_file, model) netio.LoadNet(net_file, model)
# 3. Test on train dataset # 3. Test on train dataset
print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE) print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE)
util.CreateDirIfNeed(OUTPUT_DIR) output_dir = '%s%s/%s/' % (OUTPUT_DIR, TEST_NET_NAME, TEST_DATA_DESC_FILE)
util.CreateDirIfNeed(OUTPUT_DIR + TEST_NET_NAME) util.CreateDirIfNeed(output_dir)
perf = SimplePerf(True, start=True) perf = SimplePerf(True, start=True)
i = 0 i = 0
for view_idxs, view_images, ray_positions, ray_directions in train_data_loader: for view_idxs, view_images, ray_positions, ray_directions in test_data_loader:
ray_positions = ray_positions.to(device.GetDevice()) ray_positions = ray_positions.to(device.GetDevice())
ray_directions = ray_directions.to(device.GetDevice()) ray_directions = ray_directions.to(device.GetDevice())
perf.Checkpoint("%d - Load" % i) perf.Checkpoint("%d - Load" % i)
out_view_images = model(ray_positions, ray_directions) out_view_images = model(ray_positions, ray_directions)
perf.Checkpoint("%d - Infer" % i) perf.Checkpoint("%d - Infer" % i)
util.WriteImageTensor( if test_dataset.load_images:
view_images, util.WriteImageTensor(
['%s%s/gt_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs]) view_images,
['%sgt_view_%04d.png' % (output_dir, i) for i in view_idxs])
util.WriteImageTensor( util.WriteImageTensor(
out_view_images, out_view_images,
['%s%s/out_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs]) ['%sout_view_%04d.png' % (output_dir, i) for i in view_idxs])
perf.Checkpoint("%d - Write" % i) perf.Checkpoint("%d - Write" % i)
i += 1 i += 1
......
from typing import List, Tuple from typing import Tuple
from math import pi
import torch import torch
import torch.nn as nn import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import net_modules from .my import net_modules
from .my import util from .my import util
from .my import device
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B'(1D), radius of spheres
:return: B x B' x 3, points of intersection
"""
# p, v: Expand to B x 1 x 3
p = p.unsqueeze(1)
v = v.unsqueeze(1)
# pp, vv, pv: B x 1
pp = (p * p).sum(dim=2)
vv = (v * v).sum(dim=2)
pv = (p * v).sum(dim=2)
# k: Expand to B x B' x 1
k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2)
return p + k * v
def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres = RaySphereIntersect(p, v, r)
return util.CartesianToSpherical(p_on_spheres)
class SpherNet(nn.Module): class SpherNet(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment