#!/usr/bin/env python
# -*- Mode: Python; coding: utf-8; indent-tabs-mode: nil; tab-width: 4 -*-
# Copyright 2012 Canonical
# Author: Thomi Richards
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#

"""A simple unit-test runner that 'Just Works' for the most common use cases."""

import argparse
from argparse import ArgumentParser
from contextlib import contextmanager
import logging
import sys
from testtools import iterate_tests
from unittest.loader import TestLoader
from unittest import TestSuite
import os


logger = logging.getLogger(__name__)
_output_descriptors = {
    '_output_stream':  None,
    '_formatted_output_stream': None
}


def main():
    args = parse_args()

    logging.basicConfig(stream=get_output_stream(args.output, '_output_stream'))
    test_suite = load_test_suite_from_name(args.suite)
    runner = construct_test_runner(args)
    success = runner.run(test_suite).wasSuccessful()
    if not success:
        exit(1)


def get_output_stream(output, name):
    global _output_descriptors

    if _output_descriptors[name] is None:
        if output:
            _output_descriptors[name] = open(output, 'w')
        else:
            _output_descriptors[name] = sys.stdout
    return _output_descriptors[name]


def patch_python_path():
    """Prepend the current directory to sys.path to ensure that we can
    load & run tests if the caller is in the parent directory.

    """
    if os.getcwd() not in sys.path:
        sys.path.insert(0, os.getcwd())


def load_test_suite_from_name(test_names):
    """Returns a test suite object given a dotted test names."""
    patch_python_path()
    loader = TestLoader()
    if isinstance(test_names, basestring):
        test_names = list(test_names)
    elif not isinstance(test_names, list):
        raise TypeError("test_names must be either a string or list, not %r"
                        % (type(test_names)))

    tests = []
    test_package_locations = []
    for test_name in test_names:
        top_level_pkg = test_name.split('.')[0]
        package = __import__(top_level_pkg)
        pkg_file = package.__file__
        if os.path.basename(pkg_file).startswith('__init__'):
            package_parent_path = os.path.abspath(
                os.path.join(
                    os.path.dirname(pkg_file),
                    '..'
                    )
                )
        else:
            package_parent_path = os.path.abspath(os.path.dirname(pkg_file))
        if package_parent_path not in test_package_locations:
            test_package_locations.append(package_parent_path)

        tests.append(loader.discover(top_level_pkg,
                                     top_level_dir=package_parent_path))
    all_tests = TestSuite(tests)

    test_dirs = ", ".join(sorted(test_package_locations))
    logger.info("Loading tests from: %s\n", test_dirs)

    requested_tests = {}
    for test in iterate_tests(all_tests):
        # The test loader returns tests that start with 'unittest.loader' if for
        # whatever reason the test failed to load. We run the tests without the
        # built-in exception catching turned on, so we can get at the raised
        # exception, which we print, so the user knows that something in their
        # tests is broken.
        if test.id().startswith('unittest.loader'):
            try:
                test.debug()
            except Exception as e:
                print e
        elif any([test.id().startswith(name) for name in test_names]):
            requested_tests[test.id()] = test

    return TestSuite(requested_tests.values())


def construct_test_runner(args):
    output_stream = get_output_stream(args.formatted_output, '_formatted_output_stream')

    kwargs = dict(stdout=output_stream,
        output_format=args.format
        )

    if 'coverage' in args:
        kwargs['coverage'] = args.coverage
        if 'cover_format' in args:
            kwargs['cover_format'] = args.cover_format
        if 'cover_output' in args:
            kwargs['cover_output'] = args.cover_output
        if 'cover_exclude' in args:
            kwargs['cover_exclude'] = args.cover_exclude
    return ConfigurableTestRunner(**kwargs)


class ConfigurableTestRunner(object):
    """A configurable test runner class.

    This class alows us to configure the output format and whether of not we
    collect coverage information for the test run.

    """

    def __init__(self, stdout, output_format, coverage=False, cover_format=None, cover_output='-', cover_exclude=''):
        self.stdout = stdout
        self.result_class = output_format
        self.coverage = coverage
        self.cover_format = cover_format
        self.cover_output = cover_output
        self.cover_exclude = cover_exclude.split(',')

    def run(self, test):
        "Run the given test case or test suite."
        result = self.result_class(self.stdout)
        result.startTestRun()
        try:
            if self.coverage:
                with enable_coverage(self.cover_output, self.cover_format, self.cover_exclude):
                    return test.run(result)
            else:
                return test.run(result)
        finally:
            result.stopTestRun()


@contextmanager
def enable_coverage(filename, format, exclude_list):
    """A context manager that enables coverage collection."""

    from coverage import coverage
    cov = coverage(omit=exclude_list)
    cov.start()
    yield
    cov.stop()

    if format == 'xml':
        if not filename.endswith('.' + format):
            filename += '.xml'
        cov.xml_report(outfile=filename)
    elif format == 'html':
        cov.html_report(directory=filename)
    else:
        if not filename.endswith('.' + format):
            filename += '.txt'
        cov.report(file=open(filename, 'w'))


class FormatAction(argparse.Action):
    """An Argparse action that stores the output format class object."""

    supported_formats = []

    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, self.supported_formats.get(values))


def parse_args():
    """Create the argument parser object, parse sys.stdout and return the result."""
    parser = ArgumentParser(description=__doc__)

    parser.add_argument('suite', nargs='+', help="Specify test suite(s) to run.")

    supported_formats = get_supported_output_formats()
    FormatAction.supported_formats = supported_formats

    parser.add_argument('-f', '--format', choices=supported_formats.keys(),
        action=FormatAction, default=supported_formats['text'],
        help="""Specify what format the test results should be presented in.""")
    parser.add_argument('-fo', '--formatted-output', type=str,
        help="""Specify where formatted output (e.g. xml) should go. If left
        unspecified, stdout is used.""")
    parser.add_argument('-o', '--output', help="""Specify the location where test
        output should go. If left unspecified, stdout is used.""")
    if have_coverage():
        parser.add_argument('-c', '--coverage', action='store_true', default=False,
            help="""Enable coverage collection for this test run.""")
        parser.add_argument('-cf', '--cover-format', type=str, choices=['html', 'xml', 'txt'],
            help="Specify coverage report format. Default is txt.", default='txt',
            dest='cover_format')
        parser.add_argument('-co', '--cover-output', type=str, default='coverage',
            help="Specify the file path where the coverage report should go. \
            The default is 'coverage.XXX'.", dest='cover_output')
        parser.add_argument('-ce', '--cover-exclude', type=str, help="Omit files \
        when their filename matches one of these patterns. Usually needs quoting \
        on the command line. Multiple patterns can be specified in a comma-separated \
        list.", dest='cover_exclude', default='')
    return parser.parse_args()


def get_supported_output_formats():
    """Return a dictionary mapping a short name to a Result object for each
    format we support.

    """
    formats = {}
    try:
        from testtools import TextTestResult
        formats['text'] = TextTestResult
    except:
        pass
    try:
        from junitxml import JUnitXmlResult
        formats['xml'] = JUnitXmlResult
    except:
        pass
    return formats


def have_coverage():
    """Return true if the coverage module is installed."""
    try:
        import coverage
        return True
    except ImportError:
        return False


if __name__ == '__main__':
    main()
