# import os, sys; MHCI_HOME_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)); sys.path.append(MHCI_HOME_DIR)

import os
import sys
import math
import bisect
from util import * #@UnusedWildImport

from netmhc_4_0_executable import predict as predict_netmhc
from netmhcpan_4_0_executable import predict_many as predict_netmhcpan
from pickpocket_1_1_executable import predict_single as predict_pickpocket
from netmhccons_1_1_executable import predict_sequence as predict_netmhccons
from netmhcstabpan_1_0_executable import predict as predict_netmhcstabpan

from mhci_netmhcpan_4_0_percentile_data import score_distributions as netmhcpan_4_score_distributions
from mhci_ann_predictor_percentile_data import score_distributions as netmhc_4_score_distributions
from mhci_netmhcstabpan_predictor_percentile_data import score_distributions as netmhcstabpan_score_distributions
from PercentilesCalculators import MHCIPercentilesCalculator

from iedbtools_utilities.sequence_io import FASTASequenceInput
from allele_info import MHCIAlleleData

class PredictorSet(object):
    """A set of predictors. An appropriate set of predictors should be loaded for each version."""
    def __init__(self, setupinfo):
        self.setupinfo = setupinfo
        self.path_method = setupinfo.path_method
        self.path_data = setupinfo.path_data
        self.version = setupinfo.version
        #self.dic_predictor  = self.get_dic_predictor()

    def get_dic_predictor(self):
        """Read only those predictors which current version support."""
        path_method = self.path_method
        path_data = self.path_data
        ms = MethodSet()
        method_list = ms.get_method_list(self.version)

        dic = {}
        for method in method_list:
            dic[method] = self.get_predictor(method, path_method, path_data)                      # Experimental
        return dic

    def get_predictor(self, method, method_used=None):
        path_method = self.path_method
        path_data = self.path_data
        predictor = None
        predictor = {
            'ann':               ANNPredictor(path_method, path_data),                               # Must be included.
            'smm':               SMMMatrix(path_method, path_data, method_name='smm', method_used=method_used),               # Must be included.
            'smmpmbec':          SMMMatrix(path_method, path_data, method_name='smmpmbec', method_used=method_used),          # Must be included.
            'arb':               ARBMatrix(path_method, path_data),                                  # Must be included.
            'netmhcpan':         NetMHCpanPredictor(path_method, path_data, method_used),                         # Experimental
            'comblib_sidney2008':CombinatorialLibrary(path_method, path_data, 'comblib_sidney2008'), # Must be included.
            'comblib_udaka2000': CombinatorialLibrary(path_method, path_data, 'comblib_udaka2000'),  # Must be included.
            'consensus':         ConsensusPredictor(self.setupinfo),                                 # Experimental
            'pickpocket':        PickPocketPredictor(path_method, path_data, method_used),
            'netmhccons':        NetMHCconsPredictor(path_method, path_data),
            'netmhcstabpan':        NetMHCStabPanPredictor(path_method, path_data),
        }[method]
        return predictor

    def get_method_name_list(self):
        key_list = self.dic_predictor.keys()
        key_list.sort()
        return key_list

    def get_predictor_list(self, method_list):
        predictor_list = []
        for method in method_list:
            if (self.dic_predictor.has_key(method) == True):
                predictor = self.dic_predictor[method]
                predictor_list.append(predictor)
        return predictor_list


class ARBMatrix(object):
    """Can load and save ARB matrices in multiple formats and use them to score sequences. """
    def __init__(self, path_method, path_data):
        # path dependent loading of files takes place in 'initialize'.
        self.slope= None
        self.intercept= None
        self.length= None
        self.mat={}

        self.path_method = os.path.join(path_method, 'arb')
        self.path_data = os.path.join(path_data, 'arb')

    def initialize(self, mhc, length):
        self.mhc = mhc
        self.length = length
        self.path_model = get_path_model(self.path_data, mhc, length)
        self.pickle_load(self.path_model)

    def get_score_unit(self):
        """The unit of prediction scores"""
        return 'IC50 (nM)'

    def predict_sequence(self,sequence,pred):
        """Given one protein sequence, break it up into peptides, return their predicted binding scores."""
        peptide_list = get_peptides(sequence, self.length)
        scores = self.predict_peptide_list(peptide_list)
        return scores

    def predict_peptide_list(self, peptide_list):
        """Given a list of peptides, return corresponding list of predicted binding scores."""
        scores=[]
        for peptide in peptide_list:
            score = 0.0
            for pos in range(self.length):
                amino_acid=peptide[pos]
                try:
                    score+=self.mat[amino_acid][pos]
                except:
                    raise PredictorError("""Invalid character '%c' in sequence '%s'.""" % (amino_acid, peptide))
            score/=-self.length
            score-=self.intercept
            score/=self.slope
            score=math.pow(10,score)
            if score < 0.0001:    # Cap predictable values
                score = 0.0001
            elif score > 1e6:
                score = 1e6
            scores.append(score)
        return (tuple(scores))

    def loadARBTrainOutput(self, infile):
        infile=infile.read()
        lines=infile.split("\n")
        self.mat.clear()
        self.length=len(lines[0].split())-1
        for line in lines[0:20]:
            entries = line.split()
            numbers = []
            for e in entries[1:]:
                numbers.append(math.log10(float(e)))
            if len(numbers)!=self.length:
                raise PredictorError("Invalid number of columns in ARB matrix: " + str(len(numbers)), " expected: " + str(self.length) + ".")
            self.mat[line[0]]=tuple(numbers)
        p = infile.find("SLOPE")
        self.slope = float(infile[p + 5:infile.find("\n",p+1)])
        p = infile.find("INTERCEPT")
        self.intercept = float(infile[p + 9:infile.find("\n",p+1)])

    def pickle_dump(self, file_name):
        fout=open(file_name,"wb")
        cPickle.dump(self.length, fout)
        cPickle.dump(self.mat,fout)
        cPickle.dump(self.slope,fout)
        cPickle.dump(self.intercept,fout)
        fout.close()

    def pickle_load(self, file_name):
        fin = open(file_name,"rb")
        self.length = cPickle.load(fin)
        self.mat = cPickle.load(fin)
        self.slope = cPickle.load(fin)
        self.intercept = cPickle.load(fin)
        fin.close()


class SMMMatrix:
    """Can load and save SMM matrices in multiple formats and use them to score sequences """
    def __init__(self, path_method, path_data, method_name='smm', method_used=None):
        self.offset = None
        self.length = None
        self.mat={}
        self.method_used = method_used
        self.path_method = os.path.join(path_method, method_name)
        self.path_data = os.path.join(path_data, method_name)

    def initialize(self, mhc, length):
        self.mhc = mhc
        self.length = length
        self.path_model = get_path_model(self.path_data, mhc, length)
        self.pickle_load(self.path_model)

    def get_score_unit(self):
        """The unit of prediction scores"""
        return 'IC50 (nM)'

    def predict_sequence(self,sequence,pred):
        """Given one protein sequence, break it up into peptides, return their predicted binding scores."""
        peptide_list = get_peptides(sequence, self.length)
        scores = self.predict_peptide_list(peptide_list)
        
        #get percentile scores
        args = ('smm', self.mhc.replace("*",""), self.length)
        ps = PercentileScore(os.path.dirname(self.path_data), 'consensus', args)
        percentile = ps.get_percentile_score(scores)
        return zip(scores, percentile)

    def predict_peptide_list(self, peptide_list):
        scores=[]
        for peptide in peptide_list:
            score=self.offset
            for pos in range(self.length):
                amino_acid=peptide[pos]
                try:
                    score+=self.mat[amino_acid][pos]
                except:
                    raise PredictorError("""Invalid character '%c' in sequence '%s'.""" % (amino_acid, peptide))
            score=math.pow(10,score)
            scores.append(score)
        return (tuple(scores))

    def load_text_file(self, infile):
        lines=infile.readlines()
        self.mat.clear()
        self.length=int(lines[0].split()[1])
        for line in lines[1:21]:
            entries = line.split()
            numbers = []
            for e in entries[1:]:
                numbers.append(float(e))
            if len(numbers)!=self.length:
                raise PredictorError("Invalid number of columns in SMM matrix: " + str(len(numbers)), " expected: " + str(self.length) + ".")
            self.mat[line[0]]=tuple(numbers)
        self.offset = float(lines[21])

    def save_text_file(self, outfile):
        outfile.write("NumCols:\t" + str(self.length) +"\n")
        for letter in sorted(self.keys()):
            outfile.write(letter)
            for val in self.mat[letter]:
                outfile.write("\t" + str(val))
            outfile.write("\n")
        outfile.write(str(self.offset))

    def pickle_dump(self, file_name):
        fout = open(file_name,"wb")
        cPickle.dump(self.length, fout)
        cPickle.dump(self.mat,fout)
        cPickle.dump(self.offset,fout)
        fout.close()

    def pickle_load(self, file_name):
        fin = open(file_name,"rb")
        self.length = cPickle.load(fin)
        self.mat = cPickle.load(fin)
        self.offset = cPickle.load(fin)
        fin.close()
        

class ANNPredictor:
    def __init__(self, path_method, path_data):
        """ Predictor for Artificial Neural Network (aka. ann). """
        self.path_data = os.path.join(path_data, 'ann')
        
    def initialize(self, mhc, length):
        self.mhc = mhc
        self.length = length

    def get_score_unit(self):
        """ The unit of prediction scores """
        return 'IC50 (nM)'

    def predict_sequence(self,sequence,pred):
        """ Given one protein sequence, break it up into peptides, return their predicted binding scores. """
        
        # Eliminate the asterisk (*) and colon (:) from the allele_name
        allele_name = self.mhc.replace('*', '').replace(':','')
        scores = predict_netmhc(allele_name, str(self.length), sequence)
        
        # get percentile scores
        args = ('ann', self.mhc.replace("*",""), self.length)
        ps = PercentileScore(os.path.dirname(self.path_data), 'consensus', args)
        percentile = ps.get_percentile_score(scores)

        # get percentile scores
        allele_length_pair = (self.mhc.replace("*",""), self.length)
        percentile = self.get_percentiles_for_scores(scores, allele_length_pair)
        return zip(scores, percentile)

    def get_percentiles_for_scores(self, raw_scores, allele_length_pair):
        ''' Returns the percentile scores for the raw scores passed. 
        '''
        percentiles_calculator = MHCIPercentilesCalculator(netmhc_4_score_distributions)
        allele, binding_length = allele_length_pair
        percentiles = percentiles_calculator.get_percentile_scores(
                        raw_scores, 'ann', allele, binding_length)
        return percentiles



class NetMHCpanPredictor:
    def __init__(self, path_method, path_data, method_used=None):
        self.path_data = os.path.join(path_data)
        self.method = method_used

    def initialize(self, mhc, length, hla_seq=None):
        self.mhc = mhc
        self.length = length
        self.hla_seq = hla_seq

    def predict_sequence(self, sequence, pred):
        """ Given one protein sequence, break it up into peptides, return their predicted binding scores. """
        if self.mhc != "User defined":
            allele_name_or_sequence = self.mhc
        else:
            si = FASTASequenceInput(self.hla_seq)
            allele_name_or_sequence = "".join(si.as_amino_acid_text())
        
        input_sequence_list = sequence.split()
        results = predict_netmhcpan(input_sequence_list, [(allele_name_or_sequence, self.length)])

        for scores in results.values():
            if self.mhc != "User defined":
                allele_length_pair = (allele_name_or_sequence, self.length)
                percentile = self.get_percentiles_for_scores(scores, allele_length_pair)
                return zip(scores, percentile)
            else:
                return scores
    def get_percentiles_for_scores(self, raw_scores, allele_length_pair):
        ''' Returns the percentile scores for the raw scores passed. 
        '''
        percentiles_calculator = MHCIPercentilesCalculator(netmhcpan_4_score_distributions)
        allele, binding_length = allele_length_pair
        try:
            percentiles = percentiles_calculator.get_percentile_scores(
                            raw_scores, 'netmhcpan', allele, binding_length)
        except ValueError:
            if is_user_defined_allele(allele_length_pair.allele):
                percentiles = [None for i in xrange(len(raw_scores))]
            else:
                raise
        return percentiles

class PickPocketPredictor:
    def __init__(self, path_method, path_data, method_used=None):
        self.path_data = os.path.join(path_data)
        self.method = method_used
        
    def initialize(self, mhc, length, hla_seq=None):
        self.mhc = mhc
        self.length = length
        self.hla_seq = hla_seq

    def predict_sequence(self, sequence, pred):
        """ Given one protein sequence, break it up into peptides, return their predicted binding scores. """
        if self.mhc != "User defined":
            allele_name_or_sequence = self.mhc
        else:
            si = FASTASequenceInput(self.hla_seq)
            allele_name_or_sequence = "".join(si.as_amino_acid_text())

        results = predict_pickpocket(sequence, (allele_name_or_sequence, self.length))

        if self.mhc != "User defined":
            args = ("pickpocket", self.mhc, int(self.length))
            ps = PercentileScore(self.path_data, "pickpocket", args)
            percentile = ps.get_percentile_score(results)
            return zip(results, percentile)
        else:
            return results


class NetMHCStabPanPredictor:
    def __init__(self, path_method, path_data):
        """ Predictor netMHCstabpan. """
        self.path_data = os.path.join(path_data, 'netmhcstabpan')
        
    def initialize(self, mhc, length):
        self.mhc = mhc
        self.length = length

    def predict_sequence(self,sequence,pred):
        """ Given one protein sequence, break it up into peptides, return their predicted binding scores. """
        scores = predict_netmhcstabpan(self.mhc, str(self.length), sequence)

        # get percentile scores

        allele_length_pair = (self.mhc, self.length)
        percentile = self.get_percentiles_for_scores(scores, allele_length_pair)
        return zip(scores, percentile)


    def get_percentiles_for_scores(self, raw_scores, allele_length_pair):
        ''' Returns the percentile scores for the raw scores passed. 
        '''
        percentiles_calculator = MHCIPercentilesCalculator(netmhcstabpan_score_distributions)
        allele, binding_length = allele_length_pair
        percentiles = percentiles_calculator.get_percentile_scores(
                        raw_scores, 'netmhcstabpan', allele, binding_length)
        return percentiles
    

class NetMHCconsPredictor:
    def __init__(self, path_method, path_data):
        """ Predictor netMHCcons. """
        self.path_data = os.path.join(path_data, 'netmhccons')

    def initialize(self, mhc, length, hla_seq=None):
        self.mhc = mhc
        self.length = length
        self.hla_seq = hla_seq

    def predict_sequence(self, sequence, pred):
        """ Given one protein sequence, break it up into peptides, return their predicted binding scores. """
        if self.mhc != "User defined":
            allele_name_or_sequence = self.mhc
        else:
            si = FASTASequenceInput(self.hla_seq)
            allele_name_or_sequence = "".join(si.as_amino_acid_text())

        scores = predict_netmhccons(sequence, (allele_name_or_sequence, self.length))

        # get percentile scores
        args = ('netmhccons', self.mhc, self.length)
        ps = PercentileScore(os.path.dirname(self.path_data), 'netmhccons', args)
        percentile = ps.get_percentile_score(scores)
        return zip(scores, percentile)

        # if self.mhc != "User defined":
        #     args = ("netmhccons", self.mhc, int(self.length))
        #     ps = PercentileScore(self.path_data, "netmhccons", args)
        #     percentile = ps.get_percentile_score(results)
        #     return zip(results, percentile)
        # else:
        #     return results
        
    
class CombinatorialLibrary:
    """Can load and save SMM matrices in multiple formats and use them to score sequences """
    def __init__(self, path_method, path_data, lib_source):
        self.dic_pssm = None
        self.offset = None
        self.length = None
        self.mat={}

        self.path_method = path_method
        self.path_data = path_data
        self.lib_source = lib_source

    def initialize(self, mhc, length):
        self.dic_pssm = self.read_pssm_comblib(self.lib_source)
        self.mhc = mhc
        self.length = length
        if re.search('H-2.*', self.mhc):
            i = re.search('(?<=\d)', self.mhc).start()
            key = (self.mhc[:i]+self.mhc[i:].replace("-","_"), self.length)
        else:
            key = (self.mhc.replace('-','_').replace('*','-').replace(':',''), self.length)
        
        w = self.dic_pssm[key]
        (self.mat, self.offset) = self.get_dic_mat(w)

    def get_score_unit(self):
        """The unit of prediction scores"""
        return 'Score'

    def predict_sequence(self,sequence,pred):   
        """Given one protein sequence, break it up into peptides, return their predicted binding scores."""
        peptide_list = get_peptides(sequence, self.length)
        scores = self.predict_peptide_list(peptide_list)
        
        #get percentile scores
        args = ('comblib_sidney2008', self.mhc.replace("*",""), self.length)
        ps = PercentileScore(self.path_data, 'consensus', args)
        percentile = ps.get_percentile_score(scores)
        return zip(scores, percentile)

    def predict_peptide_list(self, peptide_list):
        scores = []
        for peptide in peptide_list:
            score = self.offset
            for pos in range(self.length):
                amino_acid = peptide[pos]
                try:
                    score += self.mat[amino_acid][pos]
                except:
                    raise PredictorError("""Invalid character '%c' in sequence '%s'.""" % (amino_acid, peptide))
            score = math.pow(10,score)
            scores.append(score)
        return (tuple(scores))

    def read_data_cpickle(self,fname):
        f = open(fname,'r')
        data = cPickle.load(f)
        f.close()
        return data

    def read_pssm_comblib(self, lib_source):
        'Reads in all available pssms derived from combinatorial libraries.'
        factor = 1.0 # This will be multipled to all matrix elements.
        fname_sidney2008 = os.path.join(self.path_data,'comblib_sidney2008','dic_pssm_sidney2008.cPickle')
#         fname_udaka2000 = os.path.join(self.path_data,'comblib_udaka2000','dic_pssm_udaka2000.cPickle')
        dic_pssm_sidney2008 = self.read_data_cpickle(fname_sidney2008)
#         dic_pssm_udaka2000 = self.read_data_cpickle(fname_udaka2000)
        dic_pssm = None
        if (lib_source == 'comblib_sidney2008'):
            factor = -1.0
            dic_pssm = dic_pssm_sidney2008
#         elif (lib_source == 'comblib_udaka2000'):
#             factor = 1.0
#             dic_pssm = dic_pssm_udaka2000

        key_list = dic_pssm.keys()
        for key in key_list:
            w = dic_pssm[key]
            w = [factor*val for val in w]
            dic_pssm[key] = w
        return dic_pssm

    def get_dic_mat(self, w):
        'Converts 1-dimensional vector into a dictionary of lists key = [aa]'
        offset = w[0]
        dic_mat = {}
        aa_list = "ACDEFGHIKLMNPQRSTVWY"
        for aa_index in range(len(aa_list)):
            aa = aa_list[aa_index]
            row = []
            for pos_index in range(self.length):
                index = 1 + 20*pos_index + aa_index
                value = w[index]
                row.append(value)
            dic_mat[aa] = row
        return (dic_mat, offset)

    def pickle_dump(self, file_name):
        fout = open(file_name,"wb")
        cPickle.dump(self.length, fout)
        cPickle.dump(self.mat,fout)
        cPickle.dump(self.offset,fout)
        fout.close()

    def pickle_load(self, file_name):
        fin = open(file_name,"rb")
        self.length = cPickle.load(fin)
        self.mat = cPickle.load(fin)
        self.offset = cPickle.load(fin)
        fin.close()


class ConsensusPredictor(object):
    """Should consensus return only its scores, or those of other predictors as well?"""
    def __init__(self, setupinfo):
        self.setupinfo = setupinfo
        self.path_method = setupinfo.path_method
        self.path_data = setupinfo.path_data
        self.score_array = [None, None]  # (scores_predictor, scores_consensus) Holds predictor specific prediction scores.

    def initialize(self, allele_name, length):
        self.dic_score_distributions = self.read_score_distributions()
        self.dic_predictor = self.get_dic_predictor()
        self.predictor_selection = PredictorSelectionB(self.setupinfo)
        self.methods_used_for_consensus = ['ann', 'smm', 'comblib_sidney2008']
        
        self.mhc = allele_name
        self.length = length

        # Q: What methods are available for (allele_name,length)?
        # given an allele_name and a binding_length find which method(s) are available 
        self.available_method_list = self.predictor_selection.get_available_methods(allele_name, length, self.methods_used_for_consensus)
        self.predictor_list = [self.dic_predictor[method] for method in self.available_method_list]  # Get only those predictors for which (allele_name,length) is available.
        # Initialize older predictors requiring it.  TODO: This type of initialization shouldn't be
        #    required.  Alleles & the binding length are prediction parameters, not predictor
        #    parameters.  Move the passing of these parameters from .initialize() to .predict_*().
        for predictor in self.predictor_list:
            try:
                predictor.initialize(allele_name, length)
            except AttributeError as ex:
                # The ANN predictor no longer has an .initialize() method.
                if predictor.method_name == 'ann':
                    pass
                else:
                    raise ex

    def read_score_distributions(self):
        fname = os.path.join(self.path_data, 'consensus', 'distribution_consensus_bin.cpickle')
        f = open(fname, 'r')
        dic = cPickle.load(f)
        f.close()
        return dic

    def get_score_array(self):
        return self.score_array

    def get_score_unit(self):
        """The unit of prediction scores"""
        return 'Percentile'

    def predict_sequence(self,sequence, pred):
        scores_predictor = [] # Lower the score, the better.
        ic50scores = []
        for (predictor, method_name) in zip(self.predictor_list, self.available_method_list):
            self.mhc = self.mhc.replace("*","")
            key = (method_name, self.mhc, self.length)
            score_distribution = self.dic_score_distributions[key]
             
            if method_name == 'smm' or method_name == 'ann' or method_name == 'comblib_sidney2008':
                spercentile = predictor.predict_sequence(sequence,pred)
                scores = tuple([sp[0] for sp in spercentile])
                percentile = tuple([sp[1] for sp in spercentile])
                ic50scores.append(scores)
                scores_predictor.append(percentile)
            else:
                scores = predictor.predict_sequence(sequence,pred)
                ic50scores.append(scores)
                # here scores = individual scores for each of the methods
                scores_percentile = [self.get_percentile_score(score, score_distribution) for score in scores]  #range = [0....100]
                scores_predictor.append(tuple(scores_percentile))
         
        scores_consensus = []
        for i in range(len(scores_predictor[0])):
            score_row = [scores[i] for scores in scores_predictor]
            scores_consensus.append(median(score_row))
             
        self.score_array = (scores_predictor, scores_consensus)
        if pred == 'submit_processing':
            return tuple(scores_consensus)
        else:
            ic50_ranks = zip(ic50scores, scores_predictor)
            return tuple(scores_consensus), ic50_ranks    #ic50scores
        
    def search(self, a, x):
        'Find leftmost value greater than x'
        i = bisect.bisect_right(a, x)
        if i != len(a):
            return a[i]
        else:
            return a[i-1]
          
    def get_percentile_score(self, score, score_distributions):
        """For each score in score_list, what percentage of the scores in score_distributions is worse?
        Smaller the score, the more significant."""
        score_percentile = None
        percentile = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 
                      2, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 
                      4, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9, 
                      6, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7, 6.8, 6.9, 7, 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.8, 7.9, 
                      8, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9, 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10, 
                      11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 
                      31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 
                      51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 
                      71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 
                      91, 92, 93, 94, 95, 96, 97, 98, 99, 100]
        right_dist_score = self.search(score_distributions, score)
        
        if score not in score_distributions:
            right_indx = score_distributions.index(right_dist_score)
            return percentile[right_indx]
        else:
            return percentile[score_distributions.index(score)]

    def get_dic_predictor(self):
        path_method = self.path_method
        path_data = self.path_data
        dic = {}
        dic['ann']                = ANNPredictor(path_method, path_data) # Must be included.
        dic['smm']                = SMMMatrix(path_method, path_data) # Must be included.
        dic['arb']                = ARBMatrix(path_method, path_data) # Must be included.
        dic['netmhcpan']          = NetMHCpanPredictor(path_method, path_data) # Experimental
        dic['comblib_sidney2008'] = CombinatorialLibrary(path_method, path_data, 'comblib_sidney2008')
        dic['comblib_udaka2000']  = CombinatorialLibrary(path_method, path_data, 'comblib_udaka2000')
        #dic['consensus']          = ConsensusPredictor(path_method, path_data) # Experimental
        return dic


class PercentileScore:
    """Should consensus return only its scores, or those of other predictors as well?"""
    def __init__(self, path_data, method, args):
        self.path_data = path_data
        self.method = method

        # Data is stored with a key of (method_name, allele_name without '*', binding_length)
        self.key = (args[0], args[1].replace('*', ''), args[2])
        self.score_distributions = self.read_score_distributions()

    def get_percentile_score(self, scores):
        return tuple([self.scores(score) for score in scores])  # range = [0....100]

    def scores(self, score):
        """For each score in score_list, what percentage of the scores in score_distributions is worse?
        Smaller the score, the more significant."""
        score_percentile = None
        # TODO: this list and the one for consensus method should point to a same single list (or query from db?)
        percentile = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9,
                      2, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9,
                      4, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9,
                      6, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7, 6.8, 6.9, 7, 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.8, 7.9,
                      8, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9, 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10,
                      11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
                      31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
                      51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
                      71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
                      91, 92, 93, 94, 95, 96, 97, 98, 99, 100]

        if self.key in self.score_distributions:
            ds_for_key = self.score_distributions[self.key]
            right_dist_score = self.search(ds_for_key, score)
            if score not in ds_for_key:
                right_indx = ds_for_key.index(right_dist_score)
                return percentile[right_indx]
            else:
                return percentile[ds_for_key.index(score)]
        else:
            raise ValueError("{0} not found in the pickled file.".format(self.key))

    def search(self, a, x):
        'Find leftmost value greater than x'
        i = bisect.bisect_right(a, x)
        if i != len(a):
            return a[i]
        else:
            return a[i - 1]

    def read_score_distributions(self):
        fname = os.path.join(self.path_data, self.method, 'distribution_{0}_bin.cpickle'.format(self.method))
        f = open(fname, 'r')
        dic = cPickle.load(f)
        f.close()
        return dic
    
    
class MHCBindingPredictions:
    """A main class to access different MHC binding prediction methods.
       Also provides various types of checks such as whether a selected mhc molecule is accessible."""
    def __init__(self, input):
        #== The following set of variables should be cleaned up.
        self.version       = input.version
        self.method        = input.method    # User-selected predictive method.
        self.method_set_selected = []        # To be used when 'iedb recommended' used; methods used across a list of (mhc,length).
        self.species       = input.species
        self.length        = input.length
        self.proteins      = input.input_protein        # source of peptides to make binding predictions.
        self.proteins_mhc  = input.input_protein_mhc    # User-provided mhc sequence.
        self.setupinfo     = SetupInfo(version=self.version)
        self.path_data     = self.setupinfo.path_data
        self.predictor_set = PredictorSet(self.setupinfo)
        self.mhc           = input.mhc   # mhc allele
        self.hla_seq       = input.hla_seq   # user input mhc sequence
        self.freq          = input.freq  # boolean data type (true/false)
        self.negatives      = input.negatives  # boolean data type (true/false)
        self.duplicates    = input.duplicates  # list of duplicate allele-length pairs
        self.tool          = input.tool
        
        if(self.mhc == 'Allele'):
            self.mhc = None
        
        ps = PredictorSelectionB(self.setupinfo)
        self.tool_selection = []
        if self.hla_seq == '': self.hla_seq = None
        if self.hla_seq is None:
            self.tool_selection.extend(ps.get_tool_selection(self.method, self.mhc, self.length, self.tool))
            # for m, l in zip(self.mhc, self.length):
            #     self.tool_selection.extend(ps.get_tool_selection(self.method, m, l, self.tool))

    def get_score_unit(self):
        """The unit of prediction scores"""
        prediction_score_unit = 'ic50'
        if (self.method=='comblib_sidney2008'):
            prediction_score_unit = 'score'
        elif self.method=='consensus':
            prediction_score_unit = 'consensus_percentile_rank'
        elif self.method=='IEDB_recommended':
            prediction_score_unit = 'percentile_rank'
        return prediction_score_unit
    
    def get_method_set_selected(self, method):
        method_set_selected = []
        if method == 'IEDB_recommended' or method == 'consensus':
            for allele, length in self.tool_selection:
                method_set_selected.extend(self.method_lookup(allele, length))
            method_set_selected = list(set(method_set_selected))
            num_methods = len(method_set_selected)
            if num_methods >= 2: method_set_selected.append('consensus')
            method_set_selected.sort()
        else: method_set_selected.append(method)
        return method_set_selected

    def predict(self, sequence_list, pred=None):
        results = []
        
        if self.method == 'IEDB_recommended':
            for allele, length in self.tool_selection:
                method_name = ','.join(self.method_lookup(allele, length))
                # if 'recommended' is passed from MHC-I processing, overwrite the method_name to 'netmhcpan'
                if pred == 'submit_processing':
                    method_name = 'netmhcpan'
                     
                predictor = ''
                if method_name != 'netmhcpan':
                    predictor = self.predictor_set.get_predictor('consensus', 'IEDB_recommended')
                else:
                    predictor = self.predictor_set.get_predictor('netmhcpan', 'IEDB_recommended')
                
                # run the 'initialize' method for a specific predictor class
                predictor.initialize(allele, int(length))
                
                scores = []
                for sequence in sequence_list:
                    scores.append(predictor.predict_sequence(sequence,pred))
                results.append((length, allele, scores, method_name))
        else:
            #if 'IEDB recommended' is not selected, only one method will be chosen:
            predictor = self.predictor_set.get_predictor(self.method)
            
            self.method_set_selected = [self.method]
            if ((self.hla_seq is None) | (self.hla_seq == "")):
                """
                1. A predictive method was chosen.
                2. Loop over a set of (allele,length) alleles selected by the user.
                3. For each combination: 
                    Loop over each sequence and make predictions.
                4. Collect scores as a list.
                """
                for allele, length in self.tool_selection:
                    method_name = ','.join(self.method_lookup(allele, length))
                 
                    if not method_name:
                        method_name = 'netmhcpan'
 
                    if (self.method == 'netmhcpan'): predictor.initialize(allele, int(length), self.hla_seq)
                    else: predictor.initialize(allele, int(length))
 
                    scores = []
                    for sequence in sequence_list:
                        scores.append(predictor.predict_sequence(sequence,pred))
                    results.append((length, allele, scores, method_name))
            else:
                """
                1. If the user supplied his own mhc sequence:
                2. Loop over sequences:
                    For each sequence, make predictions using 'netmhcpan'?
                """
                mhc = 'User defined'
                length = (self.length).pop(0)
                predictor.initialize(mhc, int(length), self.hla_seq)
                method_name = "netmhcpan"
                scores = []
                for sequence in sequence_list:
                    scores.append(predictor.predict_sequence(sequence,pred))
                results.append((int(length), mhc, scores, method_name))
        
        return results
    
    
    def method_lookup(self, allele, length):
        consensus_methods = ['ann', 'smm', 'comblib_sidney2008', 'netmhcpan']
        miad = MHCIAlleleData()
        method_list = miad.get_method_names(allele_name=allele, binding_length=length)
        method_used = list(set(consensus_methods) & set(method_list))
        if len(method_used) > 1 and 'netmhcpan' in method_list:
            method_used.remove('netmhcpan')
        return method_used
