from collections import defaultdict
from statistics import mean
from tqdm import tqdm
from tensorboardX import SummaryWriter
from torch.optim import Optimizer, Adam
from torch.optim.lr_scheduler import _LRScheduler, ExponentialLR

from utils import netio, logging, loss, misc
from utils.args import BaseArgs
from utils.profile import Profiler, enable_profile, profile
from utils.types import *
from data import Dataset, RaysLoader
from model import Model


trainer_classes: dict[str, "Trainer"] = {}


class Trainer:
    class Args(BaseArgs):
        max_iters: int | None
        max_epochs: int = 20
        checkpoint_interval: int | None
        batch_size: int = 4096
        loss: list[str] = ["Color_L2", "CoarseColor_L2"]
        lr: float = 5e-4
        lr_decay: float | None
        profile_iters: int | None
        trainset: str

    args: Args
    states: dict[str, Any]
    optimizer: Optimizer
    scheduler: _LRScheduler | None
    loss_defs = {
        "Color_L2": {
            "fn": lambda out, gt: loss.mse_loss(out["color"], gt["color"]),
            "required_outputs": ["color"]
        },
        "CoarseColor_L2": {
            "fn": lambda out, gt: loss.mse_loss(out["coarse_color"], gt["color"]),
            "required_outputs": ["coarse_color"]
        },
        "Density_Reg": {
            "fn": lambda out, gt: loss.cauchy_loss(out["densities"], 1e4) * 1e-4,
            "required_outputs": ["densities"]
        }
    }

    @property
    def profile_mode(self) -> bool:
        return self.profile_iters is not None

    @staticmethod
    def get_class(typename: str) -> Type["Trainer"] | None:
        return trainer_classes.get(typename)

    def __init__(self, model: Model, run_dir: Path, args: Args = None) -> None:
        self.model = model.train()
        self.run_dir = run_dir
        self.args = args or self.__class__.Args()
        self.epoch = 0
        self.iters = 0

        self.profile_warmup_iters = 10
        self.profile_iters = self.args.profile_iters
        if self.profile_mode:  # profile mode
            self.max_iters = self.profile_warmup_iters + self.profile_iters
            self.max_epochs = None
            self.checkpoint_interval = None
        elif self.args.max_iters:  # iters mode
            self.max_iters = self.args.max_iters
            self.max_epochs = None
            self.checkpoint_interval = self.args.checkpoint_interval or 10000
        else:  # epochs mode
            self.max_iters = None
            self.max_epochs = self.args.max_epochs
            self.checkpoint_interval = self.args.checkpoint_interval or 10

        self._init_optimizer()
        self._init_scheduler()

        self.required_outputs = []
        for key in self.args.loss:
            self.required_outputs += self.loss_defs[key]["required_outputs"]
        self.required_outputs = list(set(self.required_outputs))

        if self.profile_mode:  # Enable performance measurement in profile mode
            def handle_profile_result(result: Profiler.ProfileResult):
                print(result.get_report())
                exit()
            enable_profile(self.profile_warmup_iters, self.profile_iters, handle_profile_result)
        else:  # Enable logging (Tensorboard & txt) in normal mode
            tb_log_dir = self.run_dir / "_log"
            tb_log_dir.mkdir(exist_ok=True)
            self.tb_writer = SummaryWriter(tb_log_dir, purge_step=0)
            logging.initialize(self.run_dir / "train.log")

        logging.print_and_log(f"Model arguments: {self.model.args}")
        logging.print_and_log(f"Trainer arguments: {self.args}")

        # Debug: print model structure
        print(model)

    def state_dict(self) -> dict[str, Any]:
        return {
            "model": self.model.state_dict(),
            "epoch": self.epoch,
            "iters": self.iters,
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict() if self.scheduler else None
        }

    def load_state_dict(self, state_dict: dict[str, Any]):
        self.epoch = state_dict.get("epoch", self.epoch)
        self.iters = state_dict.get("iters", self.iters)
        if "model" in state_dict:
            self.model.load_state_dict(state_dict["model"])
        if "optimizer" in state_dict:
            self.optimizer.load_state_dict(state_dict["optimizer"])
        if self.scheduler and "scheduler" in state_dict:
            self.scheduler.load_state_dict(state_dict["scheduler"])

    def reset_optimizer(self):
        self._init_optimizer()
        if self.scheduler is not None:
            scheduler_state = self.scheduler.state_dict()
            self._init_scheduler()
            self.scheduler.load_state_dict(scheduler_state)

    def train(self, dataset: Dataset):
        self.rays_loader = RaysLoader(dataset, self.args.batch_size, shuffle=True,
                                      device=self.model.device)
        self.forward_chunk_size = self.args.batch_size

        if self.max_iters:
            print(f"Begin training... Max iters: {self.max_iters}")
            self.progress = tqdm(total=self.max_iters, dynamic_ncols=True)
            self.rays_iter = self.rays_loader.__iter__()
            while self.iters < self.max_iters:
                self._train_iters(min(self.checkpoint_interval, self.max_iters - self.iters))
                self._save_checkpoint()
        else:
            print(f"Begin training... Max epochs: {self.max_epochs}")
            while self.epoch < self.max_epochs:
                self._train_epoch()
                self._save_checkpoint()

        print("Train finished")

    @staticmethod
    def create(model: Model, run_dir: PathLike, typename: str, args: dict[str, Any] = None) -> "Trainer":
        if typename not in trainer_classes:
            raise ValueError(f"Class {typename} is not found")
        return trainer_classes.get(typename)(model, run_dir, args)

    def _init_scheduler(self):
        self.scheduler = self.args.lr_decay and ExponentialLR(self.optimizer, self.args.lr_decay)

    def _init_optimizer(self):
        self.optimizer = Adam(self.model.parameters(), lr=self.args.lr)

    def _save_checkpoint(self):
        if self.checkpoint_interval is None:
            return

        ckpt = {
            "args": {
                "model": self.model.__class__.__name__,
                "model_args": vars(self.model.args),
                "trainer": self.__class__.__name__,
                "trainer_args": vars(self.args)
            },
            "states": self.state_dict()
        }

        if self.max_iters:
            # For iters mode, a checkpoint will be saved every `checkpoint_interval` iterations
            netio.save_checkpoint(ckpt, self.run_dir, self.iters)
        else:
            # For epochs mode, a checkpoint will be saved every epoch.
            # Checkpoints which don't match `checkpoint_interval` will be cleaned later
            netio.clean_checkpoint(self.run_dir, self.checkpoint_interval)
            netio.save_checkpoint(ckpt, self.run_dir, self.epoch)

    def _update_progress(self, loss: float = 0):
        self.progress.set_postfix_str(f"Loss: {loss:.2e}" if loss > 0 else "")
        self.progress.update()

    @profile("Forward")
    def _forward(self, rays: Rays) -> ReturnData:
        return self.model(rays, *self.required_outputs)

    @profile("Compute Loss")
    def _compute_loss(self, rays: Rays, out: ReturnData) -> dict[str, torch.Tensor]:
        torch.isnan
        gt = rays.select(out["rays_filter"]) if "rays_filter" in out else rays
        loss_terms: dict[str, torch.Tensor] = {}
        for key in self.args.loss:
            try:
                loss_terms[key] = self.loss_defs[key]["fn"](out, gt)
            except KeyError:
                pass
        # Debug: print loss terms
        #self.progress.write(",".join([f"{key}: {value.item():.2e}" for key, value in loss_terms.items()]))
        return loss_terms

    @profile("Train iteration")
    def _train_iter(self, rays: Rays) -> float:
        try:
            self.optimizer.zero_grad(True)
            loss_terms = defaultdict(list)
            for offset in range(0, rays.shape[0], self.forward_chunk_size):
                rays_chunk = rays.select(slice(offset, offset + self.forward_chunk_size))
                out_chunk = self._forward(rays_chunk)
                loss_chunk = self._compute_loss(rays_chunk, out_chunk)
                loss_value = sum(loss_chunk.values())
                with profile("Backward"):
                    loss_value.backward()
                loss_terms["Overall_Loss"].append(loss_value.item())
                for key, value in loss_chunk.items():
                    loss_terms[key].append(value.item())
            loss_terms = {key: mean(value) for key, value in loss_terms.items()}
            
            with profile("Update"):
                self.optimizer.step()
                if self.scheduler:
                    self.scheduler.step()
                    # Debug: print lr
                    #self.progress.write(f"Learning rate: {self.optimizer.param_groups[0]['lr']}")
                self.iters += 1

            if hasattr(self, "tb_writer"):
                for key, value in loss_terms.items():
                    self.tb_writer.add_scalar(f"Loss/{key}", value, self.iters)

            return loss_terms["Overall_Loss"]
        except RuntimeError as e:
            if not e.__str__().startswith("CUDA out of memory"):
                raise e
        self.progress.write("CUDA out of memory, half forward batch and retry.")
        logging.warning("CUDA out of memory, half forward batch and retry.")
        self.forward_chunk_size //= 2
        torch.cuda.empty_cache()
        return self._train_iter(rays)

    def _train_iters(self, iters: int):
        recent_loss_list = []
        tot_loss = 0
        train_iters_node = Profiler.Node("Train Iterations")
        for _ in range(iters):
            try:
                rays = self.rays_iter.__next__()
            except StopIteration:
                self.rays_iter = self.rays_loader.__iter__()  # A new epoch
                rays = self.rays_iter.__next__()
            loss_val = self._train_iter(rays)
            recent_loss_list = (recent_loss_list + [loss_val])[-50:]  # Keep recent 50 iterations
            recent_avg_loss = sum(recent_loss_list) / len(recent_loss_list)
            tot_loss += loss_val
            self._update_progress(recent_avg_loss)
        train_iters_node.close()
        torch.cuda.synchronize()
        avg_time = train_iters_node.device_duration / 1000 / iters
        avg_loss = tot_loss / iters
        state_str = f"Iter {self.iters}: Avg. {misc.format_time(avg_time)}/iter; Loss: {avg_loss:.2e}"
        self.progress.write(state_str)
        logging.info(state_str)

    def _train_epoch(self):
        iters_per_epoch = len(self.rays_loader)
        recent_loss_list = []
        tot_loss = 0

        self.progress = tqdm(total=iters_per_epoch, desc=f"Epoch {self.epoch + 1:<3d}", leave=False,
                             dynamic_ncols=True)
        train_epoch_node = Profiler.Node("Train Epoch")
        for rays in self.rays_loader:
            with profile("Train iteration"):
                loss_val = self._train_iter(rays)
            recent_loss_list = (recent_loss_list + [loss_val])[-50:]
            recent_avg_loss = sum(recent_loss_list) / len(recent_loss_list)
            tot_loss += loss_val
            self._update_progress(recent_avg_loss)
        self.progress.close()
        train_epoch_node.close()
        torch.cuda.synchronize()
        self.epoch += 1
        epoch_time = train_epoch_node.device_duration / 1000
        avg_time = epoch_time / iters_per_epoch
        avg_loss = tot_loss / iters_per_epoch
        state_str = f"Epoch {self.epoch} spent {misc.format_time(epoch_time)} "\
            f"(Avg. {misc.format_time(avg_time)}/iter). Loss is {avg_loss:.2e}."
        logging.print_and_log(state_str)