import json
import csv
import os
import sys
import subprocess
import tempfile
import textwrap
import preprocess
import postprocess
import pandas as pd
import warnings
import utils
from io import StringIO
import core.set_pythonpath  # This automatically configures PYTHONPATH
from PhbrArgumentParser import PhbrArgumentParser
from validators import InputManager, MHCClass, validate_alleles, validate_sequence_table, validate_mutation_position
from pathlib import Path
from typing import Dict, List, Union, Optional, Any, Tuple
from itertools import product
from dotenv import load_dotenv
load_dotenv()

# Add the nxg-tools directory to Python path
sys.path.insert(0, os.path.join(os.getenv('APP_ROOT'), 'libs', 'nxg-tools'))
from nxg_common import nxg_common as common


def save_phbr_to_formatted_json(phbr_output_df: pd.DataFrame, output_file: str) -> None:
    output_dict = {
        "warnings": [],
        "results": [
            {
                "result_type": "peptide_table",
                "table_columns": phbr_output_df.columns.tolist(),
                "table_data": phbr_output_df.values.tolist(),
            }
        ]
    }

    # save this to file
    with open(output_file, 'w') as f:
        json.dump(output_dict, f, indent=2)
    
    print(f"PHBR output saved to {output_file}")

def run_mhc_binding(input_data: Dict[str, Any], class_type: MHCClass, is_mut_vs_ref_table: bool = False) -> Path:
    cmd = alleles = method = sequence_file = None
    sequence_table_df = None
    peptide_length_range = None
    method = None
    mhc_dir_name = ''
    MHCI_DEFAULT_PEPTIDE_LENGTH_RANGE = [8, 11]
    MHCII_DEFAULT_PEPTIDE_LENGTH_RANGE = [15, 15]
    APP_ROOT = os.getenv('APP_ROOT')

    if 'mhc_sequence_tsv' not in input_data:
        raise ValueError("No sequence table found in the input JSON file.")

    sequence_file = input_data['mhc_sequence_tsv']

    if class_type == MHCClass.MHCI:
        print('Running MHCI....')
        # NOTE: If the input is a combined MHCI and MHCII input, then we need to get the MHCI sequence table from the input_data
        # if 'mhc_sequence_tsv' in input_data['class_i']:
        #     sequence_file = input_data['class_i']['mhc_sequence_tsv']
        
        sequence_table_df = pd.read_csv(sequence_file, sep='\t')

        fasta_string = utils.create_fasta_string(sequence_table_df)

        # # Add logic to capture the mutation sequence and reference sequence in a dictionary
        # # The key is ('row index + 1', 'mut_peptide') and the value is 'ref_peptide'
        # # Only create this dictionary once, not for each MHC class
        # if is_mut_vs_ref_table:  # Only create if not already created
        #     mut_peptide_ref_peptide_dict = {(i, mut): ref for i, mut, ref in zip(range(1, len(sequence_table_df) + 1), sequence_table_df['mut_peptide'], sequence_table_df['ref_peptide'])}
        
        # Need to turn alleles dictionary values to string format
        alleles = input_data['class_i']['alleles']
        method = input_data['class_i']['prediction_method']['method']


        alleles_list = list(alleles.values())
        
        # Split each allele string by comma and flatten the list
        flattened_alleles = []
        for allele_group in alleles_list:
            flattened_alleles.extend(allele_group.split(','))
        
        # print('flattened_alleles: ', flattened_alleles)

        valid_alleles, invalid_alleles = validate_alleles(flattened_alleles, class_type=class_type, method=method)
        alleles = ','.join(valid_alleles)

        print('valid_alleles: ', valid_alleles)
        print('invalid_alleles: ', invalid_alleles)

        if invalid_alleles:
            warning_msg = "The following alleles are invalid and will be excluded from predictions:\n"
            warning_msg += "\n".join(f"  - {allele}" for allele in invalid_alleles)
            warnings.warn(warning_msg, UserWarning)

        peptide_length_range = MHCI_DEFAULT_PEPTIDE_LENGTH_RANGE
        mhc_dir_name = 'mhci'


    if class_type == MHCClass.MHCII:
        print('Running MHCII....')
        # NOTE: If the input is a combined MHCI and MHCII input, then we need to get the MHCII sequence table from the input_data
        if 'mhc_sequence_tsv' in input_data['class_ii']:
            sequence_file = input_data['class_ii']['mhc_sequence_tsv']
        
        sequence_table_df = pd.read_csv(sequence_file, sep='\t')
        fasta_string = utils.create_fasta_string(sequence_table_df)
        
        alleles = input_data['class_ii']['alleles']
        alleles_list = alleles.get('DRB', '').split(',') if alleles.get('DRB') else []
        dpa_alleles = alleles.get('DPA', '').split(',') if alleles.get('DPA') else []
        dpb_alleles = alleles.get('DPB', '').split(',') if alleles.get('DPB') else []
        dqa_alleles = alleles.get('DQA', '').split(',') if alleles.get('DQA') else []
        dqb_alleles = alleles.get('DQB', '').split(',') if alleles.get('DQB') else []

        # Create combination of dpa and dpb alleles
        for dpa_allele, dpb_allele in product(dpa_alleles, dpb_alleles):
            if dpb_allele.startswith('HLA'):
                # Remove the HLA prefix
                dpb_allele = dpb_allele.replace('HLA-', '')
            
            paired_allele = f"{dpa_allele}/{dpb_allele}"

            if paired_allele not in alleles_list:
                alleles_list.append(paired_allele)

        # Create combination of dqa and dqb alleles
        for dqa_allele, dqb_allele in product(dqa_alleles, dqb_alleles):
            if dqb_allele.startswith('HLA'):
                # Remove the HLA prefix
                dqb_allele = dqb_allele.replace('HLA-', '')
            
            # NOTE: If the allele is not in the list, then add it
            paired_allele = f"{dqa_allele}/{dqb_allele}"
            if paired_allele not in alleles_list:
                alleles_list.append(paired_allele)

        # Making sure there are no empty strings in the list
        alleles_list = list(filter(None, alleles_list))

        method = input_data['class_ii']['prediction_method']['method']

        # Validate alleles
        valid_alleles, invalid_alleles = validate_alleles(alleles_list, class_type=class_type, method=method)
        alleles = ','.join(valid_alleles)
        # print(">>>> alleles: ", alleles)
        # print(">>>> invalid_alleles: ", invalid_alleles)

        if invalid_alleles:
            warning_msg = "The following alleles are invalid and will be excluded from predictions:\n"
            warning_msg += "\n".join(f"  - {allele}" for allele in invalid_alleles)
            warnings.warn(warning_msg, UserWarning)

        peptide_length_range = MHCII_DEFAULT_PEPTIDE_LENGTH_RANGE
        mhc_dir_name = 'mhcii'


    # Validate input_sequence_text
    valid_sequence_df = validate_sequence_table(sequence_table_df, peptide_length_range)
    print('valid_sequence_df: \n', valid_sequence_df)
    print('--------------------------------')

    # if valid_sequence_df is empty, then raise an error
    if valid_sequence_df.empty:
        raise ValueError(f"No valid sequences that meets the length range requirements found in the sequence table for MHC class {class_type.value.upper()}.")


    # Create mhc_binding_payload
    mhc_binding_payload = {
        'input_sequence_text': fasta_string,
        'alleles': alleles,
        'predictors': [
            {
                'type': 'binding',
                'method': method
            }
        ],
        'peptide_length_range': peptide_length_range
    }

    mhc_input_dir = Path(input_data['metadata']['output_dir']) / mhc_dir_name / 'tmp'
    mhc_input_dir.mkdir(parents=True, exist_ok=True)
    # print('--------------------------------')
    # print('input_data["metadata"]["output_dir"]: ', input_data['metadata']['output_dir'])
    # print('mhc_dir_name: ', mhc_dir_name)
    # print('mhc_input_dir: ', mhc_input_dir)
    # print('--------------------------------')

    # temporary file that contains the mhc_binding_payload
    temp_file = utils.save_json_to_temporary_file(mhc_binding_payload, mhc_input_dir)
    print('temp_file: ', temp_file)


    if class_type == MHCClass.MHCI:
        # NOTE: When running from 'predict' subcommand, we need to create temporary directory for the mhci/mhcii result
        # When running from 'preprocess' subcommand, we need to create temporary directory for the mhci/mhcii result.
        # - This is because 'predict' subcommand doesn't take in the output directory as a parameter.
        # - 'preprocess' subcommand takes in the output directory as a parameter.
        if input_data['metadata']['subcommand'] == 'preprocess':
            split_dir = Path(input_data['metadata']['output_dir']) / mhc_dir_name
            jd_file = split_dir.parent / 'job_descriptions.json'
        else:
            # Create temporary directory output where mhci/mhcii result will be stored
            tmpdir = tempfile.mkdtemp(prefix=f"mhc{class_type.value}_output_", suffix="_for_phbr")
            
            # Create 'mhci' folder inside the tmpdir
            split_dir = f'{tmpdir}/mhci'
            os.makedirs(split_dir, exist_ok=True)
            jd_file = Path(tmpdir) / 'job_descriptions.json'
        
        cmd = textwrap.dedent(f"""\
            source {APP_ROOT}/setup_tcell_class_i_env.sh
            python3 $TCELL_CLASS_I_PATH/src/tcell_mhci.py -j {temp_file} --split --split-dir={split_dir} --keep-empty-row
        """)


    if class_type == MHCClass.MHCII:
        # NOTE: When running from 'predict' subcommand, we need to create temporary directory for the mhci/mhcii result
        # When running from 'preprocess' subcommand, we need to create temporary directory for the mhci/mhcii result.
        # - This is because 'predict' subcommand doesn't take in the output directory as a parameter.
        # - 'preprocess' subcommand takes in the output directory as a parameter.
        if input_data['metadata']['subcommand'] == 'preprocess':
            split_dir = Path(input_data['metadata']['output_dir']) / 'mhcii'
            jd_file = split_dir.parent / 'job_descriptions.json'
        else:
            # Create temporary directory output where mhci/mhcii result will be stored
            tmpdir = tempfile.mkdtemp(prefix=f"mhc{class_type.value}_output_", suffix="_for_phbr")
            
            # Create 'mhci' folder inside the tmpdir
            split_dir = f'{tmpdir}/mhcii'
            os.makedirs(split_dir, exist_ok=True)
            jd_file = Path(tmpdir) / 'job_descriptions.json'

        cmd = textwrap.dedent(f"""\
            source {APP_ROOT}/setup_tcell_class_ii_env.sh
            python3 $TCELL_CLASS_II_PATH/src/tcell_mhcii.py -j {temp_file} --split --split-dir={split_dir} --keep-empty-row
        """)


    # print("==================^^^^==================")
    # print(mhc_binding_payload)
    # print("==================^^^^==================")


    # print('cmd: ', cmd)
    
    # subprocess.run(split_cmd, capture_output=True, text=True, check=True)
    subprocess.run(cmd, shell=True, executable="/bin/bash", check=True)
    
    # sys.exit()

    # Return the path to the temporary directory if the subcommand is 'preprocess'
    if input_data['metadata']['subcommand'] == 'preprocess':
        return Path(input_data['metadata']['output_dir']) / 'job_descriptions.json'

    

    # extract just the first line of the 'cmd'
    setup_cmd = cmd.strip().splitlines()[0]

    # Aggregate the result
    job_description_path = Path(split_dir).parent / "job_descriptions.json"

    with open(job_description_path, 'r') as f:
        jd_content = json.load(f)

    # The final result is located in 'predict-inputs/aggregate/aggregated-results.json'
    for i, job in enumerate(jd_content):
        sh_cmd = ' '.join([setup_cmd, '&&', 'python3', job['shell_cmd']])
        # sh_cmd = shlex.split(job['shell_cmd'])
        # # NOTE: Should we source the env file everytime we run a job?
        # sh_cmd.insert(0, python_path)
        print(f'JOB {i} >> {sh_cmd}')
        subprocess.run(sh_cmd, shell=True, executable="/bin/bash", check=True)

    return Path(split_dir)


def mhc_binding_result_json2tsv(params_dir: Path, t_cell_class: MHCClass, is_aggregated_result: bool = False) -> Optional[Path]:
    result_json_file = params_dir.parent / 'aggregate' / 'aggregated_result.json'
    new_result_file = None # final output path

    # aggregated result maybe passed in as a parameter
    if is_aggregated_result:
        result_json_file = params_dir    


    # Load JSON content
    with open(result_json_file, 'r') as f:
        data = json.load(f)

    header_mapping =  {
        "core.sequence_number": "seq #",
        "core.peptide": "peptide",
        "core.mut_peptide": "mutant peptide",
        "core.ref_peptide": "ref_peptide",
        "core.start": "start",
        "core.end": "end",
        "core.length": "peptide length",
        "core.allele": "allele",
        "core.peptide_index": "peptide index",
    }
    additional_header_mapping = {}

    # Define desired header mapping
    if t_cell_class.value == MHCClass.MHCI:
        additional_header_mapping = {
            "binding.median_percentile": "median binding percentile",
            "binding.netmhcpan_el.core": "netmhcpan_el core",
            "binding.netmhcpan_el.icore": "netmhcpan_el icore",
            "binding.netmhcpan_el.score": "netmhcpan_el score",
            "binding.netmhcpan_el.percentile": "netmhcpan_el percentile"
        }
    if t_cell_class.value == MHCClass.MHCII:
        additional_header_mapping = {
            "binding.median_percentile": "median binding percentile",
            "binding.netmhciipan_el.core": "netmhciipan_el core",
            "binding.netmhciipan_el.score": "netmhciipan_el score",
            "binding.netmhciipan_el.percentile": "netmhciipan_el percentile"
        }
    
    header_mapping = {**header_mapping, **additional_header_mapping}

    # Find the peptide_table result
    peptide_table = next(
        (res for res in data["results"] if res["result_type"] == "peptide_table"),
        None
    )

    if peptide_table is not None:
        # Extract original columns and map to custom headers
        original_columns = peptide_table["table_columns"]
        custom_headers = [header_mapping.get(col, col) for col in original_columns]

        # Write to TSV
        new_result_file = params_dir.parent / 'aggregate' / 'peptide_table.tsv'

        if is_aggregated_result:
           new_result_file = tempfile.NamedTemporaryFile(mode='w', suffix='.tsv', delete=False).name
           print('Creating temporary file: ', new_result_file)

        with open(new_result_file, 'w', newline='') as f_out:
            writer = csv.writer(f_out, delimiter='\t')
            writer.writerow(custom_headers)
            writer.writerows(peptide_table["table_data"])
        print(f'successfully written to {new_result_file}')
    else:
        print("No peptide_table found in the JSON data.")

    return new_result_file


def run_prediction(data: dict, im: InputManager = None, mhc_class: MHCClass = MHCClass.MHCI):
    print("Running prediction")
    print(data)
    print(im.describe())
    print(im.category.name)

    sequence_table_path = None
    peptide_table_path = None
    method = None
    output_dir = Path(data['metadata']['output_dir'])
    output_prefix = data['metadata']['output_prefix']
    output_format = data['metadata']['output_format']
    is_mut_vs_ref_table = False
    mut_peptide_ref_peptide_dict = None

    # Build output file path robustly. If the user supplied an absolute or nested
    # prefix (e.g., "/abs/path/to/name"), use it directly; otherwise, place it under output_dir.
    prefix_path = Path(output_prefix) if output_prefix else Path('phbr_output')
    if prefix_path.is_absolute() or prefix_path.parent != Path('.'):
        output_file_path = prefix_path.with_suffix(f".{output_format}")
    else:
        output_file_path = output_dir / f"{output_prefix}.{output_format}"

    # Ensure parent directory exists before writing
    output_file_path.parent.mkdir(parents=True, exist_ok=True)
    APP_ROOT = os.getenv('APP_ROOT')

    if im.category.name == "BINDING_RESULT_URI":
        df = utils.uri_to_df(data['mhc_result_uri'])
        # Create output directory for uri input type
        # TODO: Ideally, should go to the preprocess' output_dir/tmp directory
        output_dir = tempfile.mkdtemp(prefix="phbr-uri-output-", dir=output_dir)
        peptide_table_path, sequence_table_path = utils.create_peptide_and_sequence_files(df, output_dir)

    elif im.category.name == "PEPTIDE_SEQUENCE_TABLE": 
        peptide_table_path = data['mhc_peptide_tsv']
        sequence_table_path = data['mhc_sequence_tsv']

    else :
        ''' 
        Other input categories require MHC binding to be processed first. 
        ------------------------------------------------------------
        These cases will have sequence table path already given, but not peptide table path.
        Thus, need to run MHC binding to get peptide table path.
        ------------------------------------------------------------
        '''
        if im.category.name == "MHC_SEQUENCE_TABLE":
            sequence_table_path = data['mhc_sequence_tsv']

        if im.category.name == "NEOEPITOPES_STRING":
            # Parse neoepitopes string into a dataframe
            # Name is misleading as it is a DataFrame, not a path to a tsv file.
            # For simplicity, we will use the same function to create the sequence id dict.
            df = pd.read_csv(StringIO(data['input_neoepitopes']), sep="\t")

            # Need to convert 'peptide' column to 'sequence' column
            # as it is required for "create_sequence_id_dict" function
            if im.atomic_table.name != "MUT_VS_REF_TABLE":
                df = df.rename(columns=lambda x: 'sequence' if 'peptide' in x.lower() else x)
            
            sequence_table_path = utils.save_df_to_temporary_file(df, output_dir)    

        is_mut_vs_ref_table = im.atomic_table.name == "MUT_VS_REF_TABLE"

        ########################################################
        # Run MHC binding to get peptide table
        ########################################################
        # NOTE: "run_mhc_binding" will work off of sequence_table_path,
        # so we need to add sequence_table_path to the data as "mhc_sequence_tsv"
        data['mhc_sequence_tsv'] = sequence_table_path

        # Add logic to capture the mutation sequence and reference sequence in a dictionary
        # The key is ('row index + 1', 'mut_peptide') and the value is 'ref_peptide'
        # Only create this dictionary once, not for each MHC class
        sequence_table_df = pd.read_csv(sequence_table_path, sep='\t')

        if im.atomic_table.name == "MUT_VS_REF_TABLE":  # Only create if not already created
            mut_peptide_ref_peptide_dict = utils.map_mut_peptide_ref_peptide(sequence_table_df)

            # Turn the mut_vs_ref_table into a sequence table
            sequence_table_df = utils.turn_mut_vs_ref_table_into_sequence_table(sequence_table_df)
            print('sequence_table_df: \n', sequence_table_df)

            # Update the sequence_table_path
            sequence_table_path = utils.save_df_to_temporary_file(sequence_table_df, output_dir)


        # Run MHC binding
        # NOTE: 'run_mhc_binding' will return job_description_path when the subcommand is 'preprocess'
        mhc_binding_output_dir = run_mhc_binding(data, mhc_class, is_mut_vs_ref_table)

        # Return job description path immediately if the subcommand is 'preprocess'
        if data['metadata']['subcommand'] == 'preprocess':
            return mhc_binding_output_dir

        peptide_table_path = mhc_binding_result_json2tsv(mhc_binding_output_dir, mhc_class)

        print('peptide_table_path: ', peptide_table_path)
        print('sequence_table_path: ', sequence_table_path)

    # Create dictionary that keeps track of all the sequences
    # and their sequence numbers.
    # Example: {'MGQIVTMFEALPHIIDEVINIVIIVLIVITGI': 1}
    sequence_id_dict = utils.create_sequence_id_dict(sequence_table_path)
    print('sequence_id_dict: ', sequence_id_dict)
    print('--------------------------------')

    # Ensure the sequence table used by the converter contains a 'seq #' column
    # The converter merges on this column by default; add it if missing.
    seq_df_check = pd.read_csv(sequence_table_path, sep='\t')
    seq_df_check = utils.add_sequence_number_to_sequence_table(seq_df_check)
    seq_df_check.to_csv(sequence_table_path, sep='\t', index=False)

    ########################################################
    # At this point, all the input data is processed and we have the following:
    # - peptide_table_path
    # - sequence_table_path
    # - sequence_id_dict
    ########################################################
    '''
    Convert MHC binding result to PHBR input format.
    * Call mhc2phbr.py to convert MHC binding result to PHBR input format.
    '''
    tmp_mhc_pred_file = tempfile.NamedTemporaryFile(prefix="phbr-input-", delete=False, dir=output_dir)
    mhc2phbr_fpath = utils.find_file_path(start_dir=APP_ROOT, filename='mhc2phbr.py')

    method = data['class_i'].get('prediction_method', {}).get('method', '') if mhc_class == MHCClass.MHCI else data['class_ii'].get('prediction_method', {}).get('method', '')

    rank_colname = utils.find_ranking_column_name(peptide_table_path, method)
    print('rank_colname: ', rank_colname)
    print('--------------------------------')

    mhc_args = [
        '--peptide-output', peptide_table_path,
        '--sequence-output', sequence_table_path,
        '--phbr-input', tmp_mhc_pred_file.name,
        '--rank-colname', rank_colname # NOTE: This is the default rank_colname
    ]
    
    # NOTE: If neither is specified, it will use the central position
    # mut_pos = data['metadata'].get('mut_pos')
    # mut_pos_col = data['metadata'].get('mut_pos_col')
    mut_pos = data.get('mutation_position')
    # Convert mut_pos to integer if it's a string
    if mut_pos and isinstance(mut_pos, str):
        mut_pos = int(mut_pos)
    mut_pos_col = data.get('mutation_position_colname')

    if mut_pos_col or mut_pos:
        # Read the sequence DataFrame for validation
        seq_df = pd.read_csv(sequence_table_path, sep='\t')
        validate_mutation_position(seq_df, mut_pos_col, mut_pos)

    if mut_pos:
        mhc_args = mhc_args + ['--mutation-position', str(mut_pos),]
    elif mut_pos_col:
        mhc_args = mhc_args + ['--sequence-mutation-position-colname', mut_pos_col]

    command = ['python', mhc2phbr_fpath] + mhc_args

    try:
        print('--------------------------------------------------')
        command_str = ' '.join(str(x) for x in command)
        print("Running command for mhc2phbr.py:\n", command_str)
        print('--------------------------------------------------')
        subprocess.run(command, capture_output=True, text=True, check=True)
        # validators.validate_mhc2phbr_output(tmp_mhc_pred_file.name)
        utils.remove_empty_rank_rows(tmp_mhc_pred_file.name)
        
        print(f'Result of \'mhc2phbr.py\' saved to {tmp_mhc_pred_file.name}')
    except subprocess.CalledProcessError as e:
        print('Error:', e.stderr)
        print('Return Code:', e.returncode)
    
    # The output file of 'mhc2phbr.py' is the final input file for PHBR
    phbr_fpath = utils.find_file_path(start_dir=APP_ROOT, filename='phbr.py')
    if not phbr_fpath:
        message = f"""
            phbr.py not found. Please check if phbr.py exists 
            in this project.
        """
        raise KeyError(utils.format_message(message))

    # TODO: Should I move the tmp_mhc_pred_file to the app/predict-inputs/data?
    phbr_args = [
        '--mhc-predictions', tmp_mhc_pred_file.name,
        '--output-file', output_file_path,
        '--mhci' if mhc_class == MHCClass.MHCI else '--mhcii'
    ]

    homozygous_loci = None
    if mhc_class == MHCClass.MHCI:
        homozygous_loci = data['class_i'].get('homozygous_loci')
    if mhc_class == MHCClass.MHCII:
        homozygous_loci = data['class_ii'].get('homozygous_loci')

    if homozygous_loci:
        phbr_args = phbr_args + ['--homozygous-loci', homozygous_loci,]
    
    command = ['python', phbr_fpath] + phbr_args

    try:
        print('--------------------------------------------------')
        print("Running PHBR command:\n", command)
        print('--------------------------------------------------')
        subprocess.run(command, capture_output=True, text=True, check=True)
        # # Add sequence number to the PHBR output
        # utils.add_sequence_number_to_phbr_output(output_file_path, sequence_id_dict)

        # Filtering homozygous loci columns such as #A, #B, #C, #DP, #DQ, #DR
        # Also, renaming the PHBR column to PHBR-I or PHBR-II
        # utils.filter_output_file(output_file_path, mhc_class)

        # Add original reference sequence to the PHBR output
        # only for paired peptides input type
        # if im.category.name == 'PAIRED_PEPTIDES': 
        #     utils.add_original_reference_sequence_to_phbr_output(output_file_path, mut_peptide_ref_peptide_dict)

        print(f'Final result of saved to {output_file_path}')
    except subprocess.CalledProcessError as e:
        print('Error:', e.stderr)
        print('Return Code:', e.returncode)

    # Read the file once to a dataframe and then format the output.
    phbr_output_df = pd.read_csv(output_file_path, sep='\t')

    # Add sequence number to the PHBR output
    phbr_output_df = utils.add_sequence_number_to_phbr_output(phbr_output_df, sequence_id_dict)
    # Filtering homozygous loci columns such as #A, #B, #C, #DP, #DQ, #DR
    phbr_output_df = utils.filter_homozygous_loci_columns_from_phbr_output(phbr_output_df, mhc_class)
    # Format the PHBR output to 4 decimal places
    phbr_output_df = utils.format_phbr_values(phbr_output_df)

    # print('000000000000000000000000000000000000')
    # print(im.atomic_table.name)
    # print(im.describe())
    # print('000000000000000000000000000000000000')

    # FINAL STEP of formatting the PHBR output
    if im.atomic_table.name == "MUT_VS_REF_TABLE":
        # NOTE: Beneath function needs "seq #" column to be present in the PHBR output
        phbr_output_df = utils.add_original_reference_sequence_to_phbr_output(phbr_output_df, mut_peptide_ref_peptide_dict)

    # We will be saving the PHBR output to a file later, so we need to return the dataframe
    # Remove the current output file
    os.remove(output_file_path)

    return phbr_output_df


def main():
    parser = PhbrArgumentParser()
    args = parser.parse_args()
    print(args)

    if args.subcommand == 'predict':
        # Load the JSON file
        data = json.load(args.input_json)

        # Create a temporary directory for temporary files
        if args.output_dir != '/tmp':
            tmp_output_dir = args.output_dir + '/phbr/tmp'
            os.makedirs(tmp_output_dir, exist_ok=True)
        else:
            tmp_output_dir = args.output_dir

        data['metadata'] = {
            'output_dir': tmp_output_dir,
            'output_prefix': None if args.output_prefix is None else args.output_prefix,
            'output_format': args.output_format,
            'subcommand': args.subcommand
        }

        # Build final PHBR output file path robustly, honoring absolute prefixes
        tmp_output_dir_path = Path(tmp_output_dir)
        out_prefix = args.output_prefix if args.output_prefix is not None else 'phbr_output'
        out_prefix_path = Path(out_prefix)
        if out_prefix_path.is_absolute() or out_prefix_path.parent != Path('.'):
            phbr_output_file_path = out_prefix_path.with_suffix(f".{args.output_format}")
        else:
            phbr_output_file_path = tmp_output_dir_path / f"{out_prefix}.{args.output_format}"
        phbr_output_file_path.parent.mkdir(parents=True, exist_ok=True)
        phbr_output_file = str(phbr_output_file_path)

        # Determine user-input type
        input_manager = InputManager(data)
        # input_description = input_manager.describe()
        # print(input_manager.mhc_classes)
        # print(input_description)
        
        # TODO: I should save it to json if the output format is json
        phbr_i_output_df = None
        phbr_ii_output_df = None
        if input_manager.has_mhci:
            phbr_i_output_df = run_prediction(data, input_manager, MHCClass.MHCI)
            # Rename the 'PHBR' column to 'PHBR-I'
            phbr_i_output_df = phbr_i_output_df.rename(columns={'PHBR': 'PHBR-I'})
            print('phbr_i_output_df: \n', phbr_i_output_df)
            
            # Save to phbr_i_output-df to phbr_output_file
            if args.output_format == 'json':
                save_phbr_to_formatted_json(phbr_i_output_df, phbr_output_file)
            else:
                phbr_i_output_df.to_csv(phbr_output_file, sep='\t', index=False)


        if input_manager.has_mhcii:
            phbr_ii_output_df = run_prediction(data, input_manager, MHCClass.MHCII)
            # Rename the 'PHBR' column to 'PHBR-II'
            phbr_ii_output_df = phbr_ii_output_df.rename(columns={'PHBR': 'PHBR-II'})
            print('phbr_ii_output_df: \n', phbr_ii_output_df)

            # Save to phbr_ii_output-df to phbr_output_file
            if args.output_format == 'json':
                save_phbr_to_formatted_json(phbr_ii_output_df, phbr_output_file)
            else:
                phbr_ii_output_df.to_csv(phbr_output_file, sep='\t', index=False)


        # Combine the PHBR output
        if phbr_i_output_df is not None and phbr_ii_output_df is not None:
            # combine_phbr_results(phbr_i_output_df, phbr_ii_output_df, 'phbr_output.tsv')
            combined_df = pd.merge(phbr_i_output_df, phbr_ii_output_df, on=['seq #', 'mutant peptide'], how='outer', suffixes=('', '_II'))
            print('--------------------------------')
            print(combined_df)

            # Save to combined_df to phbr_output_file
            if args.output_format == 'json':
                save_phbr_to_formatted_json(combined_df, phbr_output_file)
            else:
                combined_df.to_csv(phbr_output_file, sep='\t', index=False)
            
            print(f'combined_df saved to {phbr_output_file}')


    if args.subcommand == 'preprocess':
        # ADD CODE LOGIC TO SPLIT INPUTS INSIDE PREPROCESS.PY
        preprocess.run(**vars(args))

    if args.subcommand == 'postprocess':
        # ADD CODE LOGIC TO COMBINE RESULTS INSIDE POSTPROCESS.PY
        postprocess.run(**vars(args))

if __name__=='__main__':
    main()