#!/usr/bin/env python3

# Copyright (c) 2020-2023 by the Zeek Project. See LICENSE for details.
#
# Turn the output spicy-doc into reST.

import argparse
import copy
import filecmp
import json
import os.path
import os
import re
import sys
import textwrap

from typing import Dict, List, Pattern, Set


def fatalError(message: str):
    print(message, file=sys.stderr)
    sys.exit(1)


def call(name):
    def _(op):
        return "{}({})".format(
            name,
            op.operands[0].rst(in_operator=True, markup=False))
    return _


def keyword(name):
    def _(op):
        return "{} <sp> {}".format(name, op.operands[0].rst(in_operator=True))
    return _


def unary(prefix, postfix=""):
    def _(op):
        return "op:{} {} op:{}".format(
            prefix,
            op.operands[0].rst(in_operator=True),
            postfix)
    return _


def binary(token):
    def _(op):
        op1 = op.operands[0].rst(in_operator=True)
        op2 = op.operands[1].rst(in_operator=True)

        commutative = ""
        if op.commutative and op1 != op2:
            commutative = " $commutative$"

        return "{} <sp> op:{} <sp> {}{}".format(op1, token, op2, commutative)

    return _


Operators = {
    "Add": lambda op: "add <sp> {}[{}]".format(
        op.operands[0].rst(in_operator=True),
        op.operands[1].rst(in_operator=True, markup=False)
    ),
    "Begin": call("begin"),
    "BitAnd": binary("&"),
    "BitOr": binary("|"),
    "BitXor": binary("^"),
    "End": call("end"),
    "Call": lambda op: "{}({})".format(
        op.operands[0].rst(in_operator=True,
                           markup=False),
        ", ".join(arg.rst(in_operator=True, markup=False)
                  for arg in op.operands[1:])),
    "Cast": lambda op: "cast<{}>({})".format(
        TypedType.sub("\\1", op.operands[1].rst(
            in_operator=True, markup=False)),
        op.operands[0].rst(in_operator=True, markup=False)),
    "CustomAssign": lambda op: "{} = {}".format(
        op.operands[0].rst(in_operator=True),
        op.operands[1].rst(in_operator=True)
    ),

    "Delete": lambda op: "delete <sp> {}[{}]".format(
        op.operands[0].rst(in_operator=True),
        op.operands[1].rst(in_operator=True, markup=False)
    ),
    "Deref": unary("*"),
    "DecrPostfix": unary("", "--"),
    "DecrPrefix": unary("++"),
    "Difference": binary("-"),
    "DifferenceAssign": binary("-="),
    "Division": binary("/"),
    "DivisionAssign": binary("/="),
    "Equal": binary("=="),
    "Greater": binary(">"),
    "GreaterEqual": binary(">="),
    "In": binary("in"),
    # Operator generated here; named so it is sorted after `In`.
    "InInv": binary("!in"),
    "HasMember": binary("?."),
    "TryMember": binary(".?"),
    "Member": binary("."),
    "Index": lambda op: "{}[{}]".format(
        op.operands[0].rst(in_operator=True),
        op.operands[1].rst(in_operator=True, markup=False)),
    "IndexAssign": lambda op: "{}[{}] = {}".format(
        op.operands[0].rst(in_operator=True),
        op.operands[1].rst(in_operator=True, markup=False),
        op.operands[2].rst(in_operator=True, markup=False)),
    "IncrPostfix": unary("", "++"),
    "IncrPrefix": unary("++"),
    "LogicalAnd": binary("&&"),
    "LogicalOr": binary("||"),
    "Lower": binary("<"),
    "LowerEqual": binary("<="),
    "Modulo": binary("%"),
    "Multiple": binary("*"),
    "MultipleAssign": binary("*="),
    "Negate": unary("~"),
    "New": keyword("new"),
    "Pack": keyword("pack"),
    "Power": binary("**"),
    "Unpack": keyword("unpack"),
    "Unset": lambda op: "unset <sp> {}.{}".format(
        op.operands[0].rst(in_operator=True),
        op.operands[1].rst(in_operator=True, markup=False)
    ),
    "SignNeg": unary("-"),
    "Size": unary("|", "|"),
    "ShiftLeft": binary("<<"),
    "ShiftRight": binary(">>"),
    "Sum": binary("+"),
    "SumAssign": binary("+="),
    "Unequal": binary("!="),
}

NamespaceMappings = {
    "signed_integer": "integer",
    "unsigned_integer": "integer",
    "struct_": "struct"
}

TypeMappings = {
    "::hilti::rt::regexp::MatchState": "spicy::MatchState",
    "::hilti::rt::bytes::Side": "spicy::Side",
}

LibraryType = re.compile(r'__library_type\("(.*)"\)')
TypedType: Pattern = re.compile(r'type<(.*)>')


def namespace(ns):
    return NamespaceMappings.get(ns, ns)


def rstHeading(title, level):
    return "{}\n{}\n".format(title, "==-~"[level] * len(title))


def fmtDoc(doc):
    n = []
    doc = doc.split("\n\n")
    for i in doc:
        x = textwrap.dedent(i).strip()
        wrapped = textwrap.indent(textwrap.fill(x), prefix="    ")
        if wrapped:
            n += [wrapped]

    return "\n\n".join(n)


def fmtType(ty):
    ty = LibraryType.sub("\\1", ty)
    ty = TypeMappings.get(ty, ty)

    if not ty:
        ty = "<no-type>"

    if ty == "any":
        return "<any>"

    ty = ty.replace("<*>", "")
    ty = re.sub("\\s+{\\s+}", "", ty)  # e.g., "enum { }" -> "enum"
    # e.g., type<enum> -> "enum-type"
    ty = re.sub("type<([^>0-9]+)>", "\\1-type", ty)
    return ty.replace(" ", "~")


class Operand:
    def __init__(self, m):
        self.const = m.get("const")
        self.mutable = m.get("mutable")
        self.default = m.get("default")
        self.id = m.get("id")
        self.optional = m.get("optional")
        self.doc = m.get("doc")
        self.type = m.get("type")

    def rst(self, in_operator=False, prefix="", markup=True):
        if self.doc:
            type = fmtType(self.doc)
        else:
            type = fmtType(self.type)

        if not in_operator:
            default = " = {}".format(self.default) if self.default else ""
            x = "{id}: {type}{default}".format(
                id=self.id, type=type, default=default).strip()
        else:
            if markup:
                x = "t:{type}".format(type=type)
            else:
                x = "{type}".format(type=type)

        x = "{}{}".format(prefix, x)
        return "[ {} ]".format(x) if self.optional else x


class Operator:
    def __init__(self, m):
        self.doc = m.get("doc")
        self.kind = m.get("kind")
        self.namespace = namespace(m.get("namespace"))
        self.operands = [Operand(i) for i in m.get("operands")]
        self.operator = m.get("operator")
        self.rtype = m.get("rtype")
        self.commutative = m.get("commutative")

    def rst(self):
        try:
            sig = Operators[self.kind](self)
        except KeyError:
            print(
                "error: "
                "operator {} not supported by spicy-doc-to-rst yet".format(
                    self.kind), file=sys.stderr)
            sys.exit(1)

        result = fmtType(self.rtype)
        return ".. spicy:operator:: "\
            "{ns}::{kind} {result} {sig}\n\n{doc}".format(
                ns=self.namespace,
                kind=self.kind,
                result=result,
                sig=sig,
                doc=fmtDoc(self.doc))

    def __lt__(self, other):
        # Sort by string representation to make sure
        # the rendered output has a fixed order.
        return self.rst() < other.rst()


class Method:
    def __init__(self, m):
        self.args = [Operand(i) for i in m.get("args")]
        self.doc = m.get("doc")
        self.id = m.get("id")
        self.kind = m.get("kind")
        self.namespace = namespace(m.get("namespace"))
        self.rtype = m.get("rtype")
        self.self = Operand(m.get("self"))

    def rst(self):
        def arg(a):
            if a.const or not a.mutable:
                qual = ""
            else:
                qual = "inout "

            return a.rst(prefix=qual)

        args = ", ".join([arg(a) for a in self.args])
        const = self.self.const == "const"
        self_ = fmtType(self.self.type)
        result = fmtType(self.rtype)
        sig = ".. spicy:method:: "\
            "{ns}::{id} {self} {id} {const} {result} ({args})\n\n{doc}".format(
                ns=self.namespace,
                result=result,
                self=self_,
                const=const,
                id=self.id,
                args=args,
                doc=fmtDoc(self.doc))
        return sig

    def __lt__(self, other):
        # Sort by string representation to make sure
        # the rendered output has a fixed order.
        return self.rst() < other.rst()

# Main


parser = argparse.ArgumentParser(
    description="Converts the output of spicy-doc on stdin into reST")
parser.add_argument("-d", action="store", dest="dir", metavar="DIR",
                    help="create output for all types in given directory")
parser.add_argument("-t", action="store", dest="types", metavar="TYPES",
                    help="create output for specified, comma-separated types;"
                    "without -d, output goes to stdout")
args = parser.parse_args()

if not args.dir and not args.types:
    print("need -t <type> or -d <dir>", file=sys.stderr)
    sys.exit(1)

try:
    meta = json.load(sys.stdin)
except ValueError as e:
    fatalError("cannot parse input: {}".format(e))

operators: Dict[str, List[Operator]] = {}
methods: Dict[str, List[Method]] = {}

for op in meta:
    if op["kind"] == "MemberCall":
        m1 = Method(op)
        x1 = methods.setdefault(m1.namespace, [])
        x1 += [m1]
    else:
        m2 = Operator(op)
        x2 = operators.setdefault(m2.namespace, [])
        x2 += [m2]

        # If we have a `in` operator automatically generate docs for `!in`.
        if m2.kind == 'In':
            m3 = copy.copy(m2)
            m3.kind = 'InInv'
            m3.doc = "Performs the inverse of the corresponding ``in`` operation."
            x2 = operators.setdefault(m3.namespace, [])
            x2 += [m3]

keys = set()

if args.dir:
    keys = set(operators.keys()) | set(methods.keys())
    try:
        os.makedirs(args.dir)
    except OSError:
        pass

if args.types:
    for k in args.types.split(","):
        for i in operators.keys() | methods.keys():
            if i.startswith(k):
                keys.add(i)

for ns in sorted(keys):
    fname = None
    fname_tmp = None
    if args.dir:
        fname = namespace(ns)

        if fname.endswith("_"):
            fname = fname[:-1]

        fname = fname.lower().replace("::", "-").replace("_", "-") + ".rst"
        fname = os.path.join(args.dir, fname)
        fname_tmp = fname + ".tmp"
        out = open(fname_tmp, "w+")
    else:
        out = sys.stdout

    prefix = ""

    if "::view" in ns:
        prefix = "View "

    if "::iterator" in ns:
        prefix = "Iterator "

    # In the following we remove duplicate entries where multiple items end up
    # turning into the same reST rendering. This happens when the difference is
    # only in typing issues that we don't track in the documentation. An
    # example is the vector's index operators for constant and non-constant
    # instances, respectively. Other duplications are coming from joining
    # namspaces for integers.
    already_recorded: Set[str] = set()

    def print_unique(s):
        if s not in already_recorded:
            print(s + "\n", file=out)
            already_recorded.add(s)

    x1 = sorted(methods.get(ns, []))
    if x1:
        print(".. rubric:: %sMethods\n" % prefix, file=out)

        for method in sorted(x1):
            print_unique(method.rst())

    x2 = sorted(operators.get(ns, []))

    if x2:
        print(".. rubric:: %sOperators\n" % prefix, file=out)

        for operator in sorted(x2):
            print_unique(operator.rst())

    if fname:
        assert fname_tmp
        out.close()

        if not os.path.exists(fname) or not filecmp.cmp(fname, fname_tmp):
            os.rename(fname_tmp, fname)
        else:
            os.unlink(fname_tmp)
