
import os
import sys
import shutil
import json
import tempfile
import logging
from subprocess import Popen, PIPE
import pandas as pd
from path_config import NETMHCPANPATH, PEPXDBPATH

# set environment variables for icerfire prediction
os.environ['NETMHCPANPATH'] = NETMHCPANPATH
os.environ['PEPXDBPATH'] = PEPXDBPATH

NXG_TOOLS_PATH = os.environ.get('NXG_TOOLS_PATH')
if NXG_TOOLS_PATH and os.path.isdir(NXG_TOOLS_PATH):
    logging.debug('load NXG_TOOLS_PATH: %s' % NXG_TOOLS_PATH)
    sys.path.append(NXG_TOOLS_PATH)
elif os.path.isdir(os.path.join(os.path.dirname(__file__), '..', 'method', 'nxg-tools')):
    sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'method', 'nxg-tools'))
# from nxg-tools package
from nxg_common.nxg_common import save_file_from_URI

# select all columns
SELECTED_COLUMNS = []

def tsv_to_dict(tsv_file, sep='\t'):
    # Read the TSV file into a DataFrame
    df = pd.read_csv(tsv_file, sep=sep, low_memory=False)
    # rename columns mer_ref_peptides and [PEPTIDELENGTH]mer_mut_peptides
    df.columns = df.columns.str.replace(r'\d+mer_ref_peptides', 'reference_peptide', regex=True)
    df.columns = df.columns.str.replace(r'\d+mer_mut_peptides', 'mutant_peptide', regex=True)
    df.columns = df.columns.str.replace('feature_id', 'transcript_id')
    df.columns = df.columns.str.replace('SerialNumber', 'peptide_pair_id')
    # filter pandas dataframe with specific column names
    if SELECTED_COLUMNS:
        df = df.loc[:, df.columns.isin(SELECTED_COLUMNS)]
    # replace na/nat/nan to -
    df = df.fillna('-')
    # Convert the DataFrame to JSON
    dict_data = df.to_dict(orient='split')

    return dict_data

def save_json(result, output_path):
    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    with open(output_path, 'w') as w_file:
        json.dump(result, w_file, indent=2)
    return os.path.abspath(output_path)

def run_all(json_filename, output_prefix, output_format, assume_valid_flag):
    if output_prefix is None:
        raise ValueError("output_prefix is required and cannot be None")
    if not isinstance(output_prefix, str) or not output_prefix.strip():
        raise ValueError("output_prefix cannot be an empty string")
    print('run all')
    with open(json_filename, 'r') as r_file:
        params = json.load(r_file)
    if params['predictors'][0]['method'] == 'icerfire':
        run_icerfire(params, output_prefix, output_format, assume_valid_flag)
        return
    with tempfile.TemporaryDirectory() as tmp_dir:
        base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
        output_dir = os.path.abspath(os.path.dirname(output_prefix))
        os.makedirs(output_dir, exist_ok=True)
        job_descriptions_path = os.path.abspath(os.path.join(tmp_dir, '..', 'job_descriptions.json'))
        pvc_executable_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'run_pvc.py'))
        # remove job_descriptions file
        if os.path.exists(job_descriptions_path): 
            os.remove(job_descriptions_path)
        # run split
        split_cmd = [pvc_executable_path, '-j', json_filename, '--split', '--split-dir=%s' % tmp_dir]
        logging.debug(' '.join(split_cmd))
        process = Popen(split_cmd, stdout=PIPE)
        stdoutdata, stderrdata_ignored = process.communicate()
        stdoutdata = stdoutdata.decode()
        logging.debug('Raw output:\n{}'.format(stdoutdata))

        with open(job_descriptions_path, 'r') as r_file:
            job_descriptions = json.load(r_file)

        for job in job_descriptions:
            cmd = job['shell_cmd'].split()
            result_output_path = job['expected_outputs']
            logging.debug(' '.join(cmd))
            process = Popen(cmd, stdout=PIPE)
            stdoutdata, stderrdata_ignored = process.communicate()
            stdoutdata = stdoutdata.decode()
            logging.debug('Raw output:\n{}'.format(stdoutdata))

        # copy final result to expectd path
        shutil.copy2(result_output_path[0], f"{output_prefix}.json")
    # remove job_descriptions file
    if os.path.exists(job_descriptions_path): 
        os.remove(job_descriptions_path)    
    print("prediction done")

def run_icerfire(params, output_prefix, output_format='json', assume_valid_flag=True):
    ICERFIRE_PATH = os.environ.get('ICERFIREPATH')
    if ICERFIRE_PATH and os.path.isdir(ICERFIRE_PATH):
        sys.path.append(ICERFIRE_PATH)
    elif os.path.isdir(os.path.join(os.path.dirname(__file__), '..', 'method', 'icerfire-1.0-executable')):
        sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'method', 'icerfire-1.0-executable'))

    to_delete = []
    output_dir = os.path.abspath(os.path.dirname(output_prefix))
    input_sequence_fasta_uri = params.pop('input_sequence_fasta_uri', '')
    input_sequence_text_uri = input_sequence_fasta_uri.replace('sequence_list_fasta', 'download_sequences')
    input_sequence_text = params.get('input_sequence_text', '')
    if not input_sequence_text and input_sequence_text_uri:
        input_sequence_text_file_path = save_file_from_URI(input_sequence_text_uri, target_dir=output_dir)
        to_delete.append(input_sequence_text_file_path)
        with open(input_sequence_text_file_path, 'r') as r_file:
            input_sequence_text = r_file.read()
    elif 'input_sequence_text_file_path' in params:
        input_sequence_text_file_path = params.pop('input_sequence_text_file_path')
        with open(input_sequence_text_file_path, 'r') as r_file:
            input_sequence_text = r_file.read()

    # remove rows from input_sequence_text if peptides (peptide-a and peptide-b for the first 2 columns) longer than 14 or shorter than 8
    # remove duplicates
    peptide_pair_dictkeys = dict.fromkeys(peptides for peptides in input_sequence_text.strip().upper().split('\n'))
    peptides_list = [peptides.split(',') for peptides in peptide_pair_dictkeys]
    peptides_list = [peptide for peptide in peptides_list if 8 <= len(peptide[0]) <= 14 and 8 <= len(peptide[1]) <= 14]
    if not peptides_list:
        raise ValueError("No valid peptides found in input_sequence_text. Peptides must be between 8 and 14 amino acids long.")
    input_sequence_text = '\n'.join([','.join(peptide) for peptide in peptides_list])
    params['input_sequence_text'] = input_sequence_text

    from icerfire_1_0_executable import icerfire_prediction
    result_tsv_file = icerfire_prediction(params)

    if output_format == 'tsv':
        os.makedirs(output_dir, exist_ok=True)
        os.rename(result_tsv_file, f'{output_prefix}.tsv')
    else:
        dict_result = tsv_to_dict(result_tsv_file, sep=',')
        del dict_result['index']
        dict_result['result_type'] = 'peptide_table'
        dict_result['table_columns'] = dict_result.pop('columns')
        dict_result['table_data'] = sorted(dict_result.pop('data'))
        output_result = {
            "results": [dict_result, ],
            "warnings": []
        }
        save_json(output_result, f'{output_prefix}.json')