import sys
import shutil
import string
import random
import re
import time
import textwrap
import pandas as pd
import numpy as np
import warnings
from pathlib import Path
from tqdm import tqdm
from datetime import date, datetime


# ===============================================================================================
# Global Vars.
# ===============================================================================================
LOG_DIR = Path(__file__).parent / "report"
DATA_DIR = Path(__file__).parent / "data"
BACKUP_DIR = DATA_DIR / "backup"
TOOLS_MAPPING_FILE = 'tools-mapping.tsv'
MOLECULE_FILE = 'mro_molecules.tsv'
MHC_ALLELES_FILE = 'mhc_alleles.tsv'
ALLELE_DATASOURCE_FILE = 'allele_datasource.tsv'


def double_check_single_mhcii_alleles(mhcii_df):
    mhcii_df_headers = list(mhcii_df.columns)

    paired_alleles = []
    # NOTE: Retrieve all alleles that is paired with alpha and beta chains.
    for row in tqdm(mhcii_df.itertuples(name=None, index=False)):
        curr_allele = row[mhcii_df_headers.index('IEDB Label')]
        allele_upper = curr_allele.upper()

        if 'DQA' in allele_upper and 'DQB' in allele_upper:
            paired_alleles.append(curr_allele)
        
        if 'DPA' in allele_upper and 'DPB' in allele_upper:
            paired_alleles.append(curr_allele)
    
    # Then manually break them up to make sure we indeed have 
    # all alpha chains and beta chains separately.
    alpha_chains = []
    beta_chains = []
    for allele in paired_alleles:
        prefix = 'HLA-'
        if allele.startswith(prefix):
            cleaned_allele = allele.replace(prefix, '')
            parts = re.split(r'[/\-]', cleaned_allele)
            if len(parts) == 2:
                alpha_chains.append(prefix + parts[0])
                beta_chains.append(prefix + parts[1])
            else:
                print(f"Skipping: {allele} (unexpected format)")
    
    # alpha_chains = list(set(alpha_chains))
    # beta_chains = list(set(beta_chains))

    # Let's now seasrch for these individual alleles from the MHC_ALLELES_FILE
    # to get MRO IDs.
    iedblabels = mhcii_df['IEDB Label'].to_list()    
    
    mapped_alleles_alpha = []
    mapped_alleles_beta = []

    unmapped_alleles_alpha = []
    unmapped_alleles_beta = []
    for i in range(len(alpha_chains)):
        alpha = alpha_chains[i]
        beta = beta_chains[i]

        if alpha in iedblabels:
            mapped_alleles_alpha.append(alpha)
        else:
            unmapped_alleles_alpha.append(alpha)

        if beta in iedblabels:
            mapped_alleles_beta.append(beta)
        else:
            unmapped_alleles_beta.append(beta)

    mapped_alleles_alpha = list(set(mapped_alleles_alpha))
    mapped_alleles_beta = list(set(mapped_alleles_beta))

    unmapped_alleles_alpha = list(set(unmapped_alleles_alpha))
    unmapped_alleles_beta = list(set(unmapped_alleles_beta))

    # print('-------')
    # print(f'{len(mapped_alleles_alpha)} mapped out of {len(alpha_chains)}')
    # print(f'{len(mapped_alleles_beta)} mapped out of {len(beta_chains)}')
    # print(len(unmapped_alleles_alpha))
    # print(len(unmapped_alleles_beta))

    # print(unmapped_alleles_alpha)
    # print(unmapped_alleles_beta)
    
    return unmapped_alleles_alpha, unmapped_alleles_beta

def add_phbr_alleles():
    '''===============================================================================================
        \n\tDescription :
          Add MHCI alleles to PHBR as all MHCI alleles should be compatible with PHBR.
          This can be done by taking the allele_datasource.tsv and taking all MHCI rows,
          and duplicating it and setting the 'Tool Group' as 'phbr'.
        
        Parameters :\n
          \t- None

        Return Value :\n
          \t- TSV file (allele_datasource.tsv)\n
    ==============================================================================================='''
    aa_df = pd.read_csv(DATA_DIR / ALLELE_DATASOURCE_FILE, skipinitialspace=True, sep='\t', encoding='utf-8')
    mol_df = pd.read_csv(DATA_DIR / MHC_ALLELES_FILE, skipinitialspace=True, sep='\t', encoding='utf-8')
    
    '''================== Add MHCI alleles to PHBR =================='''
    # Filter to only contain MHCI    
    filtered_mhci_df = aa_df[(aa_df['Tool Group'] == 'mhci')]

    # Since all MHCI alleles should be compatible with PHBR,
    # we can just replace 'Tool Group' to 'phbr' and add this df.
    filtered_mhci_df.loc[:, 'Tool Group'] = 'phbr'


    '''============ Add single alpha/beta chains to PHBR ============='''
    # Filter to only contain MHCII
    filtered_mhcii_df = mol_df[(mol_df['Parent'] == 'MHC class II protein complex')]
    mhcii_df_headers = list(filtered_mhcii_df.columns)

    rows_to_add_ii = []
    
    # NOTE: We want to add all single chains. (no pairs, skip if any)
    # DRB doesn't come as a pair, so we should only focus on DQ- DP- alleles
    for row in tqdm(filtered_mhcii_df.itertuples(name=None, index=False)):
        curr_allele = row[mhcii_df_headers.index('IEDB Label')]
        allele_upper = curr_allele.upper()

        # Edge case allele
        if curr_allele == 'HLA-DQA1*01:02/DRB1*15:01':
            continue

        has_dqa = 'DQA' in allele_upper
        has_dqb = 'DQB' in allele_upper
        has_dpa = 'DPA' in allele_upper
        has_dpb = 'DPB' in allele_upper
        has_drb = 'DRB' in allele_upper
        has_dra = 'DRA' in allele_upper

        if has_dqa and not has_dqb:
            rows_to_add_ii.append(row)
        elif has_dqb and not has_dqa:
            rows_to_add_ii.append(row)
        elif has_dpa and not has_dpb:
            rows_to_add_ii.append(row)
        elif has_dpb and not has_dpa:
            rows_to_add_ii.append(row)

        # NOTE: It seems there are DRA/DRB paired alleles.
        # Make sure only DRB chain is added
        elif has_drb and not has_dra:
            rows_to_add_ii.append(row)

    # This is a Dataframe containing all single chain DQA/DQB/DPA/DPB/DRB alleles
    # (disregarding species)
    single_chains_df = pd.DataFrame(rows_to_add_ii, columns=mhcii_df_headers)

    rows_to_add_ii.clear()

    for row in tqdm(single_chains_df.itertuples(name=None, index=False)):
        curr_iedb_allele = row[mhcii_df_headers.index('IEDB Label')]
        curr_synonyms = row[mhcii_df_headers.index('Synonyms')]

        if row[mhcii_df_headers.index('MRO ID')].startswith('NOMRO'):
            is_unobserved = 1
        else:
            is_unobserved = 0

        new_row_to_add = {
            "Alleles": curr_iedb_allele,
            "Complement": curr_synonyms,
            "Species": row[mhcii_df_headers.index('In Taxon')],
            "Tool Group": 'phbr',
            "Is Label": 1,
            "Unobserved": is_unobserved
        }

        rows_to_add_ii.append(new_row_to_add)

        # For all alleles that have synonyms, we need to create separate rows
        # for the allele datasource file
        if not pd.isna(curr_synonyms):
            synonyms = [syn.strip() for syn in curr_synonyms.split("|")]

            for synonym in synonyms :
                new_row_to_add = {
                    "Alleles": synonym,
                    "Complement": curr_iedb_allele,
                    "Species": row[mhcii_df_headers.index('In Taxon')],
                    "Tool Group": 'phbr',
                    "Is Label": 0,
                    "Unobserved": is_unobserved
                }

                rows_to_add_ii.append(new_row_to_add)

    # This will return any unmapped alpha and beta chains (Meaning NO MRO IDs)
    alpha_chains, beta_chains = double_check_single_mhcii_alleles(filtered_mhcii_df)
    
    if alpha_chains:
        for alpha in alpha_chains:
            new_row_to_add = {
                "Alleles": alpha,
                "Complement": '',
                "Species": 'human',
                "Tool Group": 'phbr',
                "Is Label": 0,
                "Unobserved": 1
            }

            rows_to_add_ii.append(new_row_to_add)
    
    if beta_chains:
        for beta in beta_chains:
            new_row_to_add = {
                    "Alleles": beta,
                    "Complement": '',
                    "Species": 'human',
                    "Tool Group": 'phbr',
                    "Is Label": 0,
                    "Unobserved": 1
                }

            rows_to_add_ii.append(new_row_to_add)

    
    # Add PVC dataframe to datasource
    combined_df = pd.concat([aa_df, filtered_mhci_df], ignore_index=True)

    new_df = pd.DataFrame(rows_to_add_ii)
    combined_df = pd.concat([combined_df, new_df], ignore_index=True)
    
    # Write to a file
    combined_df.to_csv(DATA_DIR / ALLELE_DATASOURCE_FILE, sep='\t', index=False)


def add_icerfire_alleles():
    mol_df = pd.read_csv(DATA_DIR / MHC_ALLELES_FILE, skipinitialspace=True, sep='\t', encoding='utf-8')
    tm_df = pd.read_csv(DATA_DIR / TOOLS_MAPPING_FILE, skipinitialspace=True, sep='\t', encoding='utf-8')
    aa_df = pd.read_csv(DATA_DIR / ALLELE_DATASOURCE_FILE, skipinitialspace=True, sep='\t', encoding='utf-8')
    
    icerfire_df = tm_df[tm_df['Tool'] == 'icerfire']
    tm_df_headers = list(tm_df.columns)

    for row in tqdm(icerfire_df.itertuples(name=None, index=False)) :
        curr_mroid = row[tm_df_headers.index('MRO ID')]

        # Currently it is safe to assume that 'mol_df' will have unique MRO ID.
        mroid_metadata_df = mol_df[mol_df['MRO ID'] == curr_mroid]
        curr_synonyms = mroid_metadata_df['Synonyms'].iloc[0]
        
        new_row = {
                "Alleles": mroid_metadata_df['IEDB Label'].iloc[0],
                "Complement": curr_synonyms,
                "Species": mroid_metadata_df['In Taxon'].iloc[0],
                "Tool Group": row[tm_df_headers.index('Tool Group')],
                "Is Label": 1,
                "Unobserved": 0
            }
        aa_df = pd.concat([aa_df, pd.DataFrame([new_row])], ignore_index=True)

        new_row.clear()
        # Add additional rows for each synonyms
        if curr_synonyms :
            target_synonyms = curr_synonyms.split("|")
            for target_synonym in target_synonyms :
                new_row = {
                    "Alleles": target_synonym,
                    "Complement": mroid_metadata_df['IEDB Label'].iloc[0],
                    "Species": mroid_metadata_df['In Taxon'].iloc[0],
                    "Tool Group": row[tm_df_headers.index('Tool Group')],
                    "Is Label": 0,
                    "Unobserved": 0
                }
                aa_df = pd.concat([aa_df, pd.DataFrame([new_row])], ignore_index=True)
        
    aa_df.to_csv(DATA_DIR / ALLELE_DATASOURCE_FILE, sep='\t', index=False)



def create_autocomplete_datasource():
    '''===============================================================================================
        \n\tDescription :
          This function will prepare datasource file that will be used for Allele Autocomplete file.
          It will have all IEDB Label and Synonyms in 'Alleles' column of the datasource file. Also,
          all the alleles that has 'Predictor Availability' as 0 will not be included in the
          datasource.
    
        Parameters :\n
          \t- None

        Return Value :\n
          \t- TSV file (allele_datasource.tsv)\n
    ==============================================================================================='''
    tm_df = pd.read_csv(DATA_DIR / TOOLS_MAPPING_FILE, skipinitialspace=True, sep='\t', encoding='utf-8')
    mol_df = pd.read_csv(DATA_DIR / MHC_ALLELES_FILE, skipinitialspace=True, sep='\t', encoding='utf-8')
    mol_df_headers = list(mol_df.columns)
    mol_df["Synonyms"] = mol_df["Synonyms"].fillna("")
    aa_df = pd.DataFrame(columns=["Alleles", "Complement", "Species", "Tool Group", "Is Label", "Unobserved"])

    mol_df.columns = mol_df.columns.str.replace(r"\s+", "_")

    for ref_row in tqdm(mol_df.itertuples(name=None, index=False)) :
        # Only include those with Predictor Availability
        # NOTE: This also implies that they have entry in 'tools-mapping.tsv'
        if ref_row[mol_df_headers.index('Predictor Availability')] == 0 :
            continue
        
        ref_row_synonyms = ref_row[mol_df_headers.index('Synonyms')]

        if ref_row[mol_df_headers.index('MRO ID')].startswith('NOMRO'):
            is_unobserved = 1
        else:
            is_unobserved = 0

        tool_groups = []
        if 'mhc class i' in ref_row[mol_df_headers.index('Parent')].lower():
            tool_groups = ['mhci']
        if 'mhc class ii' in ref_row[mol_df_headers.index('Parent')].lower():
            tool_groups = ['mhcii']
        if 'non-classical' in ref_row[mol_df_headers.index('Parent')].lower():
            # Search tool_group from tools-mapping file
            ref_mroid = ref_row[mol_df_headers.index('MRO ID')]
            filtered_df = tm_df[tm_df['MRO ID'] == ref_mroid]
            tool_groups = filtered_df['Tool Group'].unique().tolist()

        for tool_group in tool_groups:
            row = {
                "Alleles": ref_row[mol_df_headers.index('IEDB Label')],
                "Complement": ref_row_synonyms,
                "Species": ref_row[mol_df_headers.index('In Taxon')],
                # "Tool Group": ref_row[mol_df_headers.index('Tool Group')],
                "Tool Group": tool_group,
                "Is Label": 1,
                "Unobserved": is_unobserved
            }
            aa_df = pd.concat([aa_df, pd.DataFrame([row])], ignore_index=True)
            
        row.clear()
        # Add additional rows for each synonyms
        if ref_row_synonyms :
            synonyms = ref_row_synonyms.split("|")
            for synonym in synonyms :
                row = {
                    "Alleles": synonym,
                    "Complement": ref_row[mol_df_headers.index('IEDB Label')],
                    "Species": ref_row[mol_df_headers.index('In Taxon')],
                    "Tool Group": tool_group,
                    "Is Label": 0,
                    "Unobserved": is_unobserved
                }
                aa_df = pd.concat([aa_df, pd.DataFrame([row])], ignore_index=True)

    aa_df.to_csv(DATA_DIR / ALLELE_DATASOURCE_FILE, sep='\t', index=False)

    

def write_to_log(mod_entries, new_entries):
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')

    with open(LOG_DIR / f'data-update-{timestamp}.log', 'w') as f:
        if mod_entries:
            f.write('--- MODIFIED ROWS ---\n')
            for entry in mod_entries:
                f.write(f"Changed columns: {entry['columns']}\n")
                f.write(f"Old row: {entry['prev']}\n")
                f.write(f"New row: {entry['next']}\n")
                f.write('-' * 40 + '\n')

        if new_entries:
            f.write('\n--- ADDED ROWS ---\n')
            for entry in new_entries:
                f.write(f"New row: {entry}\n")
                f.write('-' * 40 + '\n')

def diff_between_rows(df_a: pd.DataFrame, df_b: pd.DataFrame) -> dict:
    row_a = df_a.iloc[0]
    row_b = df_b.iloc[0]

    shared_columns = row_a.index.intersection(row_b.index)
    differences = {}

    for col in shared_columns:
        val_a = row_a[col]
        val_b = row_b[col]

        # Special handling for 'Synonyms'
        if col == 'Synonyms':
            if pd.isna(val_a):
                continue  # Skip if val_a is NaN

            if pd.isna(val_b):
                differences[col] = (val_a, val_b)
                continue

            a_set = {part.strip() for part in str(val_a).split('|')}
            b_set = {part.strip() for part in str(val_b).split('|')}
            missing = a_set - b_set

            if missing:
                differences[col] = ('|'.join(missing), val_b)

            continue  # Skip further checks for 'Synonyms'

        # Skip if both are NaN, except for IEDB Label
        if col != 'IEDB Label' and pd.isna(val_a) and pd.isna(val_b):
            continue

        if val_a != val_b:
            differences[col] = (val_a, val_b)

    return differences


def update_mhc_alleles():
    '''
    NOTE:
    - tools-mapping.tsv doesn't really change.
    - mro_molecules.tsv will get updated regularly and pulled.
    - mhc_alleles.tsv will only contain unique rows.

    Thus, we have to compare mro_molecules.tsv against the existing mhc_alleles.tsv
    '''
    # mhc_alleles.tsv
    mhc_df = pd.read_csv(DATA_DIR / MHC_ALLELES_FILE, skipinitialspace=True, sep='\t')
    mhc_df_copy = mhc_df.copy(deep=True)
    # mro_molecules.tsv
    mol_df = pd.read_csv(DATA_DIR / MOLECULE_FILE, skipinitialspace=True, sep='\t')
    # tools-mapping.tsv
    tm_df = pd.read_csv(DATA_DIR / TOOLS_MAPPING_FILE, skipinitialspace=True, sep='\t')
    
    # Iterate over mro_molecules and add new data.
    mol_df_header = list(mol_df.columns)

    # List of dictionaries
    additional_entries = []
    modified_entries = []

    # Check each row of mro_molecules.tsv and see if anything new needs to be
    # added to the mhc_alleles file.
    for mol_row in tqdm(mol_df.itertuples(name=None, index=False)):
        mroid = mol_row[mol_df_header.index('MRO ID')]

        # MRO ID found in mhc_alleles.tsv -- check if synonyms need to be added
        if mroid in mhc_df['MRO ID'].values:
            matching_rows = mhc_df[mhc_df['MRO ID'] == mroid]
            matching_rows_cpy = matching_rows.copy(deep=True)

            row_df = pd.DataFrame([mol_row], columns=mol_df_header)
            
            diffs = diff_between_rows(row_df, matching_rows)
            
            diff_cols = []
            for col, (val_a, val_b) in diffs.items():
                row_dict = None

                # Need to update synonyms (appending)
                if col == 'Synonyms':
                    updated_synonyms = val_b.split('|')
                    updated_synonyms.append(val_a)
                    updated_synonyms = '|'.join(updated_synonyms)
                    matching_rows_cpy[col] = updated_synonyms


                # For any other columns, there needs to be replacement
                else:
                    # val_a is data from mro_molecules.tsv, thus the ground truth
                    matching_rows_cpy[col] = val_a
                
                # keep track of all the columns that has been changed
                diff_cols.append(col)

            # move on if no difference was found
            if not diff_cols: continue

            # keep track of modified entries
            for idx in matching_rows.index:
                orig_row = matching_rows.loc[idx]
                updated_row = matching_rows_cpy.loc[idx]
                modified_entries.append(
                    {
                        'columns': ','.join(diff_cols),
                        'prev': orig_row.to_dict(),
                        'next': updated_row.to_dict()
                    }
                )


        # MRO ID not found in mhc_alleles.tsv -- need to add new row
        else:
            new_row = {
                'MRO ID': mroid,
                'IEDB Label': mol_row[mol_df_header.index('IEDB Label')],
                'Synonyms': mol_row[mol_df_header.index('Synonyms')],
                'In Taxon': mol_row[mol_df_header.index('In Taxon')],
                'Parent': mol_row[mol_df_header.index('Parent')],
                'Predictor Availability': 1
            }
            additional_entries.append(new_row)

    # Need to replace old rows with modified entries in the mhc_df_copy
    for mod_entry in modified_entries:
        mask = (mhc_df_copy['MRO ID'] == mod_entry['next']['MRO ID'])
        
        # Get index of the row to replace
        row_index = mhc_df_copy[mask].index

        # Update rows
        mhc_df_copy.loc[row_index[0]] = mod_entry['next']

    # Add new rows to the end of mhc_df_copy
    final_mhc_df = pd.concat([mhc_df_copy, pd.DataFrame(additional_entries)], ignore_index=True)
    
    # Ensure all the synonyms don't have \n attached
    # If synonym is 'nan', then skip
    final_mhc_df['Synonyms'] = final_mhc_df['Synonyms'].apply(
        lambda x: '|'.join(part.strip() for part in x.split('|')) if pd.notna(x) else x
    )
    
    # for item in modified_entries:
    #     print(item['column'])
    #     print(item['prev'])
    #     print(item['next'])
    #     print('........')
    
    # write a human-readable text log + update the files
    if modified_entries or additional_entries:
        write_to_log(modified_entries, additional_entries)

        # Create backup of the original MHC_ALLELES_FILE before overriding
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        backup_file = f'mhc_alleles_backup_{timestamp}.tsv'
        shutil.copyfile(DATA_DIR / MHC_ALLELES_FILE, BACKUP_DIR / backup_file)

    # write to file
    final_mhc_df.to_csv(DATA_DIR / MHC_ALLELES_FILE, sep='\t', index=False)


if __name__=="__main__":
    s = time.time()

    '''
    Compare the newly pulled 'mro_molecules.tsv' file against the existing 'mhc_alleles.tsv',
    and update anything that needs to be updated.

    This will create a logfile under the <report> directory if any changes are detected.
    When no update is found, then it will not create any log file.

    It will also create backup files if there are changes.
    If there are no changes, nothing will happen.
    '''
    update_mhc_alleles()

    ''' Create initial autocomplete datasource file '''
    create_autocomplete_datasource()

    ''' Add ICERFIRE (PVC) to Autocomplete Datasource '''
    add_icerfire_alleles()

    ''' Add PHBR to Autocomplete Datasource '''
    add_phbr_alleles()

    
    e = time.time()
    print(f'Time taken: {e-s}')