# Here, the code for the following logic should be implemented.
# * Take input JSON file, and parse the sequences. Depending on the tool, split the sequences if needed, and store them under
#     'preprocess_job/input_units' folder.
# * Take the rest of the parameters from the input JSON file, and split them into an atomic job units.
#     * Make sure for each atomic job units have a key/value pair pointing to the input sequence files under 'preprocess_job/input_units'.
#     * Each job units should be stored under 'preprocess_job/parameter_units'.
# * Lastly, it should create 'job_descriptions.json' file under 'preprocess_job/'.
#     * This file will have list of descriptions for each job units.
#     * Each description will contain a command that runs single prediction (utilizes 'predict' subcommand).
#     * Note that the last command in the description file will use 'postprocess' subcommand.
import json
import os, sys
from tempfile import tempdir
import pandas as pd
import core.set_pythonpath  # This automatically configures PYTHONPATH
import utils
import shutil
from validators import InputManager
from typing import List, Dict, Any, Union
from pathlib import Path
from run_phbr import run_prediction
from validators import MHCClass


def save_json(result: Dict[str, Any], output_path: Union[str, Path]) -> None:
    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    with open(output_path, 'w') as w_file:
        json.dump(result, w_file, indent=2)


def update_job_descriptions_with_new_paths(jobs: List[Dict[str, Any]], mhc_dir: Path) -> List[Dict[str, Any]]:
    # Set paths
    param_dir = mhc_dir / 'predict-inputs' / 'params'
    predict_output_dir = mhc_dir / 'predict-outputs'
    result_dir = mhc_dir / 'results'
    aggregate_dir = mhc_dir / 'aggregate'

    for job in jobs:
        job_id = job['job_id']
        job_cmd = job['shell_cmd'].split(' ')

        # Update the input path
        if '-j' in job_cmd:
            # Find the index of the argument that contains '-j'
            j_index = job_cmd.index('-j')
            
            # Replace the argument with the new path
            job_cmd[j_index + 1] = str(param_dir / f'{job_id}.json')

        # Update the output path
        if '-o' in job_cmd:
            # Find the index of the argument that contains '-o'
            o_index = job_cmd.index('-o')
            
            # Replace the argument with the new path
            job_cmd[o_index + 1] = str(predict_output_dir / f'{job_id}')

        # Update expected outputs
        if 'expected_outputs' in job:
            job['expected_outputs'] = [str(predict_output_dir / f'{job_id}.json')]

        job['shell_cmd'] = ' '.join(job_cmd)


    # # Create a temporary file for job descriptions
    # tmp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False).name
    # with open(tmp_file_path, 'w') as f:
    #     json.dump(job_description, f, indent=2)

    local_jd_file = mhc_dir / 'job_descriptions.json'
    with open(local_jd_file, 'w') as f:
        json.dump(jobs, f, indent=2)

    # Handle aggregate job
    aggregate_job = jobs[-1]
    aggregate_job_list = aggregate_job['shell_cmd'].split(' ')
    
    # Find the index of the element that contains '--aggregate-input-dir' as a substring
    aggregate_input_dir_index = next((i for i, arg in enumerate(aggregate_job_list) if '--aggregate-input-dir' in arg), None)
    if aggregate_input_dir_index is not None:
        # Replace the argument with the new path
        aggregate_job_list[aggregate_input_dir_index] = f'--aggregate-input-dir={result_dir}'

    # Update the aggregate-result-dir
    aggregate_result_dir_index = next((i for i, arg in enumerate(aggregate_job_list) if '--aggregate-result-dir' in arg), None)
    if aggregate_result_dir_index is not None:
        # Replace the argument with the new path
        aggregate_job_list[aggregate_result_dir_index] = f'--aggregate-result-dir={aggregate_dir}'

    # find the index of the element that contains '--job-desc-file' as a substring
    job_desc_file_index = next((i for i, arg in enumerate(aggregate_job_list) if '--job-desc-file' in arg), None)
    if job_desc_file_index is not None:
        # Replace the argument with the new path
        aggregate_job_list[job_desc_file_index] = f'--job-desc-file={local_jd_file}'

    aggregate_job['shell_cmd'] = ' '.join(aggregate_job_list)
    aggregate_job['expected_outputs'] = [str(aggregate_dir / 'aggregated_result.json')]
    
    jobs[-1] = aggregate_job

    return jobs


def reset_job_ids(jobs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    job_id = 0

    for job in jobs:
        job['job_id'] = 0

    
    
    
    return jobs

def restructure_folder_structure_post_prediction(output_dir: Path) -> None:
    print('Restructuring folder structure post prediction...')
    print('output_dir: ', output_dir)

    data_path = output_dir / 'predict-inputs' / 'data'
    params_path = output_dir / 'predict-inputs' / 'params'

    # Create a new directory called 'data'
    data_path.mkdir(parents=True, exist_ok=False)

    # Create a new directory called 'params'
    params_path.mkdir(parents=True, exist_ok=False)

    # Move all files that starts with 'tmp' to the 'data' directory
    for file in output_dir.glob('tmp*'):
        shutil.move(str(file), str(data_path))


    # Move all files that starts with a number and has json extension to the 'params' directory
    # For example, 1.json, 2.json, 3.json, ... 10.json, 11.json...
    # Move all files that start with a number and have .json extension to params directory
    for file in output_dir.glob('*.json'):
        if file.stem.isdigit() and file.suffix == '.json':
            shutil.move(str(file), str(params_path))

    # each json file in the 'params' directory needs to change their 'peptide_file_path' to the new path
    for file in params_path.glob('*.json'):
        with open(file, 'r') as f:
            data = json.load(f)
        
        peptide_file_name = data['peptide_file_path'].split('/')[-1]
        data['peptide_file_path'] = str(data_path / peptide_file_name)
        
        with open(file, 'w') as f:
            json.dump(data, f, indent=2)

    # Move the 'job_descriptions.json' file in the output_dir.parent/job_descriptions.json to the output_dir
    job_descriptions_file = output_dir.parent / 'job_descriptions.json'
    shutil.move(str(job_descriptions_file), str(output_dir / 'job_descriptions.json'))

    predict_output_dir = output_dir / 'predict-outputs'
    predict_output_dir.mkdir(parents=True, exist_ok=False)

    aggregate_dir = output_dir / 'aggregate'
    aggregate_dir.mkdir(parents=True, exist_ok=False)

    # Move output_dir.parent/results directory to the output_dir
    results_dir = output_dir.parent / 'results'
    shutil.move(str(results_dir), str(output_dir))

    print('Folder structure restructured successfully!')


def create_and_add_phbr_job(kwargs: dict, job_description: List[Dict[str, Any]], mhc_class: MHCClass) -> List[Dict[str, Any]]:
    APP_ROOT = os.getenv('APP_ROOT')
    APP_NAME = os.getenv('APP_NAME', 'app')
    last_job_id = job_description[-1]['job_id']
    output_dir = kwargs['output_dir']
    result_dir_phbr = kwargs['output_dir'] / APP_NAME / 'predict-outputs'


    # Depending on the 'mhc_class', it will only extract relevant parameters and
    # create a separate input json file. For example, when 'mhc-combined-binding-input.json'
    # is passed, it will first extract 'class_i' parameters and save a sub-input json file.
    # Later when this function is called with MHCII, it will extract 'class_ii' parameters
    # and create sub-input json file.
    # Finally, this will get passed to the 'shell_cmd' in the job.
    kwargs['input_json'].seek(0)
    mhc_spec_input = json.load(kwargs['input_json'])
    if mhc_class == MHCClass.MHCI:
        if 'class_ii' in mhc_spec_input:
            del mhc_spec_input['class_ii']

    if mhc_class == MHCClass.MHCII:
        if 'class_i' in mhc_spec_input:
            del mhc_spec_input['class_i']

    # Save the mhc_spec_input to a temporary file
    temp_file = utils.save_json_to_temporary_file(mhc_spec_input, output_dir)

    job = {
        "shell_cmd": f"{APP_ROOT}/src/run_phbr.py predict -j {temp_file} -d {output_dir} -o {result_dir_phbr.resolve()}/{last_job_id + 1} -f json",
        'job_id': last_job_id + 1,
        'job_type': 'predict',
        'depends_on_job_ids': [last_job_id],
        'expected_outputs': [
            f'{result_dir_phbr.resolve()}/{last_job_id + 1}.json'
        ]
    }

    job_description.append(job)
    return job_description


def create_phbr_job(kwargs: dict, prev_job_id: int, mhc_class: MHCClass) -> Dict[str, Any]:
    APP_ROOT = os.getenv('APP_ROOT')
    APP_NAME = os.getenv('APP_NAME', 'app')

    job_type = 'predict'

    if 0 <= prev_job_id :
        curr_job_id = prev_job_id + 1
        depends_on_job_ids = [prev_job_id]
    else:
        # Negative job ID means that the input file does not require binding jobs.
        # Therefore, the current job ID should be 0.
        # And the dependent job IDs should be empty.
        curr_job_id = 0
        depends_on_job_ids = []
    
    input_file = kwargs['input_json'].name
    output_dir = kwargs['output_dir']
    result_dir_phbr = kwargs['output_dir'] / APP_NAME / 'predict-outputs'
    input_dir_phbr = kwargs['output_dir'] / APP_NAME / 'predict-inputs'

    # Depending on the 'mhc_class', it will only extract relevant parameters and
    # create a separate input json file. For example, when 'mhc-combined-binding-input.json'
    # is passed, it will first extract 'class_i' parameters and save a sub-input json file.
    # Later when this function is called with MHCII, it will extract 'class_ii' parameters
    # and create sub-input json file.
    # Finally, this will get passed to the 'shell_cmd' in the job.
    kwargs['input_json'].seek(0)
    mhc_spec_input = json.load(kwargs['input_json'])
    if mhc_class == MHCClass.MHCI:
        if 'class_ii' in mhc_spec_input:
            del mhc_spec_input['class_ii']

    if mhc_class == MHCClass.MHCII:
        if 'class_i' in mhc_spec_input:
            del mhc_spec_input['class_i']

    # Save the mhc_spec_input to a temporary file
    temp_file = utils.save_json_to_temporary_file(mhc_spec_input, output_dir)

    job = {
        "shell_cmd": f"{APP_ROOT}/src/run_phbr.py predict -j {temp_file} -d {output_dir} -o {result_dir_phbr.resolve()}/{curr_job_id} -f json",
        'job_id': 0,
        'job_type': job_type,
        'depends_on_job_ids': [],
        'expected_outputs': [
            f'{result_dir_phbr.resolve()}/0.json'
        ]
    }

    return job


def create_mhc_job(data: dict, input_manager: InputManager, mhc_class: MHCClass) -> Dict[str, Any]:
    jobs: List[Dict[str, Any]] = []
    mhci_dir = data['metadata']['output_dir'] / 'mhci'
    mhcii_dir = data['metadata']['output_dir'] / 'mhcii'

    if input_manager.category.name not in ["BINDING_RESULT_URI", "PEPTIDE_SEQUENCE_TABLE"]:
        # NOTE: For these categories, we can directly use the input file as is (No binding required)
        # - BINDING_RESULT_URI: The input file is already a binding result URI.
        # - PEPTIDE_SEQUENCE_TABLE: The input file is already a peptide sequence table.

        # Requires MHC binding to be run first
        # To run prediction, we will pass in some dummy data to satisfy the function signature
        data['metadata']['output_prefix'] = 'dummy'
        data['metadata']['output_format'] = 'json'
        job_description_path = run_prediction(data, input_manager, mhc_class)

        # Override the jobs list with this new list of jobs
        with open(job_description_path, 'r') as f:
            jobs = json.load(f)

        # Organize the files and folders after prediction is done
        if mhc_class == MHCClass.MHCI:
            restructure_folder_structure_post_prediction(mhci_dir)
            jobs = update_job_descriptions_with_new_paths(jobs, mhci_dir)
        elif mhc_class == MHCClass.MHCII:
            restructure_folder_structure_post_prediction(mhcii_dir)
            jobs = update_job_descriptions_with_new_paths(jobs, mhcii_dir)   

    return jobs


def run(**kwargs):
    APP_NAME = os.getenv('APP_NAME', 'app')
    APP_ROOT = os.getenv('APP_ROOT')
    print('--------------------------------')
    print(kwargs)
    
    data = json.load(kwargs['input_json'])
    input_manager = InputManager(data)
    # Merge all the inputs into one json obj
    data = {**data, 'metadata': {**kwargs}}
    print('data_with_metadata: \n', data)

    job_description: List[Dict[str, Any]] = []
    aggregated_job_description: List[Dict[str, Any]] = []
    job_description_path = data['metadata']['output_dir'] / 'job_descriptions.json'
    result_dir_phbr = data['metadata']['output_dir'] / APP_NAME / 'predict-outputs'
    result_dir_phbr_postprocess = data['metadata']['output_dir'] / APP_NAME / 'results'

    phbr_prediction_job_ids = []
    last_job_id = 0


    if input_manager.has_mhci:
        # Create MHCI binding jobs
        # NOTE: It will return empty jobs if MHCI binding is not required.
        jobs = create_mhc_job(data, input_manager, MHCClass.MHCI)
        job_description.extend(jobs)

        # Add PHBR job
        job_description = create_and_add_phbr_job(kwargs, job_description, MHCClass.MHCI)
        aggregated_job_description.extend(job_description)
        last_job_id = aggregated_job_description[-1]['job_id']

    if input_manager.has_mhcii:
        # clear job_description
        job_description = []

        # Create MHCII binding jobs
        # NOTE: It will return empty jobs if MHCII binding is not required.
        jobs = create_mhc_job(data, input_manager, MHCClass.MHCII)
        
        # Need correction for the job ID if MHCI and MHCII both were passed.
        if input_manager.has_mhci:
            # Update all the prediction jobs' job ID.
            for job in jobs:
                job['job_id'] = job['job_id'] + last_job_id + 1

            # Update the aggregate job's dependent job IDs.
            starting_job_id = jobs[0]['job_id']
            # -2 because the last job is the aggregate job.
            # We want the last "predict" job.
            ending_job_id = jobs[-2]['job_id']
            jobs[-1]['depends_on_job_ids'] = list(range(starting_job_id, ending_job_id + 1))

        job_description.extend(jobs)

        # Add PHBR job
        job_description = create_and_add_phbr_job(kwargs, job_description, MHCClass.MHCII)
        aggregated_job_description.extend(job_description)

    # Need to collect all the PHBR prediction job IDs as the postprocessing job depends on them.
    phbr_prediction_job_ids = [job['job_id'] for job in aggregated_job_description if job['job_type'] == 'predict' and 'run_phbr.py' in job['shell_cmd']]

    # Add the PHBR postprocess job
    curr_job_id = aggregated_job_description[-1]['job_id'] + 1
    postprocess_job = {
        "shell_cmd": f"{APP_ROOT}/src/run_phbr.py postprocess --job-desc-file {job_description_path} --postprocessed-results-dir {result_dir_phbr} -o {result_dir_phbr_postprocess / str(curr_job_id)} -f json --include-mhci-mhcii-result",
        'job_id': curr_job_id,
        'job_type': 'postprocess',
        'depends_on_job_ids': phbr_prediction_job_ids,
        'expected_outputs': [
            f'{result_dir_phbr_postprocess / str(curr_job_id)}.json',
        ]
    }

    aggregated_job_description.append(postprocess_job)

    # for job in aggregated_job_description:
    #     print(json.dumps(job, indent=2, ensure_ascii=False))


    # Save the master_job_description to a file
    save_json(aggregated_job_description, job_description_path)
    print('job_description file saved to: ', job_description_path)