#!/usr/bin/env python2.7

"""
KZorp Daemon

This is a stand alone daemon with the following responsibilities:

 - initial zone download to KZorp
 - continuous update of hostanme par of zones in KZorp

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import generators
from __future__ import nested_scopes
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import with_statement

import argparse
import enum
import errno
import imp
import itertools
import os
import radix
import select
import signal
import socket
import sys
import systemd.daemon
import threading
import time
import traceback

from Zorp.InstancesConf import InstancesConf
from Zorp.Zone import Zone

from kzorp.netlink import NetlinkException

from zorpctl.ZorpctlConf import ZorpctlConfig

import Zorp.Common as Common
import Zorp.ResolverCache as ResolverCache

import kzorp.communication
import kzorp.messages
import kzorp.zoneupdate

sys.dont_write_bytecode = True


class ZoneDownloadFactory(object):
    _instance = None
    _manage_caps = None
    _initialized = False

    @classmethod
    def init(cls, manage_caps):
       cls._manage_caps = manage_caps
       cls._initialized = True

    def __new__(cls, *args, **kwargs):
        if not cls._initialized:
            raise NotImplementedError

        if not cls._instance:
            cls._instance = super(ZoneDownloadFactory, cls).__new__(
                                cls, *args, **kwargs)
        return cls._instance

    def makeZoneDownload(self):
        return ZoneDownload(ZoneDownloadFactory._manage_caps)

class ZoneDownload(kzorp.communication.Adapter):

    def __init__(self, manage_caps):
        super(ZoneDownload, self).__init__(manage_caps=manage_caps)

    def initial(self, messages):
        self.send_messages_in_transaction([kzorp.messages.KZorpFlushZonesMessage(), ] + messages)

    def update(self, messages):
        self.send_messages_in_transaction(messages)


class DynamicZoneHandler(kzorp.zoneupdate.ZoneUpdateMessageCreator):

    def __init__(self, zones, dnscache):
        super(DynamicZoneHandler, self).__init__(zones, dnscache)

    def setup(self):
        with ZoneDownloadFactory().makeZoneDownload() as zone_download:
            messages = self.create_zone_static_address_initialization_messages()
            zone_download.initial(messages)
        self.setup_dns_cache()
        with ZoneDownloadFactory().makeZoneDownload() as zone_download:
            messages = self.create_zone_dynamic_address_initialization_messages()
            zone_download.update(messages)


class ConfigurationHandler():
    def __init__(self):
        self.instances_conf = InstancesConf()
        self.init_state()

    def _import_zones(self):
        policy_dirs = set()
        zorpctlconf = ZorpctlConfig.Instance()

        try:
            for instance in self.instances_conf:
                policy_dirs.add(os.path.dirname(instance.zorp_process.args.policy))
        except IOError, e:
            configdir = zorpctlconf['ZORP_SYSCONFDIR']
            Common.log(None, Common.CORE_INFO, 1, "Unable to open instances.conf, falling back to configuration dir; error='%s', fallback='%s'" % (e, configdir))
            policy_dirs.add(configdir)

        if len(policy_dirs) > 1:
            raise ImportError('Different directories of policy files found in instances.conf; policy_dirs=%s' % policy_dirs)
        if len(policy_dirs) == 0:
            Common.log(None, Common.CORE_INFO, 1, "No instances defined; instances_conf='%s'" % self.instances_conf.instances_conf_path)
            configdir = zorpctlconf['ZORP_SYSCONFDIR']
            Common.log(None, Common.CORE_INFO, 1, "Falling back to configuration directory; fallback='%s'" % (configdir,))
            policy_dirs.add(configdir)

        policy_dir = policy_dirs.pop()
        policy_module_name = 'zones'
        try:
            fp, pathname, description = imp.find_module(policy_module_name, [policy_dir, ])
            imp.load_module(policy_module_name, fp, pathname, description)
        except ImportError, e:
            fp = None
            Common.log(None, Common.CORE_INFO, 1, "Unable to import zones.py; error='%s'" % (e))
            raise e
        finally:
            if fp:
                fp.close()

    def init_state(self):
        self.saved_zones = {}
        self.saved_subnet_tree = radix.Radix()

    def save_state(self):
        self.saved_zones = Zone.zones
        self.saved_subnet_tree = Zone.zone_subnet_tree

    def restore_state(self):
        Zone.zones = self.saved_zones
        Zone.zone_subnet_tree = self.saved_subnet_tree

    def setup(self):
        Zone.zones = {}
        Zone.zone_subnet_tree = radix.Radix()
        self._import_zones()

    def reload(self):
        self.setup()


class DaemonEvent(enum.Enum):
    reload = 1
    exit = 2


class DaemonMessage(object):
    reload = "RELOAD\n"
    ok = "OK\n"

    @staticmethod
    def get_max_len():
        return max(len(attr) for attr, value in DaemonMessage.__dict__.iteritems())


class Daemon():
    min_sleep_in_sec = 60

    def __init__(self):
        Common.log(None, Common.CORE_INFO, 1, "KZorpd starting up...")

        self.conf_handler = ConfigurationHandler()
        self.dnscache = ResolverCache.ResolverCache(ResolverCache.DNSResolver())
        self.zone_handler = None
        self.sleep_sec = self.min_sleep_in_sec

        self.event = None

        self.listen_socket = None
        self.epoll = None
        self.connections = {}
        self.requests = {}
        self.responses = {}

    def __del__(self):
        self._remove_control_socket(get_control_socket_path())

    def _sighup_handler(self, sig_num, frame):
        self.event = DaemonEvent.reload
        Common.log(None, Common.CORE_INFO, 1, 'Received SIGHUP, reloading configuration')

    def _sigint_handler(self, sig_num, frame):
        self.event = DaemonEvent.exit
        Common.log(None, Common.CORE_INFO, 1, 'Received SIGINT, loading static zones')

    def _sigterm_handler(self, sig_num, frame):
        self.event = DaemonEvent.exit
        Common.log(None, Common.CORE_INFO, 1, 'KZorpd shutting down...')

    def _download_static_zones(self):
        if self.zone_handler:
            with ZoneDownloadFactory().makeZoneDownload() as zone_download:
                messages = self.zone_handler.create_zone_static_address_initialization_messages()
                zone_download.initial(messages)
        else:
            Common.log(None, Common.CORE_INFO, 1, 'Failed to set up zone handler, no static zones loaded')

    def _reinitialize_zone_handler(self):
        self.zone_handler = DynamicZoneHandler(Zone.zones.values(), self.dnscache)
        self.zone_handler.setup()

    def _reload(self):
        systemd.daemon.notify('RELOADING=1')

        saved_dnscache = self.dnscache
        self.conf_handler.save_state()
        self.dnscache = ResolverCache.ResolverCache(ResolverCache.DNSResolver())
        try:
            self.conf_handler.reload()
            self._reinitialize_zone_handler()
        except (ImportError, NetlinkException) as e:
            Common.log(None, Common.CORE_ERROR, 1, "Unable to load configuration, keep existing one; error='%s'" % (e))
            self.conf_handler.restore_state()
            self.dnscache = saved_dnscache

        systemd.daemon.notify('READY=1')

    def _update_dnscache(self):
        expired_hostname = None
        try:
            self.dnscache.update()
            try:
                expired_hostname, expiration_time = self.dnscache.getNextExpiration()
            except ValueError, e:
                # if no hosts are in the cache, a ValueError is raised, sleep for the minimum time
                self.sleep_sec = self.min_sleep_in_sec
                Common.log(None, Common.CORE_DEBUG, 6,
                           "No hostnames in cache, sleep minimum expiration; sleep_sec='%d'" %
                           (self.sleep_sec, ))
            else:
                self.sleep_sec = max(expiration_time - time.time(), self.min_sleep_in_sec)
                Common.log(None, Common.CORE_DEBUG, 6,
                           "Sleep until next DNS expiration; sleep_sec='%d', host='%s'" %
                           (self.sleep_sec, expired_hostname))
        except KeyError:
            self.sleep_sec = self.min_sleep_in_sec
            Common.log(None, Common.CORE_DEBUG, 6,
                       "Cache lookup failed, sleep minimum expiration; sleep_sec='%d'" %
                       (self.sleep_sec, ))
        except BaseException, e:
            self.sleep_sec = self.min_sleep_in_sec
            Common.log(None, Common.CORE_ERROR, 1, "Unexpected error; error='%s'" % (traceback.format_exc()))
        finally:
            if self.zone_handler is not None:
                if expired_hostname is not None:
                    Common.log(None, Common.CORE_INFO, 4,
                               "TTL for host expired, updating host; hostname='%s', ttl='%d'" % (expired_hostname, expiration_time))
                    messages = self.zone_handler.create_zone_update_messages(expired_hostname)
                    try:
                        with ZoneDownloadFactory().makeZoneDownload() as zone_download:
                            zone_download.update(messages)
                    except NetlinkException as e:
                        Common.log(None, Common.CORE_ERROR, 1, "Unable to update addresses in zones, keep existing ones; error='%s'" % (e))
                elif self.zone_handler.dnscache.hosts:
                    Common.log(None, Common.CORE_ERROR, 3,
                               "Name resolution has failed, reinitialize hostname based addresses;")
                    with ZoneDownloadFactory().makeZoneDownload() as zone_download:
                        messages = self.zone_handler.create_zone_dynamic_address_initialization_messages()
                        zone_download.update(messages)
                else:
                    Common.log(None, Common.CORE_INFO, 6, "No hostnames in cache, no update needed;")

    def _remove_control_socket(self, path):
        if os.path.exists(path):
            os.remove(path)

    def _server_init(self):
        ctl_socket_path = get_control_socket_path()
        self._remove_control_socket(ctl_socket_path)

        self.listen_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.listen_socket.bind(ctl_socket_path)
        self.listen_socket.listen(1)
        self.listen_socket.setblocking(0)

        self.epoll = select.epoll()
        self.epoll.register(self.listen_socket.fileno(), select.EPOLLIN)

    def _accept_connection(self):
        (connection, address) = self.listen_socket.accept()
        connection.setblocking(0)
        self.epoll.register(connection.fileno(), select.EPOLLIN)
        self.connections[connection.fileno()] = connection
        self.requests[connection.fileno()] = b''
        self.responses[connection.fileno()] = b''

    def _close_connection(self, fd):
        self.epoll.unregister(fd)
        self.connections[fd].shutdown(socket.SHUT_RDWR)
        self.connections[fd].close()
        del self.connections[fd]
        del self.requests[fd]
        del self.responses[fd]

    def _read_request(self, fd):
        max_request_len = DaemonMessage.get_max_len()
        try:
            self.requests[fd] += self.connections[fd].recv(max_request_len)
        except socket.error as e:
            Common.log(None, Common.CORE_ERROR, 1, "Could not read request; error='{}'".format(e))
            self._close_connection(fd)
            return

        if len(self.requests[fd]) > max_request_len:
            self._close_connection(fd)
        elif self.requests[fd].decode('ascii') == DaemonMessage.reload:
            self.event = DaemonEvent.reload
            self.responses[fd] = DaemonMessage.ok
            self.epoll.modify(fd, select.EPOLLOUT)

    def _write_response(self, fd):
        byteswritten = 0
        try:
            byteswritten = self.connections[fd].send(self.responses[fd])
        except socket.error as e:
            Common.log(None, Common.CORE_ERROR, 1, "Could not write response; error='{}'".format(e))
            self._close_connection(fd)
            return

        self.responses[fd] = self.responses[fd][byteswritten:]
        if len(self.responses[fd]) == 0:
            self.epoll.modify(fd, 0)
            self._close_connection(fd)

    def setup(self):
        self.conf_handler.save_state()
        try:
            self.conf_handler.setup()
            self._reinitialize_zone_handler()
        except (ImportError, NetlinkException) as e:
            Common.log(None, Common.CORE_ERROR, 1, "Unable to load configuration, keep existing one; error='%s'" % (e))
            self.conf_handler.restore_state()

        signal.signal(signal.SIGHUP, self._sighup_handler)
        signal.signal(signal.SIGINT, self._sigint_handler)
        signal.signal(signal.SIGTERM, self._sigterm_handler)

        self._server_init()

        systemd.daemon.notify('READY=1')

    def run(self):
        try:
            while True:
                systemd.daemon.notify('WATCHDOG=1')

                if self.event is DaemonEvent.reload:
                    self._reload()
                elif self.event is DaemonEvent.exit:
                    self._download_static_zones()
                    break

                self.event = None

                self._update_dnscache()

                epoll_events = []
                try:
                    epoll_events = self.epoll.poll(self.sleep_sec)
                except IOError as e:
                    if e.errno != errno.EINTR:
                        raise

                for (fd, event_type) in epoll_events:
                    if fd == self.listen_socket.fileno():
                        self._accept_connection()
                    elif event_type & select.EPOLLIN:
                        self._read_request(fd)
                    elif event_type & select.EPOLLOUT:
                        self._write_response(fd)
                    elif event & select.EPOLLHUP:
                        self._close_connection(fd)
        finally:
            self.epoll.unregister(self.listen_socket.fileno())
            self.epoll.close()
            self.listen_socket.close()


def get_control_socket_path():
    zorpctlconf = ZorpctlConfig.Instance()
    try:
        pid_file_dir_path = zorpctlconf['ZORP_PIDFILEDIR']
    except KeyError:
        raise RuntimeError('Failed to lookup Zorp PID file directory path;')

    return os.path.join(pid_file_dir_path, "kzorpd")


def run(log_verbosity, log_spec, use_syslog):
    Common.LoggerSingleton().init("kzorpd", log_verbosity, log_spec, use_syslog)

    ZoneDownloadFactory.init(manage_caps=False)

    kzorpd_daemon = Daemon()
    kzorpd_daemon.setup()
    kzorpd_daemon.run()


def run_reload():
    max_wait_sec = 60.0

    ctl_socket_path = get_control_socket_path()
    ctl_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    ctl_socket.settimeout(max_wait_sec)
    response = ''

    ctl_socket.connect(ctl_socket_path)
    ctl_socket.sendall(bytes(DaemonMessage.reload))
    response = str(ctl_socket.recv(1024))
    ctl_socket.close()

    if response != DaemonMessage.ok:
        sys.exit(1)


def process_command_line_arguments():
    parser = argparse.ArgumentParser(description='KZorp daemon')
    parser.add_argument('-v', '--verbose', action='store', dest='verbose', type=int, default=3,
                        help='set verbosity level (default: %(default)d)')
    parser.add_argument('-l', '--no-syslog', action="store_true", dest="do_not_use_syslog", default=False,
                        help='do not send messages to syslog (default: %(default)s)')
    parser.add_argument('-s', '--log-spec', action='store', dest='log_spec', type=str, default="core.accounting:4",
                        help='set log specification (default: %(default)s)')
    parser.add_argument('-r', '--reload', action='store_true', dest='reload', default=False,
                        help='reload KZorp daemon')
    return parser.parse_args()


if __name__ == "__main__":
    args = process_command_line_arguments()

    if args.reload:
        run_reload()
    else:
        run(args.verbose, args.log_spec, not args.do_not_use_syslog)
