from utils.misc import dump_tensors_to_csv
from .__common__ import *
from .base import BaseModel
from typing import Callable

from .nerf import NeRF
from utils.voxels import trilinear_interp


class CNeRF(BaseModel):

    TrainerClass = "TrainMultiScale"

    class InterpSpace(object):

        def __init__(self, space: Voxels, vidxs: Tensor, feats_fn: Callable[[Any], Tensor]) -> None:
            super().__init__()
            self.space = space
            self.corner_indices, self.corners = space.get_corners(vidxs)
            self.feats_on_corners = feats_fn(self.corners)

        @perf
        def interp(self, samples: Samples) -> Tensor:
            with perf("Prepare for coarse interpolation"):
                voxels = self.space.voxels[samples.interp_vidxs]
                cidxs = self.corner_indices[samples.interp_vidxs]  # (N, 8)
                feats_on_corners = self.feats_on_corners[cidxs]  # (N, 8, X)
                # (N, 3) normed-coords in voxel
                p = (samples.pts - voxels) / self.space.voxel_size + .5

            with perf("Interpolate features"):
                return trilinear_interp(p, feats_on_corners)

    @property
    def stage(self):
        return self.args.get("stage", 0)

    def __init__(self, args0: dict, args1: dict = None):
        super().__init__(args0, args1)
        self.sub_models = []
        args0_for_submodel = {
            key: value for key, value in args0.items()
            if key != "sub_models" and key != "interp_on_coarse"
        }
        for i in range(len(self.args["sub_models"])):
            self.args["sub_models"][i] = {
                **args0_for_submodel,
                **self.args["sub_models"][i]
            }
            self.sub_models.append(NeRF(self.args["sub_models"][i], args1))
        self.sub_models = torch.nn.ModuleList(self.sub_models)
        for i in range(self.stage):
            print(f"__init__: freeze model {i}")
            self.model(i).freeze()

    def model(self, level: int) -> NeRF:
        return self.sub_models[level]

    def trigger_stage(self, stage: int):
        print(f"trigger_stage: freeze model {stage - 1}")
        self.model(stage - 1).freeze()
        self.model(stage).space = self.model(stage - 1).space.clone()
        self.args0["stage"] = stage

    @perf
    def infer(self, *outputs: str, samples: Samples, inputs: NetInput = None, **kwargs) -> NetOutput:
        inputs = inputs or self.input(samples)
        return self.model(samples.level).infer(*outputs, samples=samples, inputs=inputs, **kwargs)

    def print_config(self):
        for i, model in enumerate(self.sub_models):
            print(f"Model {i} =====>")
            model.print_config()

    @torch.no_grad()
    def split(self):
        return self.model(self.stage).split()

    def _input(self, samples: Samples, what: str) -> Optional[Tensor]:
        if what == "f":
            if samples.level == 0:
                return None
            if samples.interp_space is None:
                return self._infer_features(pts=samples.pts, level=samples.level - 1)
            return samples.interp_space.interp(samples)
        else:
            return self.model(samples.level)._input(samples, what)

    @perf
    def _sample(self, data: InputData, **extra_args) -> Samples:
        samples: Samples = self.model(data["level"])._sample(data, **extra_args)
        samples.level = data["level"]
        # TODO remove below
        #dump_tensors_to_csv(f"/home/dengnc/dvs/data/classroom/_nets/ms_train_t0.8/_cnerf_ioc/{'train' if self.training else 'test'}.csv",
        #                    samples.voxel_indices, data["rays_d"])
        return samples

    @perf
    def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
        self._prepare_interp(samples, on_coarse=self.args.get("interp_on_coarse"))
        return self.model(samples.level).renderer(self, samples, *outputs, **{
                                                  **self.model(samples.level).args, **extra_args})

    def _infer_features(self, samples: Samples = None, **sample_data) -> NetOutput:
        samples = samples or Samples(**sample_data)
        if self.args.get("interp_on_coarse"):
            self._prepare_interp(samples, on_coarse=True)
        inputs = self.input(samples, "x", "f")
        return self.infer("features", samples=samples, inputs=inputs)["features"]

    def _prepare_interp(self, samples: Samples, on_coarse: bool):
        if samples.level == 0:
            return
        if on_coarse:
            interp_space = self.model(samples.level - 1).space
            samples.interp_vidxs = interp_space.get_voxel_indices(samples.pts)
        else:
            interp_space = self.model(samples.level).space
            samples.interp_vidxs = samples.voxel_indices
        samples.interp_space = CNeRF.InterpSpace(interp_space, samples.interp_vidxs,
                                                 lambda corners: self._infer_features(
                                                     pts=corners, level=samples.level - 1))

    def _after_load_state_dict(self) -> None:
        a: torch.Tensor = None
        return
        print(list(self.model(0).named_parameters())[2])
        exit()