#!/usr/bin/env python
# Copyright (c) 2011-2012, Universite de Versailles St-Quentin-en-Yvelines
#
# This file is part of ASK.  ASK is free software: you can redistribute
# it and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import logging
import os
import os.path
import select
import shutil
import subprocess
import sys

from common.configuration import Configuration
from common.util import fatal

VERSION="1.0.1"

###
# Driver Main Loop
#

# Will contain the configuration and logger singletons
C, L = None, None


def call_module(module, params=None):
    if params is None:
        params = []

    L.debug("Calling {0} module".format(module))
    # Prepare the command line
    # Add the executable
    cmd = [C("modules")[module]["executable"]]

    # Append configuration file
    cmd.append(C("configuration_file"))

    # Append the input/output params
    cmd += params

    cmd = " ".join(cmd)
    L.debug("  - cmd : <{0}>".format(cmd))
    try:
        p = subprocess.Popen(cmd,
                             stderr=subprocess.PIPE,
                             stdout=subprocess.PIPE,
                             bufsize=1,
                             shell=True)
    except OSError:
        fatal("Could not execute module {0}\n"
              "  {1}".format(module, cmd))

    eof_stdout, eof_stderr = False, False
    while(not(eof_stdout and eof_stderr)):
        for fd in select.select([p.stdout, p.stderr], [], [])[0]:
            l = fd.readline()
            if not l:
                if fd == p.stdout:
                    eof_stdout = True
                else:
                    eof_stderr = True
            if fd == p.stderr and l.strip():
                L.warning("\t({0}) {1}".format(module, l[:-1]))
            elif l.strip():
                L.info("\t({0}) {1}".format(module, l[:-1]))

    p.wait()
    if p.returncode != 0:
        if module == "control" and p.returncode == 254:
            L.info("Control module decided to stop experiments")
            print "Experiments finished normally"
            sys.exit(0)
        else:
            fatal("Module {0} stopped with code {1}"
                  .format(module, p.returncode), p.returncode)


def driver_loop(args):
    """ Main loop of the driver """

    L.debug("Starting driver loop")

    out_dir = C("output_directory")
    iteration = 0
    labelled = "{0}/labelled.data".format(out_dir)
    # Blank labelled file (needed in replay)
    open(labelled, 'w').close()
    while(1):
        L.debug("**** Starting ITERATION {0} ****".format(iteration))
        model = "{0}/model{1:05d}.data".format(out_dir, iteration)
        newly_labelled = "{0}/labelled{1:05d}.data".format(out_dir, iteration)
        requested = "{0}/requested{1:05d}.data".format(out_dir, iteration)

        if not args.replay_only:
            if iteration == 0:
                # Call bootstrap module to get the initial data
                call_module("bootstrap", params=[requested])
                if not os.path.exists(requested):
                    fatal("Bootstrap did not produce points "
                          "({0} does not exist)".format(requested))
            else:
                # Ask sampler to request new points
                call_module("sampler", params=[labelled, requested])
                if not os.path.exists(requested):
                    fatal("Oracle did not gave back requests "
                          "({0} does not exist)".format(requested))

            # Call source to fulfill the requests
            call_module("source", params=[requested, newly_labelled])
            if not os.path.exists(newly_labelled):
                fatal("Source did not fullfill requests "
                      "({0} does not exist)".format(newly_labelled))

        # Append newly_labelled points to labelled
        L.debug("Updating labelled points")
        orig = open(newly_labelled, "rb")
        dest = open(labelled, "ab")
        shutil.copyfileobj(orig, dest)
        dest.close()
        orig.close()

        # Call model module

        if not args.skip_model:
            call_module("model", params=[labelled, model])

        # Call reporter module
        if not args.skip_reporter:
            call_module("reporter", params=[str(iteration),
                                            labelled,
                                            newly_labelled,
                                            model])

        # Call control module (which will decide when to exit this loop)
        call_module("control", params=[str(iteration), labelled, model])
        iteration += 1


def main():
    global C, L
    import argparse
    parser = argparse.ArgumentParser(
        description="Helps choosing sampling points during experiments.")
    parser.add_argument(
        '--force_overwrite',
        action='store_true',
        help="If a previous output directory is found its contents"
        " will be overwritten")
    parser.add_argument(
        '--replay_only',
        action='store_true',
        help="(Must be run over a full previous ask run). Skip bootstrap and"
             " sampler modules, but runs previous labels through the model and"
             " reporter. This option is useful if one wants to run sampling"
             " and analyze results separately.")

    parser.add_argument(
        '--skip_model',
        action='store_true',
        help="Skip model module. (WARNING: if the chosen control module"
             " relies on the output of the model module, this option"
             " fails.)")

    parser.add_argument(
        '--skip_reporter',
        action='store_true',
        help="Skip reporter module. (WARNING: if the chosen control module"
             " relies on the output of the reporter module, this option"
             " fails.)")

    parser.add_argument(
        '--version',
        action='version',
        version="ask {}".format(VERSION))

    parser.add_argument('configuration',
                        help="the experiment configuration file")

    args = parser.parse_args()
    if args.replay_only and args.force_overwrite:
        fatal("--force_overwrite and --replay_only are exclusive options.")

    # Setup the ASKHOME
    os.environ["ASKHOME"] = os.path.dirname(os.path.realpath(__file__))

    # Setup PYTHONPATH
    if "PYTHONPATH" in os.environ:
        os.environ["PYTHONPATH"] += ":" + os.environ["ASKHOME"]
    else:
        os.environ["PYTHONPATH"] = os.environ["ASKHOME"]

    # Read Configuration
    C = Configuration(user_configuration=args.configuration,
                      reload_live_configuration=False,
                      force_overwrite=args.force_overwrite,
                      replay_only=args.replay_only)

    # Start Logger
    print "Logging to {0}".format(C("log.logfile"))
    L = logging.getLogger(__name__)
    fh = logging.FileHandler(C("log.logfile"))
    formatter = logging.Formatter('[%(asctime)s %(levelname)-8s] %(message)s',
                                  datefmt='%H:%M')
    fh.setFormatter(formatter)
    L.setLevel(eval("logging." + C("log.level")))
    fh.setLevel(eval("logging." + C("log.level")))
    L.addHandler(fh)
    L.debug("Logger started at level {0}".format(C("log.level")))

    # Enter driver loop
    driver_loop(args)


if __name__ == "__main__":
    main()
