import os
import sys
import re
import warnings
import math
import pandas as pd
import numpy as np
from scipy.stats import poisson
from Bio import SeqIO
from tempfile import NamedTemporaryFile
from pathlib import Path
PDIR = Path(os.path.abspath(__file__)).parent.absolute()
DATA_DIR = str(PDIR) + "/data"
sys.path.insert(0, str(PDIR))
from rankMap_ic50 import rankMap
#from predictors.NetMHCPan4elPredictor import NetMHCPan4elPredictor
from netmhcpan_4_0_executable import predict_many
from netmhcpan_4_1_executable import predict_many_peptides_file


class Axelf :
    def __init__(self) :
        self.tcga_data = pd.read_csv(DATA_DIR + "/tcga_medians_dev.csv")

    # def validate_csv(self, input_file, tpm_source="custom") :
        # -------------------------------------------------------------------------
        # Checks for valid headers, valid alleles, and proper column types
        # -------------------------------------------------------------------------
        # def validate_headers():
            # mandatory_cols = None
            # mandatory_cols = ["peptide"]

            # if tpm_source == "custom" :   
            #     mandatory_cols = ["peptide"]
            
            # if tpm_source == "tcga" :
            #     mandatory_cols = ["peptide", "allele", "gene name"]

            # Make sure columns are all present
            # for col_name in mandatory_cols :
            #     if col_name not in headers :
                    # raise ValueError("File must contain column for '%s', '%s', and '%s'." \
                    #                 %(mandatory_cols[0], mandatory_cols[1], mandatory_cols[2]))
                    # raise ValueError("Please check if CSV file contains header.")
        
        # def validate_column_types(input_csv):
        #     if tpm_source == "custom" :
        #         for row in input_csv.itertuples() :
        #             try:
        #                 float(row.tpm)
        #             except :
        #                 raise ValueError("TPM must be numbers/decimals.")

        # def validate_alleles(input_alleles):
        #     from allele_info import MHCIAlleleData

        #     miad = MHCIAlleleData()
        #     allele_names = miad.get_all_allele_names()
        #     for i, allele in input_alleles.items() :
        #         if allele not in allele_names :
        #             raise ValueError("Invalid allele name ({}). Please enter a valid allele name.".format(allele))
        
        # input_csv = pd.read_csv(input_file)
        # headers = input_csv.columns.tolist()
        # headers = [h.lower() for h in headers]

        # check for valid columns
        # validate_headers()

        # check if types are matching
        # validate_column_types(input_csv)

        # check if alleles are valid
        # validate_alleles(input_csv.allele)


    def validate_fasta(self, input_file) :
        # -------------------------------------------------------------------------
        # Checks for valid headers, valid alleles, and proper column types
        # -------------------------------------------------------------------------
        def validate_sequence(seq, alphabet='dna'):
            alphabets = {'dna': re.compile('^[acgtn]*$', re.I), 
                    'protein': re.compile('^[acdefghiklmnpqrstvwy]*$', re.I)}

            if alphabets[alphabet].search(seq) is not None:
                return True
            else:
                return False

        fasta_sequences = SeqIO.parse(open(input_file),'fasta')
        
        for fasta in fasta_sequences:
            name, sequence = fasta.id, str(fasta.seq)
            if not name or not sequence : 
                return False
            
            if not validate_sequence(sequence, "protein") :
                return False

        return True

    def validate_tpm(self, tpm_value) :
        # -------------------------------------------------------------------------
        # Checks if tpm is a float.
        # -------------------------------------------------------------------------
        try:
            float(tpm_value)
            return True
        except :
            return False


    def validate_gene_name(self, name) :
        # -------------------------------------------------------------------------
        # Checks if gene name can be found in the TCGA data file.
        # -------------------------------------------------------------------------
        available_genenames = self.tcga_data["gene"].values
        if name not in available_genenames :
            raise ValueError("{} does not exists. Please enter a valid gene name.")
        
        return True
    
    def validate_ensemble_id(self, id) :
        # -------------------------------------------------------------------------
        # Checks if Ensemble ID can be found in the TCGA data file.
        # -------------------------------------------------------------------------
        available_ids = self.tcga_data["gene.id"].values
        if id not in available_ids :
            raise ValueError("{} does not exists. Please enter a valid Ensemble ID.")
        
        return True     

    def validate_cancer_type(self, cancer) :
        # -------------------------------------------------------------------------
        # Checks if cancer type can be found in the TCGA data file.
        # -------------------------------------------------------------------------
        available_cancertypes = list(self.tcga_data.columns)[2:]
        if cancer not in available_cancertypes :
            raise ValueError("{} does not exists. Please enter a valid cancer type.")
        
        return True

    def validate_peptide_length(self, pep_length) :
        # -------------------------------------------------------------------------
        # Checks if the length is a number.
        # -------------------------------------------------------------------------
        return isinstance(pep_length, int)
    
    def get_cancer_types(self) :
        # -------------------------------------------------------------------------
        # Grabs all the cancer types.
        # -------------------------------------------------------------------------
        cancers = []
        with open(DATA_DIR + "/tcga_abbr.txt", "r") as f :
            content = f.readlines()
            for cancer_info in content :
                print(cancer_info)
                (cancer_abbr, cancer_fullname) = cancer_info.strip().split("\t")
                cancers.append((cancer_abbr, cancer_fullname))
        return sorted(cancers, key=lambda x: x[0])
    
    def get_gene_names(self) :
        # -------------------------------------------------------------------------
        # Grabs all the gene names.
        # -------------------------------------------------------------------------
        return self.tcga_data["gene"].drop_duplicates().tolist()


    def search_gene_name(self, query) :
        # -------------------------------------------------------------------------        
        # Searches for gene name that has exact match against the query.        
        # -------------------------------------------------------------------------        
        genes_series = self.tcga_data["gene"].drop_duplicates()
        
        # Create dictionary {LowerCaseGeneName : OriginalGeneName}        
        genes_dict = dict(zip(genes_series.str.lower().tolist(), genes_series.tolist()))
        gene_list = []
        available_genes = list(set(genes_series.str.lower().tolist()))
        count = 0
        for each_gene in available_genes :
            if query.lower() in each_gene :
                gene_list.append(each_gene)
                count = count + 1

            if count == 10 : break

        glist = [genes_dict[g] for g in gene_list]
        
        return glist


    def get_tcga(self, gene_name, cancer_type, ensemble_id=None) :
        # -------------------------------------------------------------------------
        # Searches for all the indices that the gene name occurs, and will return
        # the highest TCGA value.
        # -------------------------------------------------------------------------
        if gene_name and ensemble_id :
            # Ensemble ID takes precedence over gene name
            available_ids = self.tcga_data["gene.id"].values
            indices = np.where(available_ids == ensemble_id)
            tcga_value = 0
        elif gene_name :
            available_genenames = self.tcga_data["gene"].values
            indices = np.where(available_genenames == gene_name)
            tcga_value = 0
        elif ensemble_id :
            available_ids = self.tcga_data["gene.id"].values
            indices = np.where(available_ids == ensemble_id)
            tcga_value = 0
        else :
            raise ValueError("Please enter either the gene name or the Ensemble ID.")

        # Throw warning if there are duplicates
        if 1 < indices[0].size :
            warnings.warn("There are multiple entries for {}. Axel-F will use the largest TCGA value.".format(gene_name))

        # return highest tcga when there are duplicates
        for i in np.nditer(indices) :
            curr_tcga = self.tcga_data.at[int(i), cancer_type]
            if tcga_value < curr_tcga :
                tcga_value = curr_tcga
        
        return tcga_value

    def calculate_axelf_score(self, tpm, ic50):
        '''
        Assuming the 'tpms' dataframe contains no NaN values.
        '''
        ALPHA = 1.233333
        KT = 0.1555556
        MIN_TPM = 0.5666667
            
        est_ligands = ALPHA * max(tpm, MIN_TPM) * math.exp( -1 * math.log10(ic50) / KT )
        return 1 - poisson.cdf(0, est_ligands)
        

    def calculate_ic50_score(self, peptides, alleles):
        ''' ---------------------------------------------------------------
        Description :
        Utilized netMHCpan_4_1_executable: predict_many_peptides_file(),
        which speeds up the process compared to using predict_man().
        (predict_many_peptides_file('test.pep',['HLA-A*0101']))

        Parameters :
        peptides - list of peptides.
        alleles - list of alleles.
        -----------------------------------------------------------------'''
        scores_list = [None] * len(peptides)
        # Group peptides by allele
        allele_peptide_dict = {allele: [] for allele in alleles}
        [allele_peptide_dict[alleles[i]].append(peptides[i]) for i in range(len(peptides))] 
            

        for k_allele, v_peptides in allele_peptide_dict.items() :

            # Save list of peptides into a temporary file
            tmp = NamedTemporaryFile(delete=False, mode='w')
            [tmp.write(v_peptide + "\n") for v_peptide in v_peptides]
            tmp.close()
        
            # Get rank_el
            rank_el_data = predict_many_peptides_file(tmp.name, [k_allele], el=True)
            rank_el = [v[0] for k, v in rank_el_data.items()]

            # Get rank_el mapped to ic50
            rMap = rankMap()
            mapped_ic50 = rMap.rank2ic50(rank_el)

            tupled_scores = tuple(zip(rank_el, mapped_ic50))

            # Collect all the scores into correct ordering as peptide list
            local_peptide_list = [_[1] for _ in list(rank_el_data.keys())]
            pepkeys = list(rank_el_data.keys())
            for each_key in pepkeys :
                pep = each_key[1]
                idx = peptides.index(pep)
                local_idx = local_peptide_list.index(pep)
                scores_list[idx] = tupled_scores[local_idx]
            
            # Delete temporary file
            try :
                os.remove(tmp.name)
            except :
                raise ValueError("%s can't remove because it can't be found." %(tmp.name))
        
        return scores_list