#!/usr/bin/python3
#

# Copyright (C) 2006, 2007, 2008, 2009, 2010, 2011 Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# pylint: disable=C0103

"""confd client program

This is can be used to test and debug confd daemon functionality.

"""

from __future__ import print_function

import sys
import optparse
import time

from ganeti import constants
from ganeti import cli
from ganeti import utils
from ganeti import pathutils

from ganeti.confd import client as confd_client

USAGE = ("\tconfd-client [--addr=host] [--hmac=key]")

LOG_HEADERS = {
  0: "- ",
  1: "* ",
  2: "",
  }

OPTIONS = [
  cli.cli_option("--hmac", dest="hmac", default=None,
                 help="Specify HMAC key instead of reading"
                 " it from the filesystem",
                 metavar="<KEY>"),
  cli.cli_option("-a", "--address", dest="mc", default="localhost",
                 help="Server IP to query (default: 127.0.0.1)",
                 metavar="<ADDRESS>"),
  cli.cli_option("-r", "--requests", dest="requests", default=100,
                 help="Number of requests for the timing tests",
                 type="int", metavar="<REQUESTS>"),
  ]


def Log(msg, *args, **kwargs):
  """Simple function that prints out its argument.

  """
  if args:
    msg = msg % args
  indent = kwargs.get("indent", 0)
  sys.stdout.write("%*s%s%s\n" % (2 * indent, "",
                                  LOG_HEADERS.get(indent, "  "), msg))
  sys.stdout.flush()


def LogAtMost(msgs, count, **kwargs):
  """Log at most count of given messages.

  """
  for m in msgs[:count]:
    Log(m, **kwargs)
  if len(msgs) > count:
    Log("...", **kwargs)


def Err(msg, exit_code=1):
  """Simple error logging that prints to stderr.

  """
  sys.stderr.write(msg + "\n")
  sys.stderr.flush()
  sys.exit(exit_code)


def Usage():
  """Shows program usage information and exits the program."""

  print("Usage:", file=sys.stderr)
  print(USAGE, file=sys.stderr)
  sys.exit(2)


class TestClient(object):
  """Confd test client."""

  def __init__(self):
    """Constructor."""
    self.opts = None
    self.cluster_master = None
    self.instance_ips = None
    self.is_timing = False
    self.ParseOptions()

  def ParseOptions(self):
    """Parses the command line options.

    In case of command line errors, it will show the usage and exit the
    program.

    """
    parser = optparse.OptionParser(usage="\n%s" % USAGE,
                                   version=("%%prog (ganeti) %s" %
                                            constants.RELEASE_VERSION),
                                   option_list=OPTIONS)

    options, args = parser.parse_args()
    if args:
      Usage()

    if options.hmac is None:
      options.hmac = utils.ReadFile(pathutils.CONFD_HMAC_KEY)
    self.hmac_key = options.hmac

    self.mc_list = [options.mc]

    self.opts = options

  @staticmethod
  def _ProcessNotOk(reply, reqtype, is_timing):
    Log("Query %s gave non-ok status %s: %s" % (reply.orig_request,
                                                reply.server_reply.status,
                                                reply.server_reply))
    if is_timing:
      Err("Aborting timing tests")
    if reqtype == constants.CONFD_REQ_CLUSTER_MASTER:
      Err("Cannot continue after master query failure")
    if reqtype == constants.CONFD_REQ_INSTANCES_IPS_LIST:
      Err("Cannot continue after instance IP list query failure")

  @staticmethod
  def _ProcessIpList(answer):
    Log("Instance primary IP query: OK")
    if not answer:
      Log("no IPs received", indent=1)
    else:
      LogAtMost(answer, 5, indent=1)

  @staticmethod
  def _ProcessMapping(answer):
    Log("Instance IP to node IP query: OK")
    if not answer:
      Log("no mapping received", indent=1)
    else:
      LogAtMost(answer, 5, indent=1)

  def ConfdCallback(self, reply):
    """Callback for confd queries"""
    if reply.type == confd_client.UPCALL_REPLY:
      answer = reply.server_reply.answer
      reqtype = reply.orig_request.type
      if reply.server_reply.status != constants.CONFD_REPL_STATUS_OK:
        self._ProcessNotOk(reply, reqtype, self.is_timing)
        return
      if self.is_timing:
        return
      if reqtype == constants.CONFD_REQ_PING:
        Log("Ping: OK")
      elif reqtype == constants.CONFD_REQ_CLUSTER_MASTER:
        Log("Master: OK (%s)", answer)
        if self.cluster_master is None:
          # only assign the first time, in the plain query
          self.cluster_master = answer
      elif reqtype == constants.CONFD_REQ_NODE_ROLE_BYNAME:
        if answer == constants.CONFD_NODE_ROLE_MASTER:
          Log("Node role for master: OK",)
        else:
          Err("Node role for master: wrong: %s" % answer)
      elif reqtype == constants.CONFD_REQ_NODE_PIP_LIST:
        Log("Node primary ip query: OK")
        LogAtMost(answer, 5, indent=1)
      elif reqtype == constants.CONFD_REQ_MC_PIP_LIST:
        Log("Master candidates primary IP query: OK")
        LogAtMost(answer, 5, indent=1)
      elif reqtype == constants.CONFD_REQ_INSTANCES_IPS_LIST:
        self.instance_ips = self._ProcessIpList(answer)
      elif reqtype == constants.CONFD_REQ_NODE_PIP_BY_INSTANCE_IP:
        self._ProcessMapping(answer)
      else:
        Log("Unhandled reply %s, please fix the client", reqtype)
        print(answer)

  def DoConfdRequestReply(self, req):
    self.confd_counting_callback.RegisterQuery(req.rsalt)
    self.confd_client.SendRequest(req, async_=False)
    while not self.confd_counting_callback.AllAnswered():
      if not self.confd_client.ReceiveReply():
        Err("Did not receive all expected confd replies")
        break

  def TestConfd(self):
    """Run confd queries for the cluster.

    """
    Log("Checking confd results")

    filter_callback = confd_client.ConfdFilterCallback(self.ConfdCallback)
    counting_callback = confd_client.ConfdCountingCallback(filter_callback)
    self.confd_counting_callback = counting_callback

    self.confd_client = confd_client.ConfdClient(self.hmac_key,
                                                 self.mc_list,
                                                 counting_callback)

    tests = [
      {"type": constants.CONFD_REQ_PING},
      {"type": constants.CONFD_REQ_CLUSTER_MASTER},
      {"type": constants.CONFD_REQ_CLUSTER_MASTER,
       "query": {constants.CONFD_REQQ_FIELDS:
                 [str(constants.CONFD_REQFIELD_NAME),
                  str(constants.CONFD_REQFIELD_IP),
                  str(constants.CONFD_REQFIELD_MNODE_PIP),
                  ]}},
      {"type": constants.CONFD_REQ_NODE_ROLE_BYNAME},
      {"type": constants.CONFD_REQ_NODE_PIP_LIST},
      {"type": constants.CONFD_REQ_MC_PIP_LIST},
      {"type": constants.CONFD_REQ_INSTANCES_IPS_LIST,
       "query": None},
      {"type": constants.CONFD_REQ_NODE_PIP_BY_INSTANCE_IP},
      ]

    for kwargs in tests:
      if kwargs["type"] == constants.CONFD_REQ_NODE_ROLE_BYNAME:
        assert self.cluster_master is not None
        kwargs["query"] = self.cluster_master
      elif kwargs["type"] == constants.CONFD_REQ_NODE_PIP_BY_INSTANCE_IP:
        kwargs["query"] = {constants.CONFD_REQQ_IPLIST: self.instance_ips}

      req = confd_client.ConfdClientRequest(**kwargs)
      self.DoConfdRequestReply(req)

  def TestTiming(self):
    """Run timing tests.

    """
    # timing tests
    if self.opts.requests <= 0:
      return
    Log("Timing tests")
    self.is_timing = True
    self.TimingOp("ping", {"type": constants.CONFD_REQ_PING})
    self.TimingOp("instance ips",
                  {"type": constants.CONFD_REQ_INSTANCES_IPS_LIST})

  def TimingOp(self, name, kwargs):
    """Run a single timing test.

    """
    start = time.time()
    for _ in range(self.opts.requests):
      req = confd_client.ConfdClientRequest(**kwargs)
      self.DoConfdRequestReply(req)
    stop = time.time()
    per_req = 1000 * (stop - start) / self.opts.requests
    Log("%.3fms per %s request", per_req, name, indent=1)

  def Run(self):
    """Run all the tests.

    """
    self.TestConfd()
    self.TestTiming()


def main():
  """Main function.

  """
  return TestClient().Run()


if __name__ == "__main__":
  main()
