#!/usr/bin/env python3

# Copyright (c) 2020-2023 by the Zeek Project. See LICENSE for details.

import argparse
import sys

# Main


def error(msg):
    print("error: {}".format(msg), file=sys.stderr)
    sys.exit(1)


parser = argparse.ArgumentParser()
parser.add_argument("nodes", nargs="*", metavar="nodes",
                    action="store", help="nodes.decl file")
parser.add_argument("--output", dest="output", action="store",
                    default=None, required=True, help="Output file")
parser.add_argument("--header", dest="headers", action="append", default=[],
                    required=False, help="Header to include in generated file")

args = parser.parse_args()
out = open(args.output, "w")

print('#pragma once\n', file=out)

for header in args.headers:
    print("#include <{}>\n".format(header), file=out)

print("#define VISITOR_DISPATCHERS \\", file=out)

traits = {}
nodes = []

for f in args.nodes:
    for line in open(f):
        line = line.strip()
        if not line or line.startswith("//"):
            continue

        if line.startswith("trait"):
            m = line.split()
            traits[m[2]] = m[1]
            continue

        m = line.split(" : ")
        cls = m[0]
        trait = m[1] if len(m) > 1 else "isNode"

        nodes.append((cls, trait))

for (cls, trait) in nodes:
    if trait not in traits:
        print("No 'trait' definition for {}".format(trait), file=sys.stderr)
        sys.exit(1)

for (btrait, bcls) in traits.items():
    print("  if constexpr ( std::is_base_of<{}, Erased>::value ) {{ \\".format(
        bcls), file=out)

    for (cls, trait) in nodes:
        if trait != btrait:
            continue

        def output(msg):
            print(msg.format(cls=cls) + " \\", file=out)

        output(
            "    if ( auto r = do_dispatch_one<Result, {cls}, Erased, Dispatcher, Iterator>(n, tn, d, i, no_match_so_far) ) return r;")

    print("  } \\", file=out)
    output("")

print("", file=out)
