import json
import csv
import edlib
import os
import sys
import subprocess
import tempfile
import textwrap
import preprocess
import postprocess
import pandas as pd
import warnings
import utils
from io import StringIO
import core.set_pythonpath  # This automatically configures PYTHONPATH
from PhbrArgumentParser import PhbrArgumentParser
from validators import InputManager, MHCClass, validate_alleles
from pathlib import Path
from typing import Dict, List, Union, Optional, Any, Tuple
from itertools import product
from dotenv import load_dotenv
load_dotenv()

# Add the nxg-tools directory to Python path
sys.path.insert(0, os.path.join(os.getenv('APP_ROOT'), 'libs', 'nxg-tools'))
from nxg_common import nxg_common as common


def create_fasta_string(sequence_table_df: pd.DataFrame) -> str:
    """
    Create a fasta string from the sequence table.
    """
    # Read the header, and replace 'sequence' with 'peptide'
    # For example, if the header is ['sequence name', 'sequence', 'mutpos'],
    # then the header will be ['peptide name', 'peptide', 'mutpos']
    # Even column that has 'sequence' as substring, replace it with 'peptide'    
    
    print("Function 'create_fasta_string' called.")
    print(sequence_table_df)
    
    header = sequence_table_df.columns.tolist()
    header = [col.replace('peptide', 'sequence') for col in header]
    sequence_table_df.columns = header

    print("Columns after replacement:", sequence_table_df.columns.tolist())

    # Handle mutation peptide case where we have mut_sequence and ref_sequence
    if 'mut_sequence' in sequence_table_df.columns:
        # For mutation peptide data, use mut_sequence as the sequence
        sequence_table_df['sequence'] = sequence_table_df['mut_sequence']
        print("Using mut_sequence as sequence")
    elif 'sequence' not in sequence_table_df.columns:
        print("No sequence column found, available columns:", sequence_table_df.columns.tolist())
        # Try to find any column that might contain the sequence
        for col in sequence_table_df.columns:
            if 'sequence' in col.lower() or 'peptide' in col.lower():
                sequence_table_df['sequence'] = sequence_table_df[col]
                print(f"Using {col} as sequence")
                break

    fasta_string = ""
    for index, row in sequence_table_df.iterrows():
        if 'sequence name' in row:
            fasta_string += f">{row['sequence name']}\n{row['sequence']}\n"
        else:
            fasta_string += f">sequence {index+1}\n{row['sequence']}\n"
    
    print("Fasta string: ", fasta_string)
    return fasta_string


def add_sequence_number_to_sequence_table(sequence_table_df: pd.DataFrame) -> pd.DataFrame:
    # if 'seq #' is already in the dataframe, return the dataframe as is
    if 'seq #' in sequence_table_df.columns:
        return sequence_table_df
    
    sequence_table_df['seq #'] = range(1, len(sequence_table_df) + 1)

    # Reorder columns to put 'seq #' first
    cols = ['seq #'] + [col for col in sequence_table_df.columns if col != 'seq #']
    sequence_table_df = sequence_table_df[cols]

    return sequence_table_df


def save_json_to_temporary_file(json_data: dict, output_dir: str) -> str:
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', delete=False, dir=output_dir) as tmpfile:
        json.dump(json_data, tmpfile, indent=4)
        return tmpfile.name

def save_df_to_temporary_file(df: pd.DataFrame, output_dir: str) -> str:
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.tsv', delete=False, dir=output_dir) as tmpfile:
        df.to_csv(tmpfile.name, index=False, sep="\t")
        return tmpfile.name

def create_peptide_and_sequence_files(df: pd.DataFrame, output_dir: str) -> Tuple[str, str]:
    peptide_table_path = None
    sequence_table_path = None

    # Create a temporary file for peptide_table
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.tsv', delete=False, dir=output_dir) as tmpfile:
        peptide_df = df['peptide_table'].rename(columns={'sequence_number': 'seq #'})
        peptide_df.to_csv(tmpfile.name, index=False, sep="\t")
        peptide_table_path = tmpfile.name  # Store the file path if needed

    # Create a temporary file for input_sequence_table
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.tsv', delete=False, dir=output_dir) as tmpfile:
        # Renaming column 'sequence_number' to 'seq #' in input_sequence_table
        input_seq_df = df['input_sequence_table'].rename(columns={'sequence_number': 'seq #'})
        input_seq_df.to_csv(tmpfile.name, index=False, sep="\t")
        sequence_table_path = tmpfile.name  # Store the file path if needed

    return peptide_table_path, sequence_table_path


def uri_to_df(uri: str) -> pd.DataFrame:
    outfile_name = common.save_file_from_URI(uri)
    df = common.api_results_json_to_df(outfile_name, table_types=['peptide_table', 'input_sequence_table'])
    return df


def create_sequence_id_dict(sequence_table_path: Union[str, pd.DataFrame]) -> dict:
    # Accepts either dataframe or path to a tsv file
    if isinstance(sequence_table_path, pd.DataFrame):
        seq_df = sequence_table_path
    else:
        seq_df = pd.read_csv(sequence_table_path, sep="\t")

    # Check what columns are available and use the correct one
    if 'sequence_number' in seq_df.columns:
        sequence_to_seqnum_dict = {seq: seq_num for seq_num, seq in zip(seq_df['sequence_number'], seq_df['sequence'])}
    elif 'seq #' in seq_df.columns:
        sequence_to_seqnum_dict = {seq: seq_num for seq_num, seq in zip(seq_df['seq #'], seq_df['sequence'])}
    else:
        # Fallback: use index + 1 as sequence number
        sequence_to_seqnum_dict = {seq: idx + 1 for idx, seq in enumerate(seq_df['sequence'])}

    return sequence_to_seqnum_dict

def find_file_path(start_dir=None, filename=None):
        """Find the full path of a given file by searching both upwards and downwards from the start directory."""
        if start_dir is None:
            start_dir = os.getcwd()  # Default to the current working directory
        
        if filename is None:
            raise ValueError("Filename must be provided")
        
        # Normalize to absolute path
        start_dir = os.path.abspath(start_dir)
        
        # Check upwards
        current_dir = start_dir
        while current_dir != os.path.dirname(current_dir):
            target_path = os.path.join(current_dir, filename)
            if os.path.isfile(target_path):
                return target_path  # Return the full path of the file if found
            
            current_dir = os.path.dirname(current_dir)  # Move up one level
        
        # Check downwards
        for root, _, files in os.walk(start_dir):
            if filename in files:
                return os.path.join(root, filename)
        
        return None  # Return None if the file is not found

def find_ranking_column_name(peptide_table_path: str, method: str) -> str:
    # NOTE: if 'netmhcpan percentile' is not found, then use 'percentile' 
    #       as the default rank_colname
    header = pd.read_csv(peptide_table_path, sep='\t', nrows=0).columns.tolist()

    matching_cols = [
        col for col in header if method in col.lower() \
        and 'percentile' in col.lower()
    ]
    if matching_cols:
        rank_colname = matching_cols[0]
    else:
        # Fallback: look for any percentile column
        percentile_cols = [col for col in header if 'percentile' in col.lower()]
        if percentile_cols:
            rank_colname = percentile_cols[0]
        else:
            print(f"Warning: No percentile column found for method '{method}'. Available columns: {header}")
            print(f"Skipping processing as no suitable data is available.")
            return None
    return rank_colname


def remove_empty_rank_rows(mhc_pred_file: str) -> None:
    df = pd.read_csv(mhc_pred_file, sep='\t')

    df = df[df['rank'] != '-']

    # reset index and save to file
    df = df.reset_index(drop=True)
    df.to_csv(mhc_pred_file, sep='\t', index=False)

def format_message(message: str) -> str:
    # Remove common leading whitespace from the string / indentation
    message = textwrap.dedent(message).strip()

    # Wrap the message to a maximum of 80 characters per line
    # Also, remove \n as part of the literal string
    return textwrap.fill(message, width=80).replace('\n', '')

def add_sequence_number_to_phbr_output(phbr_output_df: pd.DataFrame, sequence_to_seqnum_dict: Dict[str, int]) -> pd.DataFrame:
    phbr_output_df['seq #'] = phbr_output_df['peptide'].map(sequence_to_seqnum_dict)
    # Rename peptide column to mutant peptide
    phbr_output_df = phbr_output_df.rename(columns={'peptide': 'mutant peptide'})
    # Reorder columns to put 'seq #' first
    cols = ['seq #'] + [col for col in phbr_output_df.columns if col != 'seq #']
    phbr_output_df = phbr_output_df[cols]

    return phbr_output_df

def add_original_reference_sequence_to_phbr_output(phbr_output_df: pd.DataFrame, mut_peptide_ref_peptide_dict: Dict[Tuple[int, str], str]) -> pd.DataFrame:
    # It needs to check if the ('seq #', 'mutant peptide') is in the mut_peptide_ref_peptide_dict.
    # The value of the mut_peptide_ref_peptide_dict should be the 'ref_peptide'
    phbr_output_df['ref_peptide'] = phbr_output_df.apply(lambda row: mut_peptide_ref_peptide_dict.get((row['seq #'], row['mutant peptide']), row['mutant peptide']), axis=1)
    
    # Reorder columns so that 'ref_peptide' comes after 'mutant peptide'
    cols = ['seq #', 'mutant peptide', 'ref_peptide'] + [col for col in phbr_output_df.columns if col not in ['seq #', 'mutant peptide', 'ref_peptide']]
    phbr_output_df = phbr_output_df[cols]

    return phbr_output_df
def filter_homozygous_loci_columns_from_phbr_output(phbr_output_df: pd.DataFrame, mhc_class: MHCClass) -> pd.DataFrame:
    """
    Removes unneeded columns from the output file.
    """
    if mhc_class == MHCClass.MHCI:
        phbr_output_df = phbr_output_df.drop(columns=['#A', '#B', '#C'])
    elif mhc_class == MHCClass.MHCII:
        phbr_output_df = phbr_output_df.drop(columns=['#DP', '#DQ', '#DR'])
    
    return phbr_output_df

def map_mut_peptide_ref_peptide(sequence_table_df: pd.DataFrame) -> Dict[Tuple[int, str], str]:
    headers = sequence_table_df.columns.tolist()
    mut_peptide_colname = None
    ref_peptide_colname = None

    for header in headers:
        if "mut" in header and "pep" in header:
            mut_peptide_colname = header
        if "ref" in header and "pep" in header:
            ref_peptide_colname = header
        if mut_peptide_colname and ref_peptide_colname:
            break

    return {(i, mut): ref for i, mut, ref in zip(range(1, len(sequence_table_df) + 1), sequence_table_df[mut_peptide_colname], sequence_table_df[ref_peptide_colname])}

def add_original_reference_sequence_to_phbr_output(phbr_output_df: pd.DataFrame, mut_peptide_ref_peptide_dict: Dict[Tuple[int, str], str]) -> pd.DataFrame:
    # It needs to check if the ('seq #', 'mutant peptide') is in the mut_peptide_ref_peptide_dict.
    # The value of the mut_peptide_ref_peptide_dict should be the 'ref_peptide'
    phbr_output_df['ref_peptide'] = phbr_output_df.apply(lambda row: mut_peptide_ref_peptide_dict.get((row['seq #'], row['mutant peptide']), row['mutant peptide']), axis=1)
    
    # Reorder columns so that 'ref_peptide' comes after 'mutant peptide'
    cols = ['seq #', 'mutant peptide', 'ref_peptide'] + [col for col in phbr_output_df.columns if col not in ['seq #', 'mutant peptide', 'ref_peptide']]
    phbr_output_df = phbr_output_df[cols]

    return phbr_output_df

def get_mut_peptide_column_name(sequence_table_df: pd.DataFrame) -> str:
    headers = sequence_table_df.columns.tolist()
    for header in headers:
        if "mut" in header and "pep" in header:
            return header
    return None

def get_ref_peptide_column_name(sequence_table_df: pd.DataFrame) -> str:
    headers = sequence_table_df.columns.tolist()
    for header in headers:
        if "ref" in header and "pep" in header:
            return header
    return None

def find_mutation_positions(mut_peptide: str, ref_peptide: str) -> List[int]:
    result = edlib.align(mut_peptide, ref_peptide, mode='NW', task="path")
    pretty_result = edlib.getNiceAlignment(result, mut_peptide, ref_peptide)
    alignment = pretty_result['matched_aligned']
    dot_indices = [i for i, char in enumerate(alignment) if char == '.']

    # Check if mutation spans the entire sequence
    if len(dot_indices) == len(ref_peptide):
        if dot_indices == list(range(0, len(ref_peptide))):
            return 'all'
        
    return ','.join(str(i+1) for i in dot_indices)


def turn_mut_vs_ref_table_into_sequence_table(sequence_table_df: pd.DataFrame) -> pd.DataFrame:
    '''
    Turn the mut_vs_ref_table into a sequence table.
    The input table has 'mut_peptide' and 'ref_peptide' columns.
    The output table has 'sequence' and 'mutation_position' columns.
    From:
          mut_peptide ref_peptide
        0   FLYNPLTRV   FLYNLLTRV
        1   MLGERLFPL   MLGEQLFPL
        2   FLDEFMEAV   FLDEFMEGV
        3   VVLSWAPPV   VVMSWAPRV
    To:
            sequence mutation_position
        0  FLYNPLTRV                 5
        1  MLGERLFPL                 5
        2  FLDEFMEAV                 8
        3  VVLSWAPPV               3,8
    '''
    mut_col_name = get_mut_peptide_column_name(sequence_table_df)
    ref_col_name = get_ref_peptide_column_name(sequence_table_df)

    # Add the mutation position column by using "edlib" library
    sequence_table_df['mutation_position'] = sequence_table_df.apply(lambda row: find_mutation_positions(row[mut_col_name], row[ref_col_name]), axis=1)
        
    # Remove the reference sequence column
    sequence_table_df = sequence_table_df.drop(columns=[ref_col_name])

    # Rename the mutation sequence column to 'sequence'
    sequence_table_df = sequence_table_df.rename(columns={mut_col_name: 'sequence'})

    return sequence_table_df

def format_phbr_values(phbr_output_df: pd.DataFrame) -> pd.DataFrame:
    # PHBR column is always the last column
    phbr_col_name = phbr_output_df.columns[-1]
    phbr_output_df[phbr_col_name] = phbr_output_df[phbr_col_name].apply(lambda x: f"{float(x):.4f}")
    return phbr_output_df