#!/usr/bin/env python

'''
Created on 03.01.2017
'''

from __future__ import print_function

from collections import namedtuple
import os
import pickle


def read_pickle_file(file_name):
    with open(file_name, 'rb') as r_file:
        data = pickle.load(r_file)
    return data

class MHCIAlleleData(object):

    def __init__(self):
        self.pickle_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),'pickles/mhci_info_dict.p')
        self._data = None
        self.hidden_method = ['arb']

    @property
    def data(self):
        if not self._data:
            self._data = read_pickle_file(self.pickle_file)
        return self._data            

    def get_method_names(self, allele_name=None, binding_length=None, include_hidden=False, connection=None):
        
        method_list = self.data['method_list']
      
        if not include_hidden:
            method_list = [method for method in method_list if not method in self.hidden_method]
        if allele_name:
            method_list = filter(lambda method:not allele_name or allele_name in self.data['method_allele_length_dict'][method].keys(), method_list)
            if binding_length:
                method_list = [method for method in method_list if binding_length in self.data['method_allele_length_dict'][method][allele_name]]
        if binding_length:
            method_length_dict = {method:list(set(reduce(lambda x,y:x+y, self.data['method_allele_length_dict'][method].values()))) for method in self.data['method_list']}  
            method_list = [method for method in method_list if  binding_length in method_length_dict[method]]
        return method_list

    
    def get_all_allele_names(self, connection=None):
        """Returns an ordered list of all allele names."""
        all_allele_list = list(set(reduce(lambda x,y:x+y, [d.keys() for d in self.data['method_allele_length_dict'].values()])))
        all_allele_list.sort()
        return all_allele_list
    
    def get_allele_names_for_method(self, method_name, binding_length=None, connection=None):
        """Returns a list of allele names valid for use with the method indicated by \a method_name. This is allows get_allowed_peptide_lengths() to validate method_name / allele_name combinations."""
        
        # Make method_name's passed case-insensitive.
        method = method_name.lower()
        if not method or method not in self.get_method_names(include_hidden=True):
            raise ValueError('Invalid method_name: {}'.format(method_name))
        allele_list = self.data['method_alleles_dict'][method]
        if binding_length and type(binding_length)!=int:
            raise ValueError('Invalid binding_length: {}'.format(binding_length))
        
        if binding_length and type(binding_length)==int:
            allele_lengths_dict = self.data['method_allele_length_dict'][method]
            allele_list = [allele for allele in allele_list if binding_length in allele_lengths_dict[allele]]
        return allele_list


    def get_allele_names(self, species=None, method_name=None, frequency_cutoff=None,
                         connection=None, include_hidden=False):
        """ @brief Returns a list of allele names given a \a species and optionally
                filtered by \a method_name or \a frequency_cutoff.
            @author: Jivan
            @since: 2015-10-06
        """
        # Make method_name's passed case-insensitive.
        if method_name: 
            method_name = method_name.lower()
        if not species:
            raise ValueError('species is a required keyword argument')
        if not species or species not in self.get_species_list():
            raise ValueError('Invalid species: {}'.format(species))
        if method_name and method_name not in self.get_method_names(include_hidden=include_hidden):
            raise ValueError('Invalid method_name: {}'.format(method_name))

        alleles = self.data['species_alleles_dict'][species]
        if method_name:
            method_allele_list = self.data['method_alleles_dict'][method_name]
            alleles=[allele for allele in alleles if allele in method_allele_list]
        if frequency_cutoff:
            alleles=[allele for allele in alleles if self.data['allele_frequency_dict'][allele] and self.data['allele_frequency_dict'][allele]>=frequency_cutoff]

        return alleles

    
    def get_allele_frequencies(self, allele_name_list, connection=None):
        """Returns a list of frequencies for the *allele_name_list* passed. If the allele has no frequency or is invalid it's returned value will be None."""
        frequencies = [self.data['allele_frequency_dict'][allele] for allele in allele_name_list]
        return frequencies


    def get_species_list(self, method_name=None, connection=None, include_hidden=False):
        """Returns a list of species for MHCI optionally filtered by \a method_name."""
        # Make method_name's passed case-insensitive.
        if method_name: 
            method_name = method_name.lower()
        else:
            return self.data['species_list']
        if method_name and method_name not in self.get_method_names(include_hidden=include_hidden):
            raise Exception('Invalid method name: {}'.format(method_name))
        return self.data['method_species_dict'][method_name]

    @staticmethod
    def get_species_for_allele_name(allele_name):
        """ @brief: Returns \a species from \a given allele name.
            @author: Dorjee
            @since: 2016-11-01
        """
        # Remove all white-spaces
        allele_name = allele_name.strip()
        # Raise error if allele_name is not provided
        if not allele_name:
            raise ValueError('allele_name is a required keyword argument')

        species_indicator = allele_name[:2].lower()
        species_by_indicator = {
            'hl': 'human',
            'h-': 'mouse',
            'pa': 'chimpanzee',
            'ma': 'macaque',
            'go': 'gorilla',
            'sl': 'pig',
            'bo': 'cow',
            'rt': 'rat',
        }
        species = species_by_indicator[species_indicator] if species_indicator in species_by_indicator else None
        return species


    def get_allowed_peptide_lengths(self, method_name, allele_name, connection=None, include_hidden=False):
        """Returns the valid peptide lengths for the \a species & \a method_name given."""
        # Make method_name's passed case-insensitive.
        method_name = method_name.lower()

        if method_name not in self.get_method_names(include_hidden=include_hidden):
            raise ValueError('invalid method_name: {}'.format(method_name))
        valid_allele_names = self.get_allele_names_for_method(method_name)
        if allele_name not in valid_allele_names:
            raise ValueError('invalid allele_name: {}'.format(allele_name))
        
        return self.data['method_allele_length_dict'][method_name][allele_name]

    def get_reference_set(self):
        """ @brief Returns a list of (<allele name>, <binding length) tuples representing
                the MHCI reference set.
            @author: Jivan
            @since: 2015-10-06
        """
        raw_reference_set = [('HLA-A*01:01', '9'), ('HLA-A*01:01', '10'), ('HLA-A*02:01', '9'), ('HLA-A*02:01', '10'), ('HLA-A*02:03', '9'), ('HLA-A*02:03', '10'), ('HLA-A*02:06', '9'), ('HLA-A*02:06', '10'), ('HLA-A*03:01', '9'), ('HLA-A*03:01', '10'), ('HLA-A*11:01', '9'), ('HLA-A*11:01', '10'), ('HLA-A*23:01', '9'), ('HLA-A*23:01', '10'), ('HLA-A*24:02', '9'), ('HLA-A*24:02', '10'), ('HLA-A*26:01', '9'), ('HLA-A*26:01', '10'), ('HLA-A*30:01', '9'), ('HLA-A*30:01', '10'), ('HLA-A*30:02', '9'), ('HLA-A*30:02', '10'), ('HLA-A*31:01', '9'), ('HLA-A*31:01', '10'), ('HLA-A*32:01', '9'), ('HLA-A*32:01', '10'), ('HLA-A*33:01', '9'), ('HLA-A*33:01', '10'), ('HLA-A*68:01', '9'), ('HLA-A*68:01', '10'), ('HLA-A*68:02', '9'), ('HLA-A*68:02', '10'), ('HLA-B*07:02', '9'), ('HLA-B*07:02', '10'), ('HLA-B*08:01', '9'), ('HLA-B*08:01', '10'), ('HLA-B*15:01', '9'), ('HLA-B*15:01', '10'), ('HLA-B*35:01', '9'), ('HLA-B*35:01', '10'), ('HLA-B*40:01', '9'), ('HLA-B*40:01', '10'), ('HLA-B*44:02', '9'), ('HLA-B*44:02', '10'), ('HLA-B*44:03', '9'), ('HLA-B*44:03', '10'), ('HLA-B*51:01', '9'), ('HLA-B*51:01', '10'), ('HLA-B*53:01', '9'), ('HLA-B*53:01', '10'), ('HLA-B*57:01', '9'), ('HLA-B*57:01', '10'), ('HLA-B*58:01', '9'), ('HLA-B*58:01', '10')]
        AlleleLengthTuple = namedtuple('AlleleLengthTuple', ['allele_name', 'binding_length'])
        # Convert raw reference set to named tuples for easier debugging & use of returned value.
        reference_set = [AlleleLengthTuple(*alt) for alt in raw_reference_set]
        return reference_set


class MHCIIAlleleData(object):

    def __init__(self):
        self.pickle_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),'pickles/mhcii_info_dict.p')
        self._data = None

    @property
    def data(self):
        if not self._data:
            self._data = read_pickle_file(self.pickle_file)
        return self._data
    
    def get_method_names(self, connection=None):
        """Returns a list of valid method names for MHC II."""        
        return self.data['method_list']

    
    def get_locus_names(self, connection=None):
        """Returns a list of valid locus names for MHC II."""        
        return self.data['locus_list']
        # [u'DP', u'DQ', u'DR', u'H2']

    
    def get_locus_names_for_method(self, method_name=None, connection=None):
        """Returns a list of valid locus names given a \a method for MHC II."""
        method_name = method_name.strip().lower()
        if not method_name:
            raise ValueError('method_name is a required keyword argument')        
        return self.data['method_locus_dict'][method_name]

    
    def get_allele_names(self, method_name=None, locus_name=None, connection=None):
        """Returns a list of allele names given a \a method and a \a locus."""
        if not method_name:
            raise ValueError('method_name is a required keyword argument')
        method_name = method_name.lower()
        if method_name not in self.get_method_names():
            raise ValueError('Invalid method_name: {}'.format(method_name))

        if locus_name and locus_name not in self.get_locus_names():
            raise ValueError('Invalid locus_name: {}'.format(locus_name))

        allele_list = self.data['method_allele_dict'][method_name]        

        if locus_name:
            allele_list = filter(lambda allele:self.get_locus_for_allele_name(allele)==locus_name, allele_list)

        return allele_list


    def get_alpha_chain(self, method_name=None, locus_name=None, connection=None):
        """Returns a list of alpha chains given a \a method and a \a locus."""
        if not method_name:
            raise ValueError('method_name is a required keyword argument')
        method_name = method_name.lower()
        if method_name not in self.get_method_names():
            raise ValueError('Invalid method_name: {}'.format(method_name))

        if not locus_name:
            raise ValueError('locus_name is a required keyword argument')
        elif locus_name not in self.get_locus_names():
            raise ValueError('Invalid locus_name: {}'.format(locus_name))

        return self.data['method_alpha_chain_dict'][(method_name, locus_name)]



    def get_beta_chain(self, method_name=None, allele_name=None, connection=None):
        """Returns a list of beta chains given a \a method and a \a allele."""
        if not method_name:
            raise ValueError('method_name is a required keyword argument')
        method_name = method_name.lower()
        if method_name not in self.get_method_names():
            raise ValueError('Invalid method_name: {}'.format(method_name))

        if not allele_name:
            raise ValueError('allele_name is a required keyword argument')

        return self.data['method_beta_chain_dict'][(method_name, allele_name)]


    def get_alpha_chains_for_locus(self, locus_name=None, connection=None):
        """Returns a list of beta chains given a \a locus."""
        if not locus_name:
            raise ValueError('locus_name is a required keyword argument')
        locus_name = locus_name.upper()

        return self.data['locus_alpha_chain_dict'][locus_name]


    def get_beta_chains_for_locus(self, locus_name=None, connection=None):
        """Returns a list of beta chains given a \a locus."""
        if not locus_name:
            raise ValueError('locus_name is a required keyword argument')
        locus_name = locus_name.upper()

        return self.data['locus_beta_chain_dict'][locus_name]
    
    def get_locus_for_allele_name(self, allele_name):
        """maybe also useful here. previouse one don't have this."""
        # Remove all white-spaces
        allele_name = allele_name.strip()
        # Raise error if allele_name is not provided
        if not allele_name:
            raise ValueError('allele_name is a required keyword argument')
        locus_name = allele_name[:2].upper()
        # if allele not starts with one of [u'DP', u'DQ', u'DR', u'H2'], raise error
        if not locus_name in self.get_locus_names():
            raise ValueError("unknown allele name")
        return locus_name

    def get_reference_set(self):
        """ @brief Returns a list of allele names representing the MHCII reference set.
            @author: Sinu
            @since: 2015-10-07
        """
        reference_set = ['HLA-DPA1*01/DPB1*04:01', 'HLA-DPA1*01:03/DPB1*02:01', 'HLA-DPA1*02:01/DPB1*01:01', 'HLA-DPA1*02:01/DPB1*05:01', 'HLA-DPA1*03:01/DPB1*04:02', 'HLA-DQA1*01:01/DQB1*05:01', 'HLA-DQA1*01:02/DQB1*06:02', 'HLA-DQA1*03:01/DQB1*03:02', 'HLA-DQA1*04:01/DQB1*04:02', 'HLA-DQA1*05:01/DQB1*02:01', 'HLA-DQA1*05:01/DQB1*03:01', 'HLA-DRB1*01:01', 'HLA-DRB1*03:01', 'HLA-DRB1*04:01', 'HLA-DRB1*04:05', 'HLA-DRB1*07:01', 'HLA-DRB1*08:02', 'HLA-DRB1*09:01', 'HLA-DRB1*11:01', 'HLA-DRB1*12:01', 'HLA-DRB1*13:02', 'HLA-DRB1*15:01', 'HLA-DRB3*01:01', 'HLA-DRB3*02:02', 'HLA-DRB4*01:01', 'HLA-DRB5*01:01']
        return reference_set


class MHCNPAlleleData(object):

    def __init__(self):
        self.pickle_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),'pickles/mhcnp_info_dict.p')
        self._data = None

    @property
    def data(self):
        if not self._data:
            self._data = read_pickle_file(self.pickle_file)
        return self._data  

    def get_method_names(self, connection=None):        
        return self.data['method_list']

    def get_allele_names(self, method_name, connection=None):
        method_name = method_name.lower()
        if method_name not in self.get_method_names():
            raise ValueError('Invalid method_name: {}'.format(method_name))
        return self.data['method_alleles_dict'][method_name]


    def get_allowed_peptide_lengths(self, method_name, allele_name, connection=None):
        # Make method_name's passed case-insensitive.
        if method_name: 
            method_name = method_name.lower()

        if method_name not in self.get_method_names():
            raise ValueError('Invalid method_name: {}'.format(method_name))
        if allele_name not in self.get_allele_names(method_name):
            raise ValueError('Invalid allele_name: {}'.format(allele_name))
        return self.data['method_allele_lengths_dict'][method_name, allele_name]


class NetCTLpanAlleleData(object):

    def __init__(self):
        self.pickle_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),'pickles/netctlpan_info_dict.p')
        self._data = None

    @property
    def data(self):
        if not self._data:
            self._data = read_pickle_file(self.pickle_file)
        return self._data  

    def get_method_names(self, connection=None):
        return self.data['method_list']
    

    def get_species_list(self, connection=None):
        return self.data['species_list']
    

    def get_allele_names_for_species(self, species, connection=None):
        return self.data['species_alleles_dict'][species] #JY: no input valid check?
    

    def get_allowed_peptide_lengths(self, allele_name, connection=None):
        return self.data['allele_lengths_dict'][allele_name]#JY: no input valid check?
    
    @staticmethod
    def get_species_for_allele_name(allele_name):
        """ @brief: Returns \a species from \a given allele name.
            @author: Dorjee
            @since: 2016-11-01
        """
        # Remove all white-spaces
        allele_name = allele_name.strip()
        # Raise error if allele_name is not provided
        if not allele_name:
            raise ValueError('allele_name is a required keyword argument')

        species_indicator = allele_name[:2].lower()
        species_by_indicator = {
            'hl': 'human',
            'h-': 'mouse',
            'pa': 'chimpanzee',
            'ma': 'macaque',
            'go': 'gorilla',
            'sl': 'pig',
            'bo': 'cow',
            'rt': 'rat',
        }
        species = species_by_indicator[species_indicator]
        return species
    
def is_user_defined_allele(allele):
    ''' | *brief*: Returns True if *allele* is a user-defined sequence, False if it is an
        |    allele name.
        | *author*: Jivan
        | *created*: 2015-11-19
    '''
    if '-' in allele or '*' in allele or '1' in allele:
        ret = False
    else:
        ret = True

    return ret

