snerf_fast.py 3.04 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
from .__common__ import *
from .nerf import NeRF
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4


Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
class SnerfFast(NeRF):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
6

Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
    def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, chunk_id: int, **kwargs) -> NetOutput:
        inputs = inputs or self.input(samples)
        return {
Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
11
            key: value.reshape(*samples.size, -1)
            for key, value in self.cores[chunk_id](inputs, *outputs).items()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
        }
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13

Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
15
16
17
    def _preprocess_args(self):
        self.args0["spherical"] = True
        super()._preprocess_args()
        self.samples_per_part = self.args['n_samples'] // self.multi_nets
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18

Nianchen Deng's avatar
sync    
Nianchen Deng committed
19
20
    def _init_chns(self):
        super()._init_chns(x=2)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
21

Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
23
24
25
26
27
    def _create_core_unit(self):
        return super()._create_core_unit(
            x_chns=self.x_encoder.out_dim * self.samples_per_part,
            density_chns=self.chns('density') * self.samples_per_part,
            color_chns=self.chns('color') * self.samples_per_part)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
    def _input(self, samples: Samples, what: str) -> torch.Tensor | None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
30
31
32
33
34
35
        if what == "x":
            return self._encode("x", samples.pts[..., -self.chns("x"):]).flatten(1, 2)
        elif what == "d":
            return self._encode("d", samples.dirs[:, 0])\
                if self.d_encoder and samples.dirs is not None else None
        else:
            return super()._input(samples, what)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
36

Nianchen Deng's avatar
sync    
Nianchen Deng committed
37
38
39
40
41
42
43
44
45
46
    def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor:
        if what == "x":
            # Normalize x according to the encoder's range requirement using space's bounding box
            bbox = self.space.bbox[:, -self.chns("x"):]
            val = (val - bbox[0]) / (bbox[1] - bbox[0])
            val = val * (self.x_encoder.in_range[1] - self.x_encoder.in_range[0])\
                + self.x_encoder.in_range[0]
            return self.x_encoder(val)
        return super()._encode(what, val)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
    def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
49
50
        return super()._render(samples, *outputs,
                               **extra_args |
                               {"raymarching_chunk_size_or_sections", [self.samples_per_part]})
Nianchen Deng's avatar
Nianchen Deng committed
51

Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
54
55
    def _sample(self, data: InputData, **extra_args) -> Samples:
        samples = super()._sample(data, **extra_args)
        samples.voxel_indices = 0
        return samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
56
57


Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
class SnerfFastExport(torch.nn.Module):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
59
60
61
62
63
64
65
66

    def __init__(self, net: SnerfFast):
        super().__init__()
        self.net = net

    def forward(self, coords_encoded, z_vals):
        colors = []
        densities = []
Nianchen Deng's avatar
sync    
Nianchen Deng committed
67
        for i in range(self.net.multi_nets):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
68
69
70
71
72
73
74
75
76
            s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part)
            mlp = self.net.nets[i] if self.net.nets is not None else self.net.net
            c, d = mlp(coords_encoded[:, s].flatten(1, 2))
            colors.append(c.view(-1, self.net.samples_per_part, self.net.color_chns))
            densities.append(d)
        colors = torch.cat(colors, 1)
        densities = torch.cat(densities, 1)
        alphas = self.net.rendering.density2alpha(densities, z_vals)
        return torch.cat([colors, alphas[..., None]], -1)