from modules.sampler import Samples from modules.space import Octree, Voxels from utils.mem_profiler import MemProfiler from utils.misc import print_and_log from .base import * class TrainWithSpace(Train): def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000, **kwargs) -> None: super().__init__(model, **kwargs) self.pruning_loop = pruning_loop self.splitting_loop = splitting_loop #MemProfiler.enable = True def _train_epoch(self): if not self.perf_mode: if self.epoch != 1: if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1: try: with torch.no_grad(): before, after = self.model.splitting() print_and_log( f"Splitting done. # of voxels before: {before}, after: {after}") except NotImplementedError: print_and_log( "Note: The space does not support splitting operation. Just skip it.") if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1: try: with torch.no_grad(): #before, after = self.model.pruning() # print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}") # self._prune_inner_voxels() self._prune_voxels_by_weights() except NotImplementedError: print_and_log( "Note: The space does not support pruning operation. Just skip it.") super()._train_epoch() def _prune_inner_voxels(self): space: Voxels = self.model.space voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, device=space.voxels.device) iters_in_epoch = 0 batch_size = self.data_loader.batch_size self.data_loader.batch_size = 2 ** 14 for _, rays_o, rays_d, _ in self.data_loader: self.model(rays_o, rays_d, raymarching_early_stop_tolerance=0.01, raymarching_chunk_size_or_sections=[1], perturb_sample=False, voxel_access_counts=voxel_access_counts, voxel_access_tolerance=0) iters_in_epoch += 1 percent = iters_in_epoch / len(self.data_loader) * 100 sys.stdout.write(f'Pruning inner voxels...{percent:.1f}% \r') self.data_loader.batch_size = batch_size before, after = space.prune(voxel_access_counts > 0) print(f"Prune inner voxels: {before} -> {after}") def _prune_voxels_by_weights(self): space: Voxels = self.model.space voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, device=space.voxels.device) iters_in_epoch = 0 batch_size = self.data_loader.batch_size self.data_loader.batch_size = 2 ** 14 for _, rays_o, rays_d, _ in self.data_loader: ret = self.model(rays_o, rays_d, raymarching_early_stop_tolerance=0, raymarching_chunk_size_or_sections=None, perturb_sample=False, extra_outputs=['weights']) valid_mask = ret['weights'][..., 0] > 0.01 accessed_voxels = ret['samples'].voxel_indices[valid_mask] voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels)) iters_in_epoch += 1 percent = iters_in_epoch / len(self.data_loader) * 100 sys.stdout.write(f'Pruning by weights...{percent:.1f}% \r') self.data_loader.batch_size = batch_size before, after = space.prune(voxel_access_counts > 0) print_and_log(f"Prune by weights: {before} -> {after}") def _prune_voxels_by_voxel_weights(self): space: Voxels = self.model.space voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, device=space.voxels.device) with torch.no_grad(): batch_size = self.data_loader.batch_size self.data_loader.batch_size = 2 ** 14 iters_in_epoch = 0 for _, rays_o, rays_d, _ in self.data_loader: ret = self.model(rays_o, rays_d, raymarching_early_stop_tolerance=0, raymarching_chunk_size_or_sections=None, perturb_sample=False, extra_outputs=['weights']) self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0], voxel_access_counts) iters_in_epoch += 1 percent = iters_in_epoch / len(self.data_loader) * 100 sys.stdout.write(f'Pruning by voxel weights...{percent:.1f}% \r') self.data_loader.batch_size = batch_size before, after = space.prune(voxel_access_counts > 0) print_and_log(f"Prune by voxel weights: {before} -> {after}") def _accumulate_access_count_by_weight(self, samples: Samples, weights: torch.Tensor, voxel_access_counts: torch.Tensor): uni_vidxs = -torch.ones_like(samples.voxel_indices) vidx_accu = torch.zeros_like(samples.voxel_indices, dtype=torch.float) uni_vidxs_row = torch.arange(samples.size[0], dtype=torch.long, device=samples.device) uni_vidxs_head = torch.zeros_like(samples.voxel_indices[:, 0]) uni_vidxs[:, 0] = samples.voxel_indices[:, 0] vidx_accu[:, 0].add_(weights[:, 0]) for i in range(samples.size[1]): # For those rows that voxels are changed, move the head one step forward next_voxel = uni_vidxs[uni_vidxs_row, uni_vidxs_head].ne(samples.voxel_indices[:, i]) uni_vidxs_head[next_voxel].add_(1) # Set voxel indices and accumulate weights uni_vidxs[uni_vidxs_row, uni_vidxs_head] = samples.voxel_indices[:, i] vidx_accu[uni_vidxs_row, uni_vidxs_head].add_(weights[:, i]) max_accu = vidx_accu.max(dim=1, keepdim=True)[0] uni_vidxs[vidx_accu < max_accu * 0.1] = -1 access_voxels, access_count = uni_vidxs.unique(return_counts=True) voxel_access_counts[access_voxels[1:]].add_(access_count[1:])