Skip to content
Snippets Groups Projects
Commit 5699ccbf authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 338ae906
No related merge requests found
Showing with 764 additions and 58 deletions
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
import sys import sys
import os import os
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torchvision.transforms.functional as trans_f import torchvision.transforms.functional as trans_f
rootdir = os.path.abspath(sys.path[0] + '/../') rootdir = os.path.abspath(sys.path[0] + '/../')
sys.path.append(rootdir) sys.path.append(rootdir)
torch.cuda.set_device(2) torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device()) print("Set CUDA:%d as current device." % torch.cuda.current_device())
from data.spherical_view_syn import * from data.spherical_view_syn import *
from configs.spherical_view_syn import SphericalViewSynConfig from configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio from utils import netio
from utils import misc from utils import misc
from utils import img from utils import img
from utils import device from utils import device
from utils import view from utils import view
from components.foveation import Foveation from components.foveation import Foveation
from components.gen_final import GenFinal from components.gen_final import GenFinal
def load_net(path): def load_net(path):
config = SphericalViewSynConfig() config = SphericalViewSynConfig()
config.from_id(path[:-4]) config.from_id(path[:-4])
config.sa['perturb_sample'] = False config.sa['perturb_sample'] = False
config.print() config.print()
net = config.create_net().to(device.default()) net = config.create_net().to(device.default())
netio.load(path, net) netio.load(path, net)
return net return net
def find_file(prefix): def find_file(prefix):
for path in os.listdir(): for path in os.listdir():
if path.startswith(prefix): if path.startswith(prefix):
return path return path
return None return None
def load_views(data_desc_file) -> view.Trans: def load_views(data_desc_file) -> view.Trans:
with open(data_desc_file, 'r', encoding='utf-8') as file: with open(data_desc_file, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read()) data_desc = json.loads(file.read())
samples = data_desc['samples'] if 'samples' in data_desc else [-1] samples = data_desc['samples'] if 'samples' in data_desc else [-1]
view_centers = torch.tensor( view_centers = torch.tensor(
data_desc['view_centers'], device=device.default()).view(samples + [3]) data_desc['view_centers'], device=device.default()).view(samples + [3])
view_rots = torch.tensor( view_rots = torch.tensor(
data_desc['view_rots'], device=device.default()).view(samples + [3, 3]) data_desc['view_rots'], device=device.default()).view(samples + [3, 3])
return view.Trans(view_centers, view_rots) return view.Trans(view_centers, view_rots)
def read_ref_images(idx): def read_ref_images(idx):
patt = 'ref/view_%04d.png' patt = 'ref/view_%04d.png'
if isinstance(idx, torch.Tensor) and len(idx.size()) > 0: if isinstance(idx, torch.Tensor) and len(idx.size()) > 0:
return img.load([patt % i for i in idx]) return img.load([patt % i for i in idx])
else: else:
return img.load(patt % idx) return img.load(patt % idx)
def adjust_cam(cam, vr_cam, gaze_center): def adjust_cam(cam, vr_cam, gaze_center):
fovea_offset = ( fovea_offset = (
(gaze_center[0]) / vr_cam.f[0].item() * cam.f[0].item(), (gaze_center[0]) / vr_cam.f[0].item() * cam.f[0].item(),
(gaze_center[1]) / vr_cam.f[1].item() * cam.f[1].item() (gaze_center[1]) / vr_cam.f[1].item() * cam.f[1].item()
) )
cam.c[0] = cam.res[1] / 2 - fovea_offset[0] cam.c[0] = cam.res[1] / 2 - fovea_offset[0]
cam.c[1] = cam.res[0] / 2 - fovea_offset[1] cam.c[1] = cam.res[0] / 2 - fovea_offset[1]
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
os.chdir(os.path.join('data/__0_user_study/us_gas_all_in_one')) os.chdir(os.path.join('data/__0_user_study/us_gas_all_in_one'))
#os.chdir(os.path.join('data/__0_user_study/us_mc_all_in_one')) #os.chdir(os.path.join('data/__0_user_study/us_mc_all_in_one'))
#os.chdir(os.path.join('data/bedroom_all_in_one')) #os.chdir(os.path.join('data/bedroom_all_in_one'))
print('Change working directory to ', os.getcwd()) print('Change working directory to ', os.getcwd())
torch.autograd.set_grad_enabled(False) torch.autograd.set_grad_enabled(False)
fovea_net = load_net(find_file('fovea')) fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph')) periph_net = load_net(find_file('periph'))
# Load Dataset # Load Dataset
views = load_views('views.json') views = load_views('views.json')
#ref_dataset = SphericalViewSynDataset('ref.json', load_images=False, calculate_rays=False) #ref_dataset = SphericalViewSynDataset('ref.json', load_images=False, calculate_rays=False)
print('Dataset loaded.') print('Dataset loaded.')
print('views:', views.size()) print('views:', views.size())
#print('ref views:', ref_dataset.samples) #print('ref views:', ref_dataset.samples)
fov_list = [20, 45, 110] fov_list = [20, 45, 110]
res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)] res_list = [(128, 128), (256, 256), (256, 230)] # (192,256)]
res_full = (1600, 1440) res_full = (1600, 1440)
gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net, device.default()) gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net, device.default())
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# ==gas== # ==gas==
set_id = 0 set_id = 0
left_center = (-137, 64) left_center = (-137, 64)
right_center = (-142, 64) right_center = (-142, 64)
set_id = 1 set_id = 1
left_center = (133, -44) left_center = (133, -44)
right_center = (130, -44) right_center = (130, -44)
set_id = 2 set_id = 2
left_center = (-20, -5) left_center = (-20, -5)
right_center = (-25, -5) right_center = (-25, -5)
# ==mc== # ==mc==
#set_id = 3 #set_id = 3
#left_center = (-107, 80) #left_center = (-107, 80)
#right_center = (-112, 80) #right_center = (-112, 80)
#set_id = 4 #set_id = 4
#left_center = (-17, -90) #left_center = (-17, -90)
#right_center = (-22, -90) #right_center = (-22, -90)
#set_id = 5 #set_id = 5
#left_center = (95, 30) #left_center = (95, 30)
#right_center = (91, 30) #right_center = (91, 30)
view_coord = [0, 0, 0, 0, 0] view_coord = [0, 0, 0, 0, 0]
for i, val in enumerate(views.size()): for i, val in enumerate(views.size()):
view_coord[i] += val // 2 view_coord[i] += val // 2
print('view_coord:', view_coord) print('view_coord:', view_coord)
test_view = views.get(*view_coord) test_view = views.get(*view_coord)
cams = [ cams = [
view.CameraParam({ view.CameraParam({
"fov": fov_list[i], "fov": fov_list[i],
"cx": 0.5, "cx": 0.5,
"cy": 0.5, "cy": 0.5,
"normalized": True "normalized": True
}, res_list[i]).to(device.default()) }, res_list[i]).to(device.default())
for i in range(len(fov_list)) for i in range(len(fov_list))
] ]
fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2] fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]
#guide_cam = ref_dataset.cam_params #guide_cam = ref_dataset.cam_params
vr_cam = view.CameraParam({ vr_cam = view.CameraParam({
'fov': fov_list[-1], 'fov': fov_list[-1],
'cx': 0.5, 'cx': 0.5,
'cy': 0.5, 'cy': 0.5,
'normalized': True 'normalized': True
}, res_full) }, res_full)
foveation = Foveation(fov_list, res_full, device=device.default()) foveation = Foveation(fov_list, res_full, device=device.default())
def plot_figures(left_images, right_images, left_center, right_center): def plot_figures(left_images, right_images, left_center, right_center):
# Plot Fovea raw # Plot Fovea raw
plt.figure(figsize=(8, 4)) plt.figure(figsize=(8, 4))
plt.subplot(121) plt.subplot(121)
img.plot(left_images['fovea_raw']) img.plot(left_images['fovea_raw'])
plt.subplot(122) plt.subplot(122)
img.plot(right_images['fovea_raw']) img.plot(right_images['fovea_raw'])
# Plot Fovea # Plot Fovea
plt.figure(figsize=(8, 4)) plt.figure(figsize=(8, 4))
plt.subplot(121) plt.subplot(121)
img.plot(left_images['fovea']) img.plot(left_images['fovea'])
plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5], plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],
[(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2], [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],
color=[0, 1, 0]) color=[0, 1, 0])
plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2], plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],
[(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5], [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],
color=[0, 1, 0]) color=[0, 1, 0])
plt.subplot(122) plt.subplot(122)
img.plot(right_images['fovea']) img.plot(right_images['fovea'])
plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5], plt.plot([(fovea_cam.res[1] - 1) / 2 - 5, (fovea_cam.res[1] - 1) / 2 + 5],
[(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2], [(fovea_cam.res[0] - 1) / 2, (fovea_cam.res[0] - 1) / 2],
color=[0, 1, 0]) color=[0, 1, 0])
plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2], plt.plot([(fovea_cam.res[1] - 1) / 2, (fovea_cam.res[1] - 1) / 2],
[(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5], [(fovea_cam.res[0] - 1) / 2 - 5, (fovea_cam.res[0] - 1) / 2 + 5],
color=[0, 1, 0]) color=[0, 1, 0])
#plt.subplot(1, 4, 2) #plt.subplot(1, 4, 2)
# img.plot(fovea_refined) # img.plot(fovea_refined)
# Plot Mid # Plot Mid
plt.figure(figsize=(8, 4)) plt.figure(figsize=(8, 4))
plt.subplot(121) plt.subplot(121)
img.plot(left_images['mid']) img.plot(left_images['mid'])
plt.subplot(122) plt.subplot(122)
img.plot(right_images['mid']) img.plot(right_images['mid'])
# Plot Periph # Plot Periph
plt.figure(figsize=(8, 4)) plt.figure(figsize=(8, 4))
plt.subplot(121) plt.subplot(121)
img.plot(left_images['periph']) img.plot(left_images['periph'])
plt.subplot(122) plt.subplot(122)
img.plot(right_images['periph']) img.plot(right_images['periph'])
# Plot Blended # Plot Blended
plt.figure(figsize=(12, 6)) plt.figure(figsize=(12, 6))
plt.subplot(121) plt.subplot(121)
img.plot(left_images['blended']) img.plot(left_images['blended'])
plt.plot([(res_full[1] - 1) / 2 + left_center[0] - 5, (res_full[1] - 1) / 2 + left_center[0] + 5], plt.plot([(res_full[1] - 1) / 2 + left_center[0] - 5, (res_full[1] - 1) / 2 + left_center[0] + 5],
[(res_full[0] - 1) / 2 + left_center[1], [(res_full[0] - 1) / 2 + left_center[1],
(res_full[0] - 1) / 2 + left_center[1]], (res_full[0] - 1) / 2 + left_center[1]],
color=[0, 1, 0]) color=[0, 1, 0])
plt.plot([(res_full[1] - 1) / 2 + left_center[0], (res_full[1] - 1) / 2 + left_center[0]], plt.plot([(res_full[1] - 1) / 2 + left_center[0], (res_full[1] - 1) / 2 + left_center[0]],
[(res_full[0] - 1) / 2 + left_center[1] - 5, [(res_full[0] - 1) / 2 + left_center[1] - 5,
(res_full[0] - 1) / 2 + left_center[1] + 5], (res_full[0] - 1) / 2 + left_center[1] + 5],
color=[0, 1, 0]) color=[0, 1, 0])
plt.subplot(122) plt.subplot(122)
img.plot(right_images['blended']) img.plot(right_images['blended'])
plt.plot([(res_full[1] - 1) / 2 + right_center[0] - 5, (res_full[1] - 1) / 2 + right_center[0] + 5], plt.plot([(res_full[1] - 1) / 2 + right_center[0] - 5, (res_full[1] - 1) / 2 + right_center[0] + 5],
[(res_full[0] - 1) / 2 + right_center[1], [(res_full[0] - 1) / 2 + right_center[1],
(res_full[0] - 1) / 2 + right_center[1]], (res_full[0] - 1) / 2 + right_center[1]],
color=[0, 1, 0]) color=[0, 1, 0])
plt.plot([(res_full[1] - 1) / 2 + right_center[0], (res_full[1] - 1) / 2 + right_center[0]], plt.plot([(res_full[1] - 1) / 2 + right_center[0], (res_full[1] - 1) / 2 + right_center[0]],
[(res_full[0] - 1) / 2 + right_center[1] - 5, [(res_full[0] - 1) / 2 + right_center[1] - 5,
(res_full[0] - 1) / 2 + right_center[1] + 5], (res_full[0] - 1) / 2 + right_center[1] + 5],
color=[0, 1, 0]) color=[0, 1, 0])
left_images = gen( left_images = gen(
left_center, left_center,
view.Trans( view.Trans(
test_view.trans_point( test_view.trans_point(
torch.tensor([-0.03, 0, 0], device=device.default()) torch.tensor([-0.03, 0, 0], device=device.default())
), ),
test_view.r test_view.r
), ),
ret_raw=True, ret_raw=True,
mono_trans=test_view, mono_trans=test_view,
shift=0) shift=0)
right_images = gen( right_images = gen(
right_center, right_center,
view.Trans( view.Trans(
test_view.trans_point( test_view.trans_point(
torch.tensor([0.03, 0, 0], device=device.default()) torch.tensor([0.03, 0, 0], device=device.default())
), ),
test_view.r test_view.r
), ),
ret_raw=True, ret_raw=True,
mono_trans=test_view, mono_trans=test_view,
shift=0) shift=0)
plot_figures(left_images, right_images, left_center, right_center) plot_figures(left_images, right_images, left_center, right_center)
misc.create_dir('output/mono_test') os.makedirs('output/mono_test', exist_ok=True)
for key in left_images: for key in left_images:
img.save( img.save(
left_images[key], 'output/mono_test/set%d_%s_l.png' % (set_id, key)) left_images[key], 'output/mono_test/set%d_%s_l.png' % (set_id, key))
for key in right_images: for key in right_images:
img.save( img.save(
right_images[key], 'output/mono_test/set%d_%s_r.png' % (set_id, key)) right_images[key], 'output/mono_test/set%d_%s_r.png' % (set_id, key))
``` ```
......
...@@ -58,7 +58,7 @@ def train(): ...@@ -58,7 +58,7 @@ def train():
epoch = EPOCH_BEGIN epoch = EPOCH_BEGIN
iters = EPOCH_BEGIN * len(train_data_loader) * BATCH_SIZE iters = EPOCH_BEGIN * len(train_data_loader) * BATCH_SIZE
misc.create_dir(RUN_DIR) os.makedirs(RUN_DIR, exist_ok=True)
perf = Perf(enable=(MODE == "Perf"), start=True) perf = Perf(enable=(MODE == "Perf"), start=True)
writer = SummaryWriter(RUN_DIR) writer = SummaryWriter(RUN_DIR)
...@@ -129,7 +129,7 @@ def test(net_file: str): ...@@ -129,7 +129,7 @@ def test(net_file: str):
# 3. Test on train dataset # 3. Test on train dataset
print("Begin test on train dataset...") print("Begin test on train dataset...")
misc.create_dir(OUTPUT_DIR) os.makedirs(OUTPUT_DIR, exist_ok=True)
for view_idxs, view_images, _, view_positions in train_data_loader: for view_idxs, view_images, _, view_positions in train_data_loader:
out_view_images = model(view_positions) out_view_images = model(view_positions)
img.save(view_images, img.save(view_images,
......
...@@ -316,8 +316,8 @@ def train(): ...@@ -316,8 +316,8 @@ def train():
if epochRange.start > 1: if epochRange.start > 1:
iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model) iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model)
else: else:
misc.create_dir(run_dir) os.makedirs(run_dir, exist_ok=True)
misc.create_dir(log_dir) os.makedirs(log_dir, exist_ok=True)
iters = 0 iters = 0
# 3. Train # 3. Train
...@@ -400,7 +400,7 @@ def test(): ...@@ -400,7 +400,7 @@ def test():
# 4. Save results # 4. Save results
print('Saving results...') print('Saving results...')
misc.create_dir(output_dir) os.makedirs(output_dir, exist_ok=True)
for key in out: for key in out:
shape = [n] + list(dataset.res) + list(out[key].size()[1:]) shape = [n] + list(dataset.res) + list(out[key].size()[1:])
...@@ -446,7 +446,7 @@ def test(): ...@@ -446,7 +446,7 @@ def test():
img.save_video(out['color'], output_file, 30) img.save_video(out['color'], output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_color" output_subdir = f"{output_dir}/{output_dataset_id}_color"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.output_flags['depth']: if args.output_flags['depth']:
...@@ -457,13 +457,13 @@ def test(): ...@@ -457,13 +457,13 @@ def test():
img.save_video(colorized_depths, output_file, 30) img.save_video(colorized_depths, output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_depth" output_subdir = f"{output_dir}/{output_dataset_id}_depth"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(colorized_depths, [ img.save(colorized_depths, [
f'{output_subdir}/{i:0>4d}.png' f'{output_subdir}/{i:0>4d}.png'
for i in dataset.indices for i in dataset.indices
]) ])
output_subdir = f"{output_dir}/{output_dataset_id}_bins" output_subdir = f"{output_dir}/{output_dataset_id}_bins"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.output_flags['layers']: if args.output_flags['layers']:
...@@ -473,7 +473,7 @@ def test(): ...@@ -473,7 +473,7 @@ def test():
img.save_video(out['layers'][j], output_file, 30) img.save_video(out['layers'][j], output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_layers" output_subdir = f"{output_dir}/{output_dataset_id}_layers"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
for j in range(config.sa['n_samples']): for j in range(config.sa['n_samples']):
img.save(out['layers'][j], [ img.save(out['layers'][j], [
f'{output_subdir}/{i:0>4d}[{j:0>3d}].png' f'{output_subdir}/{i:0>4d}[{j:0>3d}].png'
...@@ -543,7 +543,7 @@ def test1(): ...@@ -543,7 +543,7 @@ def test1():
# 4. Save results # 4. Save results
print('Saving results...') print('Saving results...')
misc.create_dir(output_dir) os.makedirs(output_dir, exist_ok=True)
for key in out: for key in out:
shape = [n] + list(dataset.res) + list(out[key].size()[1:]) shape = [n] + list(dataset.res) + list(out[key].size()[1:])
...@@ -587,7 +587,7 @@ def test1(): ...@@ -587,7 +587,7 @@ def test1():
img.save_video(out['color'], output_file, 30) img.save_video(out['color'], output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_color" output_subdir = f"{output_dir}/{output_dataset_id}_color"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.output_flags['depth']: if args.output_flags['depth']:
...@@ -598,7 +598,7 @@ def test1(): ...@@ -598,7 +598,7 @@ def test1():
img.save_video(colorized_depths, output_file, 30) img.save_video(colorized_depths, output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_depth" output_subdir = f"{output_dir}/{output_dataset_id}_depth"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(colorized_depths, [ img.save(colorized_depths, [
f'{output_subdir}/{i:0>4d}.png' f'{output_subdir}/{i:0>4d}.png'
for i in dataset.indices for i in dataset.indices
...@@ -611,7 +611,7 @@ def test1(): ...@@ -611,7 +611,7 @@ def test1():
img.save_video(out['layers'][j], output_file, 30) img.save_video(out['layers'][j], output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_layers" output_subdir = f"{output_dir}/{output_dataset_id}_layers"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
for j in range(config.sa['n_samples']): for j in range(config.sa['n_samples']):
img.save(out['layers'][j], [ img.save(out['layers'][j], [
f'{output_subdir}/{i:0>4d}[{j:0>3d}].png' f'{output_subdir}/{i:0>4d}[{j:0>3d}].png'
...@@ -679,7 +679,7 @@ def test2(): ...@@ -679,7 +679,7 @@ def test2():
# 4. Save results # 4. Save results
print('Saving results...') print('Saving results...')
misc.create_dir(output_dir) os.makedirs(output_dir, exist_ok=True)
for key in out: for key in out:
shape = [n] + list(dataset.res) + list(out[key].size()[1:]) shape = [n] + list(dataset.res) + list(out[key].size()[1:])
...@@ -723,7 +723,7 @@ def test2(): ...@@ -723,7 +723,7 @@ def test2():
img.save_video(out['color'], output_file, 30) img.save_video(out['color'], output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_color" output_subdir = f"{output_dir}/{output_dataset_id}_color"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.output_flags['depth']: if args.output_flags['depth']:
...@@ -734,7 +734,7 @@ def test2(): ...@@ -734,7 +734,7 @@ def test2():
img.save_video(colorized_depths, output_file, 30) img.save_video(colorized_depths, output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_depth" output_subdir = f"{output_dir}/{output_dataset_id}_depth"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
img.save(colorized_depths, [ img.save(colorized_depths, [
f'{output_subdir}/{i:0>4d}.png' f'{output_subdir}/{i:0>4d}.png'
for i in dataset.indices for i in dataset.indices
...@@ -747,7 +747,7 @@ def test2(): ...@@ -747,7 +747,7 @@ def test2():
img.save_video(out['layers'][j], output_file, 30) img.save_video(out['layers'][j], output_file, 30)
else: else:
output_subdir = f"{output_dir}/{output_dataset_id}_layers" output_subdir = f"{output_dir}/{output_dataset_id}_layers"
misc.create_dir(output_subdir) os.makedirs(output_subdir, exist_ok=True)
for j in range(config.sa['n_samples']): for j in range(config.sa['n_samples']):
img.save(out['layers'][j], [ img.save(out['layers'][j], [
f'{output_subdir}/{i:0>4d}[{j:0>3d}].png' f'{output_subdir}/{i:0>4d}[{j:0>3d}].png'
......
setup.py 0 → 100644
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import glob
import os
import sys
# build clib
src_root = "clib"
sources = glob.glob(f"{src_root}/src/*.cpp") + glob.glob(f"{src_root}/src/*.cu")
includes = f"{sys.path[0]}/{src_root}/include"
setup(
name='dvs',
ext_modules=[
CUDAExtension(
name='clib._ext',
sources=sources,
extra_compile_args={
"cxx": ["-O2", f"-I{includes}"],
"nvcc": ["-O2", f"-I{includes}"],
},
)
],
cmdclass={
'build_ext': BuildExtension
}
)
\ No newline at end of file
import os
import shutil
from sys import stdout
from time import sleep
from utils.progress_bar import *
i = 0
while True:
rows = shutil.get_terminal_size().lines
cols = shutil.get_terminal_size().columns
os.system('cls' if os.name == 'nt' else 'clear')
stdout.write("\n" * (rows - 1))
progress_bar(i, 10000, "Test", "XXX")
i += 1
sleep(0.02)
test.py 0 → 100644
import os
import argparse
import torch
import torch.nn.functional as nn_f
from math import nan, ceil, prod
from pathlib import Path
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str,
help='The model file to load for testing')
parser.add_argument('-r', '--output-res', type=str,
help='Output resolution')
parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'],
help='Specify what to output (perf, color, depth, all)')
parser.add_argument('--output-type', type=str, default='image',
help='Specify the output type (image, video, debug)')
parser.add_argument('--views', type=str,
help='Specify the range of views to test')
parser.add_argument('-p', '--prompt', action='store_true',
help='Interactive prompt mode')
parser.add_argument('--time', action='store_true',
help='Enable time measurement')
parser.add_argument('dataset', type=str,
help='Dataset description file')
args = parser.parse_args()
import model as mdl
from loss.ssim import ssim
from utils import color
from utils import interact
from utils import device
from utils import img
from utils.perf import Perf, enable_perf, get_perf_result
from utils.progress_bar import progress_bar
from data.dataset_factory import *
from data.loader import DataLoader
from utils.constants import HUGE_FLOAT
RAYS_PER_BATCH = 2 ** 14
DATA_LOADER_CHUNK_SIZE = 1e8
data_desc_path = DatasetFactory.get_dataset_desc_path(args.dataset)
os.chdir(data_desc_path.parent)
nets_dir = Path("_nets")
data_desc_path = data_desc_path.name
def set_outputs(args, outputs_str: str):
args.output = [s.strip() for s in outputs_str.split(',')]
if args.prompt: # Prompt test model, output resolution, output mode
model_files = [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.tar")] \
+ [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.pth")]
args.model = interact.input_enum('Specify test model:', model_files,
err_msg='No such model file')
args.output_res = interact.input_ex('Specify output resolution:',
default='')
set_outputs(args, interact.input_ex('Specify the outputs | [perf,color,depth,layers,diffuse,specular]/all:',
default='perf,color'))
args.output_type = interact.input_enum('Specify the output type | image/video:',
['image', 'video'],
err_msg='Wrong output type',
default='image')
args.output_res = tuple(int(s) for s in reversed(args.output_res.split('x'))) if args.output_res \
else None
args.output_flags = {
item: item in args.output or 'all' in args.output
for item in ['perf', 'color', 'depth', 'layers', 'diffuse', 'specular']
}
args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None
if args.time:
enable_perf()
dataset = DatasetFactory.load(data_desc_path, res=args.output_res,
load_images=args.output_flags['perf'],
views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
model_path: Path = nets_dir / args.model
model_name = model_path.parent.name
model = mdl.load(model_path, {
"raymarching_early_stop_tolerance": 0.01,
# "raymarching_chunk_size_or_sections": [8],
"perturb_sample": False
})[0].to(device.default()).eval()
model_class = model.__class__.__name__
model_args = model.args
print(f"model: {model_name} ({model_class})")
print("args:", json.dumps(model.args0))
run_dir = model_path.parent
output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}"
output_dataset_id = '%s%s' % (
dataset.name,
f'_{args.output_res[1]}x{args.output_res[0]}' if args.output_res else ''
)
if __name__ == "__main__":
with torch.no_grad():
# 1. Initialize data loader
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=False, enable_preload=True,
color=color.from_str(model.args['color']))
# 3. Test on dataset
print("Begin test, batch size is %d" % RAYS_PER_BATCH)
i = 0
offset = 0
chns = model.chns('color')
n = dataset.n_views
total_pixels = prod([n, *dataset.res])
out = {}
if args.output_flags['perf'] or args.output_flags['color']:
out['color'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['diffuse']:
out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['specular']:
out['specular'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['depth']:
out['depth'] = torch.full([total_pixels, 1], HUGE_FLOAT, device=device.default())
gt_images = torch.empty_like(out['color']) if dataset.image_path else None
tot_time = 0
tot_iters = len(data_loader)
progress_bar(i, tot_iters, 'Inferring...')
for _, rays_o, rays_d, extra in data_loader:
if args.output_flags['perf']:
test_perf = Perf.Node("Test")
n_rays = rays_o.size(0)
idx = slice(offset, offset + n_rays)
ret = model(rays_o, rays_d, extra_outputs=[key for key in out.keys() if key != 'color'])
if ret is not None:
for key in out:
out[key][idx][ret['rays_mask']] = ret[key]
if args.output_flags['perf']:
test_perf.close()
torch.cuda.synchronize()
tot_time += test_perf.duration()
if gt_images is not None:
gt_images[idx] = extra['color']
i += 1
progress_bar(i, tot_iters, 'Inferring...')
offset += n_rays
# 4. Save results
print('Saving results...')
output_dir.mkdir(parents=True, exist_ok=True)
for key in out:
out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]])
if 'color' in out:
out['color'] = out['color'].permute(0, 3, 1, 2)
if 'diffuse' in out:
out['diffuse'] = out['diffuse'].permute(0, 3, 1, 2)
if 'specular' in out:
out['specular'] = out['specular'].permute(0, 3, 1, 2)
if args.output_flags['perf']:
perf_errors = torch.full([n], nan)
perf_ssims = torch.full([n], nan)
if gt_images is not None:
gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2)
for i in range(n):
perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item()
perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100
perf_mean_time = tot_time / n
perf_mean_error = torch.mean(perf_errors).item()
perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv'
# Remove old performance reports
for file in output_dir.glob(f'perf_{output_dataset_id}*'):
file.unlink()
# Save new performance reports
with (output_dir / perf_name).open('w') as fp:
fp.write('View, PSNR, SSIM\n')
fp.writelines([
f'{dataset.indices[i]}, '
f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n'
for i in range(n)
])
for output_type in ['color', 'diffuse', 'specular']:
if not args.output_flags[output_type]:
continue
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4"
img.save_video(out[output_type], output_file, 30)
else:
output_subdir = output_dir / f"{output_dataset_id}_{output_type}"
output_subdir.mkdir(exist_ok=True)
img.save(out[output_type],
[f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.output_flags['depth']:
colored_depths = img.colorize_depthmap(out['depth'][..., 0], model_args['sample_range'])
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_depth.mp4"
img.save_video(colored_depths, output_file, 30)
else:
output_subdir = output_dir / f"{output_dataset_id}_depth"
output_subdir.mkdir(exist_ok=True)
img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
#output_subdir = output_dir / f"{output_dataset_id}_bins"
# output_dir.mkdir(exist_ok=True)
#img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.time:
s = "Performance Report ==>\n"
res = get_perf_result()
if res is None:
s += "No available data.\n"
else:
for key, val in res.items():
path_segs = key.split("/")
s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
print(s)
""" """
Clean trained nets (*/model-epoch_#.pth) whose epoch is neither the largest nor a multiple of 50 Clean trained nets (*/checkpoint_#.tar) whose epoch is neither the largest nor a multiple of 10
""" """
import sys import sys
import os import os
sys.path.append(os.path.abspath(sys.path[0] + '/../')) base_dir = os.path.abspath(sys.path[0] + '/../')
sys.path.append(base_dir)
if __name__ == "__main__": if __name__ == "__main__":
for dirpath, _, filenames in os.walk('../data'): root = sys.argv[1] if len(sys.argv) > 1 else f'{base_dir}/data'
epoch_list = [int(filename[12:-4]) for filename in filenames print(f"Clean model files in {root}...")
if filename.startswith("model-epoch_")]
for dirpath, _, filenames in os.walk(root):
epoch_list = [int(filename[11:-4]) for filename in filenames
if filename.startswith("checkpoint_")]
if len(epoch_list) <= 1: if len(epoch_list) <= 1:
continue continue
epoch_list.sort() epoch_list.sort()
for epoch in epoch_list[:-1]: for epoch in epoch_list[:-1]:
if epoch % 50 != 0: if epoch % 10 != 0:
file_to_del = f"{dirpath}/model-epoch_{epoch}.pth" file_to_del = f"{dirpath}/checkpoint_{epoch}.tar"
print(f"Clean model file: {file_to_del}") print(f"Clean model file: {file_to_del}")
os.remove(file_to_del) os.remove(file_to_del)
\ No newline at end of file
print("Finished.")
\ No newline at end of file
...@@ -18,6 +18,6 @@ os.chdir(in_set) ...@@ -18,6 +18,6 @@ os.chdir(in_set)
depthmaps = img.load(img_names) depthmaps = img.load(img_names)
depthmaps = torch.floor((depthmaps * 16)) / 16 depthmaps = torch.floor((depthmaps * 16)) / 16
misc.create_dir(out_set) os.makedirs(out_set, exist_ok=True)
os.chdir(out_set) os.chdir(out_set)
img.save(depthmaps, img_names) img.save(depthmaps, img_names)
\ No newline at end of file
...@@ -74,7 +74,7 @@ if __name__ == "__main__": ...@@ -74,7 +74,7 @@ if __name__ == "__main__":
# Load model` # Load model`
net, name = load_net(model_file) net, name = load_net(model_file)
misc.create_dir(os.path.join(opt.outdir, config.to_id())) os.makedirs(os.path.join(opt.outdir, config.to_id()), exist_ok=True)
# Export Sampler # Export Sampler
export_net(ExportNet(net), 'msl', { export_net(ExportNet(net), 'msl', {
......
...@@ -74,7 +74,7 @@ if __name__ == "__main__": ...@@ -74,7 +74,7 @@ if __name__ == "__main__":
# Load model` # Load model`
net, name = load_net(model_file) net, name = load_net(model_file)
misc.create_dir(os.path.join(opt.outdir, config.to_id())) os.makedirs(os.path.join(opt.outdir, config.to_id()), exist_ok=True)
# Export Sampler # Export Sampler
export_net(Sampler(net), 'sampler', { export_net(Sampler(net), 'sampler', {
......
...@@ -54,7 +54,7 @@ if __name__ == "__main__": ...@@ -54,7 +54,7 @@ if __name__ == "__main__":
rays_o = torch.empty(batch_size, 3, device=device.default()) rays_o = torch.empty(batch_size, 3, device=device.default())
rays_d = torch.empty(batch_size, 3, device=device.default()) rays_d = torch.empty(batch_size, 3, device=device.default())
misc.create_dir(opt.outdir) os.makedirs(opt.outdir, exist_ok=True)
# Export the model # Export the model
outpath = os.path.join(opt.outdir, config.to_id() + ".onnx") outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
......
...@@ -44,7 +44,7 @@ if not opt.output: ...@@ -44,7 +44,7 @@ if not opt.output:
else: else:
outdir = f"{dir_path}/export" outdir = f"{dir_path}/export"
output = os.path.join(outdir, f"{model_file.split('@')[0]}@{batch_size_str}.onnx") output = os.path.join(outdir, f"{model_file.split('@')[0]}@{batch_size_str}.onnx")
misc.create_dir(outdir) os.makedirs(outdir, exist_ok=True)
else: else:
output = opt.output output = opt.output
outname = os.path.splitext(os.path.split(output)[-1])[0] outname = os.path.splitext(os.path.split(output)[-1])[0]
......
...@@ -172,7 +172,7 @@ print('Dataset loaded. Views:', n_views) ...@@ -172,7 +172,7 @@ print('Dataset loaded. Views:', n_views)
videodir = os.path.dirname(os.path.abspath(opt.view_file)) videodir = os.path.dirname(os.path.abspath(opt.view_file))
tempdir = '/dev/shm/dvs_tmp/realvideo' tempdir = '/dev/shm/dvs_tmp/video'
videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}" videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}"
gazeout = f"{videodir}/{videoname}_gaze.csv" gazeout = f"{videodir}/{videoname}_gaze.csv"
if opt.noCE: if opt.noCE:
...@@ -220,8 +220,8 @@ def add_hint(image, center, right_center=None): ...@@ -220,8 +220,8 @@ def add_hint(image, center, right_center=None):
exit() exit()
misc.create_dir(os.path.dirname(inferout)) os.makedirs(os.path.dirname(inferout), exist_ok=True)
misc.create_dir(os.path.dirname(hintout)) os.makedirs(os.path.dirname(hintout), exist_ok=True)
hint_offset = infer_offset = 0 hint_offset = infer_offset = 0
if not opt.replace: if not opt.replace:
......
...@@ -8,7 +8,7 @@ from utils import misc ...@@ -8,7 +8,7 @@ from utils import misc
def batch_scale(src, target, size): def batch_scale(src, target, size):
misc.create_dir(target) os.makedirs(target, exist_ok=True)
for file_name in os.listdir(src): for file_name in os.listdir(src):
postfix = os.path.splitext(file_name)[1] postfix = os.path.splitext(file_name)[1]
if postfix == '.jpg' or postfix == '.png': if postfix == '.jpg' or postfix == '.png':
......
...@@ -11,7 +11,7 @@ from utils import misc ...@@ -11,7 +11,7 @@ from utils import misc
def copy_images(src_path, dst_path, n, offset=0): def copy_images(src_path, dst_path, n, offset=0):
misc.create_dir(os.path.dirname(dst_path)) os.makedirs(os.path.dirname(dst_path), exist_ok=True)
for i in range(n): for i in range(n):
copy(src_path % i, dst_path % (i + offset)) copy(src_path % i, dst_path % (i + offset))
......
from pathlib import Path
import sys
import argparse
import math
import torch
import torchvision.transforms.functional as trans_F
sys.path.append(str(Path(sys.path[0]).parent.absolute()))
from utils import img
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str)
parser.add_argument('dir', type=str)
args = parser.parse_args()
data_dir = Path(args.dir)
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
files = [file for file in data_dir.glob('*') if file.suffix == '.png' or file.suffix == '.jpg']
outfiles = [output_dir / file.name for file in data_dir.glob('*')
if file.suffix == '.png' or file.suffix == '.jpg']
images = img.load(files)
print(f"{images.size(0)} images loaded.")
out_images = torch.zeros_like(images)
H, W = images.shape[-2:]
for row in range(H):
phi = math.pi / H * (row + 0.5)
length = math.ceil(math.sin(phi) * W * 0.5) * 2
cols = slice((W - length) // 2, (W + length) // 2)
out_images[..., row:row + 1, cols] = trans_F.resize(images[..., row:row + 1, :], [1, length])
sys.stdout.write(f'{row + 1} / {H} processed. \r')
print('')
img.save(out_images, outfiles)
print(f"{images.size(0)} images saved.")
\ No newline at end of file
...@@ -4,26 +4,33 @@ import os ...@@ -4,26 +4,33 @@ import os
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
from itertools import product, repeat
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, default='train1') parser.add_argument('-o', '--output', type=str, default='train1')
parser.add_argument("-t", "--trans", type=float)
parser.add_argument("-v", "--views", type=int)
parser.add_argument('-g', '--grids', nargs='+', type=int)
parser.add_argument('dataset', type=str) parser.add_argument('dataset', type=str)
args = parser.parse_args() args = parser.parse_args()
if not args.dataset.endswith(".json"):
args.dataset = args.dataset.rstrip("/") + ".json"
if not args.output.endswith(".json"):
args.output = args.output.rstrip("/") + ".json"
data_desc_path = args.dataset in_desc_path = Path(args.dataset)
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0] in_name = in_desc_path.stem
data_dir = os.path.dirname(data_desc_path) + '/' root_dir = in_desc_path.parent
out_desc_path: Path = root_dir / args.output
out_dir = out_desc_path.with_suffix("")
with open(data_desc_path, 'r') as fp: with open(in_desc_path, 'r') as fp:
dataset_desc = json.load(fp) dataset_desc = json.load(fp)
indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples'])
idx = 0 idx = 0
''' '''
for i in range(3): for i in range(3):
...@@ -40,7 +47,7 @@ for i in range(3): ...@@ -40,7 +47,7 @@ for i in range(3):
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(out_desc, fp, indent=4) json.dump(out_desc, fp, indent=4)
misc.create_dir(os.path.join(data_dir, out_desc_name)) os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
for k in range(len(views)): for k in range(len(views)):
os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
...@@ -61,26 +68,62 @@ for xi in range(0, 4, 2): ...@@ -61,26 +68,62 @@ for xi in range(0, 4, 2):
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(out_desc, fp, indent=4) json.dump(out_desc, fp, indent=4)
misc.create_dir(os.path.join(data_dir, out_desc_name)) os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
for k in range(len(views)): for k in range(len(views)):
os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
idx += 1 idx += 1
''' '''
from itertools import product
out_desc_name = args.output
def extract_by_grid(*grid_indices):
indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples'])
views = []
for idx in product(*grid_indices):
views += indices[idx].flatten().tolist()
return views
def extract_by_trans(max_trans, max_views):
if max_trans is not None:
centers = np.array(dataset_desc['view_centers'])
trans = np.linalg.norm(centers, axis=-1)
indices = np.nonzero(trans <= max_trans)[0]
else:
indices = np.arange(len(dataset_desc['view_centers']))
if max_views is not None:
indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]])
return indices.tolist()
if args.grids:
views = extract_by_grid(*repeat(args.grids, 3)) # , [0, 2, 3, 5], [1, 2, 3, 4])
else:
views = extract_by_trans(args.trans, args.views)
image_path = dataset_desc['view_file_pattern']
if "/" not in image_path:
image_path = in_name + "/" + image_path
# Save new dataset
out_desc = dataset_desc.copy() out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}" out_desc['view_file_pattern'] = image_path.split('/')[-1]
views = []
for idx in product([1,2,3,4], [1,2,3,4], [1,2,3,4]):#, [0, 2, 3, 5], [1, 2, 3, 4]):
views += indices[idx].flatten().tolist()
out_desc['samples'] = [len(views)] out_desc['samples'] = [len(views)]
out_desc['views'] = views out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() if 'view_rots' in dataset_desc:
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
# Write new data desc
with open(out_desc_path, 'w') as fp:
json.dump(out_desc, fp, indent=4) json.dump(out_desc, fp, indent=4)
misc.create_dir(os.path.join(data_dir, out_desc_name))
# Create symbol links of images
out_dir.mkdir()
for k in range(len(views)): for k in range(len(views)):
os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), if out_dir.parent.absolute() == root_dir.absolute():
os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) os.symlink(Path("..") / (image_path % views[k]),
out_dir / (out_desc['view_file_pattern'] % views[k]))
else:
os.symlink(root_dir.absolute() / (image_path % views[k]),
out_dir / (out_desc['view_file_pattern'] % views[k]))
train.py 0 → 100644
import argparse
import logging
import os
from pathlib import Path
import sys
import model as mdl
import train
from utils import color
from utils import device
from data.dataset_factory import *
from data.loader import DataLoader
from utils.misc import list_epochs, print_and_log
RAYS_PER_BATCH = 2 ** 16
DATA_LOADER_CHUNK_SIZE = 1e8
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
help='Net config files')
parser.add_argument('-e', '--epochs', type=int, default=50,
help='Max epochs for train')
parser.add_argument('--perf', type=int, default=0,
help='Performance measurement frames (0 for disabling performance measurement)')
parser.add_argument('--prune', type=int, default=5,
help='Prune voxels on every # epochs')
parser.add_argument('--split', type=int, default=10,
help='Split voxels on every # epochs')
parser.add_argument('--views', type=str,
help='Specify the range of views to train')
parser.add_argument('path', type=str,
help='Dataset description file')
args = parser.parse_args()
argpath = Path(args.path)
# argpath: May be model path or data path
# 1) model path: continue training on the specified model
# 2) data path: train a new model using specified dataset
if argpath.suffix == ".tar":
args.mdl_path = argpath
else:
existed_epochs = list_epochs(argpath, "checkpoint_*.tar")
args.mdl_path = argpath / f"checkpoint_{existed_epochs[-1]}.tar" if existed_epochs else None
if args.mdl_path:
# Infer dataset path from model path
# The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
dataset_name = args.mdl_path.parent.parent.name
dataset_dir = args.mdl_path.parent.parent.parent.parent
args.data_path = dataset_dir / dataset_name
args.mdl_path = args.mdl_path.relative_to(dataset_dir)
else:
args.data_path = argpath
args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None
dataset = DatasetFactory.load(args.data_path, views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)
if args.mdl_path:
# Load model to continue training
model, states = mdl.load(args.mdl_path)
model_name = args.mdl_path.parent.name
model_class = model.__class__.__name__
model_args = model.args
else:
# Create model from specified configuration
with Path(f'{sys.path[0]}/configs/{args.config}.json').open() as fp:
config = json.load(fp)
model_name = args.config
model_class = config['model']
model_args = config['args']
model_args['bbox'] = dataset.bbox
model_args['depth_range'] = dataset.depth_range
model, states = mdl.create(model_class, model_args), None
model.to(device.default()).train()
run_dir = Path(f"_nets/{dataset.name}/{model_name}")
run_dir.mkdir(parents=True, exist_ok=True)
log_file = run_dir / "train.log"
logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO,
filename=log_file, filemode='a' if log_file.exists() else 'w')
print_and_log(f"model: {model_name} ({model_class})")
print_and_log(f"args: {json.dumps(model.args0)}")
if __name__ == "__main__":
# 1. Initialize data loader
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=True, enable_preload=True,
color=color.from_str(model.args['color']))
# 2. Initialize model and trainer
trainer = train.get_trainer(model, run_dir=run_dir, states=states, perf_frames=args.perf,
pruning_loop=args.prune, splitting_loop=args.split)
# 3. Train
trainer.train(data_loader, args.epochs)
\ No newline at end of file
import importlib
import os
from model.base import BaseModel
from . import base
# Automatically import any python files this directory
package_dir = os.path.dirname(__file__)
package = os.path.basename(package_dir)
for file in os.listdir(package_dir):
path = os.path.join(package_dir, file)
if file.startswith('_') or file.startswith('.'):
continue
if file.endswith('.py') or os.path.isdir(path):
model_name = file[:-3] if file.endswith('.py') else file
importlib.import_module(f'{package}.{model_name}')
def get_class(class_name: str) -> type:
return base.train_classes[class_name]
def get_trainer(model: BaseModel, **kwargs) -> base.Train:
train_class = get_class(model.trainer)
return train_class(model, **kwargs)
import csv
import logging
import sys
import time
import torch
import torch.nn.functional as nn_f
from typing import Dict
from pathlib import Path
import loss
from utils.constants import HUGE_FLOAT
from utils.misc import format_time
from utils.progress_bar import progress_bar
from utils.perf import Perf, checkpoint, enable_perf, perf, get_perf_result
from data.loader import DataLoader
from model.base import BaseModel
from model import save
train_classes = {}
class BaseTrainMeta(type):
def __new__(cls, name, bases, attrs):
new_cls = type.__new__(cls, name, bases, attrs)
train_classes[name] = new_cls
return new_cls
class Train(object, metaclass=BaseTrainMeta):
@property
def perf_mode(self):
return self.perf_frames > 0
def __init__(self, model: BaseModel, *,
run_dir: Path, states: dict = None, perf_frames: int = 0) -> None:
super().__init__()
self.model = model
self.epoch = 0
self.iters = 0
self.run_dir = run_dir
self.model.train()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
if states:
if 'epoch' in states:
self.epoch = states['epoch']
if 'iters' in states:
self.iters = states['iters']
if 'opti' in states:
self.optimizer.load_state_dict(states['opti'])
# For performance measurement
self.perf_frames = perf_frames
if self.perf_mode:
enable_perf()
def train(self, data_loader: DataLoader, max_epochs: int):
self.data_loader = data_loader
self.iters_per_epoch = self.perf_frames or len(data_loader)
print("Begin training...")
while self.epoch < max_epochs:
self.epoch += 1
self._train_epoch()
self._save_checkpoint()
print("Train finished")
def _save_checkpoint(self):
save(self.run_dir / f'checkpoint_{self.epoch}.tar', self.model, epoch=self.epoch,
iters=self.iters, opti=self.optimizer.state_dict())
for i in range(1, self.epoch):
if i % 10 != 0:
(self.run_dir / f'checkpoint_{i}.tar').unlink(missing_ok=True)
def _show_progress(self, iters_in_epoch: int, loss: Dict[str, float] = {}):
loss_val = loss.get('val', 0)
loss_min = loss.get('min', 0)
loss_max = loss.get('max', 0)
loss_avg = loss.get('avg', 0)
iters_per_epoch = self.perf_frames or len(self.data_loader)
progress_bar(iters_in_epoch, iters_per_epoch,
f"Loss: {loss_val:.2e} ({loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e})",
f"Epoch {self.epoch:<3d}",
f" {self.run_dir}")
def _show_perf(self):
s = "Performance Report ==>\n"
res = get_perf_result()
if res is None:
s += "No available data.\n"
else:
for key, val in res.items():
path_segs = key.split("/")
s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
print(s)
@perf
def _train_iter(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
extra: Dict[str, torch.Tensor]) -> float:
out = self.model(rays_o, rays_d, extra_outputs=['energies', 'speculars'])
if 'rays_mask' in out:
extra = {key: value[out['rays_mask']] for key, value in extra.items()}
checkpoint("Forward")
self.optimizer.zero_grad()
loss_val = loss.mse_loss(out['color'], extra['color'])
if self.model.args.get('density_regularization_weight'):
loss_val += loss.cauchy_loss(out['energies'],
s=self.model.args['density_regularization_scale']) \
* self.model.args['density_regularization_weight']
if self.model.args.get('specular_regularization_weight'):
loss_val += loss.cauchy_loss(out['speculars'],
s=self.model.args['specular_regularization_scale']) \
* self.model.args['specular_regularization_weight']
checkpoint("Compute loss")
loss_val.backward()
checkpoint("Backward")
self.optimizer.step()
checkpoint("Update")
return loss_val.item()
def _train_epoch(self):
iters_in_epoch = 0
loss_min = HUGE_FLOAT
loss_max = 0
loss_avg = 0
train_epoch_node = Perf.Node("Train Epoch")
self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0})
for idx, rays_o, rays_d, extra in self.data_loader:
loss_val = self._train_iter(rays_o, rays_d, extra)
loss_min = min(loss_min, loss_val)
loss_max = max(loss_max, loss_val)
loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
self.iters += 1
iters_in_epoch += 1
self._show_progress(iters_in_epoch, loss={
'val': loss_val,
'min': loss_min,
'max': loss_max,
'avg': loss_avg
})
if self.perf_mode and iters_in_epoch >= self.perf_frames:
self._show_perf()
exit()
train_epoch_node.close()
torch.cuda.synchronize()
epoch_dur = train_epoch_node.duration() / 1000
logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} "
f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). "
f"Loss is {loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e}")
def _train_epoch_debug(self): # TBR
iters_in_epoch = 0
loss_min = HUGE_FLOAT
loss_max = 0
loss_avg = 0
self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0})
indices = []
debug_data = []
for idx, rays_o, rays_d, extra in self.data_loader:
out = self.model(rays_o, rays_d, extra_outputs=['layers', 'weights'])
loss_val = nn_f.mse_loss(out['color'], extra['color']).item()
loss_min = min(loss_min, loss_val)
loss_max = max(loss_max, loss_val)
loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
self.iters += 1
iters_in_epoch += 1
self._show_progress(iters_in_epoch, loss={
'val': loss_val,
'min': loss_min,
'max': loss_max,
'avg': loss_avg
})
indices.append(idx)
debug_data.append(torch.cat([
extra['view_idx'][..., None],
extra['pix_idx'][..., None],
rays_d,
#out['samples'].pts[:, 215:225].reshape(idx.size(0), -1),
#out['samples'].dirs[:, :3].reshape(idx.size(0), -1),
#out['samples'].voxel_indices[:, 215:225],
out['states'].densities[:, 210:230].detach().reshape(idx.size(0), -1),
out['states'].energies[:, 210:230].detach().reshape(idx.size(0), -1)
# out['color'].detach()
], dim=-1))
# states: VolumnRenderer.States = out['states'] # TBR
indices = torch.cat(indices, dim=0)
debug_data = torch.cat(debug_data, dim=0)
indices, sort = indices.sort()
debug_data = debug_data[sort]
name = "rand.csv" if self.data_loader.shuffle else "seq.csv"
with (self.run_dir / name).open("w") as fp:
csv_writer = csv.writer(fp)
csv_writer.writerows(torch.cat([indices[:20, None], debug_data[:20]], dim=-1).tolist())
return
with (self.run_dir / 'states.csv').open("w") as fp:
csv_writer = csv.writer(fp)
for chunk_info in states.chunk_infos:
csv_writer.writerow(
[*chunk_info['range'], chunk_info['hits'], chunk_info['core_i']])
if chunk_info['hits'] > 0:
csv_writer.writerows(torch.cat([
chunk_info['samples'].pts,
chunk_info['samples'].dirs,
chunk_info['samples'].voxel_indices[:, None],
chunk_info['colors'],
chunk_info['energies']
], dim=-1).tolist())
csv_writer.writerow([])
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment