import os

from logging import getLogger

logger = getLogger(__name__)
from mhcii_predictor import MhciiPredictor
from length_rescaling import calculate_length_rescaled

class mhcII_SA():
    import os
    def pred(self, pred_dictionary, pred_outfile,
             sel_alleles="HLA-DRB1*03:01,HLA-DRB1*07:01,HLA-DRB1*15:01,HLA-DRB3*01:01,HLA-DRB3*02:02,HLA-DRB4*01:01,HLA-DRB5*01:01",
             sel_method="recommended 2.22"):
        sel_alleles="DRB1*03:01,DRB1*07:01,DRB1*15:01,DRB3*01:01,DRB3*02:02,DRB4*01:01,DRB5*01:01"
        seq = ""
        for (seq_name, pep_seq) in pred_dictionary.items():
            # seq+="%3E" + seq_name + "%0A" + pep_seq + "%0A"
            seq += ">" + seq_name + "\n" + pep_seq + "\n"
        sequence_list = [ pep_seq for (seq_name, pep_seq) in pred_dictionary.items() ]
            # seq+=pep_seq+"\n"
        # print len(pred_dictionary)
        # print seq
        # print len(seq)
        # sys.exit()
        r = get_result(method="recommended", method_version="2.22", input_sequences=sequence_list, allele_list=sel_alleles.split(','))
        if r is str:
            return {'err': 'The sequences dataset is too large for the Seven allele. Please split it and try again. Please contact us if you have questions.' }
        res = (r['content'])
        peptide_pred = open(pred_outfile, 'w')
        for result in res:
            row = '\t'.join(str(i) for i in result) + "\n"
            # print row
            peptide_pred.write(row)
        peptide_pred.close()
        # print pred_outfile



def get_result(method, method_version, input_sequences, allele_list):
    
    length_list = [15 for a in allele_list]

    if method == 'recommended' and method_version == '2023.05':
        method = 'netmhciipan_el'
        method_version = '4.1'
    elif method == 'netmhciipan' and method_version in ['4.0', '4.1', '4.2', '4.3']:
        method = 'netmhciipan_ba'

    sequences = input_sequences
    predictor = MhciiPredictor(method, allele_list, length_list, method_version)
    mhc_scores = predictor.predict(sequences)
    method_used = ','.join(predictor.get_method_set_selected())
    fpr = FormatedPredictionResult(method, sequences, mhc_scores)
    table_rows = fpr.result_rows

    if method == "recommended" or method == 'consensus':
        table_rows = sorted(table_rows, key=lambda tup: tup[5])
    elif method == 'tepitope' or method == 'netmhcpan_el':
        table_rows = sorted(table_rows, key=lambda tup: tup[6], reverse=True)
    else:
        table_rows = sorted(table_rows, key=lambda tup: tup[7])

    if method == 'recommended':
        table_rows = add_method_used(table_rows)

    unit = get_unit(method)

    if method == 'recommended':
        header = (
        'allele', 'seq_num', 'start', 'end', 'length', 'method', 'peptide', 'percentile_rank', unit, 'comblib_core',
        'comblib_score', 'comblib_rank', 'comblib_adjusted_rank', 'smm_align_core', 'smm_align_ic50',
        'smm_align_rank', 'smm_align_adjusted_rank', 'nn_align_core', 'nn_align_ic50', 'nn_align_rank',
        'nn_align_adjusted_rank', 'netmhciipan_core', 'netmhciipan_ic50', 'netmhciipan_rank',
        'netmhciipan_adjusted_rank', 'sturniolo_core', 'sturniolo_score', 'sturniolo_rank',
        'sturniolo_adjusted_rank')
    elif method == 'consensus':
        header = ('allele', 'seq_num', 'start', 'end', 'length', 'peptide', 'percentile_rank', unit, 'comblib_core',
                    'comblib_score', 'comblib_rank', 'comblib_adjusted_rank', 'smm_align_core', 'smm_align_ic50',
                    'smm_align_rank', 'smm_align_adjusted_rank', 'nn_align_core', 'nn_align_ic50', 'nn_align_rank',
                    'nn_align_adjusted_rank')
    elif method in ['netmhciipan_ba', 'netmhciipan_el']:
            header = ('allele', 'seq_num', 'start', 'end', 'length', 'core_peptide', 'peptide', unit, 'rank')

    elif method == 'netmhciipan' and (method_version in ["4.0", "4.1", "4.2", "4.3"]):
            header = ('allele', 'seq_num', 'start', 'end', 'length', 'core_peptide', 'peptide', unit, 'rank')
    else:
        header = (
        'allele', 'seq_num', 'start', 'end', 'length', 'core_peptide', 'peptide', unit, 'rank', 'adjusted_rank')
    if method not in ['netmhciipan_ba', 'netmhciipan_el']:
        if not (method == 'netmhciipan' and (method_version in ["4.0", "4.1", "4.2", "4.3"])):
            table_rows = add_column_adj_rank(table_rows, method)
    content = create_result_list(sigfig(table_rows, method), tuple(header))

    return dict(content=content)

def get_unit(method_selected):
    unit = ''
    if method_selected == 'recommended':
        unit = "adjusted_rank"
    elif method_selected == 'consensus':
        unit = "adjusted_rank"
    elif method_selected == 'tepitope' or method_selected == 'comblib' or method_selected == 'netmhciipan_el':
        unit = "score"
    else:
        unit = "ic50"
    return unit

def add_method_used(table_rows):
    formated_data = []
    for row in table_rows:
        lis = list(row)
        if '-' not in lis[-1]:
            lis.insert(5, lis[-1])
        else:
            lis.insert(5, "Consensus (" + lis[-1].replace("-", "/") + ")")
        del lis[-1]
        formated_data.append(tuple(lis))
    return formated_data

def add_column_adj_rank(test_data, method=None):
    result_list = []
    for row in test_data:

        if method == 'recommended':
            row = list(row)
            #            header = ('allele', 'seq_num', 'start', 'end', 'length', 'method', 'peptide',  'percentile_rank', unit, 'comblib_core', 'comblib_score', 'comblib_rank','comblib_adjusted_rank', 'smm_align_core', 'smm_align_ic50', 'smm_align_rank', 'smm_align_adjusted_rank', 'nn_align_core', 'nn_align_ic50', 'nn_align_rank', 'nn_align_adjusted_rank', 'netmhciipan_core', 'netmhciipan_ic50', 'netmhciipan_rank', 'netmhciipan_adjusted_rank', 'sturniolo_core', 'sturniolo_score', 'sturniolo_rank', 'sturniolo_adjusted_rank')
            for i in [22, 19, 16, 13, 10, 7]:
                # for each rank
                row.insert(i + 1, calculate_length_rescaled(length=row[4], rank=row[i]))
            result_list.append(tuple(row))
        elif method == 'consensus':
            row = list(row)
            # header = ('allele', 'seq_num', 'start', 'end', 'length', 'peptide',  'percentile_rank', unit, 'comblib_core', 'comblib_score', 'comblib_rank', 'comblib_adjusted_rank', 'smm_align_core', 'smm_align_ic50', 'smm_align_rank', 'smm_align_adjusted_rank', 'nn_align_core', 'nn_align_ic50', 'nn_align_rank', 'nn_align_adjusted_rank')
            for i in [18, 15, 12, 9, 6]:
                # for each rank
                row.insert(i + 1, calculate_length_rescaled(length=row[4], rank=row[i]))
            result_list.append(tuple(row))
        else:
            row = list(row)
            row.append(calculate_length_rescaled(length=row[4], rank=row[-1]))
            result_list.append(tuple(row))
    return result_list

def create_result_list(scores, header):
    result = []
    result.append(header)
    for score in scores:
        result.append(score)
    return result


def sigfig(table_rows, method):
    formated_scores = []
    logger.debug('method:%s' % method)

    for row in table_rows:
        lis = list(row)
        if method == 'recommended':

            for i in [6, 9, 10, 13, 14, 17, 18, 21, 22, 25, 26]:
                # for the added "length" column
                i = i + 1
                if lis[i] != '-': lis[i] = "%.2f" % lis[i]
        elif method == 'consensus':
            for i in [6, 9, 10, 13, 14, 17, 18, 21, 22]:
                # for the added "length" column
                if lis[i] != '-':  lis[i] = "%.2f" % lis[i]
        else:
            if lis[7] < 1:
                if lis[7] != '-': lis[7] = "%d" % lis[7] if isinstance(lis[7], int) else "%.4f" % lis[7]
            else:
                if lis[7] != '-': lis[7] = "%d" % lis[7] if isinstance(lis[7], int) else "%.2f" % lis[7]
        formated_scores.append(tuple(lis))
    return formated_scores

class FormatedPredictionResult(object):
    r"""
    Class to change the results' format.
    Generates a list of tuples as result from a set of sequences and predictions
    >>> p=Proteins('>TestProtein \nFNCLGMSNRDFLEGVSG')
    >>> scores=[(15, u'DRB1*01:01', [((50.25, 'FNCLGMSNR', 1806.34, 62.18, 'CLGMSNRDF', 630.0, 50.25, 'CLGMSNRDF', 34.5, 16.98, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (71.98, 'NRDFLEGVS', 1000000.0, 89.54, 'CLGMSNRDF', 2085.0, 71.98, 'CLGMSNRDF', 124.3, 33.31, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (68.79, 'RDFLEGVSG', 1000000.0, 89.54, 'CLGMSNRDF', 1717.0, 68.79, 'CLGMSNRDF', 250.4, 44.89, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'))])]
    >>> fpr = FormatedPredictionResult('recommended', p, scores)
    >>> fpr.result_rows
    [(u'HLA-DRB1*01:01', 1, 1, 15, 'FNCLGMSNRDFLEGV', 50.25, 'FNCLGMSNR', 1806.34, 62.18, 'CLGMSNRDF', 630.0, 50.25, 'CLGMSNRDF', 34.5, 16.98, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (u'HLA-DRB1*01:01', 1, 2, 16, 'NCLGMSNRDFLEGVS', 71.98, 'NRDFLEGVS', 1000000.0, 89.54, 'CLGMSNRDF', 2085.0, 71.98, 'CLGMSNRDF', 124.3, 33.31, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (u'HLA-DRB1*01:01', 1, 3, 17, 'CLGMSNRDFLEGVSG', 68.79, 'RDFLEGVSG', 1000000.0, 89.54, 'CLGMSNRDF', 1717.0, 68.79, 'CLGMSNRDF', 250.4, 44.89, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn')]
    """

    def __init__(self, method, protein_sequences, results, seq_nums=None):
        self.result_rows = []
        self.pred_method = method
        self.protein_sequences = protein_sequences
        self.results = results
        self.seq_nums = seq_nums
        self.format_binding()

    def add_rows_binding(self, allele, pep_length, scores, protein_sequences):
        seq_nums = self.seq_nums
        # TODO(JY): update 15 with pep_length
        if self.pred_method == "consensus" or self.pred_method == "recommended":
            for (i, (sequence, predictions)) in enumerate(zip(protein_sequences, scores)):
                dummy = []
                for (k, (dummy)) in enumerate(predictions):
                    peptide_sequence = sequence[k: k + pep_length]
                    peptide_source = seq_nums[i] if seq_nums else (i + 1)
                    peptide_start = k + 1
                    peptide_end = k + pep_length

                    dummy_list = []
                    dummy_list.append(allele)
                    dummy_list.append(peptide_source)
                    dummy_list.append(peptide_start)
                    dummy_list.append(peptide_end)
                    dummy_list.append(pep_length)
                    dummy_list.append(peptide_sequence)
                    for item in dummy:
                        dummy_list.append(item)
                    self.result_rows.append(tuple(dummy_list))
        else:
            for (i, (sequence, predictions)) in enumerate(zip(protein_sequences, scores)):
                for (k, (core, prediction, rank)) in enumerate(predictions):
                    peptide_sequence = sequence[k: k + pep_length]
                    core_sequence = core
                    try:
                        temp_index = peptide_sequence.index(core_sequence)
                    except:
                        logger.warning("Core sequence and peptide sequence doesn't match.")
                        # TODO(JY): uncomment this:
                        # raise ValueError("Core sequence and peptide sequence doesn't match.")

                    peptide_source = seq_nums[i] if seq_nums else (i + 1)
                    peptide_start = k + 1
                    peptide_end = k + pep_length
                    self.result_rows.append((allele, peptide_source, peptide_start, peptide_end, pep_length,
                                             core_sequence, peptide_sequence, prediction, rank))

    def format_binding(self):
        self.result_rows = []
        for (pep_length, allele, scores) in self.results:
            if 'H2' not in allele and 'BoLA' not in allele:
                allele = "HLA-%s" % allele.replace("-", "/")
            self.add_rows_binding(allele, pep_length, scores, self.protein_sequences)
        return self.result_rows