#!/usr/bin/env python

"""
Augment the alignment_summary.gff file with consensus and variants information.
"""

from collections import namedtuple
import argparse
import logging
import json
import gzip
import sys

import numpy as np

from pbcommand.utils import setup_log
from pbcommand.cli import pacbio_args_or_contract_runner, get_default_argparser
from pbcommand.models import FileTypes, get_pbparser
from pbcommand.common_options import add_resolved_tool_contract_option
from pbcore.io import GffReader, GffWriter, Gff3Record

from GenomicConsensus.utils import error_probability_to_qv
from GenomicConsensus import __VERSION__

#
# Note: GFF-style coordinates
#
Region = namedtuple("Region", ("seqid", "start", "end"))

class Constants(object):
    TOOL_ID = "genomic_consensus.tasks.summarize_consensus"
    DRIVER_EXE = "summarizeConsensus --resolved-tool-contract "

def get_argument_parser():
    parser = get_default_argparser(__VERSION__, __doc__)
    parser.add_argument("inputAlignmentSummaryGff",
                        type=str,
                        help="Input alignment_summary.gff filename")
    # FIXME not optional, should be positional
    parser.add_argument("--variantsGff",
                        type=str,
                        help="Input variants.gff or variants.gff.gz filename",
                        required=True)
    parser.add_argument("--output",
                        "-o",
                        type=str,
                        help="Output alignment_summary.gff filename")
    add_resolved_tool_contract_option(parser)
    # FIXME temporary workaround for parser chaos
    class EmitToolContractAction(argparse.Action):
        def __call__(self, parser, namespace, values, option_string=None):
            parser2 = get_contract_parser()
            sys.stdout.write(json.dumps(parser2.to_contract().to_dict(), indent=4)+'\n')
            sys.exit(0)
    parser.add_argument("--emit-tool-contract",
                        nargs=0,
                        action=EmitToolContractAction)
    return parser

def get_contract_parser():
    """
    Used to generate emitted tool contract, but not (yet) to actually process
    command-line options.
    """
    driver_exe = "variantCaller --resolved-tool-contract "
    p = get_pbparser(
        "genomic_consensus.tasks.summarize_consensus",
        __VERSION__,
        "Summarize Consensus",
        __doc__,
        Constants.DRIVER_EXE)
    p.add_input_file_type(FileTypes.GFF, "alignment_summary",
        "Alignment summary GFF", "Alignment summary GFF file")
    p.add_input_file_type(FileTypes.GFF, "variants",
        "Variants GFF", "Variants GFF file")
    p.add_output_file_type(FileTypes.GFF, "output",
        name="Output GFF file",
        description="New alignment summary GFF file",
        default_name="alignment_summary_variants.gff")
    return p

def get_args_from_resolved_tool_contract(resolved_tool_contract):
    rtc = resolved_tool_contract
    p = get_argument_parser()
    args = [
        rtc.task.input_files[0],
        "--variantsGff", rtc.task.input_files[1],
        "--output", rtc.task.output_files[0],
    ]
    return p.parse_args(args)

def run(options):
    headers = [
        ("source", "GenomicConsensus %s" % __VERSION__),
        ("pacbio-alignment-summary-version", "0.6"),
        ("source-commandline", " ".join(sys.argv)),
        ]

    inputVariantsGff = GffReader(options.variantsGff)
    inputAlignmentSummaryGff = GffReader(options.inputAlignmentSummaryGff)

    summaries = {}
    for gffRecord in inputAlignmentSummaryGff:
        region = Region(gffRecord.seqid, gffRecord.start, gffRecord.end)
        summaries[region] = { "ins" : 0,
                              "del" : 0,
                              "sub" : 0,
                              "cQv" : (0, 0, 0)
                             }
    inputAlignmentSummaryGff.close()

    counterNames = { "insertion"    : "ins",
                     "deletion"     : "del",
                     "substitution" : "sub" }
    for variantGffRecord in inputVariantsGff:
        for region in summaries:
            summary = summaries[region]
            if (region.seqid == variantGffRecord.seqid and
                region.start <= variantGffRecord.start <= region.end):
                counterName = counterNames[variantGffRecord.type]
                variantLength = max(len(variantGffRecord.reference),
                                    len(variantGffRecord.variantSeq))
                summary[counterName] += variantLength
            # TODO: base consensusQV on effective coverage
            summary["cQv"] = (20, 20, 20)

    inputAlignmentSummaryGff = open(options.inputAlignmentSummaryGff)
    outputAlignmentSummaryGff = open(options.output, "w")

    inHeader = True

    for line in inputAlignmentSummaryGff:
        line = line.rstrip()

        # Pass any metadata line straight through
        if line[0] == "#":
            print >>outputAlignmentSummaryGff, line.strip()
            continue

        if inHeader:
            # We are at the end of the header -- write the tool-specific headers
            for k, v in headers:
                print >>outputAlignmentSummaryGff, ("##%s %s" % (k, v))
            inHeader = False

        # Parse the line
        rec = Gff3Record.fromString(line)

        if rec.type == "region":
            summary = summaries[(rec.seqid, rec.start, rec.end)]
            if "cQv" in summary:
                cQvTuple = summary["cQv"]
                line += ";%s=%s" % ("cQv", ",".join(str(int(f)) for f in cQvTuple))
            for counterName in counterNames.values():
                if counterName in summary:
                    line += ";%s=%d" % (counterName, summary[counterName])
            print >>outputAlignmentSummaryGff, line
    return 0

def args_runner(args):
    return run(options=args)

def resolved_tool_contract_runner(resolved_tool_contract):
    args = get_args_from_resolved_tool_contract(resolved_tool_contract)
    return run(options=args)

def main(argv=sys.argv):
    logging.basicConfig(level=logging.WARN)
    mp = get_argument_parser()
    log = logging.getLogger()
    def dummy_log(*args, **kwargs):
        pass
    return pacbio_args_or_contract_runner(argv[1:],
                                          mp,
                                          args_runner,
                                          resolved_tool_contract_runner,
                                          log,
                                          dummy_log)

if __name__ == "__main__":
    sys.exit(main())
