'''
Created on Nov 10, 2015

@author: jivan
'''
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import logging
import os
from time import sleep
from celery_tasks import capp
from celery import group
from celery.exceptions import TimeoutError, SoftTimeLimitExceeded, TimeLimitExceeded, WorkerLostError

#from requests.exceptions import ConnectionError


from PercentilesCalculators import MHCIPercentilesCalculator
from allele_info.allele_info import is_user_defined_allele

from protection import split_prediction_request_into_number

# for length rescaled predictions 
from length_rescaling.rescale_peptide import calculate_length_rescaled

logger = logging.getLogger(__name__)

# Convert contents of environment variable USE_DISTRIBUTED_PROCESSING to
#    python boolean.  If it doesn't exist default to False.
_udp = os.environ.get('USE_DISTRIBUTED_PROCESSING', False)
USE_DISTRIBUTED_PROCESSING = False if _udp == 'False' else bool(_udp)

MHCIPredictionParams = namedtuple('MHCIPredictionParams',
                        ['method_name', 'sequence', 'allele', 'binding_length'])

AlleleLengthPair = namedtuple('AlleleLengthPair', ['allele', 'binding_length'])

@capp.task
def get_scores_and_percentiles(raw_scores, allele, binding_length, method, score_distributions):
    ''' 
    To accept scores then return scores & percentiles
    '''
    msg = 'Calculating percentiles for {}, {} scores.'.format(method, len(raw_scores))
    logger.info(msg)
    percentiles_calculator = MHCIPercentilesCalculator(score_distributions)

    try:
        percentiles = percentiles_calculator.get_percentile_scores(
                        raw_scores, method, allele, binding_length)
    except ValueError:
        # undefined "allele_length_pair"
        if is_user_defined_allele(allele_length_pair.allele):
            logger.debug( "allele '%s' is_user_defined_allele" % allele_length_pair.allele)
            percentiles = [None for i in xrange(len(raw_scores))]
        else:
            raise
    logger.debug('got percentiles for method %s: %s' %  (method, percentiles))
    #logger.debug('So the scores and percentiles zip list for method %s is: %s' % (method, zip(raw_scores,percentiles)))
    logger.info('Done calculating percentiles for {}'.format(method))
    return list(zip(raw_scores,percentiles))

class MHCIPredictor():
    ''' | *author*: Jivan
        | *created*: 2015-11-10
        | *brief*: Base class for MHCI Predictors
    '''
    __metaclass__ = ABCMeta
    # Set this in the derived clas to the name of the method the predictor provides
    method_name = 'predictor_base'
    # This is the number of decimal places that scores for this predictor are
    #    typically rounded to in user output.
    score_decimal_places = 0
    # Set this to True in the derived class if the predictor accepts
    #    named alleles like 'HLA-B40:13'
    accepts_named_alleles = False
    # Set this to True in the derived class if the predictor accepts
    #    user-defined sequences as allele input.
    accepts_user_defined_alleles = False
    # Set this to the method's score distribution data dictionary in the derived class.
    # This is used as the value passed to MHCIPercentilesCalculator's constructor.
    _score_distributions = None



    @classmethod
    def get_percentiles_for_scores(cls, raw_scores, allele_length_pair, use_distributed_processing=False, interface='web'):
        ''' | *author*: Jivan
            | *created*: 2015-12-07
            | *brief*: Returns the percentile scores for the raw scores passed.
            | *note*: Percentiles can't be calculated for user-defined alleles.
            |    A list of None values matching the length of *raw_scores* is returned
            |    if you attempt to get the percentiles for user-defined alleles.
        '''
        msg = 'Calculating percentiles for {}, {} scores.'.format(cls.method_name, len(raw_scores))
        logger.info(msg)
        allele, binding_length = allele_length_pair
        try:
            if use_distributed_processing:
                percentiles = get_scores_and_percentiles.apply_async(args=[raw_scores, cls.method_name, allele, binding_length, cls._score_distributions], queue=interface).get()
            else:
                percentiles = get_scores_and_percentiles(raw_scores, cls.method_name, allele, binding_length, cls._score_distributions) 
        except ValueError:
            if is_user_defined_allele(allele_length_pair.allele):
                percentiles = [None for i in range(len(raw_scores))]
            else:
                raise

        logger.info('Done calculating percentiles for {}'.format(cls.method_name))
        return percentiles

    @abstractmethod
    def get_score_unit(self):
        ''' | *author*: Jivan
            | *created*: 2015-11-10
            | *brief*: Returns a string representing the unit for scores returned by this predictor.
            | *note*: This must be overridden in subclasses.
        '''
        pass

    def predict_many(self, sequence_list, allele_length_2tuple_list, interface='celery', split_job_num=24, use_distributed_processing=False, el=False , with_distance_info=False): # (JY)"celery" means default queue.
        ''' | *author: Jivan
            | *created: 2015-11-10
            | *brief: Chooses to predict locally or perform predictions via celery workers.

            Calls _do_prediction() to actually perform predictions.
        '''
        # JYan (2020-09-30) API interface don't need the distance info, but this function still return it as part of the result.
        logger.info('running prediction with distance info: %s' % with_distance_info)

        if not use_distributed_processing:
            if self.method_name in ['netmhcpan', 'netmhcpan_el', 'netmhcpan_ba'] :
                results = self._do_prediction(sequence_list, allele_length_2tuple_list, with_distance_info=with_distance_info)
            else :
                results = self._do_prediction(sequence_list, allele_length_2tuple_list)
        else:
            subtask_args_list = split_prediction_request_into_number(sequence_list, allele_length_2tuple_list, split_job_num)
            logger.info('Starting group_task with {} tasks'.format(len(subtask_args_list)))
            if self.method_name in ['netmhcpan', 'netmhcpan_ba', 'netmhcpan_el']:
                group_task = group(
                        [ self._do_prediction.subtask(subtask_args, {'with_distance_info':with_distance_info}, queue=interface)
                            for subtask_args in subtask_args_list ]
                )
            else:
                group_task = group(
                        [ self._do_prediction.subtask(subtask_args, queue=interface)
                            for subtask_args in subtask_args_list ]
                )
            task_worker_timeout = 900
            task_request_timeout = 900
            seconds_between_polling = 1.0
            logger.info('Performing group_task.apply_async()')
            async_result = group_task.apply_async(
                               expires=task_worker_timeout, interval=seconds_between_polling
                           )
            try:
                logger.info('Performing aysync_result.get()')
                result = async_result.get(timeout=task_request_timeout)
            except WorkerLostError:
                logger.error('worker lost.')
                raise ValueError('Prediction failed, Please try again later or contact us for help.')
            except SoftTimeLimitExceeded:
                logger.error('Soft Time Limit Exceeded')
                raise ValueError('Prediction failed, Please try again later or contact us for help.')
            except TimeLimitExceeded:
                logger.error('Time Limit Exceeded')
                raise ValueError('Prediction failed, Please try again later or contact us for help.')
            except OSError as e:
                logger.error('OSError: %s' % e)
                raise ValueError('Prediction failed, Please try again later or contact us for help.')
            except Exception as e:
                logger.error('celery error "%s": %s' % (type(e),e))
                raise ValueError('The prediction failed, Please try again later or check the help page for further information.')

            if async_result.failed():
                logger.error('group_task unsuccessful.')
                for n, r in enumerate(async_result.results):
                    if r.successful():
                        logger.error('subtask {} was successful.'.format(n))
                    else:
                        logger.error('subtask {} failed:\n{}\n'.format(n, r.state))

                msgs = [ r.traceback + '\n' for r in result.results if r.failed() ]
                msg = 'Celery error while attempting distributed processing:\n{}'\
                           .format(''.join(msgs))
                logger.error(msg)
                raise Exception('group_task unsuccessful, see log for details.')

            logger.info('group_task successful, result count: {}'.format(len(result)))
            results = []
            distances = {}
            if with_distance_info:
                for r,d in result:
                    results.extend(r)
                    distances.update(d)
                results = (results, distances)
            else:
                for r in result:
                    results.extend(r)
        return results

    def split_prediction_request(self, sequence_list, allele_length_2tuple_list, size=1):
        """ | *author*: Jivan
            | *created*: 2015-11-10
            | *brief*: Breaks prediction parameters into pieces for distributed procesing.

            Each predictor should override this for efficient decomposition of parameters.
            This is the fallback if a predictor doesn't and simply breaks a bulk request
            into a list sequence / allele-length-tuple arguments, each to perform
            a single prediction.
        """
        
        # JY(2017-01-27): adding a new parameter to control the size of split
        split_request = []
        num_of_allele = len(allele_length_2tuple_list)
        if num_of_allele == 0:
            return []
        if num_of_allele > size:
            for sequence in sequence_list:
                for i in range(0, num_of_allele, size):
                    split_request.append(([sequence], allele_length_2tuple_list[i:i+size]))
        else:
            for i in range(0, len(sequence_list), size//num_of_allele):
                split_request.append((sequence_list[i:i+size//num_of_allele], allele_length_2tuple_list))
        return split_request



    @abstractmethod
    def _do_prediction(self, sequence_list, allele_length_2tuple_list):
        """ | *author*: Jivan
            | *created*: 2015-11-10
            | *brief*: Performs binding predictions.

            This must be overridden for each predictor.  The output should be a dictionary
            of the form: {
                <MHCIPredictionParams>: <score-tuple>,
                ...
            }
        """
        pass

    @classmethod
    def results_to_data_rows(cls, sequence_list, results, sorted_by=None, use_distributed_processing=False, interface='web'):
        """ | *author*: Jivan
            | *created*: 2016-01-15
            | *brief*: Converts prediction results from _do_prediction() format
            |   to an output-friendly format consisting of a list of dictionaries.
            | *note*: This function is very similar to the one of the same name in
            |   ConsensusPredictor, but operates on a different *results* format
            |   and returns slightly different dictionaries.
            | *sequence_list* indicates the order the sequences will be displayed
            |    so the result rows can reference them by number.
            | *results* is the return value of _do_prediction implemented in a
            |    concrete subclass.

            Each dictionary in the list returned represents one result row and has the form: {
                'allele': <allele-name>,
                'sequence_number': <sequence-number>,
                'start': <start-index>,
                'stop': <stop-index>,
                'binding_length': <binding-length>,
                'peptide': <peptide-as-string>,
                'score': <predictor-score>,
                'percentile': <percentile-of-score>,
            }
        """
        data_rows = []
            
        for (method_name, sequence, allele, binding_length), scores_and_percentiles in results.items():
            alp = AlleleLengthPair(allele, binding_length)   
            
            # netmhcpan 4.0
            if len(scores_and_percentiles[0]) == 4 :
                for score_index , (score, percentile, core, icore) in enumerate(scores_and_percentiles):
                    row = {}
                    row['allele'] = allele
                    row['sequence_number'] = sequence_list.index(sequence) + 1
                    start = score_index + 1
                    row['start'] = start
                    row['stop'] = start + binding_length - 1
                    row['binding_length'] = binding_length
                    row['peptide'] = sequence[score_index:(score_index + binding_length)]
                    row['core'] = core
                    row['icore'] = icore
                    row['score'] = score
                    row['percentile'] = percentile
                    data_rows.append(row)

            # netmhcpan 2.8 / 3.0
            else :
                logger.debug("Inside else statement")
                for score_index , (score, percentile) in enumerate(scores_and_percentiles):
                    row = {}
                    row['allele'] = allele
                    row['sequence_number'] = sequence_list.index(sequence) + 1
                    start = score_index + 1
                    row['start'] = start
                    row['stop'] = start + binding_length - 1
                    row['binding_length'] = binding_length
                    row['peptide'] = sequence[score_index:(score_index + binding_length)]
                    row['score'] = score
                    row['percentile'] = percentile
                    if method_name == 'netmhcpan 2.8':
                        row['adj_rank'] = calculate_length_rescaled(pep=row['peptide'],hla=row['allele'],rank=row['percentile'])
                    data_rows.append(row)
                
        # Set default sort order
        if sorted_by is None:
            sorted_by = ('score', 'sequence_number', 'start')
        data_rows.sort(key=lambda d: [ d[sort_field] for sort_field in sorted_by ])

        return data_rows
