import torch

import model
from .base import *
from modules import *
from utils.mem_profiler import MemProfiler
from utils.perf import perf
from utils.misc import masked_scatter


class NeRF(BaseModel):

    trainer = "TrainWithSpace"
    SamplerClass = Sampler
    RendererClass = VolumnRenderer

    def __init__(self, args0: dict, args1: dict = {}):
        """
        Initialize a NeRF model

        :param args0 `dict`: basic arguments
        :param args1 `dict`: extra arguments, defaults to {}
        """
        if "sample_step_ratio" in args0:
            args1["sample_step"] = args0["voxel_size"] * args0["sample_step_ratio"]
        super().__init__(args0, args1)

        # Initialize components
        self._init_space()
        self._init_encoders()
        self._init_core()
        self.sampler = self.SamplerClass(**self.args)
        self.rendering = self.RendererClass(**self.args)

    def _init_encoders(self):
        self.pot_encoder = InputEncoder.Get(self.args['n_pot_encode'],
                                            self.args.get('n_featdim') or 3)
        if self.args.get('n_dir_encode'):
            self.dir_chns = 3
            self.dir_encoder = InputEncoder.Get(self.args['n_dir_encode'], self.dir_chns)
        else:
            self.dir_chns = 0
            self.dir_encoder = None

    def _init_space(self):
        if 'space' not in self.args:
            self.space = Space(**self.args)
        elif self.args['space'] == 'octree':
            self.space = Octree(**self.args)
        elif self.args['space'] == 'voxels':
            self.space = Voxels(**self.args)
        else:
            self.space = model.load(self.args['space'])[0].space
        if self.args.get('n_featdim'):
            self.space.create_embedding(self.args['n_featdim'])

    def _new_core_unit(self):
        return NerfCore(coord_chns=self.pot_encoder.out_dim,
                        density_chns=self.chns('density'),
                        color_chns=self.chns('color'),
                        core_nf=self.args['fc_params']['nf'],
                        core_layers=self.args['fc_params']['n_layers'],
                        dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
                        dir_nf=self.args['fc_params']['nf'] // 2,
                        act=self.args['fc_params']['activation'],
                        skips=self.args['fc_params']['skips'])

    def _create_core(self, n_nets=1):
        return self._new_core_unit() if n_nets == 1 else nn.ModuleList([
            self._new_core_unit() for _ in range(n_nets)
        ])

    def _init_core(self):
        if not self.args.get("net_bounds"):
            self.core = self._create_core()
        else:
            self.register_buffer("net_bounds", torch.tensor(self.args["net_bounds"]), False)
            self.cores = self._create_core(self.net_bounds.size(0))

    def render(self, samples: Samples, *outputs: str, **kwargs) -> Dict[str, torch.Tensor]:
        """
        Render colors, energies and other values (specified by `outputs`) of samples 
        (invalid items are filtered out)

        :param samples `Samples(N)`: samples
        :param outputs `str...`: which types of inferred data should be returned
        :return `Dict[str, Tensor(N, *)]`: outputs of cores
        """
        x = self.encode_x(samples)
        d = self.encode_d(samples)
        return self.infer(x, d, *outputs, pts=samples.pts, **kwargs)
    
    def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, pts: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
        """
        Infer colors, energies and other values (specified by `outputs`) of samples 
        (invalid items are filtered out) given their encoded positions and directions

        :param x `Tensor(N, Ex)`: encoded positions
        :param d `Tensor(N, Ed)`: encoded directions 
        :param outputs `str...`: which types of inferred data should be returned
        :param pts `Tensor(N, 3)`: raw sample positions
        :return `Dict[str, Tensor(N, *)]`: outputs of cores
        """
        if getattr(self, "core", None):
            return self.core(x, d, outputs)
        ret = {}
        for i, core in enumerate(self.cores):
            selector = (pts >= self.net_bounds[i, 0] and pts < self.net_bounds[i, 1]).all(-1)
            partial_ret = core(x[selector], d[selector], outputs)
            for key, value in partial_ret.items():
                if value is None:
                    ret[key] = None
                    continue
                if key not in ret:
                    ret[key] = torch.zeros(*x.shape[:-1], value.shape[-1], device=x.device)
                ret[key] = masked_scatter(selector, value, ret[key])
        return ret

    def embed(self, samples: Samples) -> torch.Tensor:
        return self.space.extract_embedding(samples.pts, samples.voxel_indices)

    def encode_x(self, samples: Samples) -> torch.Tensor:
        x = self.embed(samples) if self.args.get('n_featdim') else samples.pts
        return self.pot_encoder(x)

    def encode_d(self, samples: Samples) -> torch.Tensor:
        return self.dir_encoder(samples.dirs) if self.dir_encoder is not None else None

    @torch.no_grad()
    def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
        densities = self.render(Samples(sampled_points, None, None, None, sampled_voxel_indices),
                                'density')
        return 1 - (-densities).exp()

    @torch.no_grad()
    def pruning(self, threshold: float = 0.5, train_stats=False):
        return self.space.pruning(self.get_scores, threshold, train_stats)

    @torch.no_grad()
    def splitting(self):
        ret = self.space.splitting()
        if 'n_samples' in self.args0:
            self.args0['n_samples'] *= 2
        if 'voxel_size' in self.args0:
            self.args0['voxel_size'] /= 2
            if "sample_step_ratio" in self.args0:
                self.args1["sample_step"] = self.args0["voxel_size"] \
                    * self.args0["sample_step_ratio"]
        if 'sample_step' in self.args0:
            self.args0['sample_step'] /= 2
        self.sampler = self.SamplerClass(**self.args)
        return ret

    @torch.no_grad()
    def double_samples(self):
        pass

    @perf
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
                extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
        """
        Perform rendering for given rays.

        :param rays_o `Tensor(N, 3)`: rays' origin
        :param rays_d `Tensor(N, 3)`: rays' direction
        :param extra_outputs `list[str]`: extra items should be contained in the rendering result,
                                          defaults to []
        :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
        """
        args = {**self.args, **kwargs}
        with MemProfiler(f"{self.__class__}.forward: before sampling"):
            samples, rays_mask = self.sampler(rays_o, rays_d, self.space, **args)
        MemProfiler.print_memory_stats(f"{self.__class__}.forward: after sampling")
        with MemProfiler(f"{self.__class__}.forward: rendering"):
            if samples is None:
                return None
            return {
                **self.rendering(self, samples, extra_outputs, **args),
                'samples': samples,
                'rays_mask': rays_mask
            }