import os import time from itertools import groupby from typing import TypeVar from .types import * class Profiler(object): frames: list["Profiler.ResultNode"] class Node: @property def host_duration(self) -> float: if hasattr(self, "_host_duration"): return self._host_duration if not self.closed: raise RuntimeError("Cannot get host duration of an unclosed node") self._host_duration = (self.host_end_time - self.host_start_time) * 1000 return self._host_duration @property def device_duration(self) -> float: if hasattr(self, "_device_duration"): return self._device_duration if not self.closed: raise RuntimeError("Cannot get device duration of an unclosed node") self._device_duration = self.device_start_event.elapsed_time( self.device_end_event) return self._device_duration def __init__(self, name, parent: "Profiler.Node" = None) -> None: self.name = name self.parent = parent self.child_nodes: list[Profiler.Node] = [] self.device_start_event = torch.cuda.Event(True) self.device_start_event.record() self.host_start_time = time.perf_counter() self.closed = False def add_child(self, name): if self.closed: raise RuntimeError("Cannot add child to a closed node") child = Profiler.Node(name, self) self.child_nodes.append(child) return child def close(self): if self.closed: raise RuntimeError("The node has been closed") self.closed = True self.host_end_time = time.perf_counter() self.device_end_event = torch.cuda.Event(True) self.device_end_event.record() return self.parent def get_result_node(self, parent_path: str = "") -> "Profiler.ResultNode": if not self.closed: raise RuntimeError("Cannot get result of an unclosed node") path = f"{parent_path}/{self.name}" return Profiler.ResultNode( path, [child.get_result_node(parent_path=path) for child in self.child_nodes], host_duration=self.host_duration, device_duration=self.device_duration ) def __repr__(self) -> str: ret = f"{self.__class__.__name__} \"{self.name}\" " if self.closed: ret += f"[host spent {self.host_duration:2f}ms, device spent {self.device_duration:2f}ms]" else: ret += "[Not closed]" class ResultNode(object): path: str data: dict[str, float] @property def name(self) -> str: return os.path.split(self.path)[1] def __init__(self, path: str, child_nodes: list["Profiler.ResultNode"], **data: float) -> None: self.path = path self.child_nodes = child_nodes self.data = data def value(self, key: str) -> float: if key.startswith("self_"): key = key[5:] return self.data[key] - sum([child.value(key) for child in self.child_nodes]) return self.data[key] def flatten(self) -> list["Profiler.ResultNode"]: result_list = [self] for child in self.child_nodes: result_list += child.flatten() return result_list class ResultNodeGroup(object): @property def path(self) -> str: return self.nodes[0].path @property def name(self) -> str: return self.nodes[0].name def __init__(self, nodes: Iterable["Profiler.ResultNode"]) -> None: self.nodes = list(nodes) self.count = len(self.nodes) def total(self, key: str): return sum([node.value(key) for node in self.nodes]) def average(self, key: str): return self.total(key) / self.count class ResultNodeGroupList(list["Profiler.ResultNodeGroup"]): def __init__(self, path: str, __iterable: Iterable["Profiler.ResultNodeGroup"]) -> None: super().__init__(__iterable) self.path = path def total_in_frame(self, key: str) -> list[float]: return [node_group and node_group.total(key) for node_group in self] def average_in_frame(self, key: str) -> list[float]: return [node_group and node_group.average(key) for node_group in self] def count_in_frame(self) -> list[int]: return [node_group and node_group.count for node_group in self] def total(self, key: str) -> float: return sum(filter(None, self.total_in_frame(key))) def count(self) -> int: return sum(filter(None, self.count_in_frame())) def average_by_frame(self, key: str) -> float: n_frames = len(list(filter(None, self))) return self.total(key) / n_frames def average_by_node(self, key: str) -> float: return self.total(key) / self.count() class ProfileResult(list["ResultNodeGroupList"]): def get_index_by_path(self, path: str, start: int = 0, end: int = None) -> int: if end is None: end = len(self) elif end < 0: end = len(self) + end for i in range(start, end): if self[i].path == path: return i return -1 def get_report(self): s = "Performance Report:\n" if len(self) == 0: s += "No available data.\n" return s for node_group_list in self: parts = node_group_list.path.split("/") s += f"{' ' * (len(parts) - 2)}{parts[-1]}: "\ f"{node_group_list.average_by_frame('device_duration'):.2f}ms\n" return s def __init__(self, warmup_frames: int = 0, record_frames: int = 0, then: Callable[["Profiler.ProfileResult"], Any] = None) -> None: super().__init__() self.root_node = None self.current_node = None self.frames = [] self.warmup_frames = warmup_frames self.record_frames = record_frames self.frame_counter = 0 self.enabled = True self.then_fn = then def enter_node(self, name): if not self.enabled: return if self.current_node is None: self.root_node = self.current_node = Profiler.Node(name) else: self.current_node = self.current_node.add_child(name) def leave_node(self): if not self.enabled: return self.current_node = self.current_node.close() if self.current_node is None: torch.cuda.synchronize() if self.frame_counter >= self.warmup_frames: self.frames.append(self.root_node.get_result_node()) self.frame_counter += 1 if self.frame_counter >= self.warmup_frames + self.record_frames: self.enabled = False self.then_fn(self.get_result()) def get_result(self) -> "Profiler.ProfileResult": if len(self.frames) == 0: return Profiler.ProfileResult() flat_frames = [frame.flatten() for frame in self.frames] grouped_frames = [ [ Profiler.ResultNodeGroup(node_iter) for _, node_iter in groupby(frame, lambda item: item.path) ] for frame in flat_frames ] profile_result = Profiler.ProfileResult( Profiler.ResultNodeGroupList(node_group.path, [node_group]) for node_group in grouped_frames[0] ) for i, frame in enumerate(grouped_frames): if i == 0: continue target_head = 0 for node_group in frame: matched_index = profile_result.get_index_by_path(node_group.path, target_head) if matched_index == -1: profile_result.insert(target_head, Profiler.ResultNodeGroupList(node_group.path, [None] * i + [node_group])) target_head += 1 else: for j in range(target_head, matched_index): profile_result[j].append(None) profile_result[matched_index].append(node_group) target_head = matched_index + 1 return profile_result default_profiler = None def enable_profile(warmup_frames: int = 0, record_frames: int = 0, then: Callable[["Profiler.ProfileResult"], Any] = None): global default_profiler default_profiler = Profiler(warmup_frames, record_frames, then) class _ProfileWrap(object): def __init__(self, fn: Callable = None, name: str = None) -> None: super().__init__() self.fn = fn self.name = name def __call__(self, *args: Any, **kwargs: Any) -> Any: if self.fn == None and len(args) == 1 and isinstance(args[0], Callable): self.fn = args[0] return lambda *args, **kwargs: self(*args, **kwargs) self.__enter__() ret = self.fn(*args, **kwargs) self.__exit__() return ret def __enter__(self): #print(f"Start node \"{self.name or self.fn.__qualname__}\"") start_profile_node(self.name or self.fn.__qualname__) return self def __exit__(self, *args: Any, **kwargs: Any): #print(f"End node \"{self.name or self.fn.__qualname__}\"") end_profile_node() class _DebugProfileWrap(object): def __init__(self, fn: Callable = None, name: str = None) -> None: super().__init__() self.fn = fn self.name = name def __call__(self, *args: Any, **kwargs: Any) -> Any: if self.fn == None and len(args) == 1 and isinstance(args[0], Callable): self.fn = args[0] return lambda *args, **kwargs: self(*args, **kwargs) self.__enter__() ret = self.fn(*args, **kwargs) self.__exit__() return ret def __enter__(self): #print(f"Start node \"{self.name or self.fn.__qualname__}\"") self.node = Profiler.Node(self.name) return self def __exit__(self, *args: Any, **kwargs: Any): #print(f"End node \"{self.name or self.fn.__qualname__}\"") self.node.close() torch.cuda.synchronize() print(f"Node {self.name}: host duration {self.node.host_duration:.1f}ms, " f"device duration {self.node.device_duration:.1f}ms") FnRet = TypeVar("FnRet") def profile(arg: str | Callable[..., FnRet]) -> _ProfileWrap | Callable[..., FnRet]: if isinstance(arg, str): return _ProfileWrap(name=arg) else: return lambda *args, **kwargs: _ProfileWrap(fn=arg)(*args, **kwargs) def debug_profile(arg: str | Callable): if isinstance(arg, str): return _DebugProfileWrap(name=arg) else: return lambda *args, **kwargs: _DebugProfileWrap(fn=arg)(*args, **kwargs) def start_profile_node(name): if default_profiler is not None: default_profiler.enter_node(name) def end_profile_node(): if default_profiler is not None: default_profiler.leave_node() def get_profile_result(): if default_profiler is not None: return default_profiler.get_result_report() return None