# _*_ coding:utf-8 _*_
import logging
get_method_allele_list = None # temporary 
from mhcii_predictor_data import get_method_allele_list
standard_time_dict = {
        ('mhci', 'ann'): (233,6.7,0.0336),
        ('mhci', 'smm'): (239.9,0.01,0.04),
        ('mhci', 'recommended'): (463,216,0.0108),
        ('mhci', 'consensus'): (306,5.6,0.053),
        ('mhci', 'netmhcpan'): (257,80.5,0.0296),
        ('mhci', 'smmpmbec'): (239.1,0.01,0.02),
        ('mhci', 'comblib_sidney2008'): (236.5,0.000002,0),
        ('mhci', 'pickpocket'): (237,5.1,0),
        ('mhci', 'netmhccons'): (250,90.3,0.0021),
        ('mhci', 'netmhcstabpan'): (251,46,0.0017),
        ('mhcii', 'comblib'): (15.25,0.001525,0.035),
        ('mhcii', 'consensus'): (19.8,10.1,0.077),
        ('mhcii', 'NetMHCIIpan'): (18.3,175,0.99),
        ('mhcii', 'netmhciipan'): (18.3,175,0.99),
        ('mhcii', 'nn_align'): (15.1,5.1,0.081),
        ('mhcii', 'recommended'): (21.8,10.5,0.051),
        ('mhcii', 'smm_align'): (15,5.2,0.066),
        ('mhcii', 'tepitope'): (15.5,0.00162,0.027),        
        #('mhcii', 'recommended2'): (18,180,1.7),
        ('mhci', '4'): (233,6.7,0.0336),
        ('mhci', '6'): (239.9,0.01,0.04),
        ('mhci', '1'): (463,216,0.0108),
        ('mhci', '2'): (306,5.6,0.053),
        ('mhci', '3'): (257,80.5,0.0296),
        ('mhci', '5'): (239.1,0.01,0.02),
        ('mhci', '7'): (236.5,0.000002,0),
        ('mhci', '9'): (237,5.1,0),
        ('mhci', '10'): (250,90.3,0.0021),
        ('mhci', '11'): (251,46,0.0017),
        ('mhcii', '6'): (15.25,0.001525,0.035),
        ('mhcii', '2'): (19.8,10.1,0.077),
        ('mhcii', '3'): (18.3,175,0.99),
        ('mhcii', '4'): (15.1,5.1,0.081),
        ('mhcii', '1'): (21.8,10.5,0.051),
        ('mhcii', '5'): (15,5.2,0.066),
        ('mhcii', '7'): (15.5,0.00162,0.027), 
}

ann_time_dict = {
        '8': (233,5.5,0.1),
        '9': (233,6.7,0.0336),
        '10': (233,6.7,0.1),
        '11': (233,6.1,0.255),
        '12': (233,4.3,0.49),
        '13': (233,3.3,0.77),
        '14': (233,0.96,1.13),
}

netmhcpan_time_dict = {
        '8': (257,80.5,0.4),
        '9': (257,80.5,0.0296),
        '10': (257,80.5,0.44),
        '11': (257,80.5,1.16),
        '12': (257,72.5,2.26),
        '13': (257,66.5,3.71),
        '14': (257,53,5.5),
}

def split_int_to_list(input_int, split_num):
    """
    split a digit into a certain number of digit
    >>> split_int_to_list(8, 2)
    [4, 4]
    >>> split_int_to_list(8, 3)
    [3, 3, 2]
    >>> split_int_to_list(8, 10)
    [1, 1, 1, 1, 1, 1, 1, 1]
    >>> split_int_to_list(0, 10)
    []
    >>> split_int_to_list(10, 0)
    Traceback (most recent call last):
        ...
    ValueError: Wrong split_num: "0".

    """
    if type(input_int)!=int or input_int<0:
        raise ValueError('Wrong input_int: "%s".' % input_int)
    if type(split_num)!=int or split_num<1:
        raise ValueError('Wrong split_num: "%s".' % split_num)

    split_num_list = [input_int//split_num]*split_num
    for i in range(input_int%split_num):
        split_num_list[i] += 1
    return list(filter(None, split_num_list))
    

def split_list(input_list, split_num):
    """
    split a list into a certain number of lists
    >>> split_list(range(8), 2)
    [[0, 1, 2, 3], [4, 5, 6, 7]]
    >>> split_list(range(8), 3)
    [[0, 1, 2], [3, 4, 5], [6, 7]]
    >>> split_list(range(8), 10)
    [[0], [1], [2], [3], [4], [5], [6], [7]]
    >>> split_list([], 10)
    []
    >>> split_list(range(8), 0)
    Traceback (most recent call last):
        ...
    ValueError: Wrong split_num: "0".
    """
    list_length = len(input_list)
    if split_num >= list_length:
        result_list = [[i] for i in input_list]
    else:
        split_num_list = split_int_to_list(list_length, split_num)
        result_list = []
        i = 0
        for split_num in split_num_list:
            result_list.append(input_list[i:i+split_num])
            i = i+split_num
    return result_list


def split_prediction_request_into_number(sequence_list, allele_length_2tuple_list, jobs_num='all'):
    """ split request parameters with certain size of batch
    >>> sequence_list = [i for i in '0123456789']
    >>> allele_length_2tuple_list = sequence_list = [i for i in '0123456789']
    >>> kwargs_list = split_prediction_request_into_number(sequence_list, allele_length_2tuple_list, jobs_num=8)
    >>> kwargs_list
    [(['0', '1'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['2', '3'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['4'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['5'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['6'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['7'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['8'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['9'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])]
    >>> len(kwargs_list)
    8
    >>> kwargs_list = split_prediction_request_into_number(sequence_list, allele_length_2tuple_list, jobs_num=15)
    >>> kwargs_list
    [(['0'], ['0', '1', '2', '3', '4']), (['0'], ['5', '6', '7', '8', '9']), (['1'], ['0', '1', '2', '3', '4']), (['1'], ['5', '6', '7', '8', '9']), (['2'], ['0', '1', '2', '3', '4']), (['2'], ['5', '6', '7', '8', '9']), (['3'], ['0', '1', '2', '3', '4']), (['3'], ['5', '6', '7', '8', '9']), (['4'], ['0', '1', '2', '3', '4']), (['4'], ['5', '6', '7', '8', '9']), (['5'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['6'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['7'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['8'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']), (['9'], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])]
    >>> len(kwargs_list)
    15
    >>> kwargs_list = split_prediction_request_into_number(sequence_list, allele_length_2tuple_list, jobs_num='all')
    >>> kwargs_list2 = split_prediction_request_into_number(sequence_list, allele_length_2tuple_list, jobs_num=200)
    >>> kwargs_list
    [(['0'], ['0']), (['0'], ['1']), (['0'], ['2']), (['0'], ['3']), (['0'], ['4']), (['0'], ['5']), (['0'], ['6']), (['0'], ['7']), (['0'], ['8']), (['0'], ['9']), (['1'], ['0']), (['1'], ['1']), (['1'], ['2']), (['1'], ['3']), (['1'], ['4']), (['1'], ['5']), (['1'], ['6']), (['1'], ['7']), (['1'], ['8']), (['1'], ['9']), (['2'], ['0']), (['2'], ['1']), (['2'], ['2']), (['2'], ['3']), (['2'], ['4']), (['2'], ['5']), (['2'], ['6']), (['2'], ['7']), (['2'], ['8']), (['2'], ['9']), (['3'], ['0']), (['3'], ['1']), (['3'], ['2']), (['3'], ['3']), (['3'], ['4']), (['3'], ['5']), (['3'], ['6']), (['3'], ['7']), (['3'], ['8']), (['3'], ['9']), (['4'], ['0']), (['4'], ['1']), (['4'], ['2']), (['4'], ['3']), (['4'], ['4']), (['4'], ['5']), (['4'], ['6']), (['4'], ['7']), (['4'], ['8']), (['4'], ['9']), (['5'], ['0']), (['5'], ['1']), (['5'], ['2']), (['5'], ['3']), (['5'], ['4']), (['5'], ['5']), (['5'], ['6']), (['5'], ['7']), (['5'], ['8']), (['5'], ['9']), (['6'], ['0']), (['6'], ['1']), (['6'], ['2']), (['6'], ['3']), (['6'], ['4']), (['6'], ['5']), (['6'], ['6']), (['6'], ['7']), (['6'], ['8']), (['6'], ['9']), (['7'], ['0']), (['7'], ['1']), (['7'], ['2']), (['7'], ['3']), (['7'], ['4']), (['7'], ['5']), (['7'], ['6']), (['7'], ['7']), (['7'], ['8']), (['7'], ['9']), (['8'], ['0']), (['8'], ['1']), (['8'], ['2']), (['8'], ['3']), (['8'], ['4']), (['8'], ['5']), (['8'], ['6']), (['8'], ['7']), (['8'], ['8']), (['8'], ['9']), (['9'], ['0']), (['9'], ['1']), (['9'], ['2']), (['9'], ['3']), (['9'], ['4']), (['9'], ['5']), (['9'], ['6']), (['9'], ['7']), (['9'], ['8']), (['9'], ['9'])]
    >>> kwargs_list == kwargs_list2
    True
    >>> len(kwargs_list)
    100
    """        
    # JY(2017-03-09): use a new parameter to control the number of jobs splitted
    if jobs_num != 'all' and (type(jobs_num) != int or jobs_num < 1):
        raise ValueError('Wrong jobs numbers: "%s".' % jobs_num) 
    split_request = []
    num_of_allele = len(allele_length_2tuple_list)
    num_of_seq = len(sequence_list)
    if jobs_num == 'all' or num_of_allele*num_of_seq <= jobs_num:
        for sequence in sequence_list:
            for allele_length_2tuple in allele_length_2tuple_list:
                split_request.append(([sequence], [allele_length_2tuple]))        

    elif num_of_seq >= jobs_num:
        size = num_of_seq//jobs_num
        for sequences in split_list(sequence_list, jobs_num):
            split_request.append((sequences, allele_length_2tuple_list))
    else:
        list_of_split_nums_of_2nd_list = split_int_to_list(jobs_num, num_of_seq)
        for i in range(num_of_seq):
            for alleles in split_list(allele_length_2tuple_list, list_of_split_nums_of_2nd_list[i]):
                split_request.append((sequence_list[i:i+1], alleles))
    return split_request

def get_estimated_prediction_time_cost(sequence_list=None, allele_list=None, mhc='mhci', method='recommended', sequence_length_list=None, lengths=None):
    """
    >>> seq_list = ['MGQIVTMFEALPHIIDEVINIVIIVLIVITGIKAVYNFATCGIFALISFLLLAGRSCGM',]
    >>> allele_list = ['DRB3*01:01',]
    >>> print get_estimated_prediction_time_cost(seq_list, allele_list,'mhcii')
    0.52895
    >>> print get_estimated_prediction_time_cost(sequence_length_list=[59,], allele_list=allele_list,mhc='mhcii')
    0.52895
    >>> allele_list = ['DRB3*02:02',]
    >>> print get_estimated_prediction_time_cost(seq_list, allele_list,'mhcii')
    2.5965
    >>> print get_estimated_prediction_time_cost(sequence_length_list=[59,], allele_list=allele_list,mhc='mhcii')
    2.5965
    >>> print get_estimated_prediction_time_cost(sequence_length_list=[15,]*10000, allele_list=['DRB1*01:01',],mhc='mhcii')
    1055.501
    """
    if (not sequence_length_list) and (not sequence_list):
        logging.warning('No input sequences for the time estimation')
        return None
    if not allele_list:
        logging.warning('No input alleles for the time estimation')
        return None
    if sequence_list and not all([type(seq)==str for seq in sequence_list]):
        raise ValueError('not all seq is string.')
    if not all([type(allele)==str for allele in allele_list]):
        raise ValueError('not all allele is string in allele_list: %s. and the type of it is%s.' % (allele_list, type(allele_list)))
    if not lengths:
         lengths = []

    if method == '1':
        method = 'recommended'
    elif method.startswith('netmhcpan'):
        method = 'netmhcpan'

    if mhc=='mhcii' and method=='recommended':  
        con_status = {}  
        recommended_allele_list = []
        netmhciipan_allele_list = []
        con_list = get_method_allele_list('recommended').strip().split('\n')
        for con_element in con_list:
            con_arr = con_element.split('\t')
            con_status[con_arr[0]] = con_arr[1:]  
        for allele in allele_list:
            allele = allele.replace('HLA-', '')
            if allele not in con_status:
                logging.warning('allele "%s" is not in con_status' % allele)
                continue
            if con_status[allele][0] == "0" and con_status[allele][1] == "0" and con_status[allele][2] == "0" and con_status[allele][3] == "1":
                netmhciipan_allele_list.append(allele)
            else:
                recommended_allele_list.append(allele)
        return get_estimated_prediction_time_cost_with_actual_method(sequence_list, recommended_allele_list, mhc, method, sequence_length_list,lengths) + get_estimated_prediction_time_cost_with_actual_method(sequence_list, netmhciipan_allele_list, mhc, 'NetMHCIIpan', sequence_length_list,lengths)
    else:
        return get_estimated_prediction_time_cost_with_actual_method(sequence_list, allele_list, mhc, method, sequence_length_list,lengths)

def get_estimated_prediction_time_cost_with_actual_method(sequence_list=[], allele_list='', mhc='mhci', method='recommended', sequence_length_list=[],lengths=[]):
    """
    >>> seq_list = ['MGQIVTMFEALPHIIDEVINIVIIVLIVITGIKAVYNFATCGIFALISFLLLAGRSCGM',]
    >>> allele_list = ['HLA-A*01:01',]
    >>> print get_estimated_prediction_time_cost_with_actual_method(seq_list, allele_list)
    6.165508
    >>> print get_estimated_prediction_time_cost_with_actual_method(sequence_length_list=[59,], allele_list=allele_list)
    6.165508

    """
    
    if sequence_list:
        sequence_length_list = [len(seq) for seq in sequence_list]
    if not sequence_length_list:
        logging.warning('No sequence_list input')
        return 
    if type(allele_list)!=list:
        logging.warning('No allele_list input or wrong format (allele_list: %s)' % allele_list)
        return 0

    approximate_time_cost = 0
    all_lengths_num =len([length for length in lengths if length=='All lengths'])
    allele_num = len(allele_list) + (all_lengths_num * 7)
    lengths = [length for length in lengths if length!='All lengths']
    # lengths=map(int, lengths)
    for i in range(all_lengths_num):
        lengths.extend([8,9,10,11,12,13,14])
    # TODO add all_length here and in the below forloop insert it
    for i in range(allele_num):
        predict_length = 9 if len(lengths)<=i else lengths[i]
        for sequence_length in sequence_length_list:
            single_prediction_time = get_estimated_single_prediction_time_cost(mhc,method,sequence_length,predict_length=predict_length)
            if not single_prediction_time:
                single_prediction_time = 0 
            approximate_time_cost += single_prediction_time

    time_coefficients = standard_time_dict.get((mhc, method), 'wrong_input')
    if time_coefficients == 'wrong_input':        
        return 0
    approximate_time_cost += (time_coefficients[0]/100)

    if approximate_time_cost < 0:
        logging.warning('time estimated < 0:"%s". sequence_list=%s, allele_list=%s, mhc=%s, method=%s, sequence_length_list=%s,lengths=%s' % (approximate_time_cost, sequence_list, allele_list, mhc, method, sequence_length_list,lengths))
        approximate_time_cost = 0
    return approximate_time_cost 
  
def get_estimated_single_prediction_time_cost(mhc,method,sequence_length,predict_length=9):
    """    
    >>> print get_estimated_single_prediction_time_cost('mhci', 'netmhcpan', 514)    
    0.954776
    
    """
    if (mhc,method)==('mhci','ann') or (mhc,method)==('mhci','4'):
        time_coefficients = ann_time_dict.get(str(predict_length), 'wrong_input')
    elif (mhc,method)==('mhci', 'netmhcpan') or (mhc,method)==('mhci','3'):
        time_coefficients = netmhcpan_time_dict.get(str(predict_length), 'wrong_input')
    else:
        time_coefficients = standard_time_dict.get((mhc, method), 'wrong_input')
    if time_coefficients == 'wrong_input':
        logging.warning('wrong_inpput mhc: "%s" or method "%s" or length "%s".' % (mhc, method, predict_length))
        return
    if mhc=='mhcii':
        predict_length = 15    
    logging.debug('sequence_length: %s - predict_length: %s', repr(sequence_length), repr(predict_length))
    predict_length = int(predict_length)
    length = sequence_length - predict_length +1
    if length < 1:
        logging.debug('sequence length is too short (%s) for %s and method "%s" and  length "%s"' % (sequence_length,mhc,method,predict_length))
        return 0
    return (length*time_coefficients[2]+time_coefficients[1])/100

def get_proper_batch_size(mhc_class, method, sequence_lengths, alleles):
    """
    >>> seq_list = ['MGQIVTMFEALPHIIDEVINIVIIVLIVITGIKAVYNFATCGIFALISFLLLAGRSCGM',]
    >>> allele_list = ['HLA-A*01:01',]
    >>> print get_proper_batch_size('mhci', 'recommended', [498,], ['HLA-A*02:01'])
    2
    >>> print get_proper_batch_size('mhci', 'recommended', [498, 554], ['HLA-A*02:01'])
    2
    >>> print get_proper_batch_size('mhci', 'netmhcpan', [498, 554], ['HLA-A*02:01'])    
    1
    >>> print get_proper_batch_size('mhcii', 'consensus', [498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498, 498], ['DRB1*01:01'])
    8
    """
    if method in ['smm', 'comblib_sidney2008', 'tepitope', 'comblib']:
        job_num = 1
    else:
        time_cost = get_estimated_prediction_time_cost(mhc=mhc_class, method=method, sequence_length_list=sequence_lengths, allele_list=alleles)
        if time_cost < 5:
            job_num = 1
        elif time_cost < 10:
            job_num = 2
        elif time_cost < 20:
            job_num = 4
        else:
            job_num = 8
    logging.info('For method "%s", the recommended job_num to splitted into is "%s".' % (method, job_num))
    return job_num 

		

def main():
    import doctest
    doctest.testmod()
    print("doctest passed")


if __name__ == '__main__':
    main()

