import pandas as pd
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path().absolute()))
PDIR = str(Path(__file__).absolute().parent.parent)
sys.path.insert(0, PDIR)
# from axelf import calculate_axelf_score, calculate_ic50_score, get_tcga_medians
from axelf_executable import Axelf

def sort_axelf_result(df, tpmSource):
    # Sort by Axelf score
    sorted_df = df.sort_values("axelf", ascending=False)
    sorted_df.reset_index(drop=True, inplace=True)

    if tpmSource == "custom" :
        # Format dataframe
        sorted_df["tpm"] = sorted_df["tpm"].map('{:.2f}'.format)
        sorted_df["ic50"] = sorted_df["ic50"].map('{:.3f}'.format)
        sorted_df["axelf"] = sorted_df["axelf"].map('{:.5f}'.format)
        col_order = ['allele', 'peptide', 'tpm', 'rank_el', 'ic50', 'axelf']
        sorted_df = sorted_df.reindex(columns=col_order)

        sorted_df.rename(
            columns={
                    'peptide': 'Peptide', 
                    'allele': 'Allele',
                    'tpm': 'TPM',
                    'rank_el': 'Rank_el',
                    'ic50': 'Rank_mapped_to_IC50',
                    'axelf': 'Axelf'
            }, inplace=True)
    
    if tpmSource == "tcga" :
        # Format dataframe
        sorted_df["tcga"] = sorted_df["tcga"].map('{:.2f}'.format)
        sorted_df["ic50"] = sorted_df["ic50"].map('{:.3f}'.format)
        sorted_df["axelf"] = sorted_df["axelf"].map('{:.5f}'.format)
        col_order = ['allele', 'peptide', 'gene name', 'tcga', 'rank_el', 'ic50', 'axelf']
        sorted_df = sorted_df.reindex(columns=col_order)

        sorted_df.rename(
            columns={
                    'peptide': 'Peptide', 
                    'allele': 'Allele',
                    'gene name': 'Gene Name',
                    'tcga': 'TPM',
                    'rank_el': 'Rank_el',
                    'ic50': 'Rank_mapped_to_IC50',
                    'axelf': 'Axelf'
            }, inplace=True)
    
    return sorted_df

def sequence_to_kmers(fastaFile, kmer_length):
    kmers = []
    with open(fastaFile, 'r') as f :
        sequence = "".join([line.strip() for line in f.readlines()][1:])
    
    for i in range(0, len(sequence) - kmer_length + 1) :
        seq = sequence[i:i+kmer_length]
        kmers.append(seq)

    return kmers

def main(example_file):
    
    # Parse the environment file first
    param_file = str(PDIR) + example_file
    with open(param_file, 'r') as f :
        for each_var in f.readlines() :
            key, value = each_var.strip().split("=")
            os.environ[key] = value

    # Get all env variables
    arg_dict = {}
    arg_dict['input_file'] = PDIR + os.environ.get('FILE_PATH')
    arg_dict['tpm_source'] = os.environ.get('TPM_SOURCE')
    arg_dict['tpm'] = os.environ.get('TPM')
    arg_dict['gene_name'] = os.environ.get('GENE_NAME')
    arg_dict['cancer_type'] = os.environ.get('CANCER_TYPE')
    arg_dict['peptide_length'] = os.environ.get('PEPTIDE_LENGTH')
    arg_dict['allele_list'] = os.environ.get('ALLELE_LIST')
    arg_dict['output_file'] = os.environ.get('OUTPUT')

    file_extension = Path(arg_dict['input_file']).suffix
    input_df = pd.read_csv(arg_dict['input_file'])
    
    axelf_obj = Axelf()
    
    print("FILE EXETIONS : %s" %(file_extension))

    if file_extension == ".csv" :
        if arg_dict['tpm_source'] == "custom" :
            # Source file should contain [Peptide, Alleles, TPM]
            # Retrieve 'rank_el' and 'mapped_ic50' values
            # input_df["tmp_el"] = input_df.apply(lambda row: axelf_obj.calculate_ic50_score(row["peptide"], row["allele"]), axis=1)
            peptide_col = input_df["peptide"].tolist()
            allele_col = input_df["allele"].tolist()          
            input_df["tmp_el"] = axelf_obj.calculate_ic50_score(peptide_col, allele_col)
            input_df["rank_el"] = input_df.apply(lambda row: row["tmp_el"][0], axis=1)
            input_df["ic50"] = input_df.apply(lambda row: row["tmp_el"][1], axis=1)
            input_df.drop("tmp_el", inplace=True, axis=1)

            # Calculating Axel-F    
            input_df["axelf"] = input_df.apply(lambda row: axelf_obj.calculate_axelf_score(row["tpm"], row["ic50"]), axis=1)
            
        if arg_dict['tpm_source'] == "tcga" :      
            # Source file should contain [Peptidess, geneName, Alleles]
            input_df['tcga'] = input_df.apply(lambda row: axelf_obj.get_tcga(row["gene name"], arg_dict['cancer_type']), axis=1)
            input_df["tmp_el"] = input_df.apply(lambda row: axelf_obj.calculate_ic50_score(row["peptide"], row["allele"]), axis=1)
            input_df["rank_el"] = input_df.apply(lambda row: row["tmp_el"][0], axis=1)
            input_df["ic50"] = input_df.apply(lambda row: row["tmp_el"][1], axis=1)
            input_df.drop("tmp_el", inplace=True, axis=1)

            # Calculating Axel-F    
            input_df["axelf"] = input_df.apply(lambda row: axelf_obj.calculate_axelf_score(row["tcga"], row["ic50"]), axis=1)
        
    if file_extension == ".fasta" :
        if arg_dict['tpm_source'] == "custom" :
            # Create lists
            peptides = sequence_to_kmers(arg_dict['input_file'], int(arg_dict['peptide_length']))
            alleles = [arg_dict['allele_list']] * len(peptides)
            tpms = [arg_dict['tpm']] * len(peptides)

            # Combine all the lists into a dataframe
            input_df = pd.DataFrame(list(zip(peptides, alleles, tpms)), columns =['peptide', 'allele', 'tpm'])
            peptide_col = input_df["peptide"].tolist()
            allele_col = input_df["allele"].tolist()          
            input_df["tmp_el"] = axelf_obj.calculate_ic50_score(peptide_col, allele_col)
            input_df["rank_el"] = input_df.apply(lambda row: row["tmp_el"][0], axis=1)
            input_df["ic50"] = input_df.apply(lambda row: row["tmp_el"][1], axis=1)
            input_df.drop("tmp_el", inplace=True, axis=1)

            # Calculating Axel-F    
            input_df["axelf"] = input_df.apply(lambda row: axelf_obj.calculate_axelf_score(row["tpm"], row["ic50"]), axis=1)

        if arg_dict['tpm_source'] == "tcga" :
            # Create lists
            peptides = sequence_to_kmers(arg_dict['input_file'], int(arg_dict['peptide_length']))
            alleles = [arg_dict['allele_list']] * len(peptides)
            gene_names = [arg_dict['gene_name']] * len(peptides)
            
            # Combine all the lists into a dataframe
            input_df = pd.DataFrame(list(zip(peptides, alleles, gene_names)), columns =['peptide', 'allele', 'gene name'])
            input_df['tcga'] = input_df.apply(lambda row: axelf_obj.get_tcga(row["gene name"], arg_dict['cancer_type']), axis=1)
            peptide_col = input_df["peptide"].tolist()
            allele_col = input_df["allele"].tolist()          
            input_df["tmp_el"] = axelf_obj.calculate_ic50_score(peptide_col, allele_col)
            input_df["rank_el"] = input_df.apply(lambda row: row["tmp_el"][0], axis=1)
            input_df["ic50"] = input_df.apply(lambda row: row["tmp_el"][1], axis=1)
            input_df.drop("tmp_el", inplace=True, axis=1)

            # Calculating Axel-F    
            input_df["axelf"] = input_df.apply(lambda row: axelf_obj.calculate_axelf_score(row["tcga"], row["ic50"]), axis=1)
            
    # sorted_df = sort_axelf_result(input_df, arg_dict['tpm_source'])
    
    print("-----------------------")
    # print(sorted_df)
    print(input_df)
    print("-----------------------")

    # return sorted_df
    return input_df

if __name__ == "__main__" :
    # CSV + custom tcga
    main("/tests/axelf_params")

    # CSV + tcga
    main("/tests/axelf_params")

    # FASTA + custom_tcga
    main("/tests/axelf_params")

    # FASTA + tcga
    main("/tests/axelf_params")