import os
import time
from itertools import groupby
from typing import TypeVar

from .types import *


class Profiler(object):
    frames: list["Profiler.ResultNode"]

    class Node:
        @property
        def host_duration(self) -> float:
            if hasattr(self, "_host_duration"):
                return self._host_duration
            if not self.closed:
                raise RuntimeError("Cannot get host duration of an unclosed node")
            self._host_duration = (self.host_end_time - self.host_start_time) * 1000
            return self._host_duration

        @property
        def device_duration(self) -> float:
            if hasattr(self, "_device_duration"):
                return self._device_duration
            if not self.closed:
                raise RuntimeError("Cannot get device duration of an unclosed node")
            self._device_duration = self.device_start_event.elapsed_time(
                self.device_end_event)
            return self._device_duration

        def __init__(self, name, parent: "Profiler.Node" = None) -> None:
            self.name = name
            self.parent = parent
            self.child_nodes: list[Profiler.Node] = []
            self.device_start_event = torch.cuda.Event(True)
            self.device_start_event.record()
            self.host_start_time = time.perf_counter()
            self.closed = False

        def add_child(self, name):
            if self.closed:
                raise RuntimeError("Cannot add child to a closed node")
            child = Profiler.Node(name, self)
            self.child_nodes.append(child)
            return child

        def close(self):
            if self.closed:
                raise RuntimeError("The node has been closed")
            self.closed = True
            self.host_end_time = time.perf_counter()
            self.device_end_event = torch.cuda.Event(True)
            self.device_end_event.record()
            return self.parent

        def get_result_node(self, parent_path: str = "") -> "Profiler.ResultNode":
            if not self.closed:
                raise RuntimeError("Cannot get result of an unclosed node")
            path = f"{parent_path}/{self.name}"
            return Profiler.ResultNode(
                path, [child.get_result_node(parent_path=path) for child in self.child_nodes],
                host_duration=self.host_duration, device_duration=self.device_duration
            )

        def __repr__(self) -> str:
            ret = f"{self.__class__.__name__} \"{self.name}\" "
            if self.closed:
                ret += f"[host spent {self.host_duration:2f}ms, device spent {self.device_duration:2f}ms]"
            else:
                ret += "[Not closed]"

    class ResultNode(object):
        path: str
        data: dict[str, float]

        @property
        def name(self) -> str:
            return os.path.split(self.path)[1]

        def __init__(self, path: str, child_nodes: list["Profiler.ResultNode"],
                     **data: float) -> None:
            self.path = path
            self.child_nodes = child_nodes
            self.data = data

        def value(self, key: str) -> float:
            if key.startswith("self_"):
                key = key[5:]
                return self.data[key] - sum([child.value(key) for child in self.child_nodes])
            return self.data[key]

        def flatten(self) -> list["Profiler.ResultNode"]:
            result_list = [self]
            for child in self.child_nodes:
                result_list += child.flatten()
            return result_list

    class ResultNodeGroup(object):

        @property
        def path(self) -> str:
            return self.nodes[0].path

        @property
        def name(self) -> str:
            return self.nodes[0].name

        def __init__(self, nodes: Iterable["Profiler.ResultNode"]) -> None:
            self.nodes = list(nodes)
            self.count = len(self.nodes)

        def total(self, key: str):
            return sum([node.value(key) for node in self.nodes])

        def average(self, key: str):
            return self.total(key) / self.count

    class ResultNodeGroupList(list["Profiler.ResultNodeGroup"]):
        def __init__(self, path: str, __iterable: Iterable["Profiler.ResultNodeGroup"]) -> None:
            super().__init__(__iterable)
            self.path = path

        def total_in_frame(self, key: str) -> list[float]:
            return [node_group and node_group.total(key) for node_group in self]

        def average_in_frame(self, key: str) -> list[float]:
            return [node_group and node_group.average(key) for node_group in self]

        def count_in_frame(self) -> list[int]:
            return [node_group and node_group.count for node_group in self]

        def total(self, key: str) -> float:
            return sum(filter(None, self.total_in_frame(key)))

        def count(self) -> int:
            return sum(filter(None, self.count_in_frame()))

        def average_by_frame(self, key: str) -> float:
            n_frames = len(list(filter(None, self)))
            return self.total(key) / n_frames

        def average_by_node(self, key: str) -> float:
            return self.total(key) / self.count()

    class ProfileResult(list["ResultNodeGroupList"]):
        def get_index_by_path(self, path: str, start: int = 0, end: int = None) -> int:
            if end is None:
                end = len(self)
            elif end < 0:
                end = len(self) + end
            for i in range(start, end):
                if self[i].path == path:
                    return i
            return -1

        def get_report(self):
            s = "Performance Report:\n"
            if len(self) == 0:
                s += "No available data.\n"
                return s
            for node_group_list in self:
                parts = node_group_list.path.split("/")
                s += f"{'  ' * (len(parts) - 2)}{parts[-1]}: "\
                    f"{node_group_list.average_by_frame('device_duration'):.2f}ms\n"
            return s

    def __init__(self, warmup_frames: int = 0, record_frames: int = 0,
                 then: Callable[["Profiler.ProfileResult"], Any] = None) -> None:
        super().__init__()
        self.root_node = None
        self.current_node = None
        self.frames = []
        self.warmup_frames = warmup_frames
        self.record_frames = record_frames
        self.frame_counter = 0
        self.enabled = True
        self.then_fn = then

    def enter_node(self, name):
        if not self.enabled:
            return
        if self.current_node is None:
            self.root_node = self.current_node = Profiler.Node(name)
        else:
            self.current_node = self.current_node.add_child(name)

    def leave_node(self):
        if not self.enabled:
            return
        self.current_node = self.current_node.close()
        if self.current_node is None:
            torch.cuda.synchronize()
            if self.frame_counter >= self.warmup_frames:
                self.frames.append(self.root_node.get_result_node())
            self.frame_counter += 1
            if self.frame_counter >= self.warmup_frames + self.record_frames:
                self.enabled = False
                self.then_fn(self.get_result())

    def get_result(self) -> "Profiler.ProfileResult":
        if len(self.frames) == 0:
            return Profiler.ProfileResult()

        flat_frames = [frame.flatten() for frame in self.frames]
        grouped_frames = [
            [
                Profiler.ResultNodeGroup(node_iter)
                for _, node_iter in groupby(frame, lambda item: item.path)
            ] for frame in flat_frames
        ]
        profile_result = Profiler.ProfileResult(
            Profiler.ResultNodeGroupList(node_group.path, [node_group])
            for node_group in grouped_frames[0]
        )
        for i, frame in enumerate(grouped_frames):
            if i == 0:
                continue
            target_head = 0
            for node_group in frame:
                matched_index = profile_result.get_index_by_path(node_group.path, target_head)
                if matched_index == -1:
                    profile_result.insert(target_head,
                                          Profiler.ResultNodeGroupList(node_group.path,
                                                                       [None] * i + [node_group]))
                    target_head += 1
                else:
                    for j in range(target_head, matched_index):
                        profile_result[j].append(None)
                    profile_result[matched_index].append(node_group)
                    target_head = matched_index + 1
        return profile_result


default_profiler = None


def enable_profile(warmup_frames: int = 0, record_frames: int = 0,
                   then: Callable[["Profiler.ProfileResult"], Any] = None):
    global default_profiler
    default_profiler = Profiler(warmup_frames, record_frames, then)


class _ProfileWrap(object):
    def __init__(self, fn: Callable = None, name: str = None) -> None:
        super().__init__()
        self.fn = fn
        self.name = name

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        if self.fn == None and len(args) == 1 and isinstance(args[0], Callable):
            self.fn = args[0]
            return lambda *args, **kwargs: self(*args, **kwargs)
        self.__enter__()
        ret = self.fn(*args, **kwargs)
        self.__exit__()
        return ret

    def __enter__(self):
        #print(f"Start node \"{self.name or self.fn.__qualname__}\"")
        start_profile_node(self.name or self.fn.__qualname__)
        return self

    def __exit__(self, *args: Any, **kwargs: Any):
        #print(f"End node \"{self.name or self.fn.__qualname__}\"")
        end_profile_node()


class _DebugProfileWrap(object):
    def __init__(self, fn: Callable = None, name: str = None) -> None:
        super().__init__()
        self.fn = fn
        self.name = name

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        if self.fn == None and len(args) == 1 and isinstance(args[0], Callable):
            self.fn = args[0]
            return lambda *args, **kwargs: self(*args, **kwargs)
        self.__enter__()
        ret = self.fn(*args, **kwargs)
        self.__exit__()
        return ret

    def __enter__(self):
        #print(f"Start node \"{self.name or self.fn.__qualname__}\"")
        self.node = Profiler.Node(self.name)
        return self

    def __exit__(self, *args: Any, **kwargs: Any):
        #print(f"End node \"{self.name or self.fn.__qualname__}\"")
        self.node.close()
        torch.cuda.synchronize()
        print(f"Node {self.name}: host duration {self.node.host_duration:.1f}ms, "
              f"device duration {self.node.device_duration:.1f}ms")


FnRet = TypeVar("FnRet")
def profile(arg: str | Callable[..., FnRet]) -> _ProfileWrap | Callable[..., FnRet]:
    if isinstance(arg, str):
        return _ProfileWrap(name=arg)
    else:
        return lambda *args, **kwargs: _ProfileWrap(fn=arg)(*args, **kwargs)


def debug_profile(arg: str | Callable):
    if isinstance(arg, str):
        return _DebugProfileWrap(name=arg)
    else:
        return lambda *args, **kwargs: _DebugProfileWrap(fn=arg)(*args, **kwargs)


def start_profile_node(name):
    if default_profiler is not None:
        default_profiler.enter_node(name)


def end_profile_node():
    if default_profiler is not None:
        default_profiler.leave_node()


def get_profile_result():
    if default_profiler is not None:
        return default_profiler.get_result_report()
    return None