#!/usr/bin/python3
from __future__ import annotations
from typing import Dict, List, Tuple
import argparse
import pstats
import marshal
import tempfile
import subprocess
import pprint
import sys
import os
# from pstats import SortKey
import cmd


# File name, Line number, Function or scope name
Func = Tuple[str, int, str]

# Collected profiler statistics
Stats = Dict[
    # Function of scope
    Func,
    Tuple[
        # Aggregated statistics
        int, int, float, float,
        # Callers
        Dict[
            # Caller function or scope
            Func,
            # Call statistics collected for this caller → callee
            Tuple[int, int, float, float]
        ]
    ]
]


class Node:
    """
    Node in the caller → callee graph
    """
    def __init__(self, func: Func):
        self.fname, self.lineno, self.scope = func
        self.callees: List["Call"] = []
        self.callers: List["Call"] = []
        self.cc = 0
        self.nc = 0
        self.tt = 0.0
        self.ct = 0.0

    def __str__(self):
        # Builtin function
        if self.fname == "~" and self.lineno == 0:
            return f"[builtin]:{self.scope}"

        # Shorten file names from system libraries
        self.shortname = self.fname
        for path in sorted(sys.path, key=lambda x: -len(x)):
            if not path:
                continue
            if self.fname.startswith(path):
                return f"[sys]{self.fname[len(path):]}:{self.lineno}|{self.scope}"

        # File in the local project
        return f"{self.fname}:{self.lineno}|{self.scope}"


class Call:
    """
    Arc in the caller → callee graph
    """
    def __init__(self, caller: Node, callee: Node, cc: int, nc: int, tt: float, ct: float):
        self.caller = caller
        self.callee = callee
        self.cc = cc
        self.nc = nc
        self.tt = tt
        self.ct = ct


class Graph:
    """
    Graph of callers and callees.

    Each node in the graph represents a function, with its aggregated
    statistics.

    Each arc in the graph represents each collected caller → callee statistics
    """
    def __init__(self, stats: Stats):
        # Index of all nodes in the graph
        self.nodes: Dict[Func, Node] = {}
        # Total execution time
        self.total_time = 0.0

        # Build the graph
        for callee, (cc, nc, tt, ct, callers) in stats.items():
            self.total_time += tt

            # Get the callee and fill its aggregated stats
            ncallee = self.node(callee)
            ncallee.cc = cc
            ncallee.nc = nc
            ncallee.tt = tt
            ncallee.ct = ct

            # Create caller → callee arcs
            for caller, (cc, nc, tt, ct) in callers.items():
                ncaller = self.node(caller)
                call = Call(ncaller, ncallee, cc, nc, tt, ct)
                ncallee.callers.append(call)
                ncaller.callees.append(call)

    def node(self, fun: Func) -> Node:
        """
        Lookup or create a node
        """
        res = self.nodes.get(fun)
        if res is None:
            res = Node(fun)
            self.nodes[fun] = res
        return res

    @classmethod
    def load(cls, pathname: str) -> "Graph":
        """
        Builds a Graph from profile statistics saved on a file
        """
        with open(pathname, "rb") as fd:
            return cls(marshal.load(fd))

    @classmethod
    def from_pstats(cls, stats: pstats.Stats) -> "Graph":
        """
        Builds a Graph from an existing pstats.Stats structure
        """
        return cls(stats.stats)


class Menu:
    # Markers
    CALLER = "←"
    CALLEE = "→"
    NODE = "•"

    def __init__(self):
        self.entries = []

    def add(self, marker: str, node: Node, cc: int, nc: int, tt: float, ct: float, perc: float, indent=0):
        if indent == 0:
            lead = ""
        else:
            lead = "┆ " * (indent - 1) + "├─"
        self.entries.append(
                (node, len(self.entries), lead + (marker + " " if marker else "") + str(node), cc, nc, tt, ct, perc))

    def print(self):
        flen = max(len(x[2]) for x in self.entries)
        print(f"     {'Function / Scope'.ljust(flen)} {'cc'.rjust(6)} {'nc'.rjust(6)}  tt      ct       %")
        for node, idx, desc, cc, nc, tt, ct, perc in self.entries:
            print(f"{idx:3d}: {desc.ljust(flen)} {cc:6d} {nc:6d} {tt:7.3f} {ct:7.3f} {perc:5.1f}")

    def get_node(self, pos):
        return self.entries[int(pos)][0]


class Browser(cmd.Cmd):
    def __init__(self, stats: pstats.Stats):
        super().__init__()
        self.stats = stats
        self.graph = Graph.from_pstats(stats)
        # Last menu shown
        self.menu = None

    def print_caller_table(self, calls):
        for c in calls:
            self.menu.add(Menu.CALLER, c.caller, c.cc, c.nc, c.tt, c.ct, c.ct * 100.0 / self.graph.total_time)

    def print_callee_table(self, calls):
        for c in calls:
            self.menu.add(Menu.CALLEE, c.callee, c.cc, c.nc, c.tt, c.ct, c.ct * 100.0 / self.graph.total_time)

    def print_node(self, node):
        self.menu.add(Menu.NODE, node, node.cc, node.nc, node.tt, node.ct, node.ct * 100.0 / self.graph.total_time)

    def do_calls(self, arg):
        """
        Print callees and callers for the given node
        """
        node = self.menu.get_node(arg)
        self.menu = Menu()
        self.print_caller_table(sorted(node.callers, key=lambda x: -x.ct)[:10])
        self.print_node(node)
        self.print_callee_table(sorted(node.callees, key=lambda x: -x.ct)[:10])
        self.menu.print()

    def do_index(self, arg):
        """
        Show an index of relevant functions
        """
        self.menu = Menu()
        nodes = []
        for node in self.graph.nodes.values():
            if os.path.basename(node.fname) == "ssite":
                if node.scope == "main":
                    nodes.append(node)
            if os.path.basename(node.fname) == "site.py":
                if node.scope in ("load", "scan_content", "load_content", "load_theme", "analyze"):
                    nodes.append(node)
            if os.path.dirname(node.fname).endswith("staticsite/cmd"):
                if node.scope == "run":
                    nodes.append(node)
        nodes.sort(key=lambda n: -n.ct)
        for node in nodes:
            self.print_node(node)
        self.menu.print()

    def do_tree(self, arg):
        """
        Show a callee tree with the top-called functions
        """
        node = self.menu.get_node(arg)
        self.menu = Menu()

        self.menu.add(Menu.NODE, node, node.cc, node.nc, node.tt, node.ct, node.ct * 100.0 / self.graph.total_time)

        def add_callees(n, depth=0):
            # Max tree depth
            if depth > 3:
                return

            # Add all callees to the menu
            for c in sorted(n.callees, key=lambda x: -x.ct):
                if c.ct < node.ct * 0.01:
                    break
                self.menu.add("┬", c.callee, c.cc, c.nc, c.tt, c.ct, c.ct * 100.0 / node.ct, indent=depth)
                add_callees(c.callee, depth=depth + 1)

        add_callees(node, depth=1)
        self.menu.print()

    def do_pprint(self, arg):
        """
        Pretty print the raw collected statistics
        """
        with tempfile.NamedTemporaryFile("wt") as fd:
            pprint.pprint(self.stats.stats, stream=fd)
            fd.flush()
            subprocess.run(["less", fd.name])

    def do_EOF(self, arg):
        """
        Quit
        """
        print()
        return True


def main():
    parser = argparse.ArgumentParser(description="Static site generator.")
    parser.add_argument("profile", nargs="?", default="profile.out",
                        help="cProfile output to read")
    args = parser.parse_args()

    stats = pstats.Stats(args.profile)
    # stats.sort_stats(SortKey.CUMULATIVE).print_stats(20)
    # stats.sort_stats(SortKey.TIME).print_stats(20)
    # stats.sort_stats(SortKey.CUMULATIVE).print_callees(r"site.py.+\(load_content\)")

    cmd = Browser(stats)
    cmd.do_index("")
    cmd.cmdloop()


if __name__ == "__main__":
    main()
