#!/usr/bin/env python2.7
#
# Copyright (C) 2006-2015 BalaBit IT Security, 2015-2017 BalaSys IT Security.
# This program/include file is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program/include file is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation,Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

import sys
sys.dont_write_bytecode = True
import csv
import os

from kzorp.messages import *

class Counter(object):
    def __init__(self, id, timestamp, count):
        self.id = id if type(id) is str else long(id)
        self.timestamp = timestamp
        self.count = long(count)

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, other):
        return self.id == other.id

    def __cmp__(self, other):
        return cmp(self.id, other.id)


class DumpBase(object):
    def __init__(self, dump_message, message_class, id_name_in_reply_message):
        self.dump_message = dump_message
        self.message_class = message_class
        self.id_name_in_reply_message = id_name_in_reply_message

    def dump(self):
        import kzorp.communication
        import kzorp.netlink

        replies = []
        try:
            handle = kzorp.communication.Handle()
            for reply in handle.dump(self.dump_message):
                replies.append(reply)
        except kzorp.netlink.NetlinkException as e:
            res = int(e.detail)
            sys.stderr.write("Dump failed: result='%d' error='%s'\n" % (res, os.strerror(-res)))
            return set()

        import time
        timestamp = time.time()
        counts = {
            Counter(id=getattr(reply, self.id_name_in_reply_message), timestamp=timestamp, count=reply.count)
            for reply in replies
            if isinstance(reply, self.message_class)
        }

        return counts


class DumpRuleCounters(DumpBase):
    def __init__(self):
        super(DumpRuleCounters, self).__init__(KZorpGetDispatcherMessage(), KZorpAddRuleMessage, 'rule_id')

class DumpZoneCounters(DumpBase):
    def __init__(self):
        super(DumpZoneCounters, self).__init__(KZorpGetZoneMessage(), KZorpAddZoneMessage, 'name')

class CounterCSVUpdater(object):
    csv_params = {
                   'fieldnames' : ['timestamp', 'count', 'id'],
                   'quoting'    : csv.QUOTE_NONNUMERIC,
                   'quotechar'  : '|',
                   'delimiter'  : '\t',
                 }

    def __init__(self, filename, counter_dumper):
        self.filename = filename
        self.counter_dumper = counter_dumper

    def update(self):
        old_counters = self.__read_counters()
        actual_counters = self.counter_dumper.dump()
        updated_counters = self.__update_couters(old_counters, actual_counters)
        self.__write_counters(updated_counters)

    def __update_couters(self, old_counters, actual_counters):
        updated_counters = set()

        for counter in old_counters:
            is_countable_still_exist = counter in actual_counters
            if is_countable_still_exist:
                updated_counters.add(counter)

        for counter in actual_counters:
            has_up_to_date_count = counter not in updated_counters or counter.count > 0
            if  has_up_to_date_count:
                if counter in updated_counters:
                    updated_counters.remove(counter)
                updated_counters.add(counter)

        return updated_counters

    def __write_counters(self, counters):
        try:
            with open(self.filename + '.new', 'w') as f:
                csvwriter = csv.DictWriter(f, **CounterCSVUpdater.csv_params)
                csvwriter.writeheader()
                for counter in counters:
                    csvwriter.writerow({'id' : counter.id, 'timestamp' : counter.timestamp, 'count' : counter.count})
        except IOError as e:
            sys.stderr.write("Error writing file; file='%s', error='%s'\n", (self.filename, e.strerror()))
            return

        os.rename(self.filename + ".new", self.filename)

    def __read_counters(self):
        counters = set()
        try:
            with open(self.filename, 'r') as f:
                csvreader = csv.DictReader(f, **CounterCSVUpdater.csv_params)
                for row in csvreader:
                    if set(row.values()) == set(self.csv_params['fieldnames']):
                        continue
                    try:
                        counters.add(Counter(**row))
                    except (csv.Error, ValueError) as e:
                        sys.stderr.write("Error reading file; file='%s', line='%d', error='%s'\n" % (self.filename, csvreader.line_num, e))
        except IOError as e:
            sys.stderr.write("Error reading file; file='%s', error='%s'\n" % (self.filename, e.strerror))

        return counters


def main(args):
    rule_filename = "/var/lib/zorp/kzorp/rule_statistics"
    zone_filename = "/var/lib/zorp/kzorp/zone_statistics"

    CounterCSVUpdater(zone_filename, DumpZoneCounters()).update()
    CounterCSVUpdater(rule_filename, DumpRuleCounters()).update()

if __name__ == "__main__":
    res = main(sys.argv)
    sys.exit(res)
