#!/usr/bin/env python2.7
#
# Copyright (C) 2006-2015 BalaBit IT Security, 2015-2017 BalaSys IT Security.
# This program/include file is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program/include file is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation,Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

import sys
sys.dont_write_bytecode = True

import errno
import os
import optparse
import types
import socket
import kzorp.messages
import kzorp.communication
import kzorp.netlink

import time

EvaluateArgumentDescription = "<protocol> <client address> <server address> <interface name>"

def inet_ntoa(a):
    return "%s.%s.%s.%s" % ((a >> 24) & 0xff, (a >> 16) & 0xff, (a >> 8) & 0xff, a & 0xff)

def size_to_mask(family, size):
    if family == socket.AF_INET:
        max_size = 32
    elif family == socket.AF_INET6:
        max_size = 128
    else:
        raise ValueError, "address family not supported; family='%d'" % family

    if size > max_size:
        raise ValueError, "network size is greater than the maximal size; size='%d', max_size='%d'" % (size, max_size)

    packed_mask = ''
    actual_size = 0
    while actual_size + 8 < size:
        packed_mask += '\xff'
        actual_size = actual_size + 8

    if actual_size <= size:
        packed_mask += chr((0xff << (8 - (size - actual_size))) & 0xff)
        actual_size = actual_size + 8

    while actual_size < max_size:
        packed_mask += '\x00'
        actual_size = actual_size + 8

    return socket.inet_ntop(family, packed_mask)

class DumpBase():
    def __init__(self, quiet, msg):
        self.quiet = quiet
        self.msg = msg
        self.handle = kzorp.communication.Handle()

    def _get_replies(self):
        return self.handle.dump(self.msg)

    def dump(self):
        try:
            for reply in self._get_replies():
                if self.quiet:
                    pass
                else:
                    print str(reply)
        except kzorp.netlink.NetlinkException as e:
            res = int(e.detail)
            sys.stderr.write("Dump failed: result='%d' error='%s'\n" % (res, os.strerror(-res)))
            return 1

        return 0

class DumpVersion(DumpBase):
    def __init__(self, quiet):
        DumpBase.__init__(self, quiet, kzorp.messages.KZorpGetVersionMessage())

    def _get_replies(self):
        return [self.handle.exchange(self.msg), ]

class DumpZones(DumpBase):
    def __init__(self, quiet):
        DumpBase.__init__(self, quiet, kzorp.messages.KZorpGetZoneMessage())

class DumpServices(DumpBase):
    def __init__(self, quiet):
        DumpBase.__init__(self, quiet, kzorp.messages.KZorpGetServiceMessage())

class DumpDispatchers(DumpBase):
    def __init__(self, quiet):
        DumpBase.__init__(self, True, kzorp.messages.KZorpGetDispatcherMessage())
        self.__real_quiet = quiet

    @staticmethod
    def __rule_entries_to_str(rule_entries):
        dimensions = [
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_REQID    ,   'reqid'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_IFACE    ,   'src_iface'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_IFGROUP  ,   'src_ifgroup'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_PROTO    ,   'proto'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_PROTO_TYPE,  'proto_type'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_PROTO_SUBTYPE, 'proto_subtype'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_SRC_PORT ,   'src_port'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_PORT ,   'dst_port'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_SRC_IP   ,   'src_subnet'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_SRC_IP6  ,   'src_subnet'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_SRC_ZONE ,   'src_zone'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_IP   ,   'dst_subnet'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_IP6  ,   'dst_subnet'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_IFACE,   'dst_iface'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_IFGROUP, 'dst_ifgroup'),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_ZONE ,   'dst_zone')
        ]

        entries = dict(rule_entries)

        for ip in [
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_SRC_IP, kzorp.messages.KZNL_ATTR_N_DIMENSION_SRC_IP6),
            (kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_IP, kzorp.messages.KZNL_ATTR_N_DIMENSION_DST_IP6)
        ]:
            if ip[0] in rule_entries and ip[1] in rule_entries:
                entries[ip[0]] += rule_entries[ip[1]]
                del entries[ip[1]]

        res = ""
        for dim_type in dimensions:
            if dim_type[0] in entries:
                res += "           %s=%s\n" % (dim_type[1], entries[dim_type[0]])
        return res

    def dump(self):
        DumpBase.dump(self)
        if not self.__real_quiet:
            rule_entries = {}
            for reply in self._get_replies():
                if reply.command == kzorp.messages.KZorpAddRuleEntryMessage.command:
                    reply.aggregate_rule_entries(rule_entries)
                elif reply.command == kzorp.messages.KZorpAddRuleMessage.command:
                    print DumpDispatchers.__rule_entries_to_str(rule_entries)
                    rule_entries = {}
                    print reply
                else:
                    print reply

            res = DumpDispatchers.__rule_entries_to_str(rule_entries)
            if res:
                print res

class DumpBinds(DumpBase):
    def __init__(self, quiet):
        DumpBase.__init__(self, quiet, kzorp.messages.KZorpGetBindMessage())

def upload_zones(fname):

    def parse_range(r):
        if r.count("/") == 0:
            # simple IP address
            addr = r
            mask = None
        else:
            # IP subnet
            (addr, mask) = r.split('/', 1)

        family = socket.AF_INET
        try:
            socket.inet_pton(family, addr)
        except socket.error:
            family = socket.AF_INET6
            socket.inet_pton(family, addr)

        if mask == None:
            if family == socket.AF_INET:
                mask = 32
            elif family == socket.AF_INET6:
                mask = 128

        mask = size_to_mask(family, int(mask))
        return (family, socket.inet_pton(family, addr), socket.inet_pton(family, mask))

    def process_line(handle, line):
        # skip comments
        if line.startswith("#"):
            return

        zone, parent, s = line.split(";")

        zone = zone.strip('"')
        parent = parent.strip('"')
        if parent == "":
            parent = None

        subnets = s.split(",")
        if len(subnets) <= 1:
            if subnets == [""]:
                subnets = []
            else:
                subnets = [s]

        # we send the "parent" first
        handle.send(kzorp.messages.KZorpAddZoneMessage(name=zone, pname=parent, subnet_num=len(subnets)))
        print zone
        for subnet in subnets:
            family, addr, mask = parse_range(subnet)
            handle.send(kzorp.messages.KZorpAddZoneSubnetMessage(zone, family=family, address=addr, mask=mask))

    handle = kzorp.communication.Handle()
    handle.send(kzorp.messages.KZorpStartTransactionMessage(kzorp.messages.KZ_INSTANCE_GLOBAL))
    handle.send(kzorp.messages.KZorpFlushZonesMessage())

    # process each zone
    f = file(fname)
    while 1:
        line = f.readline()
        if not line: break

        line = line.strip()

        try:
            process_line(handle, line)
        except ValueError, e:
            sys.stderr.write("Error while processing the following line: %s\n%s\n" % (e, line))
            return 1

    handle.send(kzorp.messages.KZorpCommitTransactionMessage())

    return 0

def flush(flushable_option):
    if flushable_option == 'all':
        flushables = ['binds', 'dispatchers', 'services', 'zones']
    else:
        flushables = [flushable_option, ]

    handle = kzorp.communication.Handle()

    if 'zones' in flushables:
        handle.send(kzorp.messages.KZorpStartTransactionMessage(kzorp.messages.KZ_INSTANCE_GLOBAL))
        handle.send(kzorp.messages.KZorpFlushZonesMessage())
        handle.send(kzorp.messages.KZorpCommitTransactionMessage())
        flushables.remove('zones')

    if len(flushables) > 0:
        for instance in [kzorp.messages.KZ_INSTANCE_GLOBAL, ]:
            handle.send(kzorp.messages.KZorpStartTransactionMessage(instance))
            if 'binds' in flushables:
                handle.send(kzorp.messages.KZorpFlushBindsMessage())
            if 'dispatchers' in flushables:
                handle.send(kzorp.messages.KZorpFlushDispatchersMessage())
            if 'services' in flushables:
                handle.send(kzorp.messages.KZorpFlushServicesMessage())
            handle.send(kzorp.messages.KZorpCommitTransactionMessage())

def evaluate_option_parser_cb(option, opt_str, value, parser):
    assert value is None
    value = []
    for arg in parser.rargs:
        # stop on --foo like options
        if arg[:2] == "--" and len(arg) > 2:
            break

        # stop on -a
        if arg[:1] == "-" and len(arg) > 1:
            break
        value.append(arg)

    if len(value) != 4:
        raise optparse.OptionValueError("-e option requires 4 arguments: " + EvaluateArgumentDescription)

    del parser.rargs[:len(value)]
    setattr(parser.values, option.dest, value)

def parse_ip_option(parser, ip, description):
    try:
        return (socket.AF_INET, socket.inet_pton(socket.AF_INET, ip))
    except socket.error:
        try:
            return (socket.AF_INET6, socket.inet_pton(socket.AF_INET6, ip))
        except socket.error:
            parser.error("invalid %s ip: %s" % (description, ip))


def evaluate(parser, options, quiet):
    EVAL_ARG_NUM_PROTO    = 0
    EVAL_ARG_NUM_SRC_IP   = 1
    EVAL_ARG_NUM_DST_IP   = 2
    EVAL_ARG_NUM_IFACE    = 3
    EVAL_ARG_NUM          = 4

    def parse_num(parser, num_str, max_num_bits, description):
        try:
            num = int(num_str)
            if 0 <= num < 2 ** max_num_bits:
                return num
            else:
                raise ValueError, ""
        except ValueError:
            return None
        except TypeError:
            return None

    def parse_reqid(parser, reqid, description):
        return parse_num(parser, reqid, 32, description)

    def parse_proto_type(parser, proto_type, description):
        return parse_num(parser, proto_type, 8, description)

    def parse_proto_subtype(parser, proto_subtype, description):
        return parse_num(parser, proto_subtype, 8, description)

    def parse_port(parser, port, description):
        port = parse_num(parser, port, 16, description)
        if port == 0:
            return None
        return port

    def parse_proto(parser, proto):
        try:
            return socket.getprotobyname(proto)
        except socket.error:
            try:
                return int(proto)
            except ValueError:
                parser.error('protocol must be protocol name or number')

    if len(options.evaluate) < EVAL_ARG_NUM:
        parser.error('evaluate requires %d parameters', EVAL_ARG_NUM)

    kws = {}
    kws['proto'] = parse_proto(parser, options.evaluate[EVAL_ARG_NUM_PROTO])

    saddr_str = options.evaluate[EVAL_ARG_NUM_SRC_IP]
    (sfamily, saddr) = parse_ip_option(parser, saddr_str, "client")
    kws['family'] = sfamily
    kws['saddr'] = saddr

    daddr_str = options.evaluate[EVAL_ARG_NUM_DST_IP]
    (dfamily, daddr) = parse_ip_option(parser, daddr_str, "server")
    kws['daddr'] = daddr

    if options.sport is not None:
        kws['sport'] = parse_port(parser, options.sport, "client port")
    if options.dport is not None:
        kws['dport'] = parse_port(parser, options.dport, "server port")

    if options.proto_type is not None:
        kws['proto_type'] = parse_proto_type(parser, options.proto_type, "ICMP type")
    if options.proto_subtype is not None:
        kws['proto_subtype'] = parse_proto_subtype(parser, options.proto_subtype, "ICMP code")

    kws['iface'] = options.evaluate[EVAL_ARG_NUM_IFACE]
    if len(options.evaluate[EVAL_ARG_NUM_IFACE]) > 15:
        parser.error('invalid interface name (>15 characters)')

    if options.reqid is not None:
        kws['reqid'] = parse_reqid(parser, options.reqid, "request id")

    if sfamily != dfamily:
        parser.error("family of source and destination address is not the same")

    try:
        query_msg = kzorp.messages.KZorpQueryMessage(**kws)
    except kzorp.netlink.NetlinkAttributeException as e:
        raise socket.error(e.detail)

    kws['saddr_str'] = saddr_str
    kws['daddr_str'] = daddr_str

    if not quiet:
        if kws['proto'] == socket.IPPROTO_TCP or kws['proto'] == socket.IPPROTO_UDP:
            eval_str = "evaluating %(saddr_str)s:%(sport)s -> %(daddr_str)s:%(dport)s on %(iface)s" % kws
        else:
            eval_str = "evaluating %(saddr_str)s -> %(daddr_str)s on %(iface)s" % kws
        if 'proto_type' in kws:
            eval_str += " and the protocol type is %s" % (kws['proto_type'])
        if 'proto_subtype' in kws:
            eval_str += " and the protocol subtype is %s" % (kws['proto_subtype'])
        if 'reqid' in kws:
            eval_str += " and the request id is %s" % (kws['reqid'])
        print eval_str

    try:
        print str(kzorp.communication.Handle().exchange(query_msg))
    except kzorp.netlink.NetlinkException as e:
        print "Evaluate failed: result='%s', error='%s'" % (e.detail, os.strerror(-e.detail))

def LookupZone(parser, address):
    handle = kzorp.communication.Handle()

    (family, addr) = parse_ip_option(parser, address, 'lookup')
    lookup_msg = kzorp.messages.KZorpLookupZoneMessage(family, addr)
    try:
        print str(kzorp.communication.Handle().exchange(lookup_msg))
    except kzorp.netlink.NetlinkException, e:
        if e.detail != -errno.ENOENT:
            raise e

    return 0

def main(args):
    option_list = [
                     optparse.make_option("-v", "--version",
                                          action="store_true", dest="version",
                                          default=False,
                                          help="dump KZorp version "
                                               "[default: %default]"),
                     optparse.make_option("-z", "--zones",
                                          action="store_true", dest="zones",
                                          default=False,
                                          help="dump KZorp zones "
                                               "[default: %default]"),
                     optparse.make_option("-s", "--services",
                                          action="store_true", dest="services",
                                          default=False,
                                          help="dump KZorp services "
                                               "[default: %default]"),
                     optparse.make_option("-d", "--dispatchers",
                                          action="store_true", dest="dispatchers",
                                          default=False,
                                          help="dump KZorp dispatchers "
                                               "[default: %default]"),
                     optparse.make_option("-b", "--binds",
                                          action="store_true", dest="binds",
                                          default=False,
                                          help="dump KZorp instance bind parameters"
                                               "[default: %default]"),
                     optparse.make_option("-e", "--evaluate",
                                          dest="evaluate",
                                          action="callback",
                                          default=None,
                                          callback=evaluate_option_parser_cb,
                                          help="evaluate arguments: " + EvaluateArgumentDescription + "\n"
                                               "optional parameters of the evaluation might be set by src-port, dst-port, icmp-type, icmp-code, request-id parameters"),
                     optparse.make_option(None, "--src-port",
                                          action="store", type="int", dest="sport",
                                          default=1024,
                                          help="source port inc case of evaluation "
                                               "[default: %default]"),
                     optparse.make_option(None, "--dst-port",
                                          action="store", type="int", dest="dport",
                                          default=None,
                                          help="destination port inc case of evaluation "
                                               "[default: %default]"),
                     optparse.make_option(None, "--icmp-type",
                                          action="store", type="int", dest="proto_type",
                                          default=None,
                                          help="ICMP type in case of evaluation "
                                               "[default: %default]"),
                     optparse.make_option(None, "--icmp-code",
                                          action="store", type="int", dest="proto_subtype",
                                          default=None,
                                          help="ICMP code in case of evaluation "
                                               "[default: %default]"),
                     optparse.make_option(None, "--request-id",
                                          action="store", type="int", dest="reqid",
                                          default=None,
                                          help="IPSec request number in case of evaluation "
                                               "[default: %default]"),
                     optparse.make_option("-q", "--quiet",
                                          action="store_true", dest="quiet",
                                          default=False,
                                          help="quiet operation "
                                               "[default: %default]"),
                     optparse.make_option("-u", "--upload",
                                          action="store", type="string", dest="upload",
                                          default=None,
                                          help="upload KZorp zone structure from file "
                                               "[default: %default]"),
                     optparse.make_option("-l", "--lookup",
                                          action="store", dest="lookup_ip",
                                          default=None,
                                          help="dump KZorp zone best matches the given IP"
                                               "[default: %default]"),
                     optparse.make_option("-f", "--flush",
                                          action="store", type="choice", dest="flush",
                                          choices=['binds', 'dispatchers', 'services', 'zones', 'all'],
                                          default=None,
                                          help="flush objects from KZorp"
                                               "[default: %default]"),
                  ]

    parser = optparse.OptionParser(option_list=option_list, prog="kzorp", usage = "usage: %prog [options]")
    (options, args) = parser.parse_args()

    if (options.version or options.zones or options.services or options.dispatchers or options.binds or options.upload != None or options.evaluate != None or options.lookup_ip != None or options.flush != None) == False:
        parser.error("at least one option must be set")

    if os.getuid() != 0:
        sys.stderr.write("kzorp must be run as root\n")
        return 2

    res = 3
    try:
        if options.version:
            version = DumpVersion(options.quiet)
            res = version.dump()
        if options.flush:
            flush(options.flush)
        if options.zones:
            dump_zones = DumpZones(options.quiet)
            res = dump_zones.dump()
        if options.services:
            dump_services = DumpServices(options.quiet)
            res = dump_services.dump()
        if options.dispatchers:
            dump_dispatchers = DumpDispatchers(options.quiet)
            res = dump_dispatchers.dump()
        if options.binds:
            dump_binds = DumpBinds(options.quiet)
            res = dump_binds.dump()
        if options.upload:
            res = upload_zones(options.upload)
        if options.evaluate:
            res = evaluate(parser, options, options.quiet)
        if options.lookup_ip:
            res = LookupZone(parser, options.lookup_ip)
    except socket.error, e:
        if e[0] == 111:
            sys.stderr.write("KZorp support not present in kernel\n")
            return 2
        raise

    return res

if __name__ == "__main__":
    res = main(sys.argv)
    sys.exit(res)
