# _*_ coding:utf-8 _*_
'''
Created on Nov 9, 2016
@author: Yan
Mhcii predictor for all methods.
'''

# build-in 
import re
import math
from bisect import bisect, bisect_left
from decimal import Decimal
from collections import namedtuple, defaultdict
import logging

logger = logging.getLogger(__name__)
# bioinformatics tools
from allele_info import MHCIIAlleleData

from mhcii_tepitope_predictor import single_prediction_tepitope
from mhcii_comblib_predictor import single_prediction_comblib

from mhcii_tepitope_percentile_data import percentile_manager as mhcii_tepitope_percentile_manager
from mhcii_comblib_percentile_data import percentile_manager as mhcii_comblib_percentile_manager
#from mhcii_netmhciipan_percentile_data import percentile_manager as mhcii_netmhciipan_31_percentile_manager
#from mhcii_nnalign_percentile_data import percentile_manager as mhcii_nnalign_22_percentile_manager
from mhcii_smmalign_percentile_data import percentile_manager as mhcii_smmalign_percentile_manager

from mhcii_netmhciipan_32_percentile_data import percentile_manager as mhcii_netmhciipan_32_percentile_manager
from mhcii_netmhcii_2_3_percentile_data import percentile_manager as mhcii_nnalign_23_percentile_manager

from mhcii_netmhciipan_4_1_el_percentile_data import percentile_manager as mhcii_netmhciipan_41_el_percentile_manager
from mhcii_netmhciipan_4_1_ba_percentile_data import percentile_manager as mhcii_netmhciipan_41_ba_percentile_manager

# from the same package import for the doc test
from common_bio import Proteins
# a class have 4 property method_name, id, version, predictor
from method_version_info import MHCIIMethod

# constants or like constants
MHCIIPredictionParams = namedtuple('MHCIIPredictionParams',
                                   ['method_name', 'sequence', 'allele', 'coreseq_length', 'binding_length'])
AlleleLengthPair = namedtuple('AlleleLengthPair', ['allele', 'binding_length'])
METHOD_DICT = {'1': 'recommended', '2': 'consensus', '3': 'NetMHCIIpan', '4': 'nn_align', '5': 'smm_align',
               '6': 'comblib', '7': 'tepitope'}  # Yan 2016-09-02 We must be very careful here. Why No.3 is upperCase and all the others are lowerCase?
MHCII_AVAILABLE_LENGTHS = range(10, 31)

# To instead of local data files
from mhcii_predictor_data import get_method_allele_list

from netmhcii_11_executable import single_prediction as single_prediction_smm
from netmhciipan_3_2_executable import single_prediction as single_prediction_netmhciipan32
from netmhciipan_4_1_executable import single_prediction as single_prediction_netmhciipan41
from netmhcii_23_executable import single_prediction as single_prediction_nn23

Consensus20Predictor = None
Recommended20Predictor = None


class UnexpectedInputError(Exception):
    """Exception raised for errors in the input."""

    def __init__(self, value):
        self.value = value

    def __str__(self):
        return self.value


class MhciiPredictorBase(object):
    ''' All versions of method would be treated as a different method here. 
        Because they have different predictor package and percentile_rank_data package,
        and might have different result table.
    '''
    do_prediction = None

    def __init__(self, method, allele_list, length_list='default', version=None):
        self.method = method  # user-selected predictive method.        
        self.method_set_selected = []  # For 'iedb recommended': For alleles selected, different predictive methods may be associated.
        self.allele_list = allele_list
        self.tool_selection = [(allele_list, 9)]
        # Default binding length.        
        if (not length_list) or (length_list == 'default'):
            self.length_list = [15] * len(allele_list)
            # if 'asis', all lengths have to be asis.
        elif 'asis' in length_list:
            self.length_list = ['asis', ] * len(allele_list)
        else:
            self.length_list = map(int, length_list)
        self.version = version

    def get_method_set_selected(self):
        '''
        To be used for dynamic citation generation.
        To be used when 'iedb recommended' selected.
        Because, across a list of (mhc,length), different set of methods can be chosen.
        'method_set_selected' will be used when generating citation info for the methods.
        For now, the list is convoluted: ['ann,smm', 'ann,smm,comblib', 'netmhcpan']
        '''
        method_set_selected = []
        method_set_selected = list(set(self.method_set_selected))
        '''
        This is a hack. Assumes that if there is another method other than netmhciipan, then citations for that individual method as well as consensus will be generated.
        '''
        num_methods = len(method_set_selected)
        if 'netmhciipan' in method_set_selected:
            num_methods = num_methods - 1
        if num_methods >= 2:
            method_set_selected.append('consensus')
        method_set_selected.sort()
        return method_set_selected

    def strip_allele_name(self, allele_name):
        # netmhciipan uses allele names with the '*', ':' and/or '_' stripped
        stripped_allele_name = allele_name.replace('*', '_').replace(':', '')
        if stripped_allele_name.startswith("H2") or stripped_allele_name.startswith("H-2"):
            stripped_allele_name = stripped_allele_name.replace("H2", "H-2")
        elif not allele_name.startswith("DRB"):
            stripped_allele_name = "HLA-%s" % stripped_allele_name.replace('_', '')
        return stripped_allele_name

    def get_percentiles_for_scores(self, method_name, scores, allele_length_pair, method_version=None):
        """
        Returns the percentile scores for the raw scores passed.
        The default versions are the latest ones.
        """
        # TODO (jy): move this to subclass's __init__ as self.percentile_manager       
        method_name = method_name.lower()

        # TODO: modify it as method_version later
        # TODO(JY): update the default version to latest ones later.
        method_name = method_name.split('-')[0]
        if method_version:
            method_name_version = '%s %s' % (method_name, method_version)
        else:
            method_name_version = method_name

        if method_name_version in ['tepitope 1.0', 'tepitope']:
            percentile_manager = mhcii_tepitope_percentile_manager
        elif method_name_version in ['comblib 1.0', 'comblib']:
            percentile_manager = mhcii_comblib_percentile_manager

        elif method_name_version in ['netmhciipan 3.1', ]:
            percentile_manager = mhcii_netmhciipan_31_percentile_manager
        elif method_name_version in ['netmhciipan 3.2', 'netmhciipan']:
            percentile_manager = mhcii_netmhciipan_32_percentile_manager
        elif method_name_version in ['netmhciipan_el 4.1', 'netmhciipan_el', ]:
            percentile_manager = mhcii_netmhciipan_41_el_percentile_manager
        elif method_name_version in ['netmhciipan_ba 4.1', 'netmhciipan_ba', 'netmhciipan 4.1', 'netmhciipan', ]:
            percentile_manager = mhcii_netmhciipan_41_ba_percentile_manager
            # percentile_manager = mhcii_netmhciipan_32_percentile_manager
        elif method_name_version in ['nn_align 2.2', ]:
            percentile_manager = mhcii_nnalign_22_percentile_manager
        elif method_name_version in ['nn_align 2.3', 'nn_align']:
            percentile_manager = mhcii_nnalign_23_percentile_manager
            # percentile_manager = mhcii_nnalign_23_percentile_manager
        elif method_name_version in ['smm_align 1.1', 'smm_align']:
            percentile_manager = mhcii_smmalign_percentile_manager
        else:
            raise ValueError('unknown method name: %s' % method_name)

        allele, binding_length = allele_length_pair
        allele = str(allele)

        if method_name.startswith('netmhciipan') or method_name_version == 'nn_align 2.3':
            allele = self.strip_allele_name(allele)
        # TODO(JY 2018-11-26) use each binding_length instead of 15 

        key = (str(method_name), str(allele), binding_length)

        if method_name_version == 'nn_align 2.2':
            stripped_allele_name = allele.replace('*', '').replace(':', '')
            if stripped_allele_name.startswith("H2") or stripped_allele_name.startswith("H-2"):
                stripped_allele_name = stripped_allele_name.replace("H2", "H-2")
            else:
                stripped_allele_name = "HLA-%s" % stripped_allele_name.replace('_', '')
            key = (str(method_name), stripped_allele_name, binding_length)

        score_distributions = percentile_manager.get_distributions(allele)
        if key not in score_distributions: # for netmhcpan_ba
            key = (str(method_name)+'_ba', str(allele), binding_length)
        score_distribution = score_distributions[key]

        # The index returned by bisect_left is the number of hundreths of a percentile below
        # the percentile of the score.  Ensure values returned are always 0.01 or higher.
        if method_name in ['tepitope', 'netmhciipan_el']:
            # percentile ranks for tepitope method works in reverse, ie higher IC50 score means better binder
            # percentiles = [
            #    max((len(score_distribution)-bisect_left(score_distribution, score))*0.01, 0.01) for score in scores
            # ]
            # (JY) update it to scores_set 280 descent
            if len(score_distribution) != 280:
                raise ValueError('len(score_distribution)!=280, it is %s' % len(score_distribution))
            percentiles = []
            for score in scores:
                index = 280 - bisect_left(sorted(score_distribution), score)
                if index < 1:
                    percentiles.append(0.01)
                elif index < 100:
                    percentiles.append(index * 0.01)
                elif index < 190:
                    percentiles.append(1.0 + (index - 100) * 0.1)
                elif index <= 280:
                    percentiles.append(10 + (index - 190))

        elif len(score_distribution) == 280:
            percentiles = []
            for score in scores:
                index = bisect_left(sorted(score_distribution), score)
                if index < 1:
                    percentiles.append(0.01)
                elif index < 100:
                    percentiles.append(index * 0.01)
                elif index < 190:
                    percentiles.append(1.0 + (index - 100) * 0.1)
                elif index <= 280:
                    percentiles.append(10 + (index - 190))
                    # if en(score_distribution) == 10000
        else:
            percentiles = [
                max([bisect_left(score_distribution, score) * 0.01, 0.01]) for score in scores
            ]
        percentiles = tuple([float('%.2f' % x) for x in percentiles])
        return percentiles

    def get_old_results(self, method_name, results, input_sequences):
        converter = defaultdict(dict)
        # Prediction parameters & resulting core_sequence/score 2-tuples
        for pps, coreseq_score_2tuples in results.items():
            alp = AlleleLengthPair(pps.allele, pps.binding_length)
            coreseqs = [cs[0] for cs in coreseq_score_2tuples]
            scores = [cs[1] for cs in coreseq_score_2tuples]
            percentiles = self.get_percentiles_for_scores(method_name, scores, alp, self.version)
            # Get all the amino acid subsequences from sequence of length 'binding_length'
            converter[(pps.allele, pps.binding_length)].update({
                pps.sequence: {
                    'scores': scores,
                    'percentiles': percentiles,
                    'coreseqs': coreseqs
                }
            })
        old_results = []
        for (allele_name, binding_length), sequence_details in converter.items():
            # Generate the triplets (peptide, score, percentile) for the sequences in order.
            triplets = []
            for sequence in input_sequences:
                coreseqs = sequence_details[sequence]['coreseqs']
                scores = sequence_details[sequence]['scores']
                percentiles = sequence_details[sequence]['percentiles']
                sequence_triplets = zip(coreseqs, scores, percentiles)
                triplets.append(tuple(sequence_triplets))
            old_results.append((binding_length, allele_name, triplets))
        return old_results

    def do_tepitope_prediction(self, sequence_list, allele_length_2tuple_list, coreseq_len=9):
        method_name = 'tepitope'
        predictions = {}
        single_prediction_result = single_prediction_tepitope(sequence_list, allele_length_2tuple_list, coreseq_len)

        for key, value in single_prediction_result.items():
            pp = MHCIIPredictionParams(
                method_name=method_name,
                sequence=key[0],
                allele=key[1],
                coreseq_length=coreseq_len,
                binding_length=key[2],
            )
            predictions[pp] = value

        return predictions

    def do_comblib_prediction(self, sequences, allele_length_pairs, coreseq_len=9):
        method_name = 'comblib'
        single_prediction_result = single_prediction_comblib(sequences, allele_length_pairs, coreseq_len)
        analysis_results = {}
        for key, value in single_prediction_result.items():
            mpp = MHCIIPredictionParams(
                method_name=method_name,
                sequence=key[0],
                allele=key[1],
                coreseq_length=coreseq_len,
                binding_length=key[2],
            )

            analysis_results[mpp] = value

        return analysis_results

    def do_netmhciipan_prediction(self, sequence_list, allele_length_pairs, coreseq_len=9):
        method_name = 'netmhciipan'

        for allele, binding_length in allele_length_pairs:
            if binding_length > 30 or binding_length < 11:
                msg = 'Only a binding_length between 11-30 is supported, found {}' \
                    .format(binding_length)
                raise Exception(msg)
            if coreseq_len != 9:
                msg = 'Only a core sequence length of 9 is supported, found {}' \
                    .format(coreseq_len)
                raise Exception(msg)

        # Returns a dictionary { (sequence, allele, binding_length): [(core sequence, score), (core sequence, score), ...], ... }
        single_prediction_result = single_prediction_netmhciipan(sequence_list, allele_length_pairs)

        full_results = {}
        for key, value in single_prediction_result.items():
            mpp = MHCIIPredictionParams(
                method_name=method_name,
                sequence=key.sequence,
                allele=key.allele,
                coreseq_length=coreseq_len,
                binding_length=key.binding_length,
            )

            full_results[mpp] = value
        return full_results

    # merge this with do_netmhciipan_prediction later, retrive versin info from self.version
    def do_netmhciipan32_prediction(self, sequence_list, allele_length_pairs, coreseq_len=9):
        method_name = 'netmhciipan'

        for allele, binding_length in allele_length_pairs:
            if binding_length > 30 or binding_length < 11:
                msg = 'Only a binding_length between 11-30 is supported, found {}' \
                    .format(binding_length)
                raise Exception(msg)
            if coreseq_len != 9:
                msg = 'Only a core sequence length of 9 is supported, found {}' \
                    .format(coreseq_len)
                raise Exception(msg)

        # Returns a dictionary { (sequence, allele, binding_length): [(core sequence, score), (core sequence, score), ...], ... }
        single_prediction_result = single_prediction_netmhciipan32(sequence_list, allele_length_pairs)

        full_results = {}
        for key, value in single_prediction_result.items():
            mpp = MHCIIPredictionParams(
                method_name=method_name,
                sequence=key.sequence,
                allele=key.allele,
                coreseq_length=coreseq_len,
                binding_length=key.binding_length,
            )

            full_results[mpp] = value
        return full_results

    # merge this with do_netmhciipan_prediction later, retrive versin info from self.version
    def do_netmhciipan_41_el_prediction(self, sequence_list, allele_length_pairs, coreseq_len=9):

        method_name = 'netmhciipan'

        for allele, binding_length in allele_length_pairs:
            if binding_length > 30 or binding_length < 11:
                msg = 'Only a binding_length between 11-30 is supported, found {}' \
                    .format(binding_length)
                raise Exception(msg)
            if coreseq_len != 9:
                msg = 'Only a core sequence length of 9 is supported, found {}' \
                    .format(coreseq_len)
                raise Exception(msg)

        # Returns a dictionary { (sequence, allele, binding_length): [(core sequence, score), (core sequence, score), ...], ... }
        single_prediction_result = single_prediction_netmhciipan41(sequence_list, allele_length_pairs, el=True)

        full_results = {}
        for key, value in single_prediction_result.items():
            mpp = MHCIIPredictionParams(
                method_name=method_name,
                sequence=key.sequence,
                allele=key.allele,
                coreseq_length=coreseq_len,
                binding_length=key.binding_length,
            )

            full_results[mpp] = value
        return full_results

    # merge this with do_netmhciipan_prediction later, retrive versin info from self.version
    def do_netmhciipan_41_ba_prediction(self, sequence_list, allele_length_pairs, coreseq_len=9):

        method_name = 'netmhciipan'

        for allele, binding_length in allele_length_pairs:
            if binding_length > 30 or binding_length < 11:
                msg = 'Only a binding_length between 11-30 is supported, found {}' \
                    .format(binding_length)
                raise Exception(msg)
            if coreseq_len != 9:
                msg = 'Only a core sequence length of 9 is supported, found {}' \
                    .format(coreseq_len)
                raise Exception(msg)

        # Returns a dictionary { (sequence, allele, binding_length): [(core sequence, score), (core sequence, score), ...], ... }
        single_prediction_result = single_prediction_netmhciipan41(sequence_list, allele_length_pairs, el=False)

        full_results = {}
        for key, value in single_prediction_result.items():
            mpp = MHCIIPredictionParams(
                method_name=method_name,
                sequence=key.sequence,
                allele=key.allele,
                coreseq_length=coreseq_len,
                binding_length=key.binding_length,
            )

            full_results[mpp] = value
        return full_results

    def do_nnalign_prediction(self, sequences, allele_length_pairs, coreseq_len=9):

        method_name = 'nn_align'
        prediction_results = {}
        for seq in sequences:
            for allele, binding_length in allele_length_pairs:
                pps = MHCIIPredictionParams(
                    method_name=method_name,
                    sequence=seq,
                    allele=allele,
                    coreseq_length=coreseq_len,
                    binding_length=binding_length,
                )
                single_prediction_results = single_prediction_nn(seq, allele, binding_length)

                if single_prediction_results:
                    prediction_results[pps] = single_prediction_results[0]
        return prediction_results


    def do_nnalign23_prediction(self, sequences, allele_length_pairs, coreseq_len=9):

        method_name = 'nn_align'
        prediction_results = {}
        for seq in sequences:
            for allele, binding_length in allele_length_pairs:
                pps = MHCIIPredictionParams(
                    method_name=method_name,
                    sequence=seq,
                    allele=allele,
                    coreseq_length=coreseq_len,
                    binding_length=binding_length,
                )
                single_prediction_results = single_prediction_nn23(seq, allele, binding_length)

                if single_prediction_results:
                    prediction_results[pps] = single_prediction_results[0]
        return prediction_results

    def do_smmalign_prediction(self, sequence_list, allele_length_2tuple_list, coreseq_len=9):
        from netmhcii_11_executable import single_prediction as single_prediction_smm
        method_name = 'smm_align'
        prediction_results = {}
        for seq in sequence_list:
            for allele, binding_length in allele_length_2tuple_list:
                pps = MHCIIPredictionParams(
                    method_name=method_name,
                    sequence=seq,
                    allele=allele,
                    coreseq_length=coreseq_len,
                    binding_length=binding_length
                )
                single_prediction_results = single_prediction_smm(seq, allele, binding_length)

                prediction_results[pps] = single_prediction_results[0]
        return prediction_results

    def tool_location(self, allele, peplength):
        return ("/www/data/djangotools/MHCII/%s/%s-%s" % (self.method, allele, peplength))

    def comblib_tool_location(self, allele, peplength):
        return ("/www/data/djangotools/MHCII/%s/%s-%s" % (self.method, allele, peplength))

    def consensus_tool_location(self, allele, peplength, method):
        return ("/www/data/djangotools/MHCII/%s/%s-%s" % (method, allele, peplength))

    def consensus_perc_location(self, allele, method):
        return ("/www/data/djangotools/MHCII/%s/perc_%s" % (method, allele))

    def getMedian(self, values):
        if len(values) % 2 == 1:
            return values[(len(values) + 1) // 2 - 1]
        else:
            lower = values[len(values) // 2 - 1]
            upper = values[len(values) // 2]
            return float(lower + upper) / 2

    '''
    def predict(self, sequence_list):        
        raise TypeError("This base class is not suposed to be called. Please call its subclasses instead.")
    '''

    def predict(self, sequence_list):
        self.method_set_selected = [self.method]
        # logger.debug('self.allele_list=%s, self.length_list=%s' % (list(self.allele_list), list(self.length_list)))
        # logger.debug('self.allele_list=%s, self.length_list=%s' % (list(self.allele_list), list(self.length_list)))
        allele_length_2tuple_list = list(zip(list(self.allele_list), list(self.length_list)))
        logger.debug('allele_length_2tuple_list=%s' % list(allele_length_2tuple_list))
        r = self.do_prediction(sequence_list, allele_length_2tuple_list)
        logger.debug('r:%s' % r)
        results = self.get_old_results(self.method, r, sequence_list)
        logger.debug('results:%s' % results)
        return results


class ConsensusPredictor(MhciiPredictorBase):

    # TODO(JY): add version info judgement later. 
    def predict(self, sequence_list, length_list='default', version=None):
        """
        4 methods used in this function: do_smm_predict, do_nn_predict, do_comb_predict, do_tepitope_predict
        """
        consensus_version = version if version else '2.22'
        results = []
        self.method_set_selected = [self.method]

        # == Create a dictionary to be used when building a consensus predictor?
        con_status = {}
        con_list = get_method_allele_list('consensus').strip().split("\n")
        for con_element in con_list:
            con_arr = con_element.split("\t")
            con_status[con_arr[0]] = con_arr[1:]

        # retrive allele and length from allele_list instead of from self.tools_selection
        for allele, length in zip(self.allele_list, self.length_list):
            scores = []
            for sequence in sequence_list:
                seq_scores = []
                temp_scores = []
                method_list = []
                if con_status[allele][1] == "1":
                    allele_length_pairs = [(allele, length)]
                    comb_results = self.do_comblib_prediction([sequence], allele_length_pairs)
                    # The previous format used here is different from that returned
                    #    by MHCBindingPredictions (what ComblibPredictor.old_results()
                    #    converts to).  It is is simply the data values of the
                    #    current format for a single sequence.
                    comb_scores = list(comb_results.values())[0]
                    temp_scores.append(comb_scores)
                    method_list.append("comb.lib.")
                else:
                    method_list.append("-")

                if con_status[allele][2] == "1":
                    binding_length = length
                    allele_length_pairs = [(allele, binding_length)]
                    smm_results = \
                        self.do_smmalign_prediction([sequence], allele_length_pairs)
                    # Results should be for a single prediction
                    if len(smm_results) > 1: raise Exception('Expected a single smm result')
                    prediction_params = list(smm_results.keys())[0]
                    smm_scores = smm_results[prediction_params]
                    smmscores = tuple(smm_scores)

                    # TODO: Remove after SMMAlignPredictor settles.
                    #    Delete if still here & commented out after 2016-04-01

                    temp_scores.append(smmscores)
                    method_list.append("smm")

                    binding_length = length
                    allele_length_pairs = [(allele, binding_length)]
                    if consensus_version == '2.22':
                        prediction_results = self.do_nnalign23_prediction([sequence], allele_length_pairs)
                    else:
                        prediction_results = self.do_nnalign_prediction([sequence], allele_length_pairs)
                    # See predict_many() for result format.
                    if prediction_results.values():
                        nn_scores = list(prediction_results.values())[0]
                    else:
                        nn_scores = None
                    if not nn_scores:
                        method_list.append('-')
                    else:
                        nnscores = tuple(nn_scores)
                        method_list.append('nn')

                else:
                    blank = "-", "-"
                    method_list.extend(blank)

                if con_status[allele][1] == "0" and con_status[allele][3] == "1":
                    allele_length_pairs = [(allele, length)]
                    tepitope_result = self.do_tepitope_prediction([sequence], allele_length_pairs)
                    # The previous version of this only collected results for a single sequence
                    #    convert the results to the previous score format expected.
                    tepitope_scores = list(tepitope_result.values())[0]

                    temp_scores.append(tepitope_scores)
                    method_list.append("sturniolo")
                else:
                    method_list.append("-")

                logger.debug("List of available consensus methods: {}".format(
                    filter(None, map(lambda m: m.strip('-'), method_list))))

                check_len = 0
                for check_i in range(len(temp_scores)):
                    check_len += len(temp_scores[check_i]) - len(temp_scores[-1])

                if (check_len != 0):
                    if len(sequence) > 15:
                        raise ValueError("Methods return different number of predictions!!! %f." % len(temp_scores[1]))

                temp_list = temp_scores[0]
                for i in range(len(temp_list)):
                    temp_result = []
                    consensus_percs = []
                    for m_name in method_list:
                        method_version = None
                        if m_name == "comb.lib.":
                            (temp_core, temp_score) = comb_scores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'comblib'

                        elif m_name == "smm":
                            (temp_core, temp_score) = smmscores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'smm_align'

                        elif m_name == "nn":
                            (temp_core, temp_score) = nnscores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'nn_align'
                            if consensus_version == '2.22':
                                method_version = '2.3'
                            else:
                                method_version = '2.2'

                        elif m_name == "sturniolo":
                            (temp_core, temp_score) = tepitope_scores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'tepitope'

                        else:
                            blank = "-", "-", "-"
                            temp_result.extend(blank)
                            method_name = ''

                        if method_name:
                            perc = self.get_percentiles_for_scores(method_name, [temp_score], (allele, length),
                                                                   method_version)[0]
                            temp_result.append(perc)
                            consensus_percs.append(perc)

                    consensus_percs.sort()
                    consensus_perc = self.getMedian(consensus_percs)
                    temp_result.insert(0, consensus_perc)
                    seq_scores.append(tuple(temp_result))
                scores.append(tuple(seq_scores))
            results.append((length, allele, scores))
        return results


class RecommendedPredictor(MhciiPredictorBase):

    # def __init__(self, method, allele_list):
    #    super(RecommendedPredictor, self).__init__(method, allele_list)

    # TODO(JY): add version info judgement later. 
    def predict(self, sequence_list, length_list='default', version=None):
        """
        all 5 methods used: do_smm_predict, do_nn_predict, do_net2_predict, do_comb_predict, do_tepitope_predict
        """
        recommended_version = version if version else '2.22'
        results = []
        con_status = {}
        con_list = get_method_allele_list('recommended').strip().split('\n')
        '''
        File format of 'recommended_allele_list.txt':
        mhc comblib smm_align/nn_align tepitope netmhcpan
        H2-IAb    0    1    0    0
        '''
        for con_element in con_list:
            con_arr = con_element.split('\t')
            con_status[con_arr[0]] = con_arr[1:]
        '''
        For each mhc molecule selected by the user:
            Based on con_status:
                Initialize comblib if possible:
                Initialize smm_align if possible:
                Initialize tepitope if possible:
                Initialize netmhciipan if possible:
        '''
        # retrive allele and length from allele_list instead of from self.tools_selection
        for allele, length in zip(self.allele_list, self.length_list):
            ####################### TODO - Important ##########################
            # Allele names in recommended/recommended_allele_list.txt and or  #
            # file names of MHCII/<method name>.txt need to be changed to one #
            # common naming convention.                                       #
            ###################################################################

            scores = []
            for sequence in sequence_list:
                seq_scores = []
                temp_scores = []
                method_list = []
                if con_status[allele][0] == "1":
                    allele_length_pairs = [(allele, length)]
                    comb_results = self.do_comblib_prediction([sequence], allele_length_pairs)
                    # The previous format used here is different from that returned
                    #    by MHCBindingPredictions (what ComblibPredictor.old_results()
                    #    converts to).  It is is simply the data values of the
                    #    current format for a single sequence.
                    comb_scores = list(comb_results.values())[0]
                    temp_scores.append(comb_scores)
                    method_list.append("comb.lib.")
                    self.method_set_selected.append('comblib')  # For dynamic citation generation.
                else:
                    method_list.append("-")

                if con_status[allele][1] == "1":
                    binding_length = length
                    allele_length_pairs = [(allele, binding_length)]
                    smm_results = \
                        self.do_smmalign_prediction([sequence], allele_length_pairs)
                    # Results should be for a single prediction
                    if len(smm_results) > 1: raise Exception('Expected a single smm result')
                    prediction_params = list(smm_results.keys())[0]
                    smm_scores = smm_results[prediction_params]
                    smmscores = tuple(smm_scores)

                    # TODO: Remove after SMMAlignPredictor settles.
                    #    Delete if still here & commented out after 2016-04-01
                    #                             old_smm_scores = smm_predictor.peptide_predictions(sequence, allele)
                    #                             old_smmscores = tuple([a for a in old_smm_scores[0]])
                    temp_scores.append(smmscores)
                    method_list.append("smm")  # Using 'smm' for 'smm_align' is confusing? use 'smm_align' instead?
                    self.method_set_selected.append('smm_align')  # For dynamic citation generation.

                    binding_length = length
                    allele_length_pairs = [(allele, binding_length)]

                    if recommended_version == '2.22':
                        prediction_results = self.do_nnalign23_prediction([sequence], allele_length_pairs)
                    else:
                        prediction_results = self.do_nnalign_prediction([sequence], allele_length_pairs)

                    # See predict_many() for result format.
                    nn_scores = prediction_results.values() and list(prediction_results.values())[0]
                    if not nn_scores:
                        method_list.append('-')
                    else:
                        nnscores = tuple(nn_scores)
                        method_list.append('nn')
                        self.method_set_selected.append('nn_align')  # For dynamic citation generation.


                else:
                    blank = "-", "-"
                    method_list.extend(blank)

                '''
                The code segment below for netmhciipan appears to be called if:
                    comblib, smm_align/nn_align, and tepitope are **not** available as indicated by con_status:
                '''
                if con_status[allele][0] == "0" and con_status[allele][1] == "0" and con_status[allele][2] == "0" and \
                        con_status[allele][3] == "1":
                    #                         or con_status[allele][0] == "0" and con_status[allele][2] == "0" and con_status[allele][3] == "1" :
                    allele_length_pairs = [(allele, length)]

                    if recommended_version == '2.22':
                        prediction_results = self.do_netmhciipan32_prediction([sequence], allele_length_pairs)
                    else:
                        prediction_results = self.do_netmhciipan_prediction([sequence], allele_length_pairs)

                    # net2_scores is expected to be a list of 2-tuples (<coreseq>, <score>)
                    net2scores = list(prediction_results.values())[0]
                    # TODO: Remove this code if it's still here & commented-out
                    #    after 2016-04-01
                    #                             net2_scores = net2_predictor.peptide_predictions(sequence, allele)
                    #                             old_net2scores = tuple([a for a in net2_scores[0]])
                    temp_scores.append(net2scores)
                    method_list.append('NetMHCIIpan')
                    self.method_set_selected.append('netmhciipan')  # For dynamic citation generation.
                else:
                    method_list.append('-')

                if con_status[allele][0] == "0" and con_status[allele][2] == "1":
                    allele_length_pairs = [(allele, length)]
                    tepitope_result = self.do_tepitope_prediction([sequence], allele_length_pairs)
                    # The previous version of this only collected results for a single sequence
                    #    convert the results to the previous score format expected.
                    tepitope_scores = list(tepitope_result.values())[0]
                    #                             tepitope_scores = tepitope_predictor.consensus_predict(sequence)
                    temp_scores.append(tepitope_scores)
                    method_list.append('sturniolo')
                    self.method_set_selected.append('tepitope')  # For dynamic citation generation.
                else:
                    method_list.append('-')

                logger.debug("List of available IEDB recommended methods: {}".format(
                    filter(None, map(lambda m: m.strip('-'), method_list))))

                check_len = 0
                for check_i in range(len(temp_scores)):
                    check_len += len(temp_scores[check_i]) - len(temp_scores[-1])

                if (check_len != 0):
                    if len(sequence) > 15:
                        raise ValueError("Methods return different number of predictions!!! %f." % len(temp_scores[1]))

                temp_list = temp_scores[0]
                for i in range(len(temp_list)):
                    temp_result = []
                    consensus_percs = []
                    method_used = []
                    for m_name in method_list:
                        method_version = None
                        if m_name == "comb.lib.":  # comb.lib vs. comblib?
                            (temp_core, temp_score) = comb_scores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'comblib'

                        elif m_name == "smm":  # smm vs. smm_align?
                            (temp_core, temp_score) = smmscores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'smm_align'

                        elif m_name == "nn":  # nn vs. nn_align?
                            (temp_core, temp_score) = nnscores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'nn_align'
                            if recommended_version == '2.22':
                                method_version = '2.3'
                            else:
                                method_version = '2.2'

                        elif m_name == "NetMHCIIpan":
                            (temp_core, temp_score) = net2scores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'netmhciipan'
                            if recommended_version == '2.22':
                                method_version = '3.2'
                            else:
                                method_version = '3.1'

                        elif m_name == "sturniolo":  # sturniolo vs. tepitope? be consistent.
                            (temp_core, temp_score) = tepitope_scores[i]
                            temp_result.append(temp_core)
                            temp_result.append(round(temp_score, 2))
                            method_name = 'tepitope'

                        else:
                            blank = "-", "-", "-"
                            temp_result.extend(blank)
                            method_name = ''

                        if method_name:
                            perc = self.get_percentiles_for_scores(method_name, [temp_score], (allele, length),
                                                                   method_version)[0]
                            temp_result.append(perc)
                            consensus_percs.append(perc)

                        found = re.search(r'\-', m_name)

                        if not found:
                            method_used.append(m_name)
                    consensus_percs.sort()
                    consensus_perc = self.getMedian(consensus_percs)

                    temp_result.insert(0, consensus_perc)
                    method_used = "-".join([i for i in method_used])
                    temp_result.append(method_used)

                    seq_scores.append(tuple(temp_result))

                scores.append(tuple(seq_scores))

            results.append((length, allele, scores))
        return results


class SmmPredictor(MhciiPredictorBase):
    def __init__(self, method, allele_list, length_list='default', version=None):
        super(SmmPredictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_smmalign_prediction


class Nn23Predictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(Nn23Predictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_nnalign23_prediction


class NnPredictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(NnPredictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_nnalign_prediction


class Net32Predictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(Net32Predictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_netmhciipan32_prediction

class Net41ELPredictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(Net41ELPredictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_netmhciipan_41_el_prediction

class Net41BAPredictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(Net41BAPredictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_netmhciipan_41_ba_prediction

class NetPredictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(NetPredictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_netmhciipan_prediction


class ComblibPredictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(ComblibPredictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_comblib_prediction


class TepitopePredictor(MhciiPredictorBase):

    def __init__(self, method, allele_list, length_list='default', version=None):
        super(TepitopePredictor, self).__init__(method, allele_list, length_list, version)
        self.do_prediction = self.do_tepitope_prediction


##################################################################
class MhciiPredictor(object):
    """
    Created  on 2016-08-25.  Yan
    @brief: Class to generate MHC binding predictions result.      
    >>> p=Proteins(">TestProtein\\nFNCLGMSNRDFLEGVSG")
   
    >>> pre=MhciiPredictor('recommended', ['DRB1*01:01'])                                  
    >>> pre.predict(p.sequences)    
    [(15, 'DRB1*01:01', [((61.0, 'FNCLGMSNR', 1806.34, 64.0, 'CLGMSNRDF', 630.0, 61.0, 'CLGMSNRDF', 34.5, 19.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (84.0, 'NRDFLEGVS', 1000000.0, 92.0, 'CLGMSNRDF', 2085.0, 84.0, 'CLGMSNRDF', 124.3, 37.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (82.0, 'RDFLEGVSG', 1000000.0, 92.0, 'CLGMSNRDF', 1717.0, 82.0, 'CLGMSNRDF', 250.4, 49.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'))])]

    >>> pre=MhciiPredictor('recommended', ['DRB1*01:01'], version='2.22')                                  
    >>> pre.predict(p.sequences)    
    [(15, 'DRB1*01:01', [((61.0, 'FNCLGMSNR', 1806.34, 64.0, 'CLGMSNRDF', 630.0, 61.0, 'CLGMSNRDF', 34.5, 19.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (84.0, 'NRDFLEGVS', 1000000.0, 92.0, 'CLGMSNRDF', 2085.0, 84.0, 'CLGMSNRDF', 124.3, 37.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (82.0, 'RDFLEGVSG', 1000000.0, 92.0, 'CLGMSNRDF', 1717.0, 82.0, 'CLGMSNRDF', 250.4, 49.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'))])]
       
    >>> pre=MhciiPredictor('recommended', ['DRB1*01:01'], [16], version='2.22')                                  
    >>> pre.predict(p.sequences)    
    [(16, 'DRB1*01:01', [((64.0, 'FNCLGMSNR', 1806.34, 64.0, 'FNCLGMSNR', 1066.0, 72.0, 'CLGMSNRDF', 55.6, 28.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'), (87.0, 'RDFLEGVSG', 1000000.0, 92.0, 'CLGMSNRDF', 2595.0, 87.0, 'CLGMSNRDF', 228.4, 50.0, '-', '-', '-', '-', '-', '-', 'comb.lib.-smm-nn'))])]

    >>> pre=MhciiPredictor('consensus', ['DRB1*01:01'])                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [((61.0, 'FNCLGMSNR', 1806.34, 64.0, 'CLGMSNRDF', 630.0, 61.0, 'CLGMSNRDF', 34.5, 19.0, '-', '-', '-'), (84.0, 'NRDFLEGVS', 1000000.0, 92.0, 'CLGMSNRDF', 2085.0, 84.0, 'CLGMSNRDF', 124.3, 37.0, '-', '-', '-'), (82.0, 'RDFLEGVSG', 1000000.0, 92.0, 'CLGMSNRDF', 1717.0, 82.0, 'CLGMSNRDF', 250.4, 49.0, '-', '-', '-'))])]

    >>> pre=MhciiPredictor('consensus', ['DRB1*01:01'], [16])                                  
    >>> pre.predict(p.sequences)
    [(16, 'DRB1*01:01', [((64.0, 'FNCLGMSNR', 1806.34, 64.0, 'FNCLGMSNR', 1066.0, 72.0, 'CLGMSNRDF', 55.6, 28.0, '-', '-', '-'), (87.0, 'RDFLEGVSG', 1000000.0, 92.0, 'CLGMSNRDF', 2595.0, 87.0, 'CLGMSNRDF', 228.4, 50.0, '-', '-', '-'))])]


    >>> pre=MhciiPredictor('consensus', ['DRB1*01:01'], version='2.22')                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [((61.0, 'FNCLGMSNR', 1806.34, 64.0, 'CLGMSNRDF', 630.0, 61.0, 'CLGMSNRDF', 34.5, 19.0, '-', '-', '-'), (84.0, 'NRDFLEGVS', 1000000.0, 92.0, 'CLGMSNRDF', 2085.0, 84.0, 'CLGMSNRDF', 124.3, 37.0, '-', '-', '-'), (82.0, 'RDFLEGVSG', 1000000.0, 92.0, 'CLGMSNRDF', 1717.0, 82.0, 'CLGMSNRDF', 250.4, 49.0, '-', '-', '-'))])]

    >>> pre=MhciiPredictor('smm_align', ['DRB1*01:01'])                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [(('CLGMSNRDF', 630.0, 61.0), ('CLGMSNRDF', 2085.0, 84.0), ('CLGMSNRDF', 1717.0, 82.0))])]
    
    >>> pre=MhciiPredictor('nn_align', ['DRB1*01:01'])                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [(('CLGMSNRDF', 34.5, 19.0), ('CLGMSNRDF', 124.3, 37.0), ('CLGMSNRDF', 250.4, 49.0))])]

    >>> pre=MhciiPredictor('nn_align', ['DRB1*01:01'], version='2.3')                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [(('CLGMSNRDF', 338.7, 100.0), ('CLGMSNRDF', 970.4, 100.0), ('CLGMSNRDF', 2033.6, 100.0))])]
    
    >>> pre=MhciiPredictor('NetMHCIIpan', ['DRB1*01:01'])                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [(('CLGMSNRDF', 117.56, 30.39), ('CLGMSNRDF', 350.64, 50.6), ('CLGMSNRDF', 862.87, 66.74))])]

    >>> pre=MhciiPredictor('comblib', ['DRB1*01:01'])                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [(('FNCLGMSNR', 1806.3413388580746, 64.0), ('NRDFLEGVS', 1000000, 92.0), ('RDFLEGVSG', 1000000, 92.0))])]

    >>> pre=MhciiPredictor('nn_align', ['DRB1*01:01'], [17])    
    >>> pre.predict(p.sequences)  
    [(17, 'DRB1*01:01', [(('CLGMSNRDF', 89.6, 38.0),)])]

    >>> pre=MhciiPredictor('nn_align', ['DRB1*01:01'])                          
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:01', [(('CLGMSNRDF', 34.5, 19.0), ('CLGMSNRDF', 124.3, 37.0), ('CLGMSNRDF', 250.4, 49.0))])]

    >>> pre=MhciiPredictor('tepitope', ['DRB1*01:02'])                                  
    >>> pre.predict(p.sequences)
    [(15, 'DRB1*01:02', [(('FNCLGMSNR', -1.2999999999999998, 57.0), ('LGMSNRDFL', -2.1, 70.0), ('LGMSNRDFL', -2.1, 70.0))])]

    >>> pre=MhciiPredictor('input_error_example', ['DRB1*01:02'])    
    Traceback (most recent call last):
        ...
    ValueError: Selected prediction method "input_error_example" does not exist.
    """
    PREDICTORS = {
        'consensus': ConsensusPredictor,
        'recommended': RecommendedPredictor,
        'smm_align': SmmPredictor,
        'nn_align': NnPredictor,

        'NetMHCIIpan': Net41BAPredictor,
        'netmhciipan': Net41BAPredictor,

        'comblib': ComblibPredictor,
        'tepitope': TepitopePredictor,
        'NetMHCIIpan-3.1': NetPredictor,
        'NetMHCIIpan-3.2': Net32Predictor,
        'netmhciipan-3.1': NetPredictor,
        'netmhciipan-3.2': Net32Predictor,

        'NetMHCIIpan-4.1': Net41BAPredictor,
        'netmhciipan-4.1': Net41BAPredictor,

        'NetMHCIIpan_el-4.1': Net41ELPredictor,
        'NetMHCIIpan_ba-4.1': Net41BAPredictor,
        'netmhciipan_el-4.1': Net41ELPredictor,
        'netmhciipan_ba-4.1': Net41BAPredictor,

        'consensus-2.18': ConsensusPredictor,
        'recommended-2.18': RecommendedPredictor,
        # TODO(JY): implement the new version class and update the version info later.      
        'consensus-2.22': ConsensusPredictor,
        'recommended-2.22': RecommendedPredictor,
        'smm_align-1.1': SmmPredictor,
        'nn_align-2.2': NnPredictor,
        'nn_align-2.3': Nn23Predictor,
        'comblib-1.0': ComblibPredictor,
        'tepitope-1.0': TepitopePredictor,
    }

    def __init__(self, method, allele_list, length_list='default', version=None):
        mm = MHCIIMethod(method, version)
        self.method = mm.name()
        self.method_version = mm.version
        self.predictor = self.PREDICTORS['%s-%s' % (self.method, self.method_version)](self.method, allele_list,
                                                                                       length_list, self.method_version)

    def predict(self, sequence_list):
        logger.debug("Method used: {}".format(self.method))
        return self.predictor.predict(sequence_list)

    def get_method_set_selected(self):
        return self.predictor.method_set_selected

    ##################################################################


if __name__ == '__main__':
    print('doc test proceed...')
    import doctest

    doctest.testmod()
    print('done!')
