from modules.sampler import Samples
from modules.space import Octree, Voxels
from utils.mem_profiler import MemProfiler
from utils.misc import print_and_log
from .base import *


class TrainWithSpace(Train):

    def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000,
                 **kwargs) -> None:
        super().__init__(model, **kwargs)
        self.pruning_loop = pruning_loop
        self.splitting_loop = splitting_loop
        #MemProfiler.enable = True

    def _train_epoch(self):
        if not self.perf_mode:
            if self.epoch != 1:
                if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1:
                    try:
                        with torch.no_grad():
                            before, after = self.model.splitting()
                        print_and_log(
                            f"Splitting done. # of voxels before: {before}, after: {after}")
                    except NotImplementedError:
                        print_and_log(
                            "Note: The space does not support splitting operation. Just skip it.")
                if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
                    try:
                        with torch.no_grad():
                            #before, after = self.model.pruning()
                            # print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
                            # self._prune_inner_voxels()
                            self._prune_voxels_by_weights()
                    except NotImplementedError:
                        print_and_log(
                            "Note: The space does not support pruning operation. Just skip it.")

        super()._train_epoch()

    def _prune_inner_voxels(self):
        space: Voxels = self.model.space
        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
                                          device=space.voxels.device)
        iters_in_epoch = 0
        batch_size = self.data_loader.batch_size
        self.data_loader.batch_size = 2 ** 14
        for _, rays_o, rays_d, _ in self.data_loader:
            self.model(rays_o, rays_d,
                       raymarching_early_stop_tolerance=0.01,
                       raymarching_chunk_size_or_sections=[1],
                       perturb_sample=False,
                       voxel_access_counts=voxel_access_counts,
                       voxel_access_tolerance=0)
            iters_in_epoch += 1
            percent = iters_in_epoch / len(self.data_loader) * 100
            sys.stdout.write(f'Pruning inner voxels...{percent:.1f}%   \r')
        self.data_loader.batch_size = batch_size
        before, after = space.prune(voxel_access_counts > 0)
        print(f"Prune inner voxels: {before} -> {after}")

    def _prune_voxels_by_weights(self):
        space: Voxels = self.model.space
        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
                                          device=space.voxels.device)
        iters_in_epoch = 0
        batch_size = self.data_loader.batch_size
        self.data_loader.batch_size = 2 ** 14
        for _, rays_o, rays_d, _ in self.data_loader:
            ret = self.model(rays_o, rays_d,
                             raymarching_early_stop_tolerance=0,
                             raymarching_chunk_size_or_sections=None,
                             perturb_sample=False,
                             extra_outputs=['weights'])
            valid_mask = ret['weights'][..., 0] > 0.01
            accessed_voxels = ret['samples'].voxel_indices[valid_mask]
            voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels))
            iters_in_epoch += 1
            percent = iters_in_epoch / len(self.data_loader) * 100
            sys.stdout.write(f'Pruning by weights...{percent:.1f}%   \r')
        self.data_loader.batch_size = batch_size
        before, after = space.prune(voxel_access_counts > 0)
        print_and_log(f"Prune by weights: {before} -> {after}")

    def _prune_voxels_by_voxel_weights(self):
        space: Voxels = self.model.space
        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
                                          device=space.voxels.device)
        with torch.no_grad():
            batch_size = self.data_loader.batch_size
            self.data_loader.batch_size = 2 ** 14
            iters_in_epoch = 0
            for _, rays_o, rays_d, _ in self.data_loader:
                ret = self.model(rays_o, rays_d,
                                 raymarching_early_stop_tolerance=0,
                                 raymarching_chunk_size_or_sections=None,
                                 perturb_sample=False,
                                 extra_outputs=['weights'])
                self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0],
                                                        voxel_access_counts)
                iters_in_epoch += 1
                percent = iters_in_epoch / len(self.data_loader) * 100
                sys.stdout.write(f'Pruning by voxel weights...{percent:.1f}%   \r')
            self.data_loader.batch_size = batch_size
        before, after = space.prune(voxel_access_counts > 0)
        print_and_log(f"Prune by voxel weights: {before} -> {after}")

    def _accumulate_access_count_by_weight(self, samples: Samples, weights: torch.Tensor,
                                           voxel_access_counts: torch.Tensor):
        uni_vidxs = -torch.ones_like(samples.voxel_indices)
        vidx_accu = torch.zeros_like(samples.voxel_indices, dtype=torch.float)
        uni_vidxs_row = torch.arange(samples.size[0], dtype=torch.long, device=samples.device)
        uni_vidxs_head = torch.zeros_like(samples.voxel_indices[:, 0])
        uni_vidxs[:, 0] = samples.voxel_indices[:, 0]
        vidx_accu[:, 0].add_(weights[:, 0])
        for i in range(samples.size[1]):
            # For those rows that voxels are changed, move the head one step forward
            next_voxel = uni_vidxs[uni_vidxs_row, uni_vidxs_head].ne(samples.voxel_indices[:, i])
            uni_vidxs_head[next_voxel].add_(1)
            # Set voxel indices and accumulate weights
            uni_vidxs[uni_vidxs_row, uni_vidxs_head] = samples.voxel_indices[:, i]
            vidx_accu[uni_vidxs_row, uni_vidxs_head].add_(weights[:, i])
        max_accu = vidx_accu.max(dim=1, keepdim=True)[0]
        uni_vidxs[vidx_accu < max_accu * 0.1] = -1
        access_voxels, access_count = uni_vidxs.unique(return_counts=True)
        voxel_access_counts[access_voxels[1:]].add_(access_count[1:])