#!/usr/bin/env python

"""
Augment the alignment_summary.gff file with consensus and variants information.
"""
from __future__ import absolute_import, division, print_function

from collections import namedtuple, defaultdict
import argparse
import logging
import bisect
import json
import gzip
import sys

import numpy as np

from pbcommand.utils import setup_log
from pbcommand.cli import pbparser_runner, get_default_argparser
from pbcommand.models import FileTypes, get_pbparser
from pbcommand.common_options import add_resolved_tool_contract_option
from pbcore.io import GffReader, GffWriter, Gff3Record

from GenomicConsensus.utils import error_probability_to_qv
from GenomicConsensus import __VERSION__

#
# Note: GFF-style coordinates
#
Region = namedtuple("Region", ("seqid", "start", "end"))

log = logging.getLogger(__name__)

class Constants(object):
    TOOL_ID = "genomic_consensus.tasks.summarize_consensus"
    DRIVER_EXE = "summarizeConsensus --resolved-tool-contract "


def get_contract_parser():
    p = get_pbparser(
        Constants.TOOL_ID,
        __VERSION__,
        "Summarize Consensus",
        __doc__,
        Constants.DRIVER_EXE,
        default_level="ERROR")
    p.add_input_file_type(FileTypes.GFF, "alignment_summary",
        "Alignment summary GFF", "Alignment summary GFF file")
    p.tool_contract_parser.add_input_file_type(FileTypes.GFF, "variants",
        "Variants GFF", "Variants GFF file")
    p.arg_parser.parser.add_argument("--variantsGff",
        type=str,
        help="Input variants.gff or variants.gff.gz filename",
        required=True)
    p.tool_contract_parser.add_output_file_type(FileTypes.GFF, "output",
        name="Coverage and Variant Call Summary",
        description="Coverage and variant call summary for regions (bins) spanning the reference",
        default_name="alignment_summary_variants")
    p.arg_parser.parser.add_argument("-o", "--output",
        type=str,
        help="Output alignment_summary.gff filename")
    return p

def get_args_from_resolved_tool_contract(resolved_tool_contract):
    rtc = resolved_tool_contract
    p = get_contract_parser().arg_parser.parser
    args = [
        rtc.task.input_files[0],
        "--variantsGff", rtc.task.input_files[1],
        "--output", rtc.task.output_files[0],
    ]
    return p.parse_args(args)

def run(options):
    headers = [
        ("source", "GenomicConsensus %s" % __VERSION__),
        ("pacbio-alignment-summary-version", "0.6"),
        ("source-commandline", " ".join(sys.argv)),
        ]

    inputVariantsGff = GffReader(options.variantsGff)
    inputAlignmentSummaryGff = GffReader(options.alignment_summary)

    summaries = {}
    for gffRecord in inputAlignmentSummaryGff:
        region = Region(gffRecord.seqid, gffRecord.start, gffRecord.end)
        summaries[region] = { "ins" : 0,
                              "del" : 0,
                              "sub" : 0,
                              # TODO: base consensusQV on effective coverage
                              "cQv" : (20, 20, 20)
                             }
    inputAlignmentSummaryGff.close()

    counterNames = { "insertion"    : "ins",
                     "deletion"     : "del",
                     "substitution" : "sub" }
    regions_by_contig = defaultdict(list)
    for region in summaries:
        regions_by_contig[region.seqid].append(region)
    for seqid in regions_by_contig.keys():
        r = regions_by_contig[seqid]
        regions_by_contig[seqid] = sorted(r, lambda a,b: cmp(a.start, b.start))
    logging.info("Processing variant records")
    i = 0
    have_contigs = set(regions_by_contig.keys())
    for variantGffRecord in inputVariantsGff:
        if not variantGffRecord.seqid in have_contigs:
            raise KeyError(
                "Can't find alignment summary for contig '{s}".format(
                s=variantGffRecord.seqid))
        positions = [r.start for r in regions_by_contig[variantGffRecord.seqid]]
        idx = bisect.bisect_right(positions, variantGffRecord.start) - 1
        # XXX we have to be a little careful here - an insertion at the start
        # of a contig will have start=0 versus start=1 for the first region
        if idx < 0:
            idx = 0
        region = regions_by_contig[variantGffRecord.seqid][idx]
        assert ((region.start <= variantGffRecord.start <= region.end) or
                (region.start == 1 and variantGffRecord.start == 0 and
                 variantGffRecord.type == "insertion")), \
            (variantGffRecord.seqid, region.start, variantGffRecord.start,
             region.end, variantGffRecord.type, idx)
        summary = summaries[region]
        counterName = counterNames[variantGffRecord.type]
        variantLength = max(len(variantGffRecord.reference),
                            len(variantGffRecord.variantSeq))
        summary[counterName] += variantLength
        i += 1
        if i % 1000 == 0:
            logging.info("{i} records...".format(i=i))


    inputAlignmentSummaryGff = open(options.alignment_summary)
    outputAlignmentSummaryGff = open(options.output, "w")

    inHeader = True

    for line in inputAlignmentSummaryGff:
        line = line.rstrip()

        # Pass any metadata line straight through
        if line[0] == "#":
            print(line.strip(), file=outputAlignmentSummaryGff)
            continue

        if inHeader:
            # We are at the end of the header -- write the tool-specific headers
            for k, v in headers:
                print("##%s %s" % (k, v), file=outputAlignmentSummaryGff)
            inHeader = False

        # Parse the line
        rec = Gff3Record.fromString(line)

        if rec.type == "region":
            summary = summaries[(rec.seqid, rec.start, rec.end)]
            if "cQv" in summary:
                cQvTuple = summary["cQv"]
                line += ";%s=%s" % ("cQv", ",".join(str(int(f)) for f in cQvTuple))
            for counterName in counterNames.values():
                if counterName in summary:
                    line += ";%s=%d" % (counterName, summary[counterName])
            print(line, file=outputAlignmentSummaryGff)
    return 0


def args_runner(args):
    return run(options=args)


def resolved_tool_contract_runner(resolved_tool_contract):
    args = get_args_from_resolved_tool_contract(resolved_tool_contract)
    return run(options=args)


def main(argv=sys.argv):
    return pbparser_runner(argv[1:],
                           get_contract_parser(),
                           args_runner,
                           resolved_tool_contract_runner,
                           log,
                           setup_log)

if __name__ == "__main__":
    sys.exit(main())
