Skip to content
Snippets Groups Projects
Commit dcba5844 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

update run_spherical_view_syn.py

parent 2a1d7973
Branches
No related merge requests found
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
import sys import sys
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
rootdir = os.path.abspath(sys.path[0] + '/../') rootdir = os.path.abspath(sys.path[0] + '/../')
sys.path.append(rootdir) sys.path.append(rootdir)
torch.cuda.set_device(0) torch.cuda.set_device(0)
print("Set CUDA:%d as current device." % torch.cuda.current_device()) print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False) torch.autograd.set_grad_enabled(False)
from data.spherical_view_syn import * from data.spherical_view_syn import *
from configs.spherical_view_syn import SphericalViewSynConfig from configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio from utils import netio
from utils import img from utils import img
from utils import device from utils import device
from utils.view import * from utils.view import *
from components.fnr import FoveatedNeuralRenderer from components.fnr import FoveatedNeuralRenderer
def load_net(path): def load_net(path):
config = SphericalViewSynConfig() config = SphericalViewSynConfig()
config.from_id(os.path.splitext(path)[0]) config.from_id(os.path.splitext(path)[0])
config.SAMPLE_PARAMS['perturb_sample'] = False config.SAMPLE_PARAMS['perturb_sample'] = False
net = config.create_net().to(device.default()) net = config.create_net().to(device.default())
netio.load(path, net) netio.load(path, net)
return net return net
def find_file(prefix): def find_file(prefix):
for path in os.listdir(): for path in os.listdir():
if path.startswith(prefix): if path.startswith(prefix):
return path return path
return None return None
def load_views(data_desc_file) -> Trans: def load_views(data_desc_file) -> Trans:
with open(data_desc_file, 'r', encoding='utf-8') as file: with open(data_desc_file, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read()) data_desc = json.loads(file.read())
view_centers = torch.tensor( view_centers = torch.tensor(
data_desc['view_centers'], device=device.default()).view(-1, 3) data_desc['view_centers'], device=device.default()).view(-1, 3)
view_rots = torch.tensor( view_rots = torch.tensor(
data_desc['view_rots'], device=device.default()).view(-1, 3, 3) data_desc['view_rots'], device=device.default()).view(-1, 3, 3)
return Trans(view_centers, view_rots) return Trans(view_centers, view_rots)
def plot_images(images): def plot_images(images):
plt.figure(figsize=(12, 4)) plt.figure(figsize=(12, 4))
plt.subplot(131) plt.subplot(131)
img.plot(images['layers_img'][0]) img.plot(images['layers_img'][0])
plt.subplot(132) plt.subplot(132)
img.plot(images['layers_img'][1]) img.plot(images['layers_img'][1])
plt.subplot(133) plt.subplot(133)
img.plot(images['layers_img'][2]) img.plot(images['layers_img'][2])
plt.figure(figsize=(12, 12)) plt.figure(figsize=(12, 12))
img.plot(images['blended']) img.plot(images['blended'])
scenes = { scenes = {
'classroom': 'classroom_all', 'classroom': 'classroom_all',
'stones': 'stones_all', 'stones': 'stones_all',
'barbershop': 'barbershop_all', 'barbershop': 'barbershop_all',
'lobby': 'lobby_all' 'lobby': 'lobby_all'
} }
fov_list = [20, 45, 110] fov_list = [20, 45, 110]
res_list = [(256, 256), (256, 256), (256, 230)] res_list = [(256, 256), (256, 256), (256, 230)]
res_full = (1600, 1440) res_full = (1600, 1440)
``` ```
%% Output %% Output
Set CUDA:0 as current device. Set CUDA:0 as current device.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
scene = 'lobby' scene = 'barbershop'
os.chdir(f'{rootdir}/data/__new/{scenes[scene]}') os.chdir(f'{rootdir}/data/__new/{scenes[scene]}')
print('Change working directory to ', os.getcwd()) print('Change working directory to ', os.getcwd())
fovea_net = load_net(find_file('fovea')) fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph')) periph_net = load_net(find_file('periph'))
renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]), renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]),
res_full, device=device.default()) res_full, using_mask=False, device=device.default())
``` ```
%% Output %% Output
Change working directory to /home/dengnc/dvs/data/__new/lobby_all Change working directory to /home/dengnc/dvs/data/__new/barbershop_all
Load net from fovea@snerffast4-rgb_e6_fc512x4_d2.00-50.00_s64_~p.pth ... Load net from fovea@snerffast4-rgb_e6_fc512x4_d1.20-6.00_s64_~p.pth ...
Load net from periph@snerffast4-rgb_e6_fc256x4_d2.00-50.00_s64_~p.pth ... Load net from periph@snerffast4-rgb_e6_fc256x4_d1.20-6.00_s64_~p.pth ...
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
params = { params = {
'classroom': [ 'classroom': [
[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, -53, 0, 0, 0], [0, 0, 0, -53, 0, 0, 0],
[0, 0, 0, 20, -20, 0, 0] [0, 0, 0, 20, -20, 0, 0]
], ],
'stones': [ 'stones': [
[0, 0, 0, 0, 10, -300, -50], [0, 0, 0, 0, 10, -300, -50],
[0, 0, 0, 0, 10, 200, -50] [0, 0, 0, 0, 10, 200, -50]
], ],
'barbershop': [ 'barbershop': [
[0, 0, 0, 20, 0, -300, 50], [0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, -140, -30, 150, -250], #[0, 0, 0, 20, 0, -300, 50],
[0, 0, 0, -60, -30, 75, -125], #[0, 0, 0, -140, -30, 150, -250],
#[0, 0, 0, -60, -30, 75, -125],
], ],
'lobby': [ 'lobby': [
#[0, 0, 0, 0, 0, 75, 0], #[0, 0, 0, 0, 0, 75, 0],
#[0, 0, 0, 0, 0, 5, 150], #[0, 0, 0, 0, 0, 5, 150],
[0, 0, 0, -120, 0, 75, 50], [0, 0, 0, -120, 0, 75, 50],
] ]
} }
for i, param in enumerate(params[scene]): for i, param in enumerate(params[scene]):
view = Trans(torch.tensor(param[:3], device=device.default()), view = Trans(torch.tensor(param[:3], device=device.default()),
torch.tensor(euler_to_matrix([-param[4], param[3], 0]), device=device.default()).view(3, 3)) torch.tensor(euler_to_matrix([-param[4], param[3], 0]), device=device.default()).view(3, 3))
images = renderer(view, param[-2:]) images = renderer(view, param[-2:])
if False: if False:
outputdir = '../__demo/mono/' outputdir = '../__demo/mono/'
misc.create_dir(outputdir) misc.create_dir(outputdir)
img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png') img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')
img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png') img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')
img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png') img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')
img.save(images['blended'], f'{outputdir}{scene}_{i}_blended.png') img.save(images['blended'], f'{outputdir}{scene}_{i}_blended.png')
else: else:
images = plot_images(images) images = plot_images(images)
``` ```
%% Output
/home/dengnc/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py:3828: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.
warnings.warn(
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
# Load Dataset # Load Dataset
views = load_views('train.json') views = load_views('train.json')
print('Dataset loaded.') print('Dataset loaded.')
print('views:', views.size()) print('views:', views.size())
for view_idx in range(views.size()[0]): for view_idx in range(views.size()[0]):
center = (0, 0) center = (0, 0)
test_view = views.get(view_idx) test_view = views.get(view_idx)
render_view(test_view, center) render_view(test_view, center)
''' '''
images = gen(center, test_view) images = gen(center, test_view)
outputdir = '../__2_demo/layer_blend/' outputdir = '../__2_demo/layer_blend/'
misc.create_dir(outputdir) misc.create_dir(outputdir)
for key in images: for key in images:
img.save(images[key], outputdir + '%s_view%04d_%s.png' % (scene, view_idx, key)) img.save(images[key], outputdir + '%s_view%04d_%s.png' % (scene, view_idx, key))
''' '''
''' '''
images = gen( images = gen(
center, test_view, center, test_view,
mono_trans=Trans(test_view.trans_point( mono_trans=Trans(test_view.trans_point(
torch.tensor([0.03, 0, 0], device=device.default()) torch.tensor([0.03, 0, 0], device=device.default())
), test_view.r)) ), test_view.r))
outputdir = '../__2_demo/output_mono/ref_as_right_eye/' outputdir = '../__2_demo/output_mono/ref_as_right_eye/'
misc.create_dir(outputdir) misc.create_dir(outputdir)
for key in images: for key in images:
key = 'blended' key = 'blended'
img.save(images[key], outputdir + '%s_view%04d_%s.png' % (scene, view_idx, key)) img.save(images[key], outputdir + '%s_view%04d_%s.png' % (scene, view_idx, key))
''' '''
''' '''
left_images = gen(center, left_images = gen(center,
Trans( Trans(
test_view.trans_point( test_view.trans_point(
torch.tensor([-0.03, 0, 0], device=device.default()) torch.tensor([-0.03, 0, 0], device=device.default())
), ),
test_view.r), test_view.r),
mono_trans=test_view) mono_trans=test_view)
right_images = gen(center, Trans( right_images = gen(center, Trans(
test_view.trans_point( test_view.trans_point(
torch.tensor([0.03, 0, 0], device=device.default()) torch.tensor([0.03, 0, 0], device=device.default())
), test_view.r), mono_trans=test_view) ), test_view.r), mono_trans=test_view)
outputdir = '../__2_demo/mono_periph/stereo/' outputdir = '../__2_demo/mono_periph/stereo/'
misc.create_dir(outputdir) misc.create_dir(outputdir)
key = 'blended' key = 'blended'
img.save(left_images[key], outputdir + '%s_view%04d_%s_l.png' % (scene, view_idx, key)) img.save(left_images[key], outputdir + '%s_view%04d_%s_l.png' % (scene, view_idx, key))
img.save(right_images[key], outputdir + '%s_view%04d_%s_r.png' % (scene, view_idx, key)) img.save(right_images[key], outputdir + '%s_view%04d_%s_r.png' % (scene, view_idx, key))
''' '''
``` ```
......
...@@ -406,7 +406,8 @@ def test(): ...@@ -406,7 +406,8 @@ def test():
out[key][global_idx] = ret[key] out[key][global_idx] = ret[key]
if args.output_flags['perf']: if args.output_flags['perf']:
perf_times[i] = perf.checkpoint() perf_times[i] = perf.checkpoint()
progress_bar(i, n, 'Inferring...') if not args.log_redirect:
progress_bar(i, n, 'Inferring...')
i += 1 i += 1
global_offset += n_rays global_offset += n_rays
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment