import os import torch import json import dash import dash_core_components as dcc import dash_html_components as html import plotly.express as px import numpy as np # from skimage import data from pathlib import Path from dash.dependencies import Input, Output from dash.exceptions import PreventUpdate torch.set_grad_enabled(False) from utils import device from utils import view from utils import img from utils import misc import model datadir = Path('data/__object/christmas') data_desc_file = 'test.json' net_config = 'fovea@snerffast4-rgb_e6_fc512x4_d2.00-50.00_s64_~p' model_path = datadir / 'snerf_voxels/checkpoint_50.tar' fov = 40 res = (256, 256) pix_img_res = (256, 256) center = (0, 0) def load_data_desc(data_desc_file) -> view.Trans: with open(datadir + data_desc_file, 'r', encoding='utf-8') as file: data_desc = json.loads(file.read()) view_range = torch.tensor([data_desc['range']['min'], data_desc['range']['max']]) \ if 'range' in data_desc else None view_centers = torch.tensor( data_desc['view_centers'], device=device.default()).view(-1, 3) view_rots = torch.tensor( [ view.euler_to_matrix([rot[1] if data_desc.get('gl_coord') else -rot[1], rot[0], 0]) for rot in data_desc['view_rots'] ] if len(data_desc['view_rots'][0]) == 2 else data_desc['view_rots'], device=device.default()).view(-1, 3, 3) # (N, 3, 3) return view_range, view.Trans(view_centers, view_rots) cam = view.Camera({ 'fov': fov, 'cx': 0.5, 'cy': 0.5, 'normalized': True }, res, device=device.default()) model, _ = mdl.load(model_path, { "perturb_sample": False }) # Global states x = y = None test_view = None layers = None layer_weights = None view_range, views = load_data_desc(data_desc_file) view_range_size = view_range[1] - view_range[0] print('%d Views loaded.' % views.size()[0]) ''' test_view = views.get(view_idx) rays_o, rays_d = cam.get_global_rays(test_view, True) image = net(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \ .view(1, res[0], res[1], -1).permute(0, 3, 1, 2) ''' fig = px.imshow(np.empty([res[0], res[1], 3])) layer_img = px.imshow(np.empty([res[0], res[1], 3])) pix_img = px.imshow(np.empty([pix_img_res[0], pix_img_res[1], 3])) plot_alpha = px.scatter(x=[0, 1, 2], y=[2, 0, 1]) plot_density = px.scatter(x=[0, 1, 2], y=[2, 0, 1]) view_sliders = [ dcc.Slider(className='slider', id=f'view-slider-{i}', min=view_range[0, i].item(), max=view_range[1, i].item(), step=view_range_size[i].item() / 100, marks={ view_range[0, i].item(): f'{view_range[0, i].item()}', view_range[1, i].item(): f'{view_range[1, i].item()}' }, value=0.5 * (view_range[1, i] + view_range[0, i]).item(), tooltip={'always_visible': True}) for i in range(5) ] app = dash.Dash(__name__, external_stylesheets=['https://codepen.io/chriddyp/pen/bWLwgP.css']) app.layout = html.Div([ html.H3("Drag and draw annotations"), html.Div(className='row', children=[ html.Div(className='six columns', children=[html.Label("View")] + view_sliders), html.Div(className='three columns', children=[dcc.Graph(id='image', figure=fig)]), html.Div(className='three columns', children=[ dcc.Graph(id='layer-image', figure=layer_img), dcc.RangeSlider(id='layer-slider', min=0, max=63, step=1, value=[0, 63], tooltip={}) ]) ]), html.Div(className='row', children=[ html.Div(className='six columns', children=[ dcc.Graph(id='scatter', figure=plot_alpha), dcc.Graph(id='scatter1', figure=plot_density), dcc.Slider(id='samples-slider', min=4, max=128, step=None, marks={ 4: '4', 8: '8', 16: '16', 32: '32', 64: '64', 128: '128', }, value=33, updatemode='drag') ]), html.Div(className='six columns', children=[ html.Label("Pixel View"), dcc.Graph(id='pix-image', figure=pix_img) ]) ]), ]) def plot_alpha_and_density(ray_o, ray_d): # colors, densities, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler) ret = model(ray_o, ray_d, extra_outputs=['depth', 'layers']) colors = ret['layers'][..., : 3] densities = ret['sample_densities'] depths = ret['sample_depths'] alphas = ret['weight'] # alphas = raw2color(densities, depths) scatter_x = misc.torch2np(1 / depths[0]) scatter_y = misc.torch2np(alphas[0]) scatter_y1 = misc.torch2np(densities[0]) scatter_color = misc.torch2np(colors[0] * 255) marker_colors = [i for i in range(scatter_color.shape[0])] marker_colors_str = [ 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2]) for i in range(scatter_color.shape[0]) ] plot_alpha = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors, color_continuous_scale=marker_colors_str) plot_alpha.update_traces(mode='lines+markers') plot_alpha.update_xaxes(showgrid=False) plot_alpha.update_yaxes(type='linear') plot_alpha.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10}) plot_density = px.scatter(x=scatter_x, y=scatter_y1, color=marker_colors, color_continuous_scale=marker_colors_str) plot_density.update_traces(mode='lines+markers') plot_density.update_xaxes(showgrid=False) plot_density.update_yaxes(type='linear') plot_density.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10}) return plot_alpha, plot_density def plot_pixel_image(ray_o, ray_d, r=1): with torch.no_grad(): pixel_point = ray_o + ray_d * r rays_o = torch.cat([ misc.grid2d(*pix_img_res, normalize=True) * view_range_size[:2] + view_range[0, :2], torch.zeros(*pix_img_res, 1) ], dim=-1).to(device.default()) rays_d = pixel_point - rays_o rays_d /= rays_d.norm(dim=-1, keepdim=True) image = model(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \ .view(1, *pix_img_res, -1).permute(0, 3, 1, 2) fig = px.imshow(img.torch2np(image)[0]) return fig def draw_cross(fig, x, y): fig.update_shapes({'visible': False}) fig.add_shape(type="line", xref="x", yref="y", x0=x, y0=y - 5, x1=x, y1=y + 5, line={'color': 'LightSeaGreen', 'width': 3}) fig.add_shape(type="line", xref="x", yref="y", x0=x - 5, y0=y, x1=x + 5, y1=y, line={'color': 'LightSeaGreen', 'width': 3}) def render_view(tx, ty, tz, rx, ry): global test_view, layers, layer_weights if tx is None: raise PreventUpdate with torch.no_grad(): test_view = view.Trans( torch.tensor([[tx, ty, tz]], device=device.default()), torch.tensor(view.euler_to_matrix([ry, rx, 0]), device=device.default()).view(-1, 3, 3) ) rays_o, rays_d = cam.get_global_rays(test_view, True) ret = model(rays_o.view(-1, 3), rays_d.view(-1, 3), extra_outputs=['layers', 'weights']) image = ret['color'].view(1, *res, 3).permute(0, 3, 1, 2) layers = ret['layers'].view(*res, -1, 4) layer_weights = ret['weight'].view(*res, -1) fig = px.imshow(img.torch2np(image)[0]) return fig def render_layer(layer): if layer is None: return None layer_data = torch.sum((layers * layer_weights)[..., range(*layer), :3], dim=-2) fig = px.imshow(img.torch2np(layer_data)) return fig def view_pixel(fig, x, y, samples): sampler = model.sampler if x is None or y is None: return None p = torch.tensor([x, y], device=device.default()) ray_d = test_view.trans_vector(cam.unproj(p)) ray_o = test_view.t draw_cross(fig, x, y) plot_alpha, plot_density = plot_alpha_and_density(ray_o, ray_d) pix_img = plot_pixel_image(ray_o, ray_d) return fig, pix_img, plot_alpha, plot_density @app.callback( [Output('image', 'figure'), Output('layer-image', 'figure'), Output('pix-image', 'figure'), Output('scatter', 'figure'), Output('scatter1', 'figure')], [Input(f'view-slider-{i}', 'value') for i in range(5)] + [Input('image', 'clickData'), Input('samples-slider', 'value'), Input('layer-slider', 'value')] ) def callback(tx, ty, tz, rx, ry, clickData, samples, layer): global x, y, fig, layer_img, pix_img, plot_alpha, plot_density ctx = dash.callback_context if not ctx.triggered or ctx.triggered[0]['prop_id'].startswith('view-slider'): fig = render_view(tx, ty, tz, rx, ry) layer_img = render_layer(layer) or layer_img elif ctx.triggered[0]['prop_id'] == 'layer-slider.value': layer_img = render_layer(layer) or layer_img else: if clickData: x = clickData['points'][0]['x'] y = clickData['points'][0]['y'] ret = view_pixel(fig, x, y, samples) if ret is not None: fig, pix_img, plot_alpha, plot_density = ret return fig, layer_img, pix_img, plot_alpha, plot_density if __name__ == '__main__': app.run_server(debug=True)