"""
Created on Jan 29, 2016

@author: jivan
"""
import bisect
import logging
from allele_info import MHCIAlleleData
logger = logging.getLogger(__name__)
#from celery.contrib.methods import task
from celery_tasks import capp

def get_index_descend(score, ds_list):
    """
    >>> ds_for_key = [3.47, 1.94, 1.38, 1.11, 1.0, 0.89, 0.78, 0.72, 0.69, 0.65, 0.62, 0.6, 0.55, 0.54, 0.52, 0.51, 0.5, 0.48, 0.47, 0.47, 0.46, 0.45, 0.44, 0.44, 0.43, 0.42, 0.41, 0.41, 0.4, 0.39, 0.39, 0.38, 0.38, 0.37, 0.37, 0.36, 0.36, 0.35, 0.35, 0.34, 0.34, 0.34, 0.33, 0.33, 0.33, 0.33, 0.32, 0.32, 0.32, 0.31, 0.31, 0.31, 0.31, 0.31, 0.3, 0.3, 0.3, 0.3, 0.3, 0.29, 0.29, 0.29, 0.29, 0.29, 0.29, 0.28, 0.28, 0.28, 0.28, 0.28, 0.27, 0.27, 0.27, 0.27, 0.27, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.23, 0.22, 0.21, 0.21, 0.2, 0.19, 0.19, 0.18, 0.18, 0.18, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.09, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.07, 0.07, 0.07, 0.07, 0.07, 0.07]
    >>> score = 0.49
    """
    ds_list.sort(reverse=True)
    if len(ds_list)>=2:
        i = len(ds_list)//2
        if score <= ds_list[i]:
            return i+get_index_descend(score, ds_list[i:])
        else:
            return get_index_descend(score, ds_list[:i])
    else:
        return 0            

class MHCIPercentilesCalculator:
    """ Should consensus return only its scores, or those of other predictors as well? """

    # TODO: this list and the one for consensus method should point to a same single list (or query from db?)
    percentile = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9,
                  2, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9,
                  4, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9,
                  6, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7, 6.8, 6.9, 7, 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.8, 7.9,
                  8, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9, 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10,
                  11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
                  31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
                  51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
                  71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
                  91, 92, 93, 94, 95, 96, 97, 98, 99, 100]

    # Data is keyed with a tuple: (method_name, allele_name without '*' or with ':', binding_length)
    #    Except for netmhcpan, pickpocket & netmhccons which use the standard allele name.
    #    NetMHCCons also appears to be using pickpocket distributions.
    # TODO: Update the remaining percentile packages to use standard allele names.

    # for thoes percentile data which was recalculated and always keep 2 digits.
    percentile_280 = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 
                      0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 
                      0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 
                      0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 
                      0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 
                      1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2, 3.3, 3.4, 
                      3.5, 3.6, 3.7, 3.8, 3.9, 4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 
                      5.9, 6.0, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7, 6.8, 6.9, 7.0, 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.8, 7.9, 8.0, 8.1, 8.2, 
                      8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9.0, 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
                      19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 
                      49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 
                      79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]   

    def __init__(self, score_distributions):
        self._score_distributions = score_distributions

    def get_percentile_scores(self, scores, method_name, allele_name, binding_length):
        if method_name in ['netmhcpan', 'pickpocket', 'netmhccons', 'netmhcpan_el', 'netmhcpan_ba']:
            allele_name = self.strip_allele_name(allele_name)            
        else:
            allele_name = allele_name.replace('*', '')
        
        data_key = (method_name, allele_name, binding_length)


        # to keep the result as before.
        if data_key not in self._score_distributions:
            data_key = (method_name+'_ba', allele_name, binding_length)
            if data_key not in self._score_distributions:
                msg = "{} not found in the score distributions for {}.".format(data_key, method_name)
                logger.info(msg)
                return [None,] * len(scores)

        ds_for_key = self._score_distributions[data_key]
        # for thoes percentile data which was recalculated and always keep 2 digits.
        if len(ds_for_key) == 280:
            #print 'len(ds_for_key) == 280'
            percentile = self.percentile_280
        else:
            percentile = self.percentile

        percentile_scores = tuple(
            [self.percentile_scores(score, ds_for_key, percentile, method_name)
                for score in scores]
        )
        return percentile_scores

    def percentile_scores(self, score, ds_for_key, percentile, method_name):
        """ For each score in score_list, what percentage of the scores in score_distributions is worse?
        Smaller the score, the more significant. """

        if method_name == 'netmhcstabpan' or method_name == 'netmhcpan_el':
            index = get_index_descend(score, ds_for_key)
            p = percentile[index]

        else:
            right_dist_score = self.search(ds_for_key, score)
            if score not in ds_for_key:
                right_indx = ds_for_key.index(right_dist_score)
                p = percentile[right_indx]
            else:
                p = percentile[ds_for_key.index(score)]

        return p

    def strip_allele_name(self, allele_name):
        """ | *brief*: Temporary hack to get the allele name right for netmhcpan, pickpocket, netmhccons executable.
            | *author*: Dorjee
            | *created*: 2016-09-13
    
            TODO: A more permanent solution would be to create a column in the database for canonical allele name.
        """
        miad = MHCIAlleleData()
        species = miad.get_species_for_allele_name(allele_name=allele_name)
        if species in ['macaque', 'pig']:
            stripped_allele_name = allele_name.replace('*',':')
        else:
            stripped_allele_name = allele_name.replace('*', '')
        return stripped_allele_name    

    def search(self, a, x):
        'Find leftmost value greater than x'
        i = bisect.bisect_right(a, x)
        if i != len(a):
            return a[i]
        else:
            return a[i - 1]


class MHCIIPercentilesCalculator:
    '''Should consensus return only its scores, or those of other predictors as well?'''
    def __init__(self, percentile_manager):
        self._percentile_manager = percentile_manager

    def get_percentile_scores(self, scores, method_name, allele_name, binding_length):
        """ | *brief*: Returns the percentile equivalents for *scores*.
            | *author*: Jivan
            | *created*: 2016-01-27
        """
        score_distributions = \
            self._percentile_manager.get_distributions(allele_name)
        # The score distribution provides cutoff values for every 0.01 of a percentile.
#         key = (method_name, allele_name, binding_length)
        # ***TODO: For some reason a binding length of 9 passed when default lengths of 15 are
        #    being used.

        if method_name == 'netmhciipan':
            allele_name = self.strip_allele_name(allele_name)

        key = (method_name, allele_name, 15)

        score_distribution = score_distributions[key]

        # The index returned by bisect_left is the number of hundreths of a percentile below
        #    the percentile of the score.  Ensure values returned are always 0.01 or higher.
        if method_name == 'tepitope':
            # percentile ranks for tepitope method works in reverse, ie higher IC50 score means better binder
            percentile_list = [
                max((len(score_distribution)-bisect.bisect_left(score_distribution, score))*0.01, 0.01) for score in scores
            ]
        else:
            percentile_list = [
                max([ bisect.bisect_left(score_distribution, score) * 0.01, 0.01 ]) for score in scores
            ]

        percentile_scores = tuple(percentile_list)
        return percentile_scores

    @staticmethod
    def strip_allele_name(allele_name):
        # netmhciipan uses allele names with the '*', ':' and/or '_' stripped
        stripped_allele_name = allele_name.replace('*', '_').replace(':', '')
        if stripped_allele_name.startswith("H2") or stripped_allele_name.startswith("H-2"):
            stripped_allele_name = stripped_allele_name.replace("H2", "H-2")
        elif not allele_name.startswith("DRB"):
            stripped_allele_name = "HLA-%s" % stripped_allele_name.replace('_', '')
        return stripped_allele_name
