#!/usr/bin/env python3

# Copyright (c) 2020-2023 by the Zeek Project. See LICENSE for details.

import argparse
import string
import sys
import re

Template = """
class ${class};

${class_comment}class Concept : public hilti::util::type_erasure::ConceptBase {
public:
    $concept_ctors
    virtual ~Concept() = default;
    $concept_methods
    $concept_operators
    virtual $class _clone() const = 0;
    virtual hilti::IntrusivePtr<Concept> _clone_ptr() const = 0;
private:
    $concept_private_attrs
};

template<typename T>
using ModelBase = hilti::util::type_erasure::ModelBase<T, Concept${attrs_types}>;

template<class T>
class Model : public ModelBase<T> {
public:
    using ModelBase<T>::ModelBase;
    using ModelBase<T>::data;

    $model_methods
    $class _clone() const final;
    hilti::IntrusivePtr<Concept> _clone_ptr() const final;
};

using ErasedBase = hilti::util::type_erasure::ErasedBase<$class_requires, Concept, Model${attrs_types}>;

class $class : public ErasedBase$class_traits {
public:
    using ErasedBase::ErasedBase;

    $class() noexcept = default;
    $class($class&&) noexcept = default;
    $class& operator=($class&&) = default;
    ~$class() override = default;
    ${template_addl}

    $class_methods
    $class _clone() const;

    hilti::IntrusivePtr<Concept> _clone_ptr() const;

protected:
    $class_members
};

$model_methods_impl
template<typename T>
$class Model<T>::_clone() const { return T(data()); }

template<typename T>
hilti::IntrusivePtr<Concept> Model<T>::_clone_ptr() const {
    if constexpr ( std::is_base_of<hilti::util::type_erasure::trait::Singleton, T>::value )
        return nullptr;
    else
        return hilti::make_intrusive<Model>(T(data()));
}

$class_methods_impl
inline $class $class::_clone() const { return data() ? data()->_clone() : $class(); }

inline hilti::IntrusivePtr<Concept> $class::_clone_ptr() const {
    if ( data() )
        return data()->_clone_ptr();
    else
        return nullptr;
}


"""

TemplateAddlConst = """
    $class(const $class& other) = default;
    $class& operator=(const $class& other) = default;
"""

TemplateAddlNonConst = """
    $class(const $class& other) {
        if ( auto x = other._clone_ptr() )
            ErasedBase::operator=(x);
        else
            ErasedBase::operator=(other.data());
    }

    $class& operator=(const $class& other) {
        if ( this != &other ) {
            if ( auto x = other._clone_ptr() )
                ErasedBase::operator=(x);
            else
                ErasedBase::operator=(other.data());
        }

        return *this;
    }
"""


def parseVariable(s):
    class Variable:
        pass

    v = Variable()
    (v.type, v.name) = [i.strip() for i in s.rsplit(maxsplit=1)]
    return v


def varsToString(vars):
    return ", ".join("{} {}".format(v.type, v.name) for v in vars)


def forwardParam(p):
    if p.type.endswith("&"):
        return p.name
    else:
        return "std::move({})".format(p.name)


def parseMethod(s):
    class Method:
        pass

    m = Method()

    i = re.search("(.*) *with *default *({.*)$", s)

    if i:
        s = i.group(1)
        m.default = i.group(2)
    else:
        m.default = None

    x = re.split("[()]", s)
    (m.result, m.name) = [i.strip() for i in x[0].rsplit(sep=None, maxsplit=1)]
    m.params = [parseVariable(v)
                for v in re.split(" *, *", x[1])] if x[1] else []

    if len(x) < 3:
        m.attrs = ""
        m.traits = []
        m.default = ""
    else:
        y = x[2].rsplit("if", maxsplit=1)
        if len(y) >= 2:
            m.attrs = y[0].strip()
            z = y[1].strip().split("else", maxsplit=1)
            if len(z) >= 2:
                m.traits = z[0].strip().split(" or ")
                m.default = z[1].strip()
            else:
                m.traits = y[1].strip().split(" or ")
                m.default = ""

        else:
            m.attrs = x[2].strip()
            m.traits = []
            m.default = ""

    return m


def parseTrait(s):
    class Trait:
        pass

    x = line.split()
    t = Trait()
    t.method = x[1]
    assert x[2] == "from"
    t.trait = x[3]

    if t.method.endswith("()"):
        t.method = t.method[:-2]

    return t


def parseMember(s):
    class Member:
        pass

    m = Member()
    x = s.split()
    m.declaration = " ".join(x[1:])
    return m

# Main


def error(msg):
    print("error: {}".format(msg), file=sys.stderr)
    sys.exit(1)


parser = argparse.ArgumentParser()
parser.add_argument("api", metavar="file", action="store",
                    help="File with API specification")
parser.add_argument("--output", dest="output", action="store",
                    default=None, required=False, help="Output file (default is stdout)")
parser.add_argument("--constant", dest="constant", action="store_true",
                    required=False, help="Assume instances of class are non-mutable")

args = parser.parse_args()
methods = []
traits = []
attrs = []
class_ = None
trait_methods = []
class_members = []

in_comment = False
comment = ""

for line in open(args.api):
    line = line.strip()
    if not line or line.startswith("//"):
        continue

    if line.startswith("/*"):
        in_comment = True
        comment = ""

    if in_comment:
        comment += (line + "\n")

        if line.endswith("*/"):
            in_comment = False
            if comment:
                comment += "    "

        continue

    line = line.rstrip(";")

    if line.startswith("class"):
        m = line.split()
        if len(m) == 3:
            (__, class_, _) = m

        if len(m) == 5 and m[2] == ":":
            class_ = m[1]
            traits = [m[3]]

        assert "(" in class_
        m = re.split("[()]", class_)
        class_ = m[0]
        class_comment = comment[:-4]
        requires = m[1]
        comment = ""
        continue

    if line.startswith("trait"):
        t = parseTrait(line)
        trait_methods.append(t)
        comment = ""
        continue

    if line.startswith("member"):
        m = parseMember(line)
        class_members.append(m)
        comment = ""
        continue

    if line.startswith("}"):
        break

    if class_ and "(" in line:
        m = parseMethod(line)
        methods.append(m)
        m.comment = comment
        comment = ""
        continue

    if class_:
        m = parseVariable(line)
        attrs.append(m)
        comment = ""
        continue

    error("parse error")
    sys.exit(1)

concept_methods = ["{}virtual {} {}({}) {} = 0;".format(
    m.comment, m.result, m.name, varsToString(m.params), m.attrs) for m in methods]
concept_methods += ["virtual bool {}() const = 0;".format(t.method)
                    for t in trait_methods]

class_methods = ["{} {}({}) {};" .format(
    m.result, m.name, varsToString(m.params), m.attrs) for m in methods]
class_methods += ["bool {}() const;" .format(t.method) for t in trait_methods]

class_methods_impl = ["inline {} {}::{}({}) {} {{ return data()->{}({}); }}\n".format(m.result, class_, m.name, varsToString(
    m.params), m.attrs, m.name, ", ".join([forwardParam(p) for p in m.params])) for m in methods]
class_methods_impl += ["inline bool {}::{}() const {{ return data()->{}(); }}\n" .format(
    class_, t.method, t.method) for t in trait_methods]
class_members = ["{};\n".format(m.declaration) for m in class_members]

model_methods = ["{} {}({}) {} final;" .format(
    m.result, m.name, varsToString(m.params), m.attrs) for m in methods]
model_methods += ["bool {}() const final;".format(t.method)
                  for t in trait_methods]

def mmi(m):
    result = "return data().{}({});".format(
        m.name, ", ".join([a.name for a in m.params]))
    default = "return {};".format(
        m.default) if m.default else "throw std::bad_cast();"
    if m.traits:
        cexps = ["""constexpr ( std::is_base_of<{}, T>::value )
        {}
""".format(t, result) for t in m.traits]
        cexps = "    else if ".join(cexps)
        return """
    if {}    else
        {}
""".format(cexps, default)
    else:
        return result


def tmi(t):
    return "return std::is_base_of<{}, T>::value;".format(t.trait)


model_methods_impl = ["template<typename T>\n{} Model<T>::{}({}) {} {{ {} }}\n" .format(
    m.result, m.name, varsToString(m.params), m.attrs, mmi(m)) for m in methods]
model_methods_impl += ["template<typename T>\nbool Model<T>::{}() const {{ {} }}\n" .format(t.method, tmi(t))
                       for t in trait_methods]

ctor_attrs = ["_{}(std::move({}))".format(v.name, v.name) for v in attrs]
concept_private_attrs = ["{} _{};".format(v.type, v.name) for v in attrs]

if attrs:
    concept_ctors = ["Concept({}){} {{}}".format(
        varsToString(attrs), " : " + ", ".join(ctor_attrs))]
else:
    concept_ctors = ["Concept() = default;"]

concept_ctors += ["Concept(const Concept&) = default;"]
concept_ctors += ["Concept(Concept&&) = default;"]
concept_operators = ["Concept& operator=(const Concept&) = default;"]
concept_operators += ["Concept& operator=(Concept&&) = default;"]
concept_methods += ["const {}& {}() {{ return _{}; }}".format(v.type,
                                                              v.name, v.name) for v in attrs]
class_methods += ["const {}& {}() {{ return data()->{}(); }}".format(v.type,
                                                                     v.name, v.name) for v in attrs]
class_traits = [", public {}".format(t) for t in traits]
class_requires = requires
attrs_types = "".join([(", " + v.type) for v in attrs])

vals = {
    "class": class_,
    "class_comment": class_comment,
    "concept_methods": "\n    ".join(concept_methods),
    "model_methods": "\n    ".join(model_methods),
    "model_methods_impl": "\n".join(model_methods_impl),
    "class_methods": "\n    ".join(class_methods),
    "class_methods_impl": "\n".join(class_methods_impl),
    "class_requires": class_requires,
    "class_traits": "".join(class_traits),
    "class_members": "\n    ".join(class_members),
    "concept_ctors": "\n    ".join(concept_ctors),
    "concept_operators": "\n    ".join(concept_operators),
    "concept_private_attrs": "\n    ".join(concept_private_attrs),
    "attrs_types": attrs_types,
    "template_addl": (TemplateAddlConst.strip() if args.constant else TemplateAddlNonConst.strip())
}

out = open(args.output, "w") if args.output else sys.stdout

print(string.Template(Template).substitute(
    vals).replace("$class", class_), file=out)
