From 3e1a5b04caf00aec7c14ab54366e3bcfcae2b5ba Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Fri, 25 Dec 2020 18:46:36 +0800 Subject: [PATCH] sync --- msl_net.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ run_lf_syn.py | 4 ++-- 2 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 msl_net.py diff --git a/msl_net.py b/msl_net.py new file mode 100644 index 0000000..6923827 --- /dev/null +++ b/msl_net.py @@ -0,0 +1,45 @@ +from typing import List +import torch +import torch.nn as nn +from .pytorch_prototyping.pytorch_prototyping import * +from .my import util +from .my import device + + +class FcNet(nn.Module): + + def __init__(self, in_chns, out_chns, nf, n_layers): + super().__init__() + self.layers = list() + self.layers.append(nn.Linear(in_chns, nf)) + self.layers.append(nn.LeakyReLU()) + for _ in range(1, n_layers): + self.layers.append(nn.Linear(nf, nf)) + self.layers.append(nn.LeakyReLU()) + self.layers.append(nn.Linear(nf, out_chns)) + self.net = nn.Sequential(*self.layers) + + def forward(self, x): + return self.net(x) + +class Rendering(nn.Module): + + def __init__(self, n_sphere_layers): + super().__init__() + self.n_sl = n_sphere_layers + + def forward(self, net, pos, dir): + """ + [summary] + + :param pos: B x 3, position of a ray + :param dir: B x 3, direction of a ray + """ + + +class MslNet(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): diff --git a/run_lf_syn.py b/run_lf_syn.py index b221d86..5f84021 100644 --- a/run_lf_syn.py +++ b/run_lf_syn.py @@ -27,7 +27,7 @@ BATCH_SIZE = 8 TEST_BATCH_SIZE = 10 NUM_EPOCH = 1000 MODE = "Silence" # "Perf" -EPOCH_BEGIN = 0 +EPOCH_BEGIN = 600 def train(): @@ -58,7 +58,7 @@ def train(): # 3. Train model.train() epoch = EPOCH_BEGIN - iters = EPOCH_BEGIN * len(train_data_loader) + iters = EPOCH_BEGIN * len(train_data_loader) * BATCH_SIZE util.CreateDirIfNeed(RUN_DIR) -- GitLab