'''
This script adds the PHBR alleles to the tools-mapping.tsv file.
'''

import sys
import pandas as pd
from pathlib import Path
from typing import List, Dict, Any, Set
import re


def strip_special_chars(text: str) -> str:
    '''Strip special characters from text, keeping only alphanumeric characters'''
    # Replace any non-alphanumeric characters (including underscore) with empty string
    return re.sub(r'[^a-zA-Z0-9]', '', text)


def read_tools_mapping() -> pd.DataFrame:
    '''Read the tools-mapping.tsv file and return as a DataFrame'''
    data_dir: str = str(Path(__file__).resolve().parents[1]) + "/data"
    tools_mapping_path: str = data_dir + "/tools-mapping.tsv"
    return pd.read_csv(tools_mapping_path, sep='\t')


def add_mhci_alleles_to_phbr(tools_mapping_df: pd.DataFrame) -> pd.DataFrame:
    # deep copy the tools-mapping.tsv file
    tools_mapping_df_copy: pd.DataFrame = tools_mapping_df.copy()

    # add the MHCI alleles to the tools-mapping.tsv file
    mhci_alleles_df: pd.DataFrame = tools_mapping_df_copy[tools_mapping_df_copy['Tool Group'] == 'mhci']
    phbr_i_alleles_df: pd.DataFrame = tools_mapping_df_copy[tools_mapping_df_copy['Tool'] == 'phbr_i']

    # Retrieve PHBR lengths from allele-lengths.xlsx file
    allele_lengths_path: str = str(Path(__file__).resolve().parents[1]) + "/data/allele-lengths.xlsx"
    allele_lengths_df: pd.DataFrame = pd.read_excel(allele_lengths_path, engine='openpyxl')
    allele_lengths_df: pd.DataFrame = allele_lengths_df[allele_lengths_df['method'] == 'phbr_i']
    allele_lengths: str = allele_lengths_df['lengths'].values[0]
    
    # Extract just the numbers from the string representation of a set
    allele_lengths: str = allele_lengths.strip('{}').strip("'")

    # Add the MHCI alleles to the tools-mapping.tsv file
    phbr_additional_rows: List[Dict[str, Any]] = []
    for index, row in mhci_alleles_df.iterrows():
        # Check if the row['Tool Label'] is already in the phbr_i_alleles_df
        if row['Tool Label'] in phbr_i_alleles_df['Tool Label'].values:
            continue
        
        phbr_additional_rows.append({
            'Tool Group': 'phbr',
            'Tool': 'phbr_i',
            'Tool Version': '1.0',
            'Tool Label': row['Tool Label'],
            'MRO ID': row['MRO ID'],
            'Lengths': allele_lengths,
        })
        
    # Convert the list of dictionaries to a DataFrame
    phbr_additional_rows_df: pd.DataFrame = pd.DataFrame(phbr_additional_rows)

    # Concatenate the additional rows to the original DataFrame
    tools_mapping_df: pd.DataFrame = pd.concat([tools_mapping_df, phbr_additional_rows_df], ignore_index=True)

    return tools_mapping_df


def get_single_chain_alleles(mhcii_alleles_df: pd.DataFrame) -> List[str]:
    '''Get the single chain alleles from the MHCII alleles DataFrame'''
    single_chain_alleles: Set[str] = set()  # Using set to avoid duplicates
    
    # Iterate through all Tool Labels
    for tool_label in mhcii_alleles_df['Tool Label']:
        if '/' in tool_label:
            # Split by '/' and check each part
            left_chain: str
            right_chain: str
            left_chain, right_chain = tool_label.split('/')
            if left_chain not in single_chain_alleles:
                single_chain_alleles.add(left_chain)
            if right_chain not in single_chain_alleles:
                single_chain_alleles.add(right_chain)
        else:
            # If no '/', add the entire Tool Label
            if tool_label not in single_chain_alleles:
                single_chain_alleles.add(tool_label)


    return list(single_chain_alleles)  # Convert set back to list for return value


def add_mhcii_alleles_to_phbr(tools_mapping_df: pd.DataFrame) -> pd.DataFrame:
    '''Add the MHCII alleles to the tools-mapping.tsv file'''
    # deep copy the tools-mapping.tsv file
    tools_mapping_df_copy: pd.DataFrame = tools_mapping_df.copy()

    mhc_alleles_df: pd.DataFrame = pd.read_csv(str(Path(__file__).resolve().parents[1]) + "/data/mhc_alleles.tsv", sep='\t')
    
    # Add a new column with stripped IEDB Labels
    mhc_alleles_df['Stripped IEDB Label'] = mhc_alleles_df['IEDB Label'].apply(strip_special_chars)

    print(mhc_alleles_df.head(), mhc_alleles_df.shape)
    print(mhc_alleles_df.tail(), mhc_alleles_df.shape)

    # add the MHCII alleles to the tools-mapping.tsv file
    mhcii_alleles_df: pd.DataFrame = tools_mapping_df_copy[tools_mapping_df_copy['Tool Group'] == 'mhcii']
    phbr_ii_alleles_df: pd.DataFrame = tools_mapping_df_copy[tools_mapping_df_copy['Tool'] == 'phbr_ii']

    unique_single_chain_alleles: List[str] = get_single_chain_alleles(mhcii_alleles_df)


    # Add the MHCII alleles to the tools-mapping.tsv file
    phbr_additional_rows: List[Dict[str, Any]] = []
    for allele in unique_single_chain_alleles:
        # Check if the allele is already in the phbr_ii_alleles_df
        if allele in phbr_ii_alleles_df['Tool Label'].values:
            continue

        # Strip special characters from the allele for comparison
        stripped_allele: str = strip_special_chars(allele)
        
        # Get the IEDB Label from the mhc_alleles.tsv file and strip special characters
        filtered_mhc_alleles_df: pd.DataFrame = mhc_alleles_df[mhc_alleles_df['Stripped IEDB Label'] == stripped_allele]
        

        # If empty, try searching for HLA-{allele}
        if filtered_mhc_alleles_df.empty:
            filtered_mhc_alleles_df: pd.DataFrame = mhc_alleles_df[mhc_alleles_df['Stripped IEDB Label'] == f"HLA{stripped_allele}"]


        # If still empty, print error and exit
        if filtered_mhc_alleles_df.empty:
            print(f"Excluding: No matching IEDB Label found for stripped allele: {stripped_allele}: {allele}")
            continue
        
        mro_id: str = filtered_mhc_alleles_df['MRO ID'].values[0]

        # Add the MHCII alleles to the tools-mapping.tsv file
        phbr_additional_rows.append({
            'Tool Group': 'phbr',
            'Tool': 'phbr_ii',
            'Tool Version': '1.0',
            'Tool Label': allele,
            'MRO ID': mro_id,
            'Lengths': '15',
        })

    # Convert the list of dictionaries to a DataFrame
    phbr_additional_rows_df: pd.DataFrame = pd.DataFrame(phbr_additional_rows)
    
    # Concatenate the additional rows to the original DataFrame
    tools_mapping_df: pd.DataFrame = pd.concat([tools_mapping_df, phbr_additional_rows_df], ignore_index=True)

    return tools_mapping_df


def main() -> None:
    '''Main function'''
    tools_mapping_df: pd.DataFrame = read_tools_mapping()

    # Add MHCI alleles to the tools-mapping.tsv file as part of PHBR tool.
    tools_mapping_df: pd.DataFrame = add_mhci_alleles_to_phbr(tools_mapping_df)

    # Add MHCII alleles to the tools-mapping.tsv file as part of PHBR tool.
    tools_mapping_df: pd.DataFrame = add_mhcii_alleles_to_phbr(tools_mapping_df)

    # Save the tools-mapping.tsv file
    tools_mapping_df.to_csv(str(Path(__file__).resolve().parents[1]) + "/data/tools-mapping.tsv", sep='\t', index=False)


if __name__ == "__main__":
    main()