from .__common__ import *
from .base import BaseModel
from operator import itemgetter

from utils import math
from utils.misc import masked_scatter, merge


class NeRF(BaseModel):

    TrainerClass = "TrainWithSpace"
    SamplerClass = None
    RendererClass = None

    space: Union[Space, Voxels, Octree]

    @property
    def multi_nets(self) -> int:
        return self.args.get("multi_nets", 1)

    def __init__(self, args0: dict, args1: dict = None):
        """
        Initialize a NeRF model

        :param args0 `dict`: basic arguments
        :param args1 `dict`: extra arguments, defaults to {}
        """
        super().__init__(args0, args1)

        # Initialize components

        self._init_space()
        self._init_encoders()
        self._init_core()
        self._init_sampler()
        self._init_renderer()

    @perf
    def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
        inputs = inputs or self.input(samples)
        if len(self.cores) == 1:
            return self.cores[0](inputs, *outputs, samples=samples, **kwargs)
        return self._multi_infer(inputs, *outputs, samples=samples, **kwargs)

    @torch.no_grad()
    def split(self):
        ret = self.space.split()
        if 'n_samples' in self.args0:
            self.args0['n_samples'] *= 2
        if 'voxel_size' in self.args0:
            self.args0['voxel_size'] /= 2
            if "sample_step_ratio" in self.args0:
                self.args1["sample_step"] = self.args0["voxel_size"] \
                    * self.args0["sample_step_ratio"]
        if 'sample_step' in self.args0:
            self.args0['sample_step'] /= 2
        return ret

    def _preprocess_args(self):
        if "sample_step_ratio" in self.args0:
            self.args1["sample_step"] = self.args0["voxel_size"] * self.args0["sample_step_ratio"]
        if self.args0.get("spherical"):
            sample_range = [1 / self.args0['depth_range'][0], 1 / self.args0['depth_range'][1]] \
                if self.args0.get('depth_range') else [1, 0]
            rot_range = [[-180, -90], [180, 90]]
            self.args1['bbox'] = [
                [sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
                [sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
            ]
            self.args1['sample_range'] = sample_range
        if not self.args.get("net_bounds"):
            self.register_temp("net_bounds", None)
            self.args1["multi_nets"] = 1
        else:
            self.register_temp("net_bounds", torch.tensor(self.args["net_bounds"]))
            self.args1["multi_nets"] = self.net_bounds.size(0)

    def _init_chns(self, **chns):
        super()._init_chns(**{
            "x": self.args.get('n_featdim') or 3,
            "d": 3 if self.args.get('encode_d') else 0,
            **chns
        })

    def _init_space(self):
        self.space = Space.create(self.args)
        if self.args.get('n_featdim'):
            self.space.create_embedding(self.args['n_featdim'])

    def _init_encoders(self):
        self.x_encoder = InputEncoder(self.chns("x"), self.args['encode_x'], cat_input=True)
        self.d_encoder = InputEncoder(self.chns("d"), self.args['encode_d'])\
            if self.chns("d") > 0 else None

    def _init_core(self):
        self.cores = self.create_multiple(self._create_core_unit, self.args.get("multi_nets", 1))

    def _init_sampler(self):
        if self.SamplerClass is None:
            SamplerClass = Sampler
        else:
            SamplerClass = self.SamplerClass
        self.sampler = SamplerClass(**self.args)

    def _init_renderer(self):
        if self.RendererClass is None:
            if self.args.get("core") == "nerfadv":
                RendererClass = DensityFirstVolumnRenderer
            else:
                RendererClass = VolumnRenderer
        else:
            RendererClass = self.RendererClass
        self.renderer = RendererClass(**self.args)

    def _create_core_unit(self, core_params: dict = None, **args):
        core_params = core_params or self.args["core_params"]
        if self.args.get("core") == "nerfadv":
            return NerfAdvCore(**{
                "x_chns": self.x_encoder.out_dim,
                "d_chns": self.d_encoder.out_dim,
                "density_chns": self.chns('density'),
                "color_chns": self.chns('color'),
                **core_params,
                **args
            })
        else:
            return NerfCore(**{
                "x_chns": self.x_encoder.out_dim,
                "density_chns": self.chns('density'),
                "color_chns": self.chns('color'),
                "d_chns": self.d_encoder.out_dim if self.d_encoder else 0,
                **core_params,
                **args
            })

    @perf
    def _sample(self, data: InputData, **extra_args) -> Samples:
        return self.sampler(*itemgetter("rays_o", "rays_d")(data), self.space,
                            **merge(self.args, extra_args))

    @perf
    def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
        if len(samples.size) == 1:
            return self.infer(*outputs, samples=samples)
        return self.renderer(self, samples, *outputs, **merge(self.args, extra_args))

    def _input(self, samples: Samples, what: str) -> Optional[torch.Tensor]:
        if what == "x":
            if self.args.get('n_featdim'):
                return self._encode("emb", self.space.extract_embedding(
                    samples.pts, samples.voxel_indices))
            else:
                return self._encode("x", samples.pts)
        elif what == "d":
            if self.d_encoder and samples.dirs is not None:
                return self._encode("d", samples.dirs)
            else:
                return None
        elif what == "f":
            return None
        else:
            ValueError(f"Don't know how to process input \"{what}\"")

    def _encode(self, what: str, val: torch.Tensor) -> torch.Tensor:
        if what == "x":
            if self.args.get("spherical"):
                sr = self.args['sample_range']
                # scale val.r: [sr[0], sr[1]] -> [-PI/2, PI/2]
                val = val.clone()
                val[..., 0] = ((val[..., 0] - sr[0]) / (sr[1] - sr[0]) - .5) * math.pi
                return self.x_encoder(val)
            else:
                return self.x_encoder(val * math.pi)
        elif what == "emb":
            return self.x_encoder(val * math.pi)
        elif what == "d":
            return self.d_encoder(val, angular=True)
        else:
            ValueError(f"Don't know how to encode \"{what}\"")

    @perf
    def _multi_infer(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
        ret: NetOutput = {}
        for i, core in enumerate(self.cores):
            selector = (samples.pts >= self.net_bounds[i, 0]
                        and samples.pts < self.net_bounds[i, 1]).all(-1)
            partial_ret: NetOutput = core(inputs[selector], *outputs, samples=samples[selector],
                                          **kwargs)
            for key, value in partial_ret.items():
                if key not in ret:
                    ret[key] = value.new_zeros(*inputs.shape, value.shape[-1])
                ret[key] = masked_scatter(selector, value, ret[key])
        return ret


class NSVF(NeRF):

    SamplerClass = VoxelSampler