dash_test.py 9.38 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
import os
import torch
import json
import dash
import dash_core_components as dcc
import dash_html_components as html
import plotly.express as px
Nianchen Deng's avatar
Nianchen Deng committed
8
9
import numpy as np
# from skimage import data
Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
from pathlib import Path
Nianchen Deng's avatar
Nianchen Deng committed
11
from dash.dependencies import Input, Output
Nianchen Deng's avatar
Nianchen Deng committed
12
from dash.exceptions import PreventUpdate
Nianchen Deng's avatar
Nianchen Deng committed
13
14


Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
torch.set_grad_enabled(False)
Nianchen Deng's avatar
Nianchen Deng committed
16

Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
19
20

from utils import device
from utils import view
from utils import img
Nianchen Deng's avatar
Nianchen Deng committed
21
from utils import misc
Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
import model
Nianchen Deng's avatar
Nianchen Deng committed
23
24


Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
26
datadir = Path('data/__object/christmas')
data_desc_file = 'test.json'
Nianchen Deng's avatar
Nianchen Deng committed
27
net_config = 'fovea@snerffast4-rgb_e6_fc512x4_d2.00-50.00_s64_~p'
Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
model_path = datadir / 'snerf_voxels/checkpoint_50.tar'
Nianchen Deng's avatar
Nianchen Deng committed
29
fov = 40
Nianchen Deng's avatar
Nianchen Deng committed
30
res = (256, 256)
Nianchen Deng's avatar
Nianchen Deng committed
31
pix_img_res = (256, 256)
Nianchen Deng's avatar
Nianchen Deng committed
32
center = (0, 0)
Nianchen Deng's avatar
Nianchen Deng committed
33
34


Nianchen Deng's avatar
Nianchen Deng committed
35
def load_data_desc(data_desc_file) -> view.Trans:
Nianchen Deng's avatar
Nianchen Deng committed
36
37
    with open(datadir + data_desc_file, 'r', encoding='utf-8') as file:
        data_desc = json.loads(file.read())
Nianchen Deng's avatar
Nianchen Deng committed
38
39
        view_range = torch.tensor([data_desc['range']['min'], data_desc['range']['max']]) \
            if 'range' in data_desc else None
Nianchen Deng's avatar
Nianchen Deng committed
40
        view_centers = torch.tensor(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
            data_desc['view_centers'], device=device.default()).view(-1, 3)
Nianchen Deng's avatar
Nianchen Deng committed
42
        view_rots = torch.tensor(
Nianchen Deng's avatar
Nianchen Deng committed
43
44
45
46
47
48
49
            [
                view.euler_to_matrix([rot[1] if data_desc.get('gl_coord') else -rot[1], rot[0], 0])
                for rot in data_desc['view_rots']
            ]
            if len(data_desc['view_rots'][0]) == 2 else data_desc['view_rots'],
            device=device.default()).view(-1, 3, 3)  # (N, 3, 3)
        return view_range, view.Trans(view_centers, view_rots)
Nianchen Deng's avatar
Nianchen Deng committed
50
51


Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
cam = view.Camera({
Nianchen Deng's avatar
Nianchen Deng committed
53
54
55
56
    'fov': fov,
    'cx': 0.5,
    'cy': 0.5,
    'normalized': True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
57
}, res, device=device.default())
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
59
60
61
model, _ = mdl.load(model_path, {
    "perturb_sample": False
})

Nianchen Deng's avatar
Nianchen Deng committed
62
63

# Global states
Nianchen Deng's avatar
Nianchen Deng committed
64
x = y = None
Nianchen Deng's avatar
Nianchen Deng committed
65
66
67
test_view = None
layers = None
layer_weights = None
Nianchen Deng's avatar
Nianchen Deng committed
68

Nianchen Deng's avatar
Nianchen Deng committed
69
70
view_range, views = load_data_desc(data_desc_file)
view_range_size = view_range[1] - view_range[0]
Nianchen Deng's avatar
Nianchen Deng committed
71
print('%d Views loaded.' % views.size()[0])
Nianchen Deng's avatar
Nianchen Deng committed
72

Nianchen Deng's avatar
Nianchen Deng committed
73
'''
Nianchen Deng's avatar
Nianchen Deng committed
74
75
test_view = views.get(view_idx)
rays_o, rays_d = cam.get_global_rays(test_view, True)
Nianchen Deng's avatar
Nianchen Deng committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
image = net(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \
    .view(1, res[0], res[1], -1).permute(0, 3, 1, 2)
'''

fig = px.imshow(np.empty([res[0], res[1], 3]))
layer_img = px.imshow(np.empty([res[0], res[1], 3]))
pix_img = px.imshow(np.empty([pix_img_res[0], pix_img_res[1], 3]))
plot_alpha = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
plot_density = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
view_sliders = [
    dcc.Slider(className='slider', id=f'view-slider-{i}',
               min=view_range[0, i].item(), max=view_range[1, i].item(),
               step=view_range_size[i].item() / 100,
               marks={
                   view_range[0, i].item(): f'{view_range[0, i].item()}',
                   view_range[1, i].item(): f'{view_range[1, i].item()}'
               },
               value=0.5 * (view_range[1, i] + view_range[0, i]).item(),
               tooltip={'always_visible': True})
    for i in range(5)
]
app = dash.Dash(__name__, external_stylesheets=['https://codepen.io/chriddyp/pen/bWLwgP.css'])
Nianchen Deng's avatar
Nianchen Deng committed
98
99
100
app.layout = html.Div([
    html.H3("Drag and draw annotations"),
    html.Div(className='row', children=[
Nianchen Deng's avatar
Nianchen Deng committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        html.Div(className='six columns', children=[html.Label("View")] + view_sliders),
        html.Div(className='three columns', children=[dcc.Graph(id='image', figure=fig)]),
        html.Div(className='three columns', children=[
            dcc.Graph(id='layer-image', figure=layer_img),
            dcc.RangeSlider(id='layer-slider', min=0, max=63, step=1, value=[0, 63], tooltip={})
        ])
    ]),
    html.Div(className='row', children=[
        html.Div(className='six columns', children=[
            dcc.Graph(id='scatter', figure=plot_alpha),
            dcc.Graph(id='scatter1', figure=plot_density),
            dcc.Slider(id='samples-slider', min=4, max=128, step=None,
                       marks={
                           4: '4',
                           8: '8',
                           16: '16',
                           32: '32',
                           64: '64',
                           128: '128',
                       },
                       value=33,
                       updatemode='drag')
        ]),
        html.Div(className='six columns', children=[
            html.Label("Pixel View"),
            dcc.Graph(id='pix-image', figure=pix_img)
        ])
    ]),
Nianchen Deng's avatar
Nianchen Deng committed
129

Nianchen Deng's avatar
Nianchen Deng committed
130
])
Nianchen Deng's avatar
Nianchen Deng committed
131
132


Nianchen Deng's avatar
Nianchen Deng committed
133
134
def plot_alpha_and_density(ray_o, ray_d):
    # colors, densities, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
    ret = model(ray_o, ray_d, extra_outputs=['depth', 'layers'])
Nianchen Deng's avatar
Nianchen Deng committed
136
137
138
139
140
141
142
143
144
145
146
    colors = ret['layers'][..., : 3]
    densities = ret['sample_densities']
    depths = ret['sample_depths']
    alphas = ret['weight']
    # alphas = raw2color(densities, depths)

    scatter_x = misc.torch2np(1 / depths[0])
    scatter_y = misc.torch2np(alphas[0])
    scatter_y1 = misc.torch2np(densities[0])
    scatter_color = misc.torch2np(colors[0] * 255)
    marker_colors = [i for i in range(scatter_color.shape[0])]
Nianchen Deng's avatar
Nianchen Deng committed
147
    marker_colors_str = [
Nianchen Deng's avatar
Nianchen Deng committed
148
        'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
Nianchen Deng's avatar
Nianchen Deng committed
149
150
151
        for i in range(scatter_color.shape[0])
    ]

Nianchen Deng's avatar
Nianchen Deng committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    plot_alpha = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors,
                            color_continuous_scale=marker_colors_str)
    plot_alpha.update_traces(mode='lines+markers')
    plot_alpha.update_xaxes(showgrid=False)
    plot_alpha.update_yaxes(type='linear')
    plot_alpha.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})

    plot_density = px.scatter(x=scatter_x, y=scatter_y1, color=marker_colors,
                              color_continuous_scale=marker_colors_str)
    plot_density.update_traces(mode='lines+markers')
    plot_density.update_xaxes(showgrid=False)
    plot_density.update_yaxes(type='linear')
    plot_density.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})

    return plot_alpha, plot_density


def plot_pixel_image(ray_o, ray_d, r=1):
    with torch.no_grad():
        pixel_point = ray_o + ray_d * r
        rays_o = torch.cat([
Nianchen Deng's avatar
sync    
Nianchen Deng committed
173
            misc.grid2d(*pix_img_res, normalize=True) * view_range_size[:2] + view_range[0, :2],
Nianchen Deng's avatar
Nianchen Deng committed
174
175
176
177
            torch.zeros(*pix_img_res, 1)
        ], dim=-1).to(device.default())
        rays_d = pixel_point - rays_o
        rays_d /= rays_d.norm(dim=-1, keepdim=True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
178
        image = model(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \
Nianchen Deng's avatar
Nianchen Deng committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            .view(1, *pix_img_res, -1).permute(0, 3, 1, 2)
        fig = px.imshow(img.torch2np(image)[0])
    return fig


def draw_cross(fig, x, y):
    fig.update_shapes({'visible': False})
    fig.add_shape(type="line", xref="x", yref="y",
                  x0=x, y0=y - 5,
                  x1=x, y1=y + 5,
                  line={'color': 'LightSeaGreen', 'width': 3})
    fig.add_shape(type="line", xref="x", yref="y",
                  x0=x - 5, y0=y,
                  x1=x + 5, y1=y,
                  line={'color': 'LightSeaGreen', 'width': 3})


def render_view(tx, ty, tz, rx, ry):
    global test_view, layers, layer_weights
    if tx is None:
        raise PreventUpdate
    with torch.no_grad():
        test_view = view.Trans(
            torch.tensor([[tx, ty, tz]], device=device.default()),
            torch.tensor(view.euler_to_matrix([ry, rx, 0]), device=device.default()).view(-1, 3, 3)
        )
        rays_o, rays_d = cam.get_global_rays(test_view, True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
206
207
208
209
        ret = model(rays_o.view(-1, 3), rays_d.view(-1, 3), extra_outputs=['layers', 'weights'])
        image = ret['color'].view(1, *res, 3).permute(0, 3, 1, 2)
        layers = ret['layers'].view(*res, -1, 4)
        layer_weights = ret['weight'].view(*res, -1)
Nianchen Deng's avatar
Nianchen Deng committed
210
211
212
213
214
215
216
        fig = px.imshow(img.torch2np(image)[0])
    return fig


def render_layer(layer):
    if layer is None:
        return None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
217
    layer_data = torch.sum((layers * layer_weights)[..., range(*layer), :3], dim=-2)
Nianchen Deng's avatar
Nianchen Deng committed
218
219
220
221
222
    fig = px.imshow(img.torch2np(layer_data))
    return fig


def view_pixel(fig, x, y, samples):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
223
    sampler = model.sampler
Nianchen Deng's avatar
Nianchen Deng committed
224
225
226
227
228
229
230
231
232
    if x is None or y is None:
        return None
    p = torch.tensor([x, y], device=device.default())
    ray_d = test_view.trans_vector(cam.unproj(p))
    ray_o = test_view.t
    draw_cross(fig, x, y)
    plot_alpha, plot_density = plot_alpha_and_density(ray_o, ray_d)
    pix_img = plot_pixel_image(ray_o, ray_d)
    return fig, pix_img, plot_alpha, plot_density
Nianchen Deng's avatar
Nianchen Deng committed
233

Nianchen Deng's avatar
Nianchen Deng committed
234
235
236

@app.callback(
    [Output('image', 'figure'),
Nianchen Deng's avatar
Nianchen Deng committed
237
238
     Output('layer-image', 'figure'),
     Output('pix-image', 'figure'),
Nianchen Deng's avatar
Nianchen Deng committed
239
240
     Output('scatter', 'figure'),
     Output('scatter1', 'figure')],
Nianchen Deng's avatar
Nianchen Deng committed
241
    [Input(f'view-slider-{i}', 'value') for i in range(5)] +
Nianchen Deng's avatar
Nianchen Deng committed
242
    [Input('image', 'clickData'),
Nianchen Deng's avatar
Nianchen Deng committed
243
244
     Input('samples-slider', 'value'),
     Input('layer-slider', 'value')]
Nianchen Deng's avatar
Nianchen Deng committed
245
)
Nianchen Deng's avatar
Nianchen Deng committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def callback(tx, ty, tz, rx, ry, clickData, samples, layer):
    global x, y, fig, layer_img, pix_img, plot_alpha, plot_density
    ctx = dash.callback_context
    if not ctx.triggered or ctx.triggered[0]['prop_id'].startswith('view-slider'):
        fig = render_view(tx, ty, tz, rx, ry)
        layer_img = render_layer(layer) or layer_img
    elif ctx.triggered[0]['prop_id'] == 'layer-slider.value':
        layer_img = render_layer(layer) or layer_img
    else:
        if clickData:
            x = clickData['points'][0]['x']
            y = clickData['points'][0]['y']
    ret = view_pixel(fig, x, y, samples)
    if ret is not None:
        fig, pix_img, plot_alpha, plot_density = ret
    return fig, layer_img, pix_img, plot_alpha, plot_density
Nianchen Deng's avatar
Nianchen Deng committed
262
263
264
265


if __name__ == '__main__':
    app.run_server(debug=True)