#!/usr/bin/env python
# Copyright (c) 2011-2012, Universite de Versailles St-Quentin-en-Yvelines
#
# This file is part of ASK.  ASK is free software: you can redistribute
# it and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

"""
This module call the hierarchical sampling module to find interesting
points to label
"""

import numpy as np
from random import randint, random

from common.configuration import Configuration
from common.util import fatal


def draw_point(factors):
    p = []
    for f in factors:
        if f["type"] == "integer":
            g = randint(int(f["range"]["min"]),
                        int(f["range"]["max"]))
            p.append(g)
        elif f["type"] == "float":
            g = (random() *
                (f["range"]["max"] - f["range"]["min"]) +
                f["range"]["min"])
            p.append(g)
        elif f["type"] == "categorical":
            g = randint(0, len(f["values"]) - 1)
            p.append(g)
        else:
            fatal("Wrong factor type: <{0}>"
                  .format(f["type"]))
    return tuple(p)


def generate_possible_points(factors, labelled, n):
    # Prepare a list of generators that generate all possible
    # factors satisfying constraints
    count = 0
    tries = 0
    MAX_TRIES = 50
    while(count < n):
        p = draw_point(factors)
        if p in labelled and tries < MAX_TRIES:
            tries += 1
            continue
        else:
            tries = 0
            count += 1
            yield p


def randomsample(configuration, input_file, output_file, n):
    # read the labelled points
    labelled = np.genfromtxt(input_file)

    # when there is only one observation we have to
    # reshape the data to say it is multidimensionnal
    if len(labelled.shape) == 1:
        labelled = labelled.reshape(1, labelled.shape[0])

    # take all the labelled points and create a set
    already_labelled = set()
    for v in labelled[:,:-1]:
        already_labelled.add(tuple(v))

    # open a file to write suggested points
    of = open(output_file, "w")
    for p in generate_possible_points(configuration["factors"],
                                      already_labelled,
                                      n):
        already_labelled.add(p)
        of.write(" ".join(map(str, p)) + "\n")
    of.close()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description="Draws random points which are not among the"
        " currently labelled points")
    parser.add_argument('configuration')
    parser.add_argument('input_file')
    parser.add_argument('output_file')
    args = parser.parse_args()
    conf = Configuration(args.configuration)
    randomsample(
        conf,
        args.input_file,
        args.output_file,
        conf("modules.sampler.params.n", int))
