# Here, the code for the following logic should be implemented.
# * It should read over all the results file created (under 'preprocess_job/results/') from each job units.
# * Every tool will differ, but logic to combine all the results into single file is needed.
#     * This file should be saved under 'postprocess_job/aggregated_result.json'.
import json
import os
import sys
import pandas as pd
from pathlib import Path
from pprint import pprint
from enum import Enum
from typing import Dict, Any


class TCellClass(str, Enum):
    I = 'i'
    II = 'ii'


def get_column_indices(table_columns, target_columns):
    """
    Get column indices for target columns.
    
    Args:
        table_columns: List of column names
        target_columns: List of target column names to find
        
    Returns:
        Dict mapping target column names to their indices
    """
    indices = {}
    for target_col in target_columns:
        for i, col_name in enumerate(table_columns):
            if col_name == target_col:
                indices[target_col] = i
                break
    return indices


def reformat_phbr_output(phbr_output_content):
    '''
    Reformat the output content
    '''
    output = {}

    output['warnings'] = phbr_output_content['warnings']
    
    phbr_results_section = phbr_output_content['results'][0]

    del phbr_results_section['method']

    updated_table_columns = []
    for table_col_name in phbr_results_section['table_columns']:
        if table_col_name == 'peptide':
            updated_table_columns.append('core.peptide')
        elif table_col_name == 'PHBR-I':
            updated_table_columns.append('phbr.PHBR-I')
        elif table_col_name == 'PHBR-II':
            updated_table_columns.append('phbr.PHBR-II')
        else:
            updated_table_columns.append(table_col_name)

    phbr_results_section['table_columns'] = updated_table_columns
    phbr_results_section['field_ranges'] = {}

    peptides = []
    score_range_i = []
    score_range_ii = []
    
    # Get column indices based on column names instead of counting
    table_columns = phbr_results_section['table_columns']
    table_data = phbr_results_section['table_data']
    
    # Validate that we have the expected structure
    if not table_data:
        raise ValueError("No table data found in PHBR results")
    
    # Find indices for PHBR columns using helper function
    target_columns = ['phbr.PHBR-I', 'phbr.PHBR-II', 'core.peptide']
    column_indices = get_column_indices(table_columns, target_columns)
    
    phbr_i_idx = column_indices.get('phbr.PHBR-I')
    phbr_ii_idx = column_indices.get('phbr.PHBR-II')
    peptide_idx = column_indices.get('core.peptide')
    
    # Validate that we found at least one PHBR column
    if phbr_i_idx is None and phbr_ii_idx is None:
        raise ValueError("No PHBR columns found in table_columns. Expected 'phbr.PHBR-I' or 'phbr.PHBR-II'")
    
    # Process PHBR-I if present
    if phbr_i_idx is not None:
        phbr_i_values = [row[phbr_i_idx] for row in table_data]
        # Validate that values are numeric
        try:
            phbr_i_values = [float(val) for val in phbr_i_values]
            val_max = max(phbr_i_values)
            val_min = min(phbr_i_values)
            phbr_results_section['field_ranges']['phbr.PHBR-I'] = {
                'min': val_min,
                'max': val_max,
            }
        except (ValueError, TypeError) as e:
            raise ValueError(f"PHBR-I values are not numeric: {e}")
    
    # Process PHBR-II if present
    if phbr_ii_idx is not None:
        phbr_ii_values = [row[phbr_ii_idx] for row in table_data]
        # Validate that values are numeric
        try:
            phbr_ii_values = [float(val) for val in phbr_ii_values]
            val_max = max(phbr_ii_values)
            val_min = min(phbr_ii_values)
            phbr_results_section['field_ranges']['phbr.PHBR-II'] = {
                'min': val_min,
                'max': val_max,
            }
        except (ValueError, TypeError) as e:
            raise ValueError(f"PHBR-II values are not numeric: {e}")

    # Process peptides if peptide column is found
    if peptide_idx is not None:
        for row in table_data:
            pep = row[peptide_idx]
            if pep not in peptides: 
                peptides.append(pep)

    phbr_results_section['unique_vals'] = {
        'core.peptide': peptides
    }
    
    output['results'] = [phbr_results_section]

    return output


def replace_extension_with_json(filepath: str) -> str:
    """
    Replace the extension of a file with .tsv
    
    Args:
        filepath: Path to the file
        
    Returns:
        Path with .tsv extension
    """
    return str(Path(filepath).with_suffix('.json'))


def save_json_to_file(json_data: Dict[str, Any], output_file: str) -> None:
    """
    Save JSON data to a file with proper formatting.
    
    Args:
        json_data: Dictionary containing the JSON data
        output_file: Path to the output file
    """
    json_output_file = replace_extension_with_json(output_file)
    with open(json_output_file, 'w') as f:
        json.dump(json_data, f, indent=2)
    
    print(f"Converted JSON to TSV: {json_output_file}")


def df_to_json(df: pd.DataFrame) -> Dict[str, Any]:
    """
    Convert DataFrame to JSON format matching the expected structure.
    
    Args:
        df: DataFrame to convert
        
    Returns:
        Dictionary containing the JSON structure
    """
    # Create the JSON structure
    json_data = {
        "warnings": [],
        "results": [
            {
                "method": "phbr",
                "type": "peptide_table",
                "table_columns": df.columns.tolist(),
                "table_data": df.values.tolist()
            }
        ]
    }
    
    return json_data


def combine_phbr_results(mhci_output: str, mhcii_output: str, output_file: str) -> None:
    """
    Combine MHCI and MHCII results into a single file when they share the same peptide.
    
    Args:
        mhci_output: Path to MHCI output file
        mhcii_output: Path to MHCII output file
        output_file: Path to combined output file
    """
    print('--------------------------------------------------')
    print("Combining MHCI and MHCII results...")

    # Read both files
    mhci_df = pd.read_csv(mhci_output, sep='\t')
    mhcii_df = pd.read_csv(mhcii_output, sep='\t')
    
    print(f"MHCI columns: {mhci_df.columns.tolist()}")
    print(f"MHCII columns: {mhcii_df.columns.tolist()}")
    
    # Rename PHBR columns to distinguish between MHCI and MHCII
    mhci_df = mhci_df.rename(columns={'PHBR': 'PHBR-I'})
    mhcii_df = mhcii_df.rename(columns={'PHBR': 'PHBR-II'})
    
    # Merge on peptide column
    combined_df = pd.merge(
        mhci_df,
        mhcii_df,
        on=['peptide'],
        how='outer',
        suffixes=('', '-II')  # Only add suffix to MHCII columns that might conflict
    )
    
    # Define the desired column order
    column_order = [
        'peptide',
        '#A', '#B', '#C',  # MHCI columns
        '#DP', '#DQ', '#DR',  # MHCII columns
        'PHBR-I', 'PHBR-II'  # PHBR scores
    ]
    
    # Reorder columns
    # First, ensure all expected columns exist (fill with 0.0 if missing)
    for col in column_order:
        if col not in combined_df.columns:
            combined_df[col] = 0.0
    
    # Then reorder columns
    combined_df = combined_df[column_order]
    
    # Fill NaN values with 0.0
    combined_df = combined_df.fillna(0.0)
    
    # Save combined results
    combined_df.to_csv(output_file, sep='\t', index=False)


def filter_output_file(output_file: str, t_cell_class: TCellClass) -> None:
    """
    Removes unneeded columns from the output file.
    """
    colmns_to_remove = []
    if t_cell_class == TCellClass.I:
        colmns_to_remove = ['#A', '#B', '#C']
    elif t_cell_class == TCellClass.II:
        colmns_to_remove = ['#DP', '#DQ', '#DR']

    # Read the output file
    df = pd.read_csv(output_file, sep='\t')

    # Remove the columns
    df = df.drop(colmns_to_remove, axis=1)
    

    # Determine if the PHBR column name is PHBR or PHBR_I or PHBR_II
    phbr_column_name = None
    if t_cell_class == TCellClass.I:
        # Search for either PHBR or PHBR_I
        phbr_columns = df.columns[df.columns.str.contains('PHBR')]
        phbr_i_columns = phbr_columns[phbr_columns.str.contains('PHBR-I')]
        if len(phbr_i_columns) > 0:
            phbr_column_name = phbr_i_columns[0]
        else:
            phbr_column_name = phbr_columns[0]
    elif t_cell_class == TCellClass.II:
        # Search for either PHBR or PHBR_II
        phbr_columns = df.columns[df.columns.str.contains('PHBR')]
        phbr_ii_columns = phbr_columns[phbr_columns.str.contains('PHBR-II')]
        if len(phbr_ii_columns) > 0:
            phbr_column_name = phbr_ii_columns[0]
        else:
            phbr_column_name = phbr_columns[0]


    # Change column name to 'PHBR-I' or 'PHBR-II'
    if t_cell_class == TCellClass.I:
        # If there's a 'PHBR' column, rename it to 'PHBR-I'
        if 'PHBR' in list(df.columns):
            df = df.rename(columns={'PHBR': 'PHBR-I'})
        
        phbr_column_name = 'PHBR-I'
            
    elif t_cell_class == TCellClass.II:
        # If there's a 'PHBR' column, rename it to 'PHBR-II'
        if 'PHBR' in list(df.columns):
            df = df.rename(columns={'PHBR': 'PHBR-II'})

        phbr_column_name = 'PHBR-II'


    # Format the PHBR column to 4 decimal places
    df[phbr_column_name] = df[phbr_column_name].apply(lambda x: f"{float(x):.4f}")

    # Write the output file
    df.to_csv(output_file, sep='\t', index=False)




def run(**kwargs):
    '''
    options:
        -h, --help            show this help message and exit
        --job-desc-file JOB_DESC_FILE
                                Path to job description file.
        --input-results-dir POSTPROCESS_INPUT_DIR
                                directory containing the result files to postprocess
        --postprocessed-results-dir POSTPROCESS_RESULT_DIR
                                a directory to contain the post-processed results
        --output-prefix OUTPUT_PREFIX, -o OUTPUT_PREFIX
                                prediction result output prefix.
        --output-format OUTPUT_FORMAT, -f OUTPUT_FORMAT
                                prediction result output format (Default=json)
    '''
    job_desc_file = kwargs.get('job_desc_file')
    include_mhci_mhcii_result = kwargs.get('include_mhci_mhcii_result', False)
    postprocess_result_dir = kwargs.get('postprocess_result_dir')
    output_prefix = kwargs.get('output_prefix', 'formatted_phbr_result')
    output_format = kwargs.get('output_format')
    output_file_name = f'{output_prefix}.{output_format}'

    
    # Read the job description file
    jd_content = json.load(job_desc_file)
    
    # # Last job from the job description file is the PHBR job that outputs the final result.
    # phbr_job = jd_content[-1]
    
    print('postprocess_result_dir: ', postprocess_result_dir)
    # count number of json files in the postprocess_result_dir
    tsv_files = [f for f in os.listdir(postprocess_result_dir) if f.endswith('.tsv')]
    json_files = [f for f in os.listdir(postprocess_result_dir) if f.endswith('.json')]

    print('tsv_files: ', tsv_files)

    if len(tsv_files) == 1:
        phbr_job = Path(postprocess_result_dir) / json_files[0]
    else:
        # Need to combine results from multiple jobs
        # 1. Read all the json files
        # 2. Combine the results
        # 3. Save the combined results to a single json file
        print(tsv_files)

        # NOTE: Because MHCI always preceeds MHCII, the file name with lower number is MHCI
        # sort by the file name
        tsv_files.sort()

        phbr_mhci_output = Path(postprocess_result_dir) / tsv_files[0]
        phbr_mhcii_output = Path(postprocess_result_dir) / tsv_files[1]

        print(phbr_mhci_output)
        print(phbr_mhcii_output)
        # sys.exit()


        # Combine the results
        combine_phbr_results(phbr_mhci_output, phbr_mhcii_output, output_file_name)
        print(f"Combined results saved to: {output_file_name}")
        filter_output_file(output_file_name, TCellClass.I)
        filter_output_file(output_file_name, TCellClass.II)

        # Convert it back to JSON file
        df = pd.read_csv(output_file_name, sep='\t')
        json_result = df_to_json(df)
        save_json_to_file(json_result, output_file_name)

        phbr_job = output_file_name
        
        # print('Multiple PHBR jobs found. Terminating script.')
        # sys.exit(1)
    
    with open(phbr_job, 'r') as f:
        phbr_output_content = json.load(f)

    # Reformat the output content
    formatted_output_content = reformat_phbr_output(phbr_output_content)

    # if include_mhci_mhcii_result is True, we will include the MHCI and MHCII results in the output
    """
    {
        "shell_cmd": "/share/apps/iedbtools/ng_tc2-0.2.1-beta/src/tcell_mhcii.py --aggregate --job-desc-file=/scratch/stages/tmpp9zb6s0z/split_inputs/mhcii/job_descriptions.json --aggregate-input-dir=/scratch/stages/tmpp9zb6s0z/split_inputs/mhcii/results --aggregate-result-dir=/scratch/stages/tmpp9zb6s0z/split_inputs/mhcii/aggregate --keep-empty-row",
        "job_id": 32,
        "job_type": "aggregate",
        "depends_on_job_ids": [
        26,
        27,
        28,
        29,
        30,
        31
        ],
        "expected_outputs": [
        "/scratch/stages/tmpp9zb6s0z/split_inputs/mhcii/aggregate/aggregated_result.json"
        ]
    }
    """
    if include_mhci_mhcii_result:
        for job_data in jd_content:
            if job_data['job_type'] == 'aggregate' and 'tcell_mhci.py' in job_data['shell_cmd']:
                mhci_output_file = job_data['expected_outputs'][0]
                with open(mhci_output_file, 'r') as mhci_output_file:
                    mhci_result = json.load(mhci_output_file)
                for result in mhci_result['results']:
                    if result['result_type'] == 'peptide_table':
                        result = result.copy()  # Create a shallow copy to avoid modifying the original
                        result.pop('result_type', None)  # Remove 'result_type' if it exists
                        result['type'] = 'tc1_peptide_table'
                        formatted_output_content['results'].append(result)
            elif job_data['job_type'] == 'aggregate' and 'tcell_mhcii.py' in job_data['shell_cmd']:
                mhcii_output_file = job_data['expected_outputs'][0]
                with open(mhcii_output_file, 'r') as mhcii_output_file:
                    mhcii_result = json.load(mhcii_output_file)
                for result in mhcii_result['results']:
                    if result['result_type'] == 'peptide_table':
                        result = result.copy()  # Create a shallow copy to avoid modifying the original
                        result.pop('result_type', None)  # Remove 'result_type' if it exists
                        result['type'] = 'tc2_peptide_table'
                        formatted_output_content['results'].append(result)

    # Save the formatted output to a JSON file
    output_file_path = f'{output_prefix}.{output_format}'

    print('output_file_path: ', output_file_path)


    with open(output_file_path, 'w') as f:
        json.dump(formatted_output_content, f, indent=2)

    print(f"Saved formatted output to: {output_file_path}")
    
