Commit 3e1a5b04 authored by BobYeah's avatar BobYeah
Browse files

sync

parent 67c4de9e
from typing import List
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import util
from .my import device
class FcNet(nn.Module):
def __init__(self, in_chns, out_chns, nf, n_layers):
super().__init__()
self.layers = list()
self.layers.append(nn.Linear(in_chns, nf))
self.layers.append(nn.LeakyReLU())
for _ in range(1, n_layers):
self.layers.append(nn.Linear(nf, nf))
self.layers.append(nn.LeakyReLU())
self.layers.append(nn.Linear(nf, out_chns))
self.net = nn.Sequential(*self.layers)
def forward(self, x):
return self.net(x)
class Rendering(nn.Module):
def __init__(self, n_sphere_layers):
super().__init__()
self.n_sl = n_sphere_layers
def forward(self, net, pos, dir):
"""
[summary]
:param pos: B x 3, position of a ray
:param dir: B x 3, direction of a ray
"""
class MslNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
...@@ -27,7 +27,7 @@ BATCH_SIZE = 8 ...@@ -27,7 +27,7 @@ BATCH_SIZE = 8
TEST_BATCH_SIZE = 10 TEST_BATCH_SIZE = 10
NUM_EPOCH = 1000 NUM_EPOCH = 1000
MODE = "Silence" # "Perf" MODE = "Silence" # "Perf"
EPOCH_BEGIN = 0 EPOCH_BEGIN = 600
def train(): def train():
...@@ -58,7 +58,7 @@ def train(): ...@@ -58,7 +58,7 @@ def train():
# 3. Train # 3. Train
model.train() model.train()
epoch = EPOCH_BEGIN epoch = EPOCH_BEGIN
iters = EPOCH_BEGIN * len(train_data_loader) iters = EPOCH_BEGIN * len(train_data_loader) * BATCH_SIZE
util.CreateDirIfNeed(RUN_DIR) util.CreateDirIfNeed(RUN_DIR)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment