renderer.py 14.9 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
import torch
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
3
4
5
from itertools import cycle
from typing import Dict, Set, Tuple, Union

from utils.type import NetInput, ReturnData
Nianchen Deng's avatar
sync    
Nianchen Deng committed
6

Nianchen Deng's avatar
Nianchen Deng committed
7
from .generic import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
8
9
10
11
12
from model.base import BaseModel
from utils import math
from utils.module import Module
from utils.perf import checkpoint, perf
from utils.samples import Samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
    """
    Calculate energies from densities inferred by model.

    :param densities `Tensor(N..., 1)`: model's output densities
    :param dists `Tensor(N...)`: integration times
    :param raw_noise_std `float`: the noise std used to egularize network during training (prevents 
                                  floater artifacts), defaults to 0, means no noise is added
    :return `Tensor(N..., 1)`: energies which block light rays
    """
    if raw_noise_std > 0:
        # Add noise to model's predictions for density. Can be used to
        # regularize network during training (prevents floater artifacts).
        densities = densities + torch.normal(0.0, raw_noise_std, densities.size())
    return densities * dists[..., None]


def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
    """
    Calculate alphas from densities inferred by model.

    :param densities `Tensor(N..., 1)`: model's output densities
    :param dists `Tensor(N...)`: integration times
    :param raw_noise_std `float`: the noise std used to egularize network during training (prevents 
                                  floater artifacts), defaults to 0, means no noise is added
    :return `Tensor(N..., 1)`: alphas
    """
    energies = density2energy(densities, dists, raw_noise_std)
    return 1.0 - torch.exp(-energies)
Nianchen Deng's avatar
Nianchen Deng committed
44
45


Nianchen Deng's avatar
sync    
Nianchen Deng committed
46
class AlphaComposition(Module):
Nianchen Deng's avatar
Nianchen Deng committed
47
48
49
50
51

    def __init__(self):
        super().__init__()

    def forward(self, colors, alphas, bg=None):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
54
55
56
57
58
59
        """
        [summary]

        :param colors `Tensor(N, P, C)`: [description]
        :param alphas `Tensor(N, P, 1)`: [description]
        :param bg `Tensor([N, ]C)`: [description], defaults to None
        :return `Tensor(N, C)`: [description]
        """
Nianchen Deng's avatar
Nianchen Deng committed
60
61
62
        # Compute weight for RGB of each sample along each ray.  A cumprod() is
        # used to express the idea of the ray not having reflected up to this
        # sample yet.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
        one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + math.tiny, dim=-2)
Nianchen Deng's avatar
Nianchen Deng committed
64
        one_minus_alpha = torch.cat([
Nianchen Deng's avatar
sync    
Nianchen Deng committed
65
            torch.ones_like(one_minus_alpha[..., :1, :]),
Nianchen Deng's avatar
Nianchen Deng committed
66
            one_minus_alpha
Nianchen Deng's avatar
sync    
Nianchen Deng committed
67
68
        ], dim=-2)
        weights = alphas * one_minus_alpha  # (N, P, 1)
Nianchen Deng's avatar
Nianchen Deng committed
69

Nianchen Deng's avatar
sync    
Nianchen Deng committed
70
71
        # (N, C), computed weighted color of each sample along each ray.
        final_color = torch.sum(weights * colors, dim=-2)
Nianchen Deng's avatar
Nianchen Deng committed
72
73
74
75
76
77
78
79
80
81
82
83
84

        # To composite onto a white background, use the accumulated alpha map.
        if bg is not None:
            # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
            acc_map = torch.sum(weights, -1)
            final_color = final_color + bg * (1. - acc_map[..., None])

        return {
            'color': final_color,
            'weights': weights,
        }


Nianchen Deng's avatar
sync    
Nianchen Deng committed
85
class VolumnRenderer(Module):
Nianchen Deng's avatar
Nianchen Deng committed
86

Nianchen Deng's avatar
sync    
Nianchen Deng committed
87
    class States:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
88
        kernel: BaseModel
Nianchen Deng's avatar
sync    
Nianchen Deng committed
89
90
        samples: Samples
        early_stop_tolerance: float
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
92
        outputs: Set[str]
        hit_mask: torch.Tensor
Nianchen Deng's avatar
sync    
Nianchen Deng committed
93
94
        N: int
        P: int
Nianchen Deng's avatar
sync    
Nianchen Deng committed
95
        device: torch.device
Nianchen Deng's avatar
sync    
Nianchen Deng committed
96
97

        colors: torch.Tensor
Nianchen Deng's avatar
sync    
Nianchen Deng committed
98
        densities: torch.Tensor
Nianchen Deng's avatar
sync    
Nianchen Deng committed
99
100
101
102
        energies: torch.Tensor
        weights: torch.Tensor
        cum_energies: torch.Tensor
        exp_energies: torch.Tensor
Nianchen Deng's avatar
sync    
Nianchen Deng committed
103

Nianchen Deng's avatar
sync    
Nianchen Deng committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        tot_evaluations: Dict[str, int]

        chunk: Tuple[slice, slice]
        cum_chunk: Tuple[slice, slice]
        cum_last: Tuple[slice, slice]
        chunk_id: int

        @property
        def start(self) -> int:
            return self.chunk[1].start

        @property
        def end(self) -> int:
            return self.chunk[1].stop

Nianchen Deng's avatar
sync    
Nianchen Deng committed
119
120
        def __init__(self, kernel: BaseModel, samples: Samples, early_stop_tolerance: float,
                     outputs: Set[str]) -> None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
121
122
123
            self.kernel = kernel
            self.samples = samples
            self.early_stop_tolerance = early_stop_tolerance
Nianchen Deng's avatar
sync    
Nianchen Deng committed
124
            self.outputs = outputs
Nianchen Deng's avatar
sync    
Nianchen Deng committed
125
126

            N, P = samples.size
Nianchen Deng's avatar
sync    
Nianchen Deng committed
127
128
            self.device = self.samples.device
            self.hit_mask = samples.voxel_indices != -1  # (N, P) | bool
Nianchen Deng's avatar
sync    
Nianchen Deng committed
129
            self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
130
            self.densities = torch.zeros(N, P, 1, device=samples.device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
131
132
133
134
135
136
137
138
            self.energies = torch.zeros(N, P, 1, device=samples.device)
            self.weights = torch.zeros(N, P, 1, device=samples.device)
            self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
            self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
            self.tot_evaluations = {}
            self.N, self.P = N, P
            self.chunk_id = -1

Nianchen Deng's avatar
sync    
Nianchen Deng committed
139
140
141
142
143
144
        def n_hits(self, index: Union[int, slice] = None) -> int:
            if not isinstance(self.hit_mask, torch.Tensor):
                if index is not None:
                    return self.N * self.colors[:, index].shape[1]
                return self.N * self.P
            if index is None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
145
                return self.hit_mask.count_nonzero().item()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
            return self.hit_mask[:, index].count_nonzero().item()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

        def accumulate_tot_evaluations(self, key: str, n: int):
            if key not in self.tot_evaluations:
                self.tot_evaluations[key] = 0
            self.tot_evaluations[key] += n

        def next_chunk(self, *, length=None, end=None):
            start = 0 if not hasattr(self, "chunk") else self.end
            length = length or self.P
            end = min(end or start + length, self.P)
            self.chunk = slice(None), slice(start, end)
            self.cum_chunk = slice(None), slice(start + 1, end + 1)
            self.cum_last = slice(None), slice(start, start + 1)
            self.chunk_id += 1
            return self

Nianchen Deng's avatar
sync    
Nianchen Deng committed
163
164
165
166
167
168
169
170
171
172
        def put(self, key: str, values: torch.Tensor, indices: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[slice, slice]]):
            if not hasattr(self, key):
                new_tensor = torch.zeros(self.N, self.P, values.shape[-1], device=self.device)
                setattr(self, key, new_tensor)
            tensor: torch.Tensor = getattr(self, key)
            # if isinstance(indices[0], torch.Tensor):
            #    tensor.index_put_(indices, values)
            # else:
            tensor[indices] = values

Nianchen Deng's avatar
sync    
Nianchen Deng committed
173
    def __init__(self, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
174
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
175
176

    @perf
Nianchen Deng's avatar
sync    
Nianchen Deng committed
177
    def forward(self, kernel: BaseModel, samples: Samples, *outputs: str,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
178
179
                raymarching_early_stop_tolerance: float = 0,
                raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
180
                **kwargs) -> ReturnData:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
181
182
183
        """
        Perform volumn rendering.

Nianchen Deng's avatar
sync    
Nianchen Deng committed
184
        :param kernel `BaseModel`: render kernel
Nianchen Deng's avatar
sync    
Nianchen Deng committed
185
        :param samples `Samples(N, P)`: samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
186
187
        :param outputs `str...`: items should be contained in the result dict.
                Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
Nianchen Deng's avatar
sync    
Nianchen Deng committed
188
189
190
191
192
193
194
        :param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
                Should between 0 and 1 (0 means no early stop). Defaults to 0
        :param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
                Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
                Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
                0 and `None` means not splitting the raymarching process. Defaults to `None`
        :return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
Nianchen Deng's avatar
Nianchen Deng committed
195
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
196
197
198
199
        if samples.size[1] == 0:
            print("VolumnRenderer.forward(): # of samples is zero")
            return None

Nianchen Deng's avatar
sync    
Nianchen Deng committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        infer_outputs = set()
        for key in outputs:
            if key == "color":
                infer_outputs.add("colors")
                infer_outputs.add("densities")
            elif key == "specular":
                infer_outputs.add("speculars")
                infer_outputs.add("densities")
            elif key == "diffuse":
                infer_outputs.add("diffuses")
                infer_outputs.add("densities")
            elif key == "depth":
                infer_outputs.add("densities")
            else:
                infer_outputs.add(key)
        s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance, infer_outputs)

        checkpoint("Prepare states object")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
218
219
220
221
222

        if not raymarching_chunk_size_or_sections:
            raymarching_chunk_size_or_sections = [s.P]
        elif isinstance(raymarching_chunk_size_or_sections, int) and \
                raymarching_chunk_size_or_sections > 0:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
223
224
225
            raymarching_chunk_size_or_sections = [
                math.ceil(s.P / raymarching_chunk_size_or_sections)
            ]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        if isinstance(raymarching_chunk_size_or_sections, list):
            chunk_sections = raymarching_chunk_size_or_sections
            for chunk_samples in cycle(chunk_sections):
                self._forward_chunk(s.next_chunk(length=chunk_samples))
                if s.end >= s.P:
                    break
        else:
            chunk_size = -raymarching_chunk_size_or_sections
            chunk_hits = s.n_hits(0)
            for i in range(1, s.P):
                n_hits = s.n_hits(i)
                if chunk_hits + n_hits > chunk_size:
                    self._forward_chunk(s.next_chunk(end=i))
                    n_hits = s.n_hits(i)
                    chunk_hits = 0
                chunk_hits += n_hits
            self._forward_chunk(s.next_chunk())

Nianchen Deng's avatar
sync    
Nianchen Deng committed
245
246
247
248
249
250
251
        checkpoint("Run forward chunks")

        ret = {}
        for key in outputs:
            if key == 'color':
                ret['color'] = torch.sum(s.colors * s.weights, 1)
            elif key == 'depth':
Nianchen Deng's avatar
sync    
Nianchen Deng committed
252
                ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
253
            elif key == 'diffuse' and hasattr(s, "diffuses"):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
254
                ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
255
            elif key == 'specular' and hasattr(s, "speculars"):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
256
257
258
259
260
261
                ret['specular'] = torch.sum(s.speculars * s.weights, 1)
            elif key == 'layers':
                ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
            elif key == 'states':
                ret['states'] = s
            else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
262
263
                if hasattr(s, key):
                    ret[key] = getattr(s, key)
Nianchen Deng's avatar
Nianchen Deng committed
264

Nianchen Deng's avatar
sync    
Nianchen Deng committed
265
        checkpoint("Set return data")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
266

Nianchen Deng's avatar
sync    
Nianchen Deng committed
267
        return ret
Nianchen Deng's avatar
sync    
Nianchen Deng committed
268

Nianchen Deng's avatar
sync    
Nianchen Deng committed
269
    @perf
Nianchen Deng's avatar
sync    
Nianchen Deng committed
270
271
272
273
274
275
276
    def _calc_weights(self, s: States):
        """
        Calculate weights of samples in composited outputs

        :param s `States`: states
        :param start `int`: chunk's start
        :param end `int`: chunk's end
Nianchen Deng's avatar
Nianchen Deng committed
277
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
278
        s.energies[s.chunk] = density2energy(s.densities[s.chunk], s.samples.dists[s.chunk])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
279
280
281
282
        s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
            + s.cum_energies[s.cum_last]
        s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
        s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
Nianchen Deng's avatar
Nianchen Deng committed
283

Nianchen Deng's avatar
sync    
Nianchen Deng committed
284
    @perf
Nianchen Deng's avatar
sync    
Nianchen Deng committed
285
    def _apply_early_stop(self, s: States):
Nianchen Deng's avatar
Nianchen Deng committed
286
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
287
288
289
290
291
        Stop rays whose accumulated opacity are larger than a threshold

        :param s `States`: s
        :param end `int`: chunk's end
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
292
        if s.end < s.P and s.early_stop_tolerance > 0 and isinstance(s.hit_mask, torch.Tensor):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
293
294
295
            rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
            s.hit_mask[rays_to_stop, s.end:] = 0

Nianchen Deng's avatar
sync    
Nianchen Deng committed
296
    @perf
Nianchen Deng's avatar
sync    
Nianchen Deng committed
297
    def _forward_chunk(self, s: States) -> int:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
298
299
300
301
302
303
304
305
306
307
        if isinstance(s.hit_mask, torch.Tensor):
            fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)
            if fi_idxs[0].size(0) == 0:
                s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
                s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
                return
            fi_idxs[1].add_(s.start)
            s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))
        else:
            fi_idxs = s.chunk
Nianchen Deng's avatar
sync    
Nianchen Deng committed
308

Nianchen Deng's avatar
sync    
Nianchen Deng committed
309
310
311
        fi_outputs = s.kernel.infer(*s.outputs, samples=s.samples[fi_idxs], chunk_id=s.chunk_id)
        for key, value in fi_outputs.items():
            s.put(key, value, fi_idxs)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

        self._calc_weights(s)
        self._apply_early_stop(s)


class DensityFirstVolumnRenderer(VolumnRenderer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _forward_chunk(self, s: VolumnRenderer.States) -> int:
        fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)  # (N')
        fi_idxs[1].add_(s.start)

        if fi_idxs[0].size(0) == 0:
            s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
            s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
329
            return
Nianchen Deng's avatar
sync    
Nianchen Deng committed
330
331
332
333
334

        # fi_* means "filtered" by hit mask
        fi_samples = s.samples[fi_idxs]  # N -> N'

        # For all valid samples: encode X
Nianchen Deng's avatar
sync    
Nianchen Deng committed
335
        density_inputs = s.kernel.input(fi_samples, "x", "f")  # (N', Ex)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
336
337

        # Infer densities (shape)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
338
339
340
341
        density_outputs = s.kernel.infer('densities', 'features', samples=fi_samples,
                                         inputs=density_inputs, chunk_id=s.chunk_id)
        s.put('densities', density_outputs['densities'], fi_idxs)
        s.accumulate_tot_evaluations("densities", fi_idxs[0].size(0))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
342
343
344
345
346
347
348
349
350
351

        self._calc_weights(s)
        self._apply_early_stop(s)

        # Remove samples whose weights are less than a threshold
        s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0

        # Update "filtered" tensors
        fi_mask = s.hit_mask[fi_idxs]
        fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask])  # N' -> N"
Nianchen Deng's avatar
sync    
Nianchen Deng committed
352
353
354
355
        fi_samples = s.samples[fi_idxs]  # N -> N"
        fi_features = density_outputs['features'][fi_mask]
        color_inputs = s.kernel.input(fi_samples, "d")  # (N")
        color_inputs.x = density_inputs.x[fi_mask]
Nianchen Deng's avatar
Nianchen Deng committed
356

Nianchen Deng's avatar
sync    
Nianchen Deng committed
357
        # Infer colors (appearance)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
358
359
360
361
362
        outputs = s.outputs.copy()
        if 'densities' in outputs:
            outputs.remove('densities')
        color_outputs = s.kernel.infer(*outputs, samples=fi_samples, inputs=color_inputs,
                                       chunk_id=s.chunk_id, features=fi_features)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
363
364
365
366
367
368
369
370
        # if s.chunk_id == 0:
        #     fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
        # elif s.chunk_id == 1:
        #     fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
        # elif s.chunk_id == 2:
        #     fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
        # else:
        #     fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
371
372
373
        for key, value in color_outputs.items():
            s.put(key, value, fi_idxs)
        s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))