train_with_space.py 6.59 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from modules.sampler import Samples
from modules.space import Octree, Voxels
from utils.mem_profiler import MemProfiler
from utils.misc import print_and_log
from .base import *


class TrainWithSpace(Train):

    def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000,
                 **kwargs) -> None:
        super().__init__(model, **kwargs)
        self.pruning_loop = pruning_loop
        self.splitting_loop = splitting_loop
        #MemProfiler.enable = True

    def _train_epoch(self):
        if not self.perf_mode:
            if self.epoch != 1:
                if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1:
                    try:
                        with torch.no_grad():
                            before, after = self.model.splitting()
                        print_and_log(
                            f"Splitting done. # of voxels before: {before}, after: {after}")
                    except NotImplementedError:
                        print_and_log(
                            "Note: The space does not support splitting operation. Just skip it.")
                if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
                    try:
                        with torch.no_grad():
                            #before, after = self.model.pruning()
                            # print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
                            # self._prune_inner_voxels()
                            self._prune_voxels_by_weights()
                    except NotImplementedError:
                        print_and_log(
                            "Note: The space does not support pruning operation. Just skip it.")

        super()._train_epoch()

    def _prune_inner_voxels(self):
        space: Voxels = self.model.space
        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
                                          device=space.voxels.device)
        iters_in_epoch = 0
        batch_size = self.data_loader.batch_size
        self.data_loader.batch_size = 2 ** 14
        for _, rays_o, rays_d, _ in self.data_loader:
            self.model(rays_o, rays_d,
                       raymarching_early_stop_tolerance=0.01,
                       raymarching_chunk_size_or_sections=[1],
                       perturb_sample=False,
                       voxel_access_counts=voxel_access_counts,
                       voxel_access_tolerance=0)
            iters_in_epoch += 1
            percent = iters_in_epoch / len(self.data_loader) * 100
            sys.stdout.write(f'Pruning inner voxels...{percent:.1f}%   \r')
        self.data_loader.batch_size = batch_size
        before, after = space.prune(voxel_access_counts > 0)
        print(f"Prune inner voxels: {before} -> {after}")

    def _prune_voxels_by_weights(self):
        space: Voxels = self.model.space
        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
                                          device=space.voxels.device)
        iters_in_epoch = 0
        batch_size = self.data_loader.batch_size
        self.data_loader.batch_size = 2 ** 14
        for _, rays_o, rays_d, _ in self.data_loader:
            ret = self.model(rays_o, rays_d,
                             raymarching_early_stop_tolerance=0,
                             raymarching_chunk_size_or_sections=None,
                             perturb_sample=False,
                             extra_outputs=['weights'])
            valid_mask = ret['weights'][..., 0] > 0.01
            accessed_voxels = ret['samples'].voxel_indices[valid_mask]
            voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels))
            iters_in_epoch += 1
            percent = iters_in_epoch / len(self.data_loader) * 100
            sys.stdout.write(f'Pruning by weights...{percent:.1f}%   \r')
        self.data_loader.batch_size = batch_size
        before, after = space.prune(voxel_access_counts > 0)
        print_and_log(f"Prune by weights: {before} -> {after}")

    def _prune_voxels_by_voxel_weights(self):
        space: Voxels = self.model.space
        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
                                          device=space.voxels.device)
        with torch.no_grad():
            batch_size = self.data_loader.batch_size
            self.data_loader.batch_size = 2 ** 14
            iters_in_epoch = 0
            for _, rays_o, rays_d, _ in self.data_loader:
                ret = self.model(rays_o, rays_d,
                                 raymarching_early_stop_tolerance=0,
                                 raymarching_chunk_size_or_sections=None,
                                 perturb_sample=False,
                                 extra_outputs=['weights'])
                self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0],
                                                        voxel_access_counts)
                iters_in_epoch += 1
                percent = iters_in_epoch / len(self.data_loader) * 100
                sys.stdout.write(f'Pruning by voxel weights...{percent:.1f}%   \r')
            self.data_loader.batch_size = batch_size
        before, after = space.prune(voxel_access_counts > 0)
        print_and_log(f"Prune by voxel weights: {before} -> {after}")

    def _accumulate_access_count_by_weight(self, samples: Samples, weights: torch.Tensor,
                                           voxel_access_counts: torch.Tensor):
        uni_vidxs = -torch.ones_like(samples.voxel_indices)
        vidx_accu = torch.zeros_like(samples.voxel_indices, dtype=torch.float)
        uni_vidxs_row = torch.arange(samples.size[0], dtype=torch.long, device=samples.device)
        uni_vidxs_head = torch.zeros_like(samples.voxel_indices[:, 0])
        uni_vidxs[:, 0] = samples.voxel_indices[:, 0]
        vidx_accu[:, 0].add_(weights[:, 0])
        for i in range(samples.size[1]):
            # For those rows that voxels are changed, move the head one step forward
            next_voxel = uni_vidxs[uni_vidxs_row, uni_vidxs_head].ne(samples.voxel_indices[:, i])
            uni_vidxs_head[next_voxel].add_(1)
            # Set voxel indices and accumulate weights
            uni_vidxs[uni_vidxs_row, uni_vidxs_head] = samples.voxel_indices[:, i]
            vidx_accu[uni_vidxs_row, uni_vidxs_head].add_(weights[:, i])
        max_accu = vidx_accu.max(dim=1, keepdim=True)[0]
        uni_vidxs[vidx_accu < max_accu * 0.1] = -1
        access_voxels, access_count = uni_vidxs.unique(return_counts=True)
        voxel_access_counts[access_voxels[1:]].add_(access_count[1:])