#!/usr/bin/env python
from __future__ import print_function
from bisect import bisect, bisect_left
import os
import json
import tempfile
import traceback, time
import math
import configparser
import sys
import re
import csv
import pickle
import logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d:%H:%M:%S',)

from length_rescaling import calculate_length_rescaled

# adding all methods to the python path
script_dir = os.path.dirname(os.path.realpath(__file__))
methods_dir = script_dir + '..//method'
methods_base_dirs = (
    'allele-info', 
    'iedbtools-utilities', 
    'mhcii-tepitope-predictor',
    'mhcii-comblib-predictor',
    'mhcii-comblib-percentile-data',
    #'mhcii-netmhciipan-percentile-data',
    #'mhcii-nnalign-percentile-data',
    'mhcii-smmalign-percentile-data',
    'mhcii-tepitope-percentile-data',
    'mhcii-netmhciipan-3.2-percentile-data',
    'mhcii-netmhcii-2.3-percentile-data',
    'mhcii-predictor-data',
    #'netmhciipan-3.1-executable',
    'netmhciipan-3.2-executable',
    'netmhcii-1.1-executable',
    #'netmhcii-2.2-executable',
    'netmhcii-2.3-executable',
    'netmhciipan-4.1-executable',
    'mhcii-netmhciipan-4.1-ba-percentile-data',
    'mhcii-netmhciipan-4.1-el-percentile-data',
    'netmhciipan-4.2-executable',
    'mhcii-netmhciipan-4.2-ba-percentile-data',
    'mhcii-netmhciipan-4.2-el-percentile-data',
    'netmhciipan-4.3-executable',
    'mhcii-netmhciipan-4.3-ba-percentile-data',
    'mhcii-netmhciipan-4.3-el-percentile-data',
    'mhcii-predictor-data',
    'mhciinp'
)
for method_base_dir in methods_base_dirs:
    sys.path.append(methods_dir + '/' + method_base_dir)

from mhcii_predictor_data import get_method_allele_list

from mhcii_predictor import MhciiPredictor
from common_bio import Proteins

from mhciinp_package.mhc_ii_ligand_predictor import MHCIILigandPredict

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


class InputError(Exception):
    """Exception raised for errors in the input."""

    def __init__(self, value):
            self.value = value

    def __str__(self):
            return self.value

class UnexpectedInputError(Exception):
    """Exception raised for errors in the input."""

    def __init__(self, value):
            self.value = value

    def __str__(self):
            return self.value


class PredictionTable:
    """Generates a table from a set of sequences and predictions"""
    def __init__(self):
        self.row_data = []

    def add_rows_binding(self, allele, pep_length, proteins, scores, method):
        logging.debug((allele, pep_length, proteins, scores))
        if method == "consensus3" or method == "IEDB_recommended":
            for (i,(sequence, predictions)) in enumerate(zip(proteins.sequences,scores)):
                dummy = []
                for (k, (dummy)) in enumerate(predictions):
                    peptide_sequence = sequence[k : k + pep_length]
                    try:
                        temp_index = peptide_sequence.index(peptide_sequence)
                    except:
                        raise ValueError("Core sequence and peptide sequence doesn't match.")
                    dummy_list = []
                    dummy_list.append(allele)
                    # sequence number
                    dummy_list.append(i+1)
                    # peptide start position
                    dummy_list.append(k+1)
                    # peptide end position
                    dummy_list.append(k+pep_length)
                    # peptide length
                    dummy_list.append(str(pep_length))                    
                    dummy_list.append(peptide_sequence)
                    for item in dummy:
                        dummy_list.append(item)
                    self.row_data.append(tuple(dummy_list))
                    
        else:
            for (i,(sequence, predictions)) in enumerate(zip(proteins.sequences,scores)):
                for (k, (core, prediction, rank)) in enumerate(predictions):
                    peptide_sequence = sequence[k : k + pep_length]
                    core_sequence = core
                    #try:
                    logging.debug(peptide_sequence)
                    logging.debug(core_sequence)
                    logging.debug(pep_length)
                    #temp_index = peptide_sequence.index(core_sequence)
                    #except:
                    #    raise ValueError("Core sequence and peptide sequence doesn't match.")
                    peptide_source = i+1
                    peptide_start = k + 1
                    peptide_end = k + pep_length
                    self.row_data.append((allele, peptide_source, peptide_start, peptide_end, str(pep_length), core_sequence, peptide_sequence, prediction, rank))

    def format_binding(self, proteins, results, method):
        for(pep_length, allele, scores) in results:
            if 'H2' not in allele and 'BoLA' not in allele:
                allele = "HLA-%s" % allele.replace('HLA-','').replace("-", "/")
            self.add_rows_binding(allele,pep_length, proteins, scores, method)
        if not method.startswith('netmhciipan_el') and not method.startswith('netmhciipan_ba'):
            self.row_data = add_column_adj_rank(self.row_data, method)
        return self.row_data

    def format_immunogenicity(self, proteins, results, method):
        if method == "immunogenicity":
            # 'result': [[1, ['LCMV Armstrong, Protein GP'], 'MGQIVTMFEALPHII', 1, '15', 95.714, 'TMFEALPHI'], ...]
            #  To: ["seq_num",  "start", "end", "peptide_length", "core", "peptide", "score", ]
            self.row_data.append([
                "seq_num",
                "peptide",
                "start",
                "end",
                "peptide_length",
                "core",
                "score",
            ])
            for(seq_num, seq_name, peptide, start, end, score, core) in results:
                self.row_data.append([seq_num, peptide, start, end, 15, core, score, ])
        elif method == "7-allele":
            #Protein Number	Protein Description	Peptide	Start	End	Median Percentile Rank	HLA-DRB1:03:01	HLA-DRB1:07:01	HLA-DRB1:15:01	HLA-DRB3:01:01	HLA-DRB3:02:02	HLA-DRB4:01:01	HLA-DRB5:01:01
            #1	sp|P01588|EPO_HUMAN Erythropoietin OS=Homo sapiens GN=EPO PE=1 SV=1	MGVHECPAWLWLLLS	1	15	64	64	78	36	33	58	68	74
            #1	sp|P01588|EPO_HUMAN Erythropoietin OS=Homo sapiens GN=EPO PE=1 SV=1	ECPAWLWLLLSLLSL	5	19	44	44	12	7.1	56	80	49	8.7

            # 'result': [[1, ['LCMV Armstrong, Protein GP'], 'MGQIVTMFEALPHII', 1, '15', 18.0, 27.0, 14.0, 7.0, 18.0, 35.0, 13.0, 30.0], ...]
            #  To: ["seq_num",  "start", "end", "peptide_length", "core", "peptide", "score", ]
            self.row_data.append([
                "seq_num",
                "peptide",
                "start",
                "end",
                "peptide_length",
                "median_percentile",
            ])
            for row in results:
                seq_num, seq_name, peptide, start, end, median_percentile = row[:6]
                self.row_data.append([seq_num, peptide, start, end, 15, median_percentile, ])        
        return self.row_data

def print_input_page(sequence):
        config_parser = configparser.ConfigParser()
        config_parser.read("../setup.cfg")
        html_path=config_parser.get("path", "html")
        template = open(html_path + "/html/mhc_II_binding.seq","r").read()
        print( "Content-Type: text/html")
        print( "")
        print( template % sequence)


def sort_table(method, table_rows):
    if not table_rows:
        return table_rows
    method = method.lower()
    if not method in ['consensus3', 'iedb_recommended', ]:
        table_rows = sorted(table_rows, key=lambda tup: float(tup[-1]))
    if method.lower() in ['consensus3', 'iedb_recommended', ]:
        table_rows = sorted(table_rows, key=lambda tup: float(tup[7]))
    elif method in ['sturniolo', ]:
        table_rows = sorted(table_rows, key=lambda tup: float(tup[7]), reverse=True)
        table_rows = sorted(table_rows, key=lambda tup: float(tup[9]))
    elif method in ['netmhciipan_el', ]:
        table_rows = sorted(table_rows, key=lambda tup: float(tup[7]), reverse=True)
        table_rows = sorted(table_rows, key=lambda tup: float(tup[8]))
    elif method in ['netmhciipan_ba', ]:
        table_rows = sorted(table_rows, key=lambda tup: float(tup[7]))
        table_rows = sorted(table_rows, key=lambda tup: float(tup[8]))
    else:
        table_rows = sorted(table_rows, key=lambda tup: float(tup[7]))
        table_rows = sorted(table_rows, key=lambda tup: float(tup[9])) 
    return table_rows

def form_valid(form):

    method = form["pred_method"]
    form_alleles = form["allele"]
    form_alleles = form_alleles.replace("HLA-", "")       
    #Validate method selection
    method_list = {
            "consensus3":'consensus',
            "IEDB_recommended": 'netmhciipan_el-4.1',
            "IEDB_recommended_epitope": 'netmhciipan_el-4.1',
            "IEDB_recommended_binding": 'netmhciipan_ba-4.1',
            "NetMHCIIpan":'NetMHCIIpan',
            "nn_align":"nn_align-2.3",
            "smm_align":"smm_align",
            "sturniolo":'tepitope',
            "comblib":"comblib",
            "netmhciipan_el":'netmhciipan_el-4.1',
            "netmhciipan_ba":'netmhciipan_ba-4.1',
            "netmhciipan":'netmhciipan_ba-4.1',
            "netmhciipan_el-4.2":'netmhciipan_el-4.2',
            "netmhciipan_ba-4.2":'netmhciipan_ba-4.2',
            "netmhciipan_el-4.3":'netmhciipan_el-4.3',
            "netmhciipan_ba-4.3":'netmhciipan_ba-4.3',
        }
    if method not in method_list:
        raise UnexpectedInputError("Selected prediction method '%s' does not exist. To see all available methods, please type command:\n$ python mhc_II_binding.py method\n" % method)
    else:
        method = method_list[method]
            
    allele_list = form_alleles.split(',')  
    length = form['length']
    length_list = length if ( type(length) is list) else  length.split(',')
    length_list = [item.strip() for item in length_list]
    # get_min_length
    min_length = min([int(i.split('-')[0]) for i in length_list])

    miiad = MHCIIAlleleData()
    available_allele_names = miiad.get_allele_names(method_name = method)
    Invalid_alleles = [allele for allele in allele_list if allele not in available_allele_names]
    if Invalid_alleles:
        raise UnexpectedInputError('Invalid allele name "%s" for method "%s" was found. To see all available alleles, please type command:\n$ python mhc_II_binding.py allele' % (', '.join(Invalid_alleles), form["pred_method"]))
    proteins=Proteins(form["sequences"])
    for sequence in proteins.sequences:
        if len(sequence.strip())<11:
            raise UnexpectedInputError('The length of the input protein sequence must be at least 11, please check your sequence: "%s".' % sequence)
        elif len(sequence.strip())<min_length:
            raise UnexpectedInputError('The length of the input protein sequence should be no less than the input length, please check your sequence: "%s".' % sequence)
    # Check if string is DNA sequence
    for protein_sequence in proteins.sequences:
        if re.match('^[ACGT]+$', protein_sequence.upper()):                 
            return 'DNA_sequence_input'
    return True
   
def get_allele_length_combo_list(allele_list, length_list):
    # if one of the lengths is 'asis', all have to be 'asis'
    if 'asis' in length_list or 'as_is' in length_list:
        length_list = ['asis',]*len(allele_list)
        return list(zip(allele_list, length_list))

    allele_length_combo_list = []
    for lengths in length_list:
        # lengths should be like "15" or "12-18" or "10,12-18,20" or "asis"   
        lengths = str(lengths).strip()
        if not re.match(r'^[\d,\s-]+$', lengths):
            raise ValueError('invalid length input: %s' % lengths)
        # length should be like "15" or "12-18"
        for length in lengths.split(','):
            length = length.strip()
            if "-" in length:
                if length.count("-") > 1:
                    raise ValueError('invalid length input: %s' % length)
                len_start, len_end = [l.strip() for l in length.split("-")]
                if not (len_start.isdigit() and len_end.isdigit()):
                    raise ValueError('invalid length input: %s' % length)
                for length in range(int(len_start),int(len_end)+1):
                    for allele in allele_list:
                        allele_length_combo_list.append((allele,length))
            else:
                if not length.isdigit():
                    raise ValueError('invalid length input: %s' % length)
                length = int(length)
                for allele in allele_list:
                    allele_length_combo_list.append((allele,length))
    allele_length_combo_list = list(set(allele_length_combo_list))
    allele_length_combo_list.sort()
    return allele_length_combo_list
 
def add_column_adj_rank(test_data, method=None):
    logging.debug( 'add_column_adj_rank for method: %s' % method)
    result_list = []
    method = method.lower()
    for row in test_data:

        if method == 'iedb_recommended':
            logging.debug('method == iedb_recommended')
            logging.debug(repr(row))
            row = list(row)
            #            header = ('allele', 'seq_num', 'start', 'end', 'length', 'method', 'peptide',  'percentile_rank', unit, 'comblib_core', 'comblib_score', 'comblib_rank','comblib_adjusted_rank', 'smm_align_core', 'smm_align_ic50', 'smm_align_rank', 'smm_align_adjusted_rank', 'nn_align_core', 'nn_align_ic50', 'nn_align_rank', 'nn_align_adjusted_rank', 'netmhciipan_core', 'netmhciipan_ic50', 'netmhciipan_rank', 'netmhciipan_adjusted_rank', 'sturniolo_core', 'sturniolo_score', 'sturniolo_rank', 'sturniolo_adjusted_rank')            
            for i in [21,  18,  15,  12,  9,  6]:
                # for each rank
                row.insert(i+1, calculate_length_rescaled(length=row[4],rank=row[i]))
            logging.debug(repr(row))
            result_list.append(tuple(row))
        elif method == 'consensus3':
            row = list(row)
            #header = ('allele', 'seq_num', 'start', 'end', 'length', 'peptide',  'percentile_rank', unit, 'comblib_core', 'comblib_score', 'comblib_rank', 'comblib_adjusted_rank', 'smm_align_core', 'smm_align_ic50', 'smm_align_rank', 'smm_align_adjusted_rank', 'nn_align_core', 'nn_align_ic50', 'nn_align_rank', 'nn_align_adjusted_rank')
            for i in [18, 15, 12, 9, 6]:
                # for each rank
                row.insert(i+1, calculate_length_rescaled(length=row[4],rank=row[i]))
            result_list.append(tuple(row))
        else:
            #logging.debug('row:%s' % row)
            row = list(row)
            row.append(calculate_length_rescaled(length=row[4],rank=row[-1]))
            result_list.append(tuple(row))
    return result_list

def main(form):

    if "sequence" in form:
        print_input_page(form["sequence"])
    elif "sequence_format" not in form:
        print_input_page("")
    else:
        try:
            input_params = form_valid(form)
            if not input_params:
                return

            proteins=Proteins(form["sequences"])

            method_dict = {
                "consensus3":'consensus',
                "IEDB_recommended": 'netmhciipan_el',
                "IEDB_recommended_epitope": 'netmhciipan_el',
                "IEDB_recommended_binding": 'netmhciipan_ba',
                "NetMHCIIpan":'NetMHCIIpan',
                "nn_align":"nn_align",
                "comblib":"comblib",
                "smm_align":"smm_align",
                "sturniolo":'tepitope',
                "netmhciipan_el":'netmhciipan_el',
                "netmhciipan_ba":'netmhciipan_ba',
                "netmhciipan_el-4.2":'netmhciipan_el-4.2',
                "netmhciipan_ba-4.2":'netmhciipan_ba-4.2',
                "netmhciipan_el-4.3":'netmhciipan_el-4.3',
                "netmhciipan_ba-4.3":'netmhciipan_ba-4.3',
            }
            method = method_dict[form["pred_method"]]

            
            alleles = form["allele"]
            alleles = alleles.replace('HLA-','').replace("/","-")  
            allele_list = alleles.split(',')  
            length = form['length']
            length_list = length if ( type(length) is list) else  length.split(',')
            length_list = [item.strip() for item in length_list]


            allele_length_combo_list = get_allele_length_combo_list(allele_list, length_list)
            allele_list,length_list = zip(*allele_length_combo_list)

            pre=MhciiPredictor(method, allele_list, length_list)      
                              
            mhc_scores = pre.predict(proteins.sequences)
             

            table = PredictionTable(form)
            table_rows = table.format_binding(proteins, mhc_scores)
            table_rows = sort_table(form["pred_method"], table_rows)
        
            con_status = {}
            con_list = get_method_allele_list('consensus').strip().split("\n")
            for con_element in con_list:
                con_arr = con_element.split("\t")
                con_status[con_arr[0]] = con_arr[1:]
        except Exception as inst:
            sys.exit(inst)
        else:
            if form["pred_method"] == "consensus3":
                print( "allele\tseq_num\tstart\tend\tlength\tpeptide\tconsensus_percentile_rank\tadjusted_consensus_percentile_rank\tcomblib_core\tcomblib_score\tcomblib_rank\tadjusted_comblib_rank\tsmm_align_core\tsmm_align_ic50\tsmm_align_rank\tadjusted_smm_align_rank\tnn_align_core\tnn_align_ic50\tnn_align_rank\tadjusted_nn_align_rank\tsturniolo_core\tsturniolo_score\tsturniolo_rank\tadjusted_sturniolo_rank")
            elif form["pred_method"] == "IEDB_recommended":
                print( "allele\tseq_num\tstart\tend\tlength\tmethod\tpeptide\tconsensus_percentile_rank\tadjusted_consensus_percentile_rank\tcomblib_core\tcomblib_score\tcomblib_rank\tadjusted_comblib_rank\tsmm_align_core\tsmm_align_ic50\tsmm_align_rank\tadjusted_smm_align_rank\tnn_align_core\tnn_align_ic50\tnn_align_rank\tadjusted_nn_align_rank\tnetmhciipan_core\tnetmhciipan_ic50\tnetmhciipan_rank\tadjusted_netmhciipan_rank\tsturniolo_core\tsturniolo_score\tsturniolo_rank\tadjusted_sturniolo_rank")
            elif form["pred_method"] in ["sturniolo", ]:
                print( "allele\tseq_num\tstart\tend\tlength\tcore_peptide\tpeptide\tscore\tpercentile_rank\tadjusted_rank")
            elif form["pred_method"] in ["netmhciipan_el"]:
                print( "allele\tseq_num\tstart\tend\tlength\tcore_peptide\tpeptide\tscore\tpercentile_rank")
            elif form["pred_method"] in ["netmhciipan_ba"]:
                print( "allele\tseq_num\tstart\tend\tlength\tcore_peptide\tpeptide\tic50\tpercentile_rank")
            else:
                print( "allele\tseq_num\tstart\tend\tlength\tcore_peptide\tpeptide\tic50\tpercentile_rank\tadjusted_rank")
            
            for table_row in table_rows:
                table_row = list(table_row)
                
                if 'H2' not in table_row[0]:
                    table_row[0] = "HLA-%s" % table_row[0].replace("-", "/")
                logging.debug(form["pred_method"])    
                if form["pred_method"] == "consensus3":
                    print ('\t'.join(map(str, tuple(table_row))))
                elif form["pred_method"] == "IEDB_recommended":
                    method_used = table_row[-1].upper().replace("STURNIOLO","Sturniolo").replace('NETMHCIIPAN','NetMHCIIpan')
                    method_list = method_used.split(',')
                    if len(method_list) <= 1:  table_row.insert(5,method_used)
                    else:
                        method_used = "Consensus("+method_used+")"
                        table_row.insert(5,method_used)
                    logging.debug(len(table_row))
                    print( '\t'.join(map(str, tuple(table_row[:-1]))))
                else:
                    # 6 decimal for netmhciipan_el scores
                    if form["pred_method"] == "netmhciipan_el":
                        table_row[-2] = '%.6f' % float(table_row[-2])
                    # 2 decimal for ic50/scores: 1.607724271704614  ->  1.61
                    elif form["pred_method"] == "netmhciipan_ba":
                        table_row[-2] = '%.2f' % float(table_row[-2])
                    else:
                        table_row[-3] = '%.2f' % float(table_row[-3])
                    print("\t".join(map(str, tuple(table_row))))

            if input_params == 'DNA_sequence_input':
                eprint ('# Warning: Potential DNA sequence(s) found! This tool is intended to predict for amino acid sequences. Please double check your input fasta file.')

def commandline_help():
    print( " ________________________________________________________________________________________")
    print( "|****************************************************************************************|")
    print ("| * List all available commands.                                                         |")
    print( "| python mhc_II_binding.py                                                               |")
    print ("|________________________________________________________________________________________|")
    print( "| * List all available mhc_II prediction methods.                                        |")
    print ("| python mhc_II_binding.py method                                                        |")
    print( "|________________________________________________________________________________________|")
    print( "| * List all alleles.                                                                    |")
    print( "| python mhc_II_binding.py allele                                                        |")
    print( "|________________________________________________________________________________________|")
    print( "| * Make predictions given a file containing a list of sequences.                        |")
    print( "| python mhc_II_binding.py prediction_method_name allele_name input_sequence_file_name   |")
    print ("| Example: python mhc_II_binding.py consensus3 HLA-DRB1*03:01 test.fasta                 |")
    print ("|________________________________________________________________________________________|")
    print( "| * You may also redirect (pipe) the input file into the script.                         |")
    print( "| Example: echo -e test.fasta | python mhc_II_binding.py consensus3 HLA-DRB1*03:01       |")
    print( "|________________________________________________________________________________________|")

def commandline_method():
    '''Return all available prediction methods.'''
    print( )
    print ("Available methods are:")
    print ("----------------------")
    print( "comblib")
    print ("consensus3")
    print ("netmhciipan_el (version 4.1, IEDB_recommended_epitope)")
    print ("netmhciipan_ba (version 4.1, IEDB_recommended_binding)")
    print ("netmhciipan_el-4.3")
    print ("netmhciipan_ba-4.3")
    print ("netmhciipan_el-4.2")
    print ("netmhciipan_ba-4.2")
    print ("nn_align")
    print ("smm_align")
    print ("sturniolo")
    print()
 

def commandline_allele():
    '''Return all available alleles.'''
    looking_up_methods = ['consensus3', 'comblib', 'smm_align', 'nn_align', 'sturniolo']
    allele_info_for_printing = get_available_alleles(looking_up_methods)
    print(allele_info_for_printing)

if __name__ == '__main__':
    import select
    debug = False
    
    method = allele = infile = None
    
    if (len(sys.argv) == 1):
        commandline_help()
    elif ((len(sys.argv) == 2) and (sys.argv[1] == "method")):
        commandline_method()
    elif ((len(sys.argv) == 2) and (sys.argv[1] == "allele")):
        commandline_allele()
    else:
        method = sys.argv[1]
        if len(sys.argv) > 2:
            allele = sys.argv[2]
    
    # If there's input ready, do something, else do something
    # else. Note timeout is zero so select won't block at all.
    if not sys.stdin.isatty():
        stdin = sys.stdin.readline().strip()
        sys.argv.append(stdin)
    sys.argv = list(filter(None,sys.argv))

    if len(sys.argv) > 3:
        infile = sys.argv[3]

    if len(sys.argv) > 4:
        length = sys.argv[4]
    else:
        length = '15'
            
    if not method or not allele or not infile:
        eprint( "# To run the predction, you must specify method, allele and input file name.")
        exit(0)
    
    allele = allele.replace("_","-").replace("H-2","H2")
    with open(infile, 'r') as r_file:
        sequences = r_file.read()

    # for method NetMHCII
    if 'netmhciipan' not in method.lower():
        method = method.replace('netmhcii-2.3','nn_align').replace('netmhcii-1.1','smm_align').replace('netmhcii','nn_align').replace('IEDB_recommended_epitope','netmhciipan_el').replace('IEDB_recommended_binding','netmhciipan_ba').replace('IEDB_recommended','netmhciipan_el')   
    seq = [('sequence_format', 'auto'), ('sort_output', 'position_in_sequence'), ('cutoff_type', 'none'), ('allele', allele), ('sequence_file', infile), ('pred_method', method), ('sequences', sequences), ('length', length),]
 

   
    form = dict(seq)
    if not debug:
        main(form)
    else:
        calltime = str(time.ctime(time.time()))
        cgilog = open("logs/tool_log.txt","a")
        cgilog.write('\n--- '+ calltime +' ---\n')
        cgilog.write('mhc_II_binding.py\n')
        cgilog.flush()
        cgilog.write('Form keys: %d\n' % len(form.keys()))
        for key in form.keys():
            cgilog.write('%s - "%s" - "%s"\n' % (calltime,key, form[key]))
        cgilog.close()

        try:
            main(form)
        except Exception as inst:
            cgilog = open("logs/tool_log.txt","a")
            cgilog.write('XXX '+ calltime +' XXX\n')
            cgilog.write("EXCEPTION: '%s'\n" % str(inst))
            traceback.print_exc(None, cgilog)
            cgilog.close()
        else:
            cgilog = open("logs/tool_log.txt","a")
            cgilog.write('... '+ calltime +' ...\n')
            cgilog.close()



def generate_pep(sequences, names, length_list=[15], peptide_shift=None, peptide_spacing=5):
    from kmer import protein_to_peptide_kmer
    pkmer = protein_to_peptide_kmer()
    # peptide length
    k = 15
    # peptide shift is a synonym for peptide_spacing
    if peptide_shift:
        peptide_spacing = int(peptide_shift)
    # TODO add validation for peptide_spacing (overlapping)
    inc = peptide_spacing
    pep_dict = {}
    seq_num = 0
    protein_fasta = {}
    logging.info('Generating peptides and updating in seqs = ' + str(sequences))
    for seq, description in zip(sequences, names):
        # k mer for each sequence
        for k in length_list:
            seq_num += 1
            protein_fasta.setdefault(str(seq), []).append(description)
            pep_dict.update(pkmer.kmer(str(seq), seq_num, pepmer=k, inc=inc))
    return pep_dict, protein_fasta


def imm_prediction(method, sequences, names, length_list, peptide_shift=None, peptide_spacing=5):
    import pandas as pd
    from imscore import Immunogenicity, SevenAllele
    # peptide shift is a synonym for peptide_spacing
    if peptide_shift:
        peptide_spacing = int(peptide_shift)

    # method = 'immunogenicity'
    threshold = 100
    sorter = 'position'

    # Create a temporary directory using TemporaryDirectory
    with tempfile.TemporaryDirectory() as cd4_tmpdir:
        print(f"Temporary directory created: {cd4_tmpdir}")
        jobid = cd4_tmpdir
        # You can use this directory like a regular directory
        # Example: Creating a file inside the temporary directory
        temp_file_path = os.path.join(cd4_tmpdir, "tempfile.txt")
        
        with open(temp_file_path, 'w') as temp_file:
            temp_file.write("This is some temporary data.")
        
        # Verify the file exists in the temporary directory
        print(f"Temporary file created: {temp_file_path}")
        print(f"Contents of the file: {open(temp_file_path).read()}")
        pep_dictionary, fasta_protein = generate_pep(sequences, names, length_list, peptide_spacing)
        pep_df = pd.DataFrame(pep_dictionary.items(), columns=['details', 'peptide'])
        if method == 'immunogenicity':

            imm = Immunogenicity()
            out = imm.pep2txt(pep_dictionary, cd4_tmpdir, peptide_spacing)
            if out == None:
                return {
                    'err': "The immunogenicity model is not running, please report this to us with job id =" + str(jobid), }

            df_imm = pd.read_csv(out, sep="\t", low_memory=False)
            df_imm.columns = ['peptide', 'observed', 'Score', 'core']
            df_imm_mapped = pd.merge(df_imm, pep_df, on='peptide', how='right')
            df_imm_mapped['seq_num'], df_imm_mapped['position'], df_imm_mapped['sequence'] = df_imm_mapped[
                'details'].str.split('_').str
            df_imm_mapped['start'], df_imm_mapped['end'] = df_imm_mapped['position'].str.split('-').str
            df_imm_mapped['description'] = df_imm_mapped['sequence'].map(fasta_protein)
            df_imm_mapped['Imm_score'] = round((1 - df_imm_mapped['Score']) * 100, 4)

            df_imm_final = df_imm_mapped[['seq_num', 'description', 'peptide', 'start', 'end', 'Imm_score', 'core']]
            df_imm_final = df_imm_final.loc[(df_imm_final['Imm_score'] <= threshold)]
            df_imm_final[['seq_num', 'start', 'Imm_score']] = df_imm_final[['seq_num', 'start', 'Imm_score']].apply(
                pd.to_numeric)
            if sorter == 'position':
                df_imm_final = df_imm_final.sort_values(['seq_num', 'start'], ascending=[True, True])
            if sorter == 'percentile_rank':
                df_imm_final = df_imm_final.sort_values(['Imm_score'], ascending=[True])
            result = df_imm_final.values.tolist()
            df_imm_final['description'] = df_imm_final.description.apply(lambda x: ', '.join([str(i) for i in x]))
            #df_imm_final.to_csv(final_result_file, index=False)
            print()
            result_data = {'result': result, 'method': method, 'num_proteins': len(fasta_protein), 'num_15mer': len(pep_dictionary),
                    'threshold': threshold, 'jobid': jobid, 'sorter': sorter}
        # After the with block, the temporary directory and its contents are automatically deleted.
    
        elif method == '7-allele':
            allele7 = SevenAllele()
            pep_pred_file = allele7.do_prediction(pep_dictionary, cd4_tmpdir)
            #print('pep_pred_file')
            #print(pep_pred_file)
            try:
                f = open(pep_pred_file, 'r')
                f.close()
            except:
                print('Seven allele method is running for jobid = ' + str(jobid))
                return {'err': 'Seven allele prediction is not running. Please report us with job id = ' + str(jobid), }
            # pep_pred_file='/tmp/tmpe2RMSy_7allele_result.csv'
            df_7_allele = pd.read_csv(pep_pred_file)
            df_7_mapped = pd.merge(df_7_allele, pep_df, on='peptide', how='right')
            df_7_mapped['seq_num'], df_7_mapped['position'], df_7_mapped['sequence'] = df_7_mapped['details'].str.split(
                '_').str
            df_7_mapped['start'], df_7_mapped['end'] = df_7_mapped['position'].str.split('-').str
            df_7_mapped['description'] = df_7_mapped['sequence'].map(fasta_protein)
            df_7_final = df_7_mapped[
                ['seq_num', 'description', 'peptide', 'start', 'end', 'Median', 'HLA-DRB1*03:01', 'HLA-DRB1*07:01',
                'HLA-DRB1*15:01', 'HLA-DRB3*01:01', 'HLA-DRB3*02:02', 'HLA-DRB4*01:01', 'HLA-DRB5*01:01']]
            df_7_final = df_7_final.loc[(df_7_final['Median'] <= threshold)]
            df_7_final[['seq_num', 'start', 'Median']] = df_7_final[['seq_num', 'start', 'Median']].apply(pd.to_numeric)
            if sorter == 'position':
                df_7_final = df_7_final.sort_values(['seq_num', 'start'], ascending=[True, True])
            if sorter == 'percentile_rank':
                df_7_final = df_7_final.sort_values(['Median'], ascending=[True])
            result = df_7_final.values.tolist()
            df_7_final['description'] = df_7_final.description.apply(lambda x: ', '.join([str(i) for i in x]))
            result_data = {'result': result, 'method': method, 'num_proteins': len(fasta_protein), 'num_15mer': len(pep_dictionary),
                    'threshold': threshold, 'jobid': jobid, 'sorter': sorter}
        elif method == 'combined':
            result_data = {}
        else:
            raise ValueError(f"immunogenicity method must be one of [immunogenicity, 7-allele, combined], but '{method}' is given.")
    print("Temporary directory cleaned up.")
    return result_data

def get_seq_arg_for_mhciinp(fname, seq_file_type):
    """ Get sequence from file or string. """
    with open(fname, 'r') as r_file:
        sequence_text = r_file.read()
    seq_name_list = []
    seq_list = []
    seq_arg = ''
    if seq_file_type == 'fasta':
        input_sequences = sequence_text.split(">")
        for i in input_sequences[1:]:
            if (len(i) > 0):
                end_of_name = i.find("\n")
                print(end_of_name)
            seq_name = i[:end_of_name].rstrip()
            seq = i[end_of_name:].split()
            seq_name_list.append(seq_name)
            seq_list.append(''.join(seq))
        # Yan: ignore fasta seq name/info and use seq_num instead to be consistent with other tool
        #for i in range(len(seq_list)):
        #    seq_arg = seq_arg + '>' + seq_name_list[i] + '<' + seq_list[i] + '<'
    elif seq_file_type == 'peptides':
        seq_list = [x.strip() for x in filter(None,sequence_text.split('\n'))]
    else:
        raise ValueError('Invalid sequence file type: %s' % seq_file_type)
    seq_arg = '<'.join(seq_list)
    return seq_arg

def input_data_prediction(options, args):
    """ This version takes a file containing an peptide sequences as input."""

    from mhcii_predictor_data import get_method_allele_list

    from mhcii_predictor import MhciiPredictor
    from common_bio import Proteins
    # 1. read input params
    output_prefix = options.output_prefix
    output_format = options.output_format
    if output_format.lower() not in ['tsv', 'json']:
        eprint('The output format options are "tsv" or "json". Invalid format "%s" is given.' % output_format)
        return
    input_ic50_file = options.input_ic50_file
    output_json = ''
    if output_prefix:
        output_tsv = output_prefix+'.tsv'
        output_json = output_prefix+'.json'
    additional_result_info = {}
    warnings = []
    additional_result_info['warnings'] = warnings

    if options.json_filename:
        with open(options.json_filename, 'r') as r_file:
            input_data = json.load(r_file)
            peptide_length_range = input_data.get('peptide_length_range', None)
            if peptide_length_range:
                if peptide_length_range == "asis":
                    maximum_length = minimum_length = 0
                    lengths = []
                else:
                    minimum_length, maximum_length = map(int, peptide_length_range)
                    lengths = ','.join(map(str,range(minimum_length,maximum_length+1)))
            else:
                # TODO: is this also required for mhcii?
                minimum_length, maximum_length = 8, 15
                lengths = ''

        if 'input_sequence_text_file_path' in input_data:
            fname = input_data['input_sequence_text_file_path']
            seq_file_type = 'fasta'
            peptide_length_range = input_data['peptide_length_range']
        elif 'input_sequence_fasta_uri' in input_data:
            fname = save_file_from_URI(input_data['input_sequence_fasta_uri'])
            seq_file_type = 'fasta'
            peptide_length_range = input_data['peptide_length_range']
        elif 'peptide_file_path' in input_data:
            fname = input_data['peptide_file_path']
            seq_file_type = 'peptides'
            with open(fname, 'r') as r_file:
                input_data['input_sequence_text'] = r_file.read()
        elif 'input_sequence_text' in input_data:
            with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_peptides_file:
                fname = tmp_peptides_file.name
                seq_file_type = 'fasta'
                tmp_peptides_file.write(input_data['input_sequence_text'])
        else:
            peptide_list = input_data.get('peptide_list')
            seq_file_type = 'peptides'

            if not options.assume_valid_flag and maximum_length:
                for peptide in peptide_list:
                    if len(peptide) > maximum_length or len(peptide) < minimum_length:
                        peptide_list.remove(peptide)
                        warnings.append('peptide "%s" length is out of valid range (%s)' % (peptide, '%d-%d' % (minimum_length,maximum_length)))
            to_delete = []
            with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_peptides_file:
                fname = tmp_peptides_file.name
                to_delete.append(fname)
                tmp_peptides_file.write('\n'.join(peptide_list))

        input_allele = input_data.get('alleles')
        additional_result_info['warnings'] = warnings
        additional_result_info["results"] = []

        predictors = input_data.get('predictors')
        results = []
        for predictor in predictors:
            # add recommended as alias
            method = predictor.get('method', '')#.replace('recommended_epitope','netmhcpan_el').replace('recommended_binding','netmhcpan_ba')
            # TODO: update allele validation code
            #allele_warnings, allele_errors = validate_allele(input_allele, method)
            #if allele_errors:
            #    raise ValueError('\n'.join(allele_errors))
            #elif allele_warnings:
            #    additional_result_info.setdefault('warnings', []).extend(warnings)
            predictor_type = predictor.get('type')
            # if not method or method==type: return type else: join
            predictor_source = '.'.join(filter(None, dict.fromkeys((predictor_type, method))))
            if predictor_type == 'processing' and method in ['netchop', 'netctl', 'netctlpan', 'basic_processing']:
                if method in ['netchop', 'netctl', 'netctlpan']:
                    predict_netchop = PredictNetChop()
                    netchop_options = dict(fasta_file=fname, allele=input_allele, peptide_length_range=peptide_length_range)
                    netchop_options.update(predictor)
                    netchop_options.pop("type")
                    netchop_options["noplot"] = True
                    result = predict_netchop.predict(NetchopOptions(**netchop_options), [])
                    if type(result) == tuple and len(result) == 2:
                        result, plots = result
                        # add predictor_type info as well instead of only mehtod?
                        plots_result = {
                            "type": "processing_plots",
                            "method": predictor_source,
                            "plots_data": plots,
                        }
                        additional_result_info['results'].append(plots_result)

                elif method == 'basic_processing':  # it would be base_processing
                    #if not input_ic50_file:
                    #    raise ValueError('Input ic50 file must be provided.')
                    result = processing_predict(predictor, fasta_file=fname, alleles=input_allele, peptide_length_range=peptide_length_range)
            elif predictor_type == 'immunogenicity':
                proteins=Proteins(input_data['input_sequence_text'])
                # peptide_shift is a synonym for peptide_spacing
                peptide_shift = predictor.get('peptide_shift', None)
                peptide_spacing =  int(predictor.get('peptide_spacing', 5))
                if peptide_shift:
                    peptide_spacing = int(peptide_shift)
                length = lengths if lengths else 'asis'
                length_list = length if ( type(length) is list) else  length.split(',')
                length_list = [int(item.strip()) for item in length_list]

                if method == 'cd4episcore':
                    imm_method = 'immunogenicity'
                imm_scores = imm_prediction(imm_method, proteins.sequences, proteins.names, length_list, peptide_spacing)
                table = PredictionTable()
                table_rows = table.format_immunogenicity(proteins, imm_scores['result'], imm_method)
                #table_rows = sort_table(form["pred_method"], table_rows)
                table_columns = table_rows.pop(0)
                imm_result = {
                    "method": predictor_source,
                    "type": "peptide_table",
                    "table_columns": table_columns,
                    "table_data":table_rows
                }
                results.append(imm_result)
            elif predictor_type == 'binding' and method == 'mhcnp':
                from mhcnp_predicter import predict as mhcnp_predict
                if seq_file_type not in ['fasta','peptides']:
                    raise ValueError('can not accept seq_file_type: %s' % seq_file_type)
                if not output_prefix:
                    raise ValueError('Please use "-o" to specify output file prefix for method "%s"' % method)
                result = mhcnp_predict(input_allele=input_allele, lengths=lengths, fname=fname, seq_file_type=seq_file_type, output_path=output_json, **predictor)
                print('mhcnp prediction done.')
                return
            elif predictor_type == 'processing' and method == 'mhciinp':

                if seq_file_type not in ['fasta','peptides']:
                    raise ValueError('can not accept seq_file_type: %s' % seq_file_type)
                if not output_prefix:
                    raise ValueError('Please use "-o" to specify output file prefix for method "%s"' % method)
                # get seq_arg
                seq_arg = get_seq_arg_for_mhciinp(fname, seq_file_type)
                mhciinp_predict = MHCIILigandPredict()
                mhciinp_result_df = mhciinp_predict.predict_mhciiligands(seq_arg)[1]
                # transfer seq name to seq_num
                mhciinp_result_df.rename(columns = {'Seq name':'sequence_number'}, inplace = True)
                mhciinp_result_df['sequence_number'] = mhciinp_result_df['sequence_number'].str.replace('seq_', '').astype(int)
                # transform the result to a list of tuples
                table_columns = mhciinp_result_df.columns.tolist()
                table_columns = [column_name.lower().strip().replace(' ', '_').replace('peptide_start', 'start').replace('peptide_end', 'end') for column_name in table_columns]

                table_rows = mhciinp_result_df.values.tolist()
                result = {
                    "method": predictor_source,
                    "type": "peptide_table",
                    "table_columns": table_columns,
                    "table_data":table_rows
                }
                print('mhciinp prediction done.')
                results.append(result)
            elif predictor_type == 'binding' and method == 'mhcflurry':
                from mhcflurry_predicter import predict as mhcflurry_predict
                if seq_file_type not in ['fasta','peptides']:
                    raise ValueError('can not accept seq_file_type: %s' % seq_file_type)
                elif seq_file_type == 'fasta':
                    peptide_fname = transfer_fasta_to_peptide_file(fname, lengths)
                    seq_file_type == 'peptides'
                #if not output_prefix:
                #    raise ValueError('Please use "-o" to specify output file prefix for method "%s"' % method)
                result = mhcflurry_predict(input_allele=input_allele,  fname=fname)
                #result = mhcflurry_predict(input_allele=input_allele, lengths=lengths, fname=fname, input_path=options.json_filename, output_path=output_prefix, **predictor)
                #print('mhcflurry prediction done.')
                #return
            else:

                proteins=Proteins(input_data['input_sequence_text'])

                method_dict = {
                    "consensus3":'consensus',
                    "IEDB_recommended": 'netmhciipan_el',
                    "IEDB_recommended_epitope": 'netmhciipan_el',
                    "IEDB_recommended_binding": 'netmhciipan_ba',
                    "NetMHCIIpan":'NetMHCIIpan',
                    "nn_align":"nn_align",
                    "nn_align-2.2":"nn_align-2.2",
                    "comblib":"comblib",
                    "smm_align":"smm_align",
                    "tepitope":'tepitope',
                    "sturniolo":'tepitope',
                    "netmhciipan_el":'netmhciipan_el',
                    "netmhciipan_ba":'netmhciipan_ba',
                    "netmhciipan_el-4.2":'netmhciipan_el-4.2',
                    "netmhciipan_ba-4.2":'netmhciipan_ba-4.2',
                    "netmhciipan_el-4.3":'netmhciipan_el-4.3',
                    "netmhciipan_ba-4.3":'netmhciipan_ba-4.3',
                    "mhciinp":"mhciinp",
                }
                method = method_dict[method]
                alleles = input_allele.replace('HLA-','').replace("/","-")  
                allele_list = alleles.split(',')  
                length = lengths if lengths else 'asis'
                length_list = length if ( type(length) is list) else  length.split(',')
                length_list = [item.strip() for item in length_list]

                allele_length_combo_list = get_allele_length_combo_list(allele_list, length_list)
                allele_list,length_list = zip(*allele_length_combo_list)

                pre=MhciiPredictor(method, allele_list, length_list)      
                                
                mhc_scores = pre.predict(proteins.sequences)

                table = PredictionTable()
                table_rows = table.format_binding(proteins, mhc_scores, method)
                score_unit = "ic50"
                if method.startswith('netmhciipan_el') or method in ("tepitope", "comblib"):
                    score_unit = "score"
                table_columns = [
                        "allele",
                        "seq_num",
                        "start",
                        "end",
                        "peptide_length",
                        "core",
                        "peptide",
                        score_unit,
                        "percentile"
                ]
                if table_rows and len(table_rows[0]) == 10:
                    table_columns.append("adjusted_percentile")
                binding_result = {
                    "method": predictor_source,
                    "type": "peptide_table",
                    "table_columns": table_columns,
                    "table_data":table_rows
                }
                results.append(binding_result)
        return results
        """
                mhci_predictor = MHCIIPredictor(method)
                if seq_file_type == 'fasta':
                    peptide_fname = transfer_fasta_to_peptide_file(fname, lengths)
                elif seq_file_type == 'peptides':
                    peptide_fname = fname
                elif seq_file_type != 'peptides':
                    raise ValueError('can not accept seq_file_type: %s' % seq_file_type)
                result = mhci_predictor.predict(input_allele, lengths, peptide_fname, 'peptides')
                if type(result) == tuple and len(result) == 2:
                    result, distances = result
                    # add predictor_type info as well instead of only mehtod?
                    distance_result = {
                        "type": "netmhcpan_allele_distance",
                        "table_columns": ["input_allele", "closest_allele", "allele_distances"],
                        "table_data": get_allele_distances_table_data(distances),
                    }
                    additional_result_info['results'].append(distance_result)
            if output_prefix:
                if output_format.lower()=='tsv':
                    truncate_file(output_tsv)
                    save_tsv(result, output_tsv)
                elif output_format.lower()=='json':
                    if method == 'netchop':
                        result_type = 'residue_table'
                    else:
                        result_type = 'peptide_table'
                    result_dict = dict(method=predictor_source, type=result_type, table_columns=result[0], table_data=result[1:])
                    additional_result_info.setdefault('results', []).insert(0, result_dict)
                else:
                    eprint('invalid output format: %s' % output_format)
                    return
                save_json(additional_result_info, output_json)
            else:
                if output_format.lower()=='tsv':
                    print_result(result)
                elif output_format.lower()=='json':
                    if method == 'netchop':
                        result_type = 'residue_table'
                    else:
                        result_type = 'peptide_table'
                    result_dict = dict(method=predictor_source, type=result_type, table_columns=result[0], table_data=result[1:])
                    additional_result_info.setdefault('results', []).insert(0, result_dict)
                    print(json.dumps(additional_result_info, indent=2))
                else:
                    eprint('invalid output format: %s' % output_format)
                    return
        additional_result_info['results'].sort(key=lambda result:ResultTypeSortingOrder.index(result['type']))

        if warnings and not options.assume_valid_flag:
            eprint('warnings:')
            eprint(*warnings, sep='\n')
        return
        """
    elif options.filename_peptide:
        fname = options.filename_peptide
        seq_file_type = 'peptides'
        method = options.method
        input_allele = options.allele
        input_length = options.length

    elif options.download_fasta_url:
        input_sequence_text_file_path = save_file_from_URI(options.download_fasta_url)
        seq_file_type = 'peptides'
        method = options.method
        input_allele = options.allele
        input_length = options.length
        fname = transfer_fasta_to_peptide_file(input_sequence_text_file_path, input_length)

    # 2 validation
    # these method only works with JSON input for time being
    if method in ["basic_processing", "netchop", "netctl", "netctlpan", "immunogenicity", "mhcnp", "mhcflurry"]:
         raise ValueError('The method %s requires JSON input, Please specify the path of the JSON input file with the -j option. For example: \n * python3 src/tcell_mhci.py  -j [input_json_file] -f json -o [output-prefix]' % method)

    # input validation
    if not options.assume_valid_flag:
        errors =  input_validation(method, input_allele, input_length, fname, seq_file_type)
        if errors:
            eprint('validation error: %s' % errors)
            return
    allele_warnings, allele_errors = validate_allele(input_allele, method)
    if allele_errors:
        raise ValueError('\n'.join(allele_errors))
    elif allele_warnings:
        additional_result_info.setdefault('warnings', []).extend(warnings)

    # 3. predict
    predictor = MHCIPredictor(method)
    result = predictor.predict(input_allele, input_length, fname, seq_file_type)

    # 4. output
    if type(result) == tuple and len(result) == 2:
        result, distances = result
        additional_result_info['allele_distances'] = distances
    if output_prefix:
        if output_format.lower()=='tsv':
            save_tsv(result, output_tsv)
        elif output_format.lower()=='json':
            save_json(result, output_json)
        else:
            raise ValueError('invalida output format: %s' % output_format)
        save_json(additional_result_info, output_json)
    else:
        print('printing result')
        print(result)
        print_result(result)
    

def input_validation(method, input_allele, input_length, fname, seq_file_type='peptide'):
    return True
    '''input validation'''


def input_validation(method, input_allele, input_length, fname, seq_file_type='peptide'):
    '''input validation'''
    # TODO: determine which validations apply to all methods and which
    #       are specific.  E.g., allele, method, length validation is general
    #       and can be done before any method-specific validation using the
    #       allele-validator package
    from mhcipredictor import MHCIPredictor
    predictor = MHCIPredictor(method)
    if predictor.predictor:
        return predictor.is_valid(input_allele, input_length, fname, seq_file_type)
    else:
        return []
