from collections import namedtuple, defaultdict
import logging
import math
import os
import re
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
import pkg_resources
from allele_info import is_user_defined_allele, MHCIAlleleData
from iedbtools_utilities.sequence_io import SequenceOutput

logger = logging.getLogger(__name__)


_executable_filename = \
    pkg_resources.resource_filename('netmhcpan_2_8_executable', 'netMHCpan')
EXECUTABLE_PATH = os.path.join(os.path.dirname(__file__), _executable_filename)

PredictionInput = namedtuple('PredictionInput', ['sequence', 'allele', 'binding_length'])

def predict_many(sequence_list, allele_length_2tuple_list, delete_tempfiles=True):
    """ | *brief*: Provides python interface to netmhcpan 3rd-party tool.
        | *author*: Jivan
        | *created*: 2015-11-18

        Performs a prediction for every sequence / allele-length combination.
        Returns a dictionary { (sequence, allele, length): [score, score, ...], ... }
        with one entry per prediction made.
    """
    sequence_filepath = sequence_list_to_fasta_file(sequence_list)
    to_delete = [sequence_filepath]
    try:
        all_scores = {}
        for allele, binding_length in allele_length_2tuple_list:
            # The call made to netmhcpan is different for a user-defined allele (hlaseq)
            #    vs an allele name.
            if is_user_defined_allele(allele):
                # A user-defined allele must be passed via a fasta file.
                t = NamedTemporaryFile(delete=False)
                fasta = SequenceOutput.to_fasta(allele)
                t.write(fasta)
                t.close()

                user_defined_allele_filepath = sequence_list_to_fasta_file([allele])
                cmd = [
                    EXECUTABLE_PATH, '-f', sequence_filepath,
                    '-hlaseq', user_defined_allele_filepath,
                    '-l', str(binding_length)
                ]
                get_scores_from_output = get_user_defined_allele_scores_from_netmhcpan_output
            else:
                # Temporary fix.
                stripped_allele_name = strip_allele_name(allele)
                cmd = [
                    EXECUTABLE_PATH, '-f', sequence_filepath,
                    '-ic50', '-a', stripped_allele_name,
                    '-l', str(binding_length)
                ]
                get_scores_from_output = get_allele_name_scores_from_netmhcpan_output

            logger.info('Executing: "{}"'.format(' '.join(cmd)))
            process = Popen(cmd, stdout=PIPE)
            stdoutdata, stderrdata_ignored = process.communicate()
            logger.debug('Raw output:\n{}'.format(stdoutdata))
            scores_by_sequence_idx = get_scores_from_output(stdoutdata)
            # Wrap the scores up nicely to return to the caller
            for seqidx, scores in scores_by_sequence_idx.items():
                sequence = sequence_list[seqidx]
                prediction_input = PredictionInput(sequence, allele, binding_length)
                all_scores[prediction_input] = scores
    finally:
        if delete_tempfiles:
            for filepath in to_delete:
                os.unlink(filepath)
    return all_scores

def strip_allele_name(allele_name):
    """ | *brief*: Temporary hack to get the allele name right for netmhcpan executable.
        | *author*: Dorjee
        | *created*: 2016-09-12

        TODO: A more permanent solution would be to create a column in the database for canonical allele name.
    """
    miad = MHCIAlleleData()
    species = miad.get_species_for_allele_name(allele_name=allele_name)
    if species in ['pig', 'macaque']:
        stripped_allele_name = allele_name.replace('*',':')
    else:
        stripped_allele_name = allele_name.replace('*', '')
    return stripped_allele_name    

def sequence_list_to_fasta_file(sequence_list):
    """ | *brief*: Writes the sequences in *sequence_list* as fasta sequences to a file and returns
        |    the filepath.
        | *author*: Jivan
        | *created*: 2015-11-18
    """
    t = NamedTemporaryFile(delete=False)
    for i, sequence in enumerate(sequence_list):
        t.write('>seq-{}\n'.format(i))
        t.write(sequence)
        t.write('\n\n')
    t.close()
    return t.name

def get_allele_name_scores_from_netmhcpan_output(netmhcpan_output):
    """ | *brief*: Parses the string *netmhcpan_output* for scores from an allele name
        |    prediction request.
        | *author*: Jivan
        | *created*: 2015-11-19
    """
    # Output lines of interest look like this for allele-name requets:
    #  10  HLA-A*01:01    AKLAEQAER           seq-0         0.006     46989.19    50.00
    #  11  HLA-A*01:01    KLAEQAERY           seq-0         0.255      3169.09     1.50 <= WB
    # We can get the sequence index and score from the 4th & 6th items.
    score_regex = r'^\s*\d+\s+\S+\s+\S+\s+seq-(\d+)\s+[\d\.]+\s+([\d\.]+).*$'
    p = re.compile(score_regex, flags=re.MULTILINE)
    matches = p.finditer(netmhcpan_output)
    scores = defaultdict(list)
    for m in matches:
        sequence_idx = int(m.group(1))
        ic50_score = float(m.group(2))
        scores[sequence_idx].append(ic50_score)
    return scores

def get_user_defined_allele_scores_from_netmhcpan_output(netmhcpan_output):
    """ | *brief*: Parses the string *netmhcpan_output* for resulting scores.
        | *author*: Dorjee
    """
    # Output lines of interest look like this for allele-name requets:
    #  10  HLA-A*01:01    AKLAEQAER           seq-0         0.006     46989.19    50.00
    #  11  HLA-A*01:01    KLAEQAERY           seq-0         0.255      3169.09     1.50 <= WB
    # We can get the sequence index and binding affinity from the 4th & 5th items.
    score_regex = r'^\s*\d+\s+\S+\s+\S+\s+seq-(\d+)\s+([\d\.]+).*$'
    p = re.compile(score_regex, flags=re.MULTILINE)
    matches = p.finditer(netmhcpan_output)
    scores = defaultdict(list)
    for m in matches:
        sequence_idx = int(m.group(1))
        binding_affinity = float(m.group(2))
        ic50_score = math.pow(50000, (1 - binding_affinity))
        scores[sequence_idx].append(ic50_score)
    return scores

    scores = []
    for lines in netmhcpan_output:
        if 'PEPLIST' in lines:
            data_list = lines.split()
            if data_list[0].isdigit():
            # TODO: is this a duplicate?
                if re.search("USER_DEF", data_list[1]):
                    peptide = data_list[2]
                    binding_affinity = float(data_list[4])
                    IC50_score = math.pow(50000, (1 - binding_affinity))
                    scores.append(IC50_score)
                else:
                    IC50_score = float(data_list[5])
                    scores.append(IC50_score)
    return scores

