# NOTE:
# This script is to remove all old netmhciipan alleles and add 
# valid netmhciipan alleles that are retrieved from the 
# netmhciipan-4.1-alleles-name.txt, netmhciipan-4.2-alleles-name.txt,
# and netmhciipan-4.3-alleles-name.txt files.

import sys
import re
import pandas as pd
import numpy as np
from pathlib import Path
from itertools import product
PROJECT_DIR = str(Path(__file__).resolve().parents[1])
sys.path.insert(1, PROJECT_DIR)

# Get allele data
PARENT_DIR = Path(__file__).parent
DATA_DIR = PARENT_DIR.parent / "data"
NETMHCIIPAN_43 = DATA_DIR / "netmhciipan-4.3"
NETMHCIIPAN_42 = DATA_DIR / "netmhciipan-4.2"
NETMHCIIPAN_41 = DATA_DIR / "netmhciipan-4.1"
ORIG_TM_FILE = DATA_DIR / "Tools_MRO_mapping.xlsx"
TOOLS_MAPPING_FILE = DATA_DIR / "tools-mapping.tsv"
MRO_MOLECULES_FILE = DATA_DIR / "mro_molecules.tsv"


def get_netmhciipan_4_1_alleles():
    ALLELELIST_FILE = NETMHCIIPAN_41 / "netmhciipan-4.1-alleles-name.txt"
    
    # Read the file and split into lines
    with open(ALLELELIST_FILE, 'r') as f:
        lines = f.readlines()
    
    # Get column names
    column_names = ['DR', 'DQ alpha', 'DQ beta', 'DP alpha', 'DP beta', 'Mouse']
    
    # Initialize dictionary with empty lists for each column
    allele_dict = {col: [] for col in column_names}
    
    # Process each line after the header
    for line in lines[1:]:
        # Split the line by multiple spaces and filter out empty strings
        values = [val for val in re.split(r'\s{1,}', line.strip()) if val]
        
        # Map values to their correct columns based on patterns
        for value in values:
            if value.startswith('DRB'):
                allele_dict['DR'].append(value)
            elif value.startswith('DQA'):
                allele_dict['DQ alpha'].append(value)
            elif value.startswith('DQB'):
                allele_dict['DQ beta'].append(value)
            elif value.startswith('DPA'):
                allele_dict['DP alpha'].append(value)
            elif value.startswith('DPB'):
                allele_dict['DP beta'].append(value)
            elif value.startswith('H-2'):
                allele_dict['Mouse'].append(value)
    
    return allele_dict

def create_proper_netmhciipan_4_1_alleles(single_allele_dict):
    proper_dict = {}
    
    # Copy over DR, Mouse columns
    proper_dict['DR'] = single_allele_dict['DR']
    proper_dict['Mouse'] = single_allele_dict['Mouse']
    
    # Create DQ combinations (alpha/beta) using itertools.product
    dq_combinations = [f"{alpha}/{beta}" for alpha, beta in product(single_allele_dict['DQ alpha'], single_allele_dict['DQ beta'])]
    proper_dict['DQ'] = dq_combinations
    
    # Create DP combinations (alpha/beta) using itertools.product
    dp_combinations = [f"{alpha}/{beta}" for alpha, beta in product(single_allele_dict['DP alpha'], single_allele_dict['DP beta'])]
    proper_dict['DP'] = dp_combinations
    
    return proper_dict



def get_netmhciipan_4_2_alleles():
    ALLELELIST_FILE = NETMHCIIPAN_42 / "netmhciipan-4.2-alleles-name.txt"
    
    # Read the file and split into lines
    with open(ALLELELIST_FILE, 'r') as f:
        lines = f.readlines()
    
    # Get column names
    column_names = ['DR', 'DQ alpha', 'DQ beta', 'DP alpha', 'DP beta', 'Mouse']
    
    # Initialize dictionary with empty lists for each column
    allele_dict = {col: [] for col in column_names}
    
    # Process each line after the header
    for line in lines[1:]:
        # Split the line by multiple spaces and filter out empty strings
        values = [val for val in re.split(r'\s{2,}', line.strip()) if val]
        
        # Map values to their correct columns based on patterns
        for value in values:
            if value.startswith('DRB'):
                allele_dict['DR'].append(value)
            elif value.startswith('DQA'):
                allele_dict['DQ alpha'].append(value)
            elif value.startswith('DQB'):
                allele_dict['DQ beta'].append(value)
            elif value.startswith('DPA'):
                allele_dict['DP alpha'].append(value)
            elif value.startswith('DPB'):
                allele_dict['DP beta'].append(value)
            elif value.startswith('H-2'):
                allele_dict['Mouse'].append(value)
    
    return allele_dict

def create_proper_netmhciipan_4_2_alleles(single_allele_dict):
    proper_dict = {}
    
    # Copy over DR, Mouse columns
    proper_dict['DR'] = single_allele_dict['DR']
    proper_dict['Mouse'] = single_allele_dict['Mouse']
    
    # Create DQ combinations (alpha/beta) using itertools.product
    dq_combinations = [f"{alpha}/{beta}" for alpha, beta in product(single_allele_dict['DQ alpha'], single_allele_dict['DQ beta'])]
    proper_dict['DQ'] = dq_combinations
    
    # Create DP combinations (alpha/beta) using itertools.product
    dp_combinations = [f"{alpha}/{beta}" for alpha, beta in product(single_allele_dict['DP alpha'], single_allele_dict['DP beta'])]
    proper_dict['DP'] = dp_combinations
    
    return proper_dict



def get_netmhciipan_4_3_alleles():
    ALLELELIST_FILE = NETMHCIIPAN_43 / "netmhciipan-4.3-alleles-name.txt"
    
    # Read the file and split into lines
    with open(ALLELELIST_FILE, 'r') as f:
        lines = f.readlines()
    
    # Get column names
    column_names = ['DR', 'DQ alpha', 'DQ beta', 'DP alpha', 'DP beta', 'Mouse', 'BoLA']
    
    # Initialize dictionary with empty lists for each column
    allele_dict = {col: [] for col in column_names}
    
    # Process each line after the header
    for line in lines[1:]:
        # Split the line by multiple spaces and filter out empty strings
        values = [val for val in re.split(r'\s{2,}', line.strip()) if val]
        
        # Map values to their correct columns based on patterns
        for value in values:
            if value.startswith('DRB'):
                allele_dict['DR'].append(value)
            elif value.startswith('DQA'):
                allele_dict['DQ alpha'].append(value)
            elif value.startswith('DQB'):
                allele_dict['DQ beta'].append(value)
            elif value.startswith('DPA'):
                allele_dict['DP alpha'].append(value)
            elif value.startswith('DPB'):
                allele_dict['DP beta'].append(value)
            elif value.startswith('H-2'):
                allele_dict['Mouse'].append(value)
            elif value.startswith('BoLA'):
                allele_dict['BoLA'].append(value)
    
    return allele_dict

def create_proper_netmhciipan_4_3_alleles(single_allele_dict):
    # Initialize new dictionary
    proper_dict = {}
    
    # Copy over DR, Mouse, and BoLA columns
    proper_dict['DR'] = single_allele_dict['DR']
    proper_dict['Mouse'] = single_allele_dict['Mouse']
    proper_dict['BoLA'] = single_allele_dict['BoLA']
    
    # Create DQ combinations (alpha/beta) using itertools.product
    dq_combinations = [f"{alpha}/{beta}" for alpha, beta in product(single_allele_dict['DQ alpha'], single_allele_dict['DQ beta'])]
    proper_dict['DQ'] = dq_combinations
    
    # Create DP combinations (alpha/beta) using itertools.product
    dp_combinations = [f"{alpha}/{beta}" for alpha, beta in product(single_allele_dict['DP alpha'], single_allele_dict['DP beta'])]
    proper_dict['DP'] = dp_combinations
    
    return proper_dict




def clean_allele_name(allele: str) -> str:
    """
    Clean an allele name by removing special characters and converting to lowercase.
    
    Args:
        allele (str): The allele name to clean (e.g. 'BoLA-1:00901')
        
    Returns:
        str: The cleaned allele name (e.g. 'bola100901')
    """
    # Convert to lowercase
    allele = allele.lower()
    
    # Remove special characters, keeping only alphanumeric
    return re.sub(r'[^a-z0-9]', '', allele)


def create_allele_mro_dict(allele_dict):
    allele_mro_dict = {}
    
    for key in allele_dict.keys():
        for allele in allele_dict[key]:
            clean_allele = clean_allele_name(allele)

            if clean_allele not in allele_mro_dict:
                allele_mro_dict[clean_allele] = {
                    'label': allele,
                    'mro_id': ''
                }

    return allele_mro_dict


def get_valid_netmhciipan_alleles_df(allele_mro_dict, version):
    tools_mapping_df = pd.read_csv(TOOLS_MAPPING_FILE, sep='\t')
    
    # NOTE: Filter the dataframe so that Tool is netmhciipan and Tool Version is 4.3
    tools_mapping_df = tools_mapping_df[(tools_mapping_df['Tool'] == 'netmhciipan') & (tools_mapping_df['Tool Version'] == version)]
    
    # print(tools_mapping_df)

    # Create a new column called 'Cleaned Label' that is the Tool Label with the special characters removed
    tools_mapping_df['Cleaned Label'] = tools_mapping_df['Tool Label'].apply(clean_allele_name)

    # Create a new column called 'Remove' that is defaulted to False.
    tools_mapping_df['Remove'] = False

    # Check if the cleaned label is in the allele_mro_dict.
    # If it is not in allele_mro_dict, then set the Remove column to True.
    for index, row in tools_mapping_df.iterrows():
        if row['Cleaned Label'] not in allele_mro_dict:
            tools_mapping_df.at[index, 'Remove'] = True
    
    # Create a new dataframe that is the tools_mapping_df where the Remove column is False.
    tools_mapping_valid_df = tools_mapping_df[tools_mapping_df['Remove'] == False]

    return tools_mapping_valid_df


if __name__ == "__main__":
    '''=========================================================
    NETMHCIIPAN-4.3 ALLELES
    ========================================================='''
    single_allele_dict = get_netmhciipan_4_3_alleles()
    
    # NOTE: This should take the single_allele_dict and
    # create DQ alpha/beta chains and DP alpha/beta chains
    allele_dict = create_proper_netmhciipan_4_3_alleles(single_allele_dict)
    

    # NOTE: Take allele_dict and create a new dictionary with the following structure:
    # For BoLA-1:00901,
    # {
    #   'bola100901': {
    #     label: 'BoLA-1:00901',
    #     'mro_id': 'MRO:0036770',
    #   }
    # }
    allele_mro_dict = create_allele_mro_dict(allele_dict)
    valid_netmhciipan_43_alleles_df = get_valid_netmhciipan_alleles_df(allele_mro_dict, 4.3)

    print(valid_netmhciipan_43_alleles_df)
    print(len(valid_netmhciipan_43_alleles_df))


    '''=========================================================
    NETMHCIIPAN-4.2 ALLELES
    ========================================================='''
    single_allele_dict = get_netmhciipan_4_2_alleles()
    
    # NOTE: This should take the single_allele_dict and
    # create DQ alpha/beta chains and DP alpha/beta chains
    allele_dict = create_proper_netmhciipan_4_2_alleles(single_allele_dict)

    allele_mro_dict = create_allele_mro_dict(allele_dict)
    valid_netmhciipan_42_alleles_df = get_valid_netmhciipan_alleles_df(allele_mro_dict, 4.2)
    
    print(valid_netmhciipan_42_alleles_df)
    print(len(valid_netmhciipan_42_alleles_df))


    '''=========================================================
    NETMHCIIPAN-4.1 ALLELES
    ========================================================='''
    single_allele_dict = get_netmhciipan_4_1_alleles()

    # NOTE: This should take the single_allele_dict and
    # create DQ alpha/beta chains and DP alpha/beta chains
    allele_dict = create_proper_netmhciipan_4_1_alleles(single_allele_dict)

    allele_mro_dict = create_allele_mro_dict(allele_dict)
    valid_netmhciipan_41_alleles_df = get_valid_netmhciipan_alleles_df(allele_mro_dict, 4.1)
    
    print(valid_netmhciipan_41_alleles_df)
    print(len(valid_netmhciipan_41_alleles_df))


    '''=========================================================
    Combine all valid netmhciipan alleles dataframe
    ========================================================='''
    # NOTE: There wouldn't be duplicates because the alleles are different versions of the same allele.
    valid_netmhciipan_alleles_df = pd.concat([
        valid_netmhciipan_41_alleles_df,
        valid_netmhciipan_42_alleles_df, 
        valid_netmhciipan_43_alleles_df
    ]).drop_duplicates()

    # Drop 'Cleaned Label' column and 'Remove' column
    valid_netmhciipan_alleles_df = valid_netmhciipan_alleles_df.drop(columns=['Cleaned Label', 'Remove'])

    print("Combined valid NetMHCIIpan alleles:")
    print(valid_netmhciipan_alleles_df)
    print(f"Total number of unique alleles: {len(valid_netmhciipan_alleles_df)}")


    # NOTE: Remove all netmhciipan alleles and add each valid_netmhciipan_43/42/41_alleles_df to the tools-mapping
    tools_mapping_df = pd.read_csv(TOOLS_MAPPING_FILE, sep='\t')
    
    # Remove all netmhciipan alleles from the tools_mapping_df
    tools_mapping_df = tools_mapping_df[tools_mapping_df['Tool'] != 'netmhciipan']

    # Add valid_netmhciipan_alleles_df to the tools_mapping_df
    tools_mapping_df = pd.concat([tools_mapping_df, valid_netmhciipan_alleles_df])

    # Save the tools_mapping_df to a new file
    tools_mapping_df.to_csv(TOOLS_MAPPING_FILE, sep='\t', index=False)

    print(f"New tools-mapping file saved to: {TOOLS_MAPPING_FILE}")