import os
import sys
import pandas as pd
from pathlib import Path
METHOD_DIR = str(Path(__file__).absolute().parent) + '/method/'
method_base_dirs = (
    'axelf-executable',
    'netmhcpan-4.0-executable',
    'netmhcpan-4.1-executable',
    'allele-info',
    'iedbtools-utilities',
)
for method_base_dir in method_base_dirs :
    sys.path.insert(0, str(METHOD_DIR) + method_base_dir)
from argument_parser import ArgumentParser
from axelf_executable import Axelf

USAGE_INFO = "" \
            "usage: axelf.py [-h] [-p] inputFile -s [TPM_SOURCE] [-a " \
            "[ALLELE_LIST]] [-c [CANCER_TYPE]] [-l [PEPTIDE_LENGTH]] " \
            "[-g [GENE_NAME]] [-o [OUTPUT]]\n\n" \
            "Please use the following command for detailed usage information: \n" \
            "\t> axelf.py --help\n"

def main() :
    '''
    Examples :
        python run_axelf.py tests/data/input/sample_input.csv -s custom -a "HLA-A*01:01"
        python run_axelf.py tests/data/input/sample_input.fasta -s custom -a "HLA-A*01:03" -t 3.5 -l 9
        python run_axelf.py tests/data/input/sample_input.fasta -s tcga -a "HLA-A*01:03" -t 3.5 -l 9 -g TIGAR -c BRCA
    '''
    args = ArgumentParser().parse_arguments()
    axelf_obj = Axelf()

    # CSV with header : [peptide, allele, tpm]
    input_file = args.inputFile[0].name
    file_extension = input_file.split('.')[-1]
    
    input_df = None
    
    # Validate Inputs
    if file_extension == "csv" :
        check_flags_options(args, "csv")
        # axelf_obj.validate_csv(input_file, args.tpm_source)
        
        # Read input file
        input_df = get_input_file(input_file, "csv")

        # Source file should contain [Peptide, Alleles, TPM]
        tcol = "tpm" if (args.tpm_source != "tcga") else "tcga"

        # If allele is given, this should override the alleles in the user's input
        if args.allele_list :
            alleles = args.allele_list.split(',')

            # With single allele, entire column should be overriden with this allele
            if len(alleles) == 1 :
                try :
                    input_df["allele"] = input_df["allele"].map(lambda x : alleles[0])
                except :
                    input_df["allele"] = [alleles[0]] * (len(input_df))
            
                # print(input_df.to_string())
            elif len(alleles) == len(input_df) :
                input_df["allele"] = alleles
            else :
                raise ValueError("Please provide a single allele or %s alleles separated by comma." %(len(input_df)))
        else :
            # Allele list is not given and yet CSV file lacks allele column
            if 'allele' not in input_df.columns :
                raise ValueError("Please provide a single allele or %s alleles separated by comma." %(len(input_df)))

        if args.tpm :
            try :
                input_df["tpm"] = input_df["tpm"].map(lambda x : float(args.tpm))
            except :
                input_df["tpm"] = [float(args.tpm)] * len(input_df)
        else :
            if 'tpm' not in input_df.columns and args.tpm_source != "tcga":
                raise ValueError("Please provide TPM value by using the '-t/--tpm' flag. " \
                                 "If you are planning to retrieve TPM value from a different source, " \
                                 "please specify TPM source flag by using '-s/--tpm-source'. \n%s" %(USAGE_INFO))
        
        if args.tpm_source == "tcga" :
            if args.cancer_type == None : raise ValueError("Must specify cancer type if you want to use tcga data.")
            try :
                # 'gene name' column is provided
                if args.gene_name : input_df["gene name"] = [args.gene_name] * len(input_df)
                input_df['tcga'] = input_df.apply(lambda row: axelf_obj.get_tcga(row["gene name"], args.cancer_type), axis=1)
            except :
                # user provided a gene
                input_df["gene name"] = [args.gene_name] * len(input_df)
                input_df['tcga'] = input_df.apply(lambda row: axelf_obj.get_tcga(row["gene name"], args.cancer_type), axis=1)
    
    if file_extension == "fasta" :
        check_flags_options(args, "fasta")
        axelf_obj.validate_fasta(input_file)
        
        # Read input file
        fasta_input = get_input_file(input_file, "fasta")
        
        # Error handling : when fasta sequence is too short
        if len(fasta_input[1]) < int(args.peptide_length) : 
            raise ValueError("Provided FASTA sequence is too short. Please provide sequence with minimum length of 8.")

        # Divide peptides into kmers
        peptides = sequence_to_kmers(fasta_input[1], int(args.peptide_length))
        alleles = [args.allele_list] * len(peptides)
        
        # Source file should contain [Peptide, Alleles, TPM]
        tcol = "tpm" if (args.tpm_source != "tcga") else "tcga"

        if args.tpm_source == "custom" :
            # Create lists
            tpms = [float(args.tpm)] * len(peptides)

            # Combine all the lists into a dataframe
            input_df = pd.DataFrame(list(zip(peptides, alleles, tpms)), columns =['peptide', 'allele', 'tpm'])
        
        if args.tpm_source == "tcga" :
            # Create and combine all the lists into a dataframe
            gene_names = [args.gene_name] * len(peptides)
            input_df = pd.DataFrame(list(zip(peptides, alleles, gene_names)), columns =['peptide', 'allele', 'gene name'])
            input_df['tcga'] = input_df.apply(lambda row: axelf_obj.get_tcga(row["gene name"], args.cancer_type), axis=1)

    # Check peptide lengths are at least 8
    input_df_copy = input_df.copy()
    invalid_peptides = []
    invalid_idx = []
    for row_tuple in input_df.itertuples(name=None): 
        if len(row_tuple[1]) < 8 :
            invalid_peptides.append(row_tuple[1])
            invalid_idx.append(row_tuple[0])
            
    input_df_copy.drop(index=invalid_idx, inplace=True)

    if len(input_df_copy) == 0 :
        raise ValueError("No results available. Please verify that provided sequences are at least length of 8.")

    input_df = input_df_copy

    # Fil out common data
    input_df["tmp_el"] = axelf_obj.calculate_ic50_score(input_df["peptide"].tolist(), input_df["allele"].tolist())
    input_df["rank_el"] = input_df.apply(lambda row: row["tmp_el"][0], axis=1)
    input_df["ic50"] = input_df.apply(lambda row: row["tmp_el"][1], axis=1)
    input_df.drop("tmp_el", inplace=True, axis=1)
        
    # Calculating Axel-F    
    input_df["axelf"] = input_df.apply(lambda row: axelf_obj.calculate_axelf_score(float(row[tcol]), float(row["ic50"])), axis=1)
        
    sorted_df = sort_axelf_result(input_df, args.tpm_source)

    if args.output :
        sorted_df.to_csv(args.output, sep='\t', index=False, encoding='utf-8')
    else :
        if 0 < len(invalid_peptides) :
            invalid_peptide_str = ""
            if len(invalid_peptides) == 1 :
                invalid_peptide_str = invalid_peptides[0]
            else :
                invalid_peptide_str = ", ".join(invalid_peptides)
            print("ATTENTION: Sequence(s) %s is/are too short for at least one of the length selection, " \
            "and has therefore not been included in any prediction." %(invalid_peptide_str))

        print(sorted_df.to_string(index=False))


def check_flags_options(args, file_type):
    if file_type == "csv" :
        # Length flag should be disabled
        if args.peptide_length :
            raise ValueError("Please remove peptide length flag (-l / --peptide-length). For CSV input, peptide length is not needed." \
                            "\n\n%s" %(USAGE_INFO))
    else :
        # Length flag MUST be specified
        if not args.peptide_length :
            raise ValueError("Peptide length is missing. For FASTA input, peptide length must be specified." \
                            "\n\n%s" %(USAGE_INFO))
        
        # Allele flag MUST be specified
        if not args.allele_list :
            raise ValueError("Allele is missing. For FASTA input, allele(s) must be provided." \
                            "\n\n%s" %(USAGE_INFO))

    # If TCGA flag is set, both Cancer Type and Gene Type should be specified
    if args.tpm_source == "tcga" :
        fname = args.inputFile[0].name
        with open(fname, 'r') as f :
            header = f.readline()
            
            if 'gene name' in header and 'cancer type' not in header :
                if not args.cancer_type :
                    raise ValueError("Cancer type is not found. Please specify cancer type." \
                                "\n\n%s" %(USAGE_INFO))
            
            if 'cancer type' in header and 'gene name' not in header :
                if not args.gene_name :
                    raise ValueError("Gene name is not found. Please specify gene name." \
                                "\n\n%s" %(USAGE_INFO))
                
            if 'cancer type' not in header and 'gene name' not in header :
                if not args.cancer_type and not args.gene_name :
                    raise ValueError("To use TCGA data, cancer type and gene name must be specified." \
                                "\n\n%s" %(USAGE_INFO))
                
def sequence_to_kmers(sequence, kmer_length):
    kmers = []
        
    for i in range(0, len(sequence) - kmer_length + 1) :
        seq = sequence[i:i+kmer_length]
        kmers.append(seq)

    return kmers

def sort_axelf_result(df, tpmSource):
    # Sort by Axelf score
    sorted_df = df.sort_values("axelf", ascending=False)
    sorted_df.reset_index(drop=True, inplace=True)
    col_order = []

    if tpmSource == "tcga" :
        sorted_df["tcga"] = sorted_df["tcga"].map('{:.2f}'.format)
        col_order = ['peptide', 'allele', 'gene name', 'tcga', 'rank_el', 'ic50', 'axelf']
    else :
        # manual or from_file
        sorted_df["tpm"] = pd.to_numeric(sorted_df["tpm"], downcast="float")
        sorted_df["tpm"] = sorted_df["tpm"].map('{:.2f}'.format)
        col_order = ['allele', 'peptide', 'tpm', 'rank_el', 'ic50', 'axelf']

    # Formatting common columns
    sorted_df["rank_el"] = sorted_df["rank_el"].map(
            lambda x : rankELformat(x)
        )
    sorted_df["ic50"] = sorted_df["ic50"].map('{:.3f}'.format)
    sorted_df["axelf"] = sorted_df["axelf"].map('{:.5f}'.format)
    sorted_df = sorted_df.reindex(columns=col_order)

    if tpmSource == "tcga" :
        sorted_df.rename(
            columns={
                    'peptide': 'Peptide', 
                    'allele': 'Allele',
                    'gene name': 'Gene Name',
                    'tcga': 'TPM',
                    'rank_el': 'Rank EL',
                    'ic50': 'Rank mapped to IC50',
                    'axelf': 'Axelf'
            }, inplace=True)
                    
    else :
        sorted_df.rename(
            columns={
                    'peptide': 'Peptide', 
                    'allele': 'Allele',
                    'tpm': 'TPM',
                    'rank_el': 'Rank EL',
                    'ic50': 'Rank mapped to IC50',
                    'axelf': 'Axelf'
            }, inplace=True)    
        
    return sorted_df

def rankELformat(x) :
    leading_zeros = 0

    # Find leading zeros after decimal point
    if 0 < str(x)[::-1].find('.') :
        right_side_number = str(x).split(".")[1]
        leading_zeros = len(right_side_number) - len(right_side_number.lstrip('0'))
        
    # Format to scientific notation
    if str(x)[::-1].find('.') < 0 :
        # ex) 1e-04 (No '.' was found, thus returns -1)
        return str(x)
    elif 0 < str(x)[::-1].find('e') :
        # ex) 1.1e-04 ('.' was found. Thus use 'e' to detect scientific notation)
        return str(x)
    elif 2 < leading_zeros :
        # ex) 0.0004 (3 leading zeros, thus reformat it)
        return "{:.3e}".format(x)
    else :
        return "{:.3f}".format(x)


def get_input_file(fpath, extension):
    contentlist = []
    with open(fpath, 'r') as f :
        contents = f.readlines()
        [contentlist.append(line.strip('\r\n').split(",")) for line in contents]
    
    if extension == "csv" :
        return pd.DataFrame(contentlist[1:], columns=contentlist[0])
    else :
        # print("Header : %s" %(contentlist[0]))
        # print(''.join([''.join(seq[0]) for seq in contentlist[1:]]))
        return [contentlist[0][0], ''.join([''.join(seq[0]) for seq in contentlist[1:]])]


if __name__ == "__main__" :
    main()
