#!/usr/bin/env python
import os
import json
import shutil
import tempfile
import sys
import math

from sequences import Proteins
from allele_validator import Allele_Validator
# from nxg-tools package
from nxg_common.nxg_common import save_file_from_URI

MHCI_CONSENSUS_METHOD_SET = ('smm', 'ann', 'comblib_sidney2008')

def read_json_file(file_path):
    with open(file_path, 'r') as r_file:
        return json.load(r_file)

def save_json(result, output_path):
    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    with open(output_path, 'w') as w_file:
        json.dump(result, w_file, indent=2)

def get_peptide_list_with_start_and_stop_and_sequencenumber(input_sequence_text, alleles, peptide_length_range='asis'):
    table_data = []
    sequence_list = get_sequence_list(input_sequence_text)
    if not peptide_length_range or peptide_length_range == 'asis':
        # means lengths = ['asis',]
        lengths = ['asis',]
    else:
        lengths = range(peptide_length_range[0], peptide_length_range[1]+1)
    for allele in alleles:
        peptide_index = 1
        for seq_num, sequence in enumerate(sequence_list) :
            seq_num += 1
            seq = sequence['sequence']
            for length in lengths:
                if length == 'asis':
                    peptide = seq
                    start = 1
                    end = len(peptide)
                    table_data.append((seq_num,peptide,start,end,len(peptide),allele,peptide_index))
                    peptide_index += 1
                else:
                    length = int(length)
                    for i in range(len(seq)-length+1):
                        peptide = seq[i:i+length]
                        start = i+1
                        end = i+length
                        table_data.append((seq_num,peptide,start,end,length,allele,peptide_index))
                        peptide_index += 1
    table_columns = [
        "sequence_number",
        "peptide",
        "start",
        "end",
        "length",
        "allele",
        "peptide_index",
    ]
    results = [dict(type="peptide_table", table_columns=table_columns, table_data=table_data),]
    return dict(warnings=[], results=results)

def get_sequence_list(input_sequence_text):
    proteins = Proteins(input_sequence_text)
    sequence_list = []
    for name, seq in zip(proteins.names, proteins.sequences):
        sequence_list.append(dict(name=name, sequence=seq))
    return sequence_list

def split_peptides_with_diff_length(input_sequence_text, peptide_length_range):
    peptide_list = []
    sequence_list = get_sequence_list(input_sequence_text)
    if not peptide_length_range:
        # means lengths = ['asis',]
        peptide_list = [seq['sequence'] for seq in sequence_list]
        return get_peptides_with_diff_length(peptide_list)
    else:
        lengths = range(peptide_length_range[0], peptide_length_range[1]+1)
        peptides_with_diff_length = []
        for length in lengths:
            peptide_list = []
            for sequence in sequence_list:
                seq = sequence['sequence']
                length = int(length)
                for i in range(len(seq)-length+1):
                    peptide = seq[i:i+length]
                    peptide_list.append(peptide)
            peptides_with_diff_length.append(peptide_list)
        return peptides_with_diff_length



def split_peptides_with_diff_length_batch(input_sequence_text,
                                          peptide_length_range,
                                          batch_size=8192,
                                          max_batches=32):
    """_summary_

    Args:
        input_sequence_text (_type_): _description_
        peptide_length_range (_type_): _description_
        batch_size (int, optional): _description_. Defaults to 2048.
        max_batches (int, optional): _description_. Defaults to 64.

    Returns:
        _type_: _description_
    """

    sequence_list = get_sequence_list(input_sequence_text)
    if not peptide_length_range:
        # means lengths = ['asis',]
        peptide_list = [seq['sequence'] for seq in sequence_list]
        return get_peptides_with_diff_length(peptide_list)
    else:
        lengths = range(peptide_length_range[0], peptide_length_range[1]+1)
        num_lengths = len(lengths)
        peptides_with_diff_length = dict()
        
        # keep track of the total number of peptides
        total_peptides = 0
        num_peptides_per_length = dict()
        
        for length in lengths:
            peptide_list = []
            for sequence in sequence_list:
                seq = sequence['sequence']
                length = int(length)
                for i in range(len(seq)-length+1):
                    peptide = seq[i:i+length]
                    peptide_list.append(peptide)
            
            num_peptides_per_length[length] = len(peptide_list)        
            total_peptides += num_peptides_per_length[length]
            
            peptides_with_diff_length[length] = peptide_list


        # TODO: pull the batching logic below into a separate function

        # now we go through and divide the peptide lists into batches, keeping
        # peptides of the same length together            
        
        print(f'total peptides: {total_peptides}')
        print(f'num lengths {num_lengths}')
        
        # this is the maximum number of peptides that will fit
        # into the max_batches with the given batch size
        max_batch_peptide_limit = max_batches * batch_size
        print(f'max batch peptide limit: {max_batch_peptide_limit}')
        
        # since we need to also batch by length, we have to determine
        # the maximum number of batches per length
        max_batches_per_length = math.floor(max_batches / num_lengths)
        
        batched_peptide_list = []
        
        for l in lengths:
        
            print(f'batching length: {l}')
            
            num_peptides = num_peptides_per_length[l]
            peptide_list = peptides_with_diff_length[l]
            num_batches = math.ceil(num_peptides / batch_size)
        
            # if the total number of peptides is greater than the
            # max_batch_size x max_batches_per_length, we break up into
            # the max_num_batches, ignoring the max_batch_size
            if (num_batches > max_batches_per_length):
                print(f'number of batches ({num_batches}) exceeds the maximum batches per length ({max_batches_per_length})')
                num_batches = max_batches_per_length
                batch_size = math.ceil(num_peptides / num_batches)
            
            print(f'num batches: {num_batches}')
            print(f'batch size: {batch_size}')
            index_start = 0
            for b_num in range(0, num_batches):
                # find the index end
                index_end = min(index_start + batch_size, len(peptide_list))        
                print(f'index start: {index_start}')
                print(f'index end: {index_end}')
                #print(peptide_list[index_start:index_end])
                batched_peptide_list.append(peptide_list[index_start:index_end])
                # increment index_start
                index_start += batch_size
            
        return batched_peptide_list


def transfer_fasta_to_peptide_file(input_sequence_text_file_path, input_length=''):
    with open(input_sequence_text_file_path, 'r') as rf:
        input_sequence_text = rf.read()
    peptide_list = []
    sequence_list = get_sequence_list(input_sequence_text)
    if not input_length:
        # means lengths = ['asis',]
        peptide_list = [seq['sequence'] for seq in sequence_list]
    else:
        lengths = [int(l) for l in input_length.split(',')]
        for length in lengths:
            for sequence in sequence_list:
                seq = sequence['sequence']
                length = int(length)
                for i in range(len(seq)-length+1):
                    peptide = seq[i:i+length]
                    peptide_list.append(peptide)
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_peptides_file:
        fname = tmp_peptides_file.name
        tmp_peptides_file.write('\n'.join(peptide_list))
    return fname

def get_peptides_with_diff_length(peptide_list):
    peptide_lists_len_dict = {}
    for p in peptide_list:
        key = len(p)
        value_list = peptide_lists_len_dict.setdefault(key, [])
        value_list.append(p)
    return list(peptide_lists_len_dict.values())

def split_parameters(input_data, split_inputs_dir=None):
    output_data = []
    # TODO: how to deal with the  temp files, when should we remove them?
    to_delete = []
    peptide_file_path_length_pairs = []
    input_sequence_text = input_data.pop('input_sequence_text','')
    peptide_length_range = input_data.get('peptide_length_range', None)
    peptide_list = input_data.pop('peptide_list', '')
    input_sequence_fasta_uri = input_data.pop('input_sequence_fasta_uri', '')
    alleles = input_data.get('alleles').split(',')
    predictors = input_data.get('predictors')
    allele_validator = Allele_Validator()
    if input_sequence_fasta_uri:
        input_sequence_text_file_path = save_file_from_URI(input_sequence_fasta_uri, target_dir=split_inputs_dir)
        to_delete.append(input_sequence_text_file_path)
        with open(input_sequence_text_file_path, 'r') as r_file:
            input_sequence_text = r_file.read()
        # must have peptide_length_ranged
        peptide_lists_with_diff_len = split_peptides_with_diff_length_batch(input_sequence_text, peptide_length_range)

        sequence_peptide_index = get_peptide_list_with_start_and_stop_and_sequencenumber(input_sequence_text, alleles, peptide_length_range)
    elif input_sequence_text:
        with tempfile.NamedTemporaryFile(mode='w', dir=split_inputs_dir, delete=False) as tmp_peptides_file:
            input_sequence_text_file_path = tmp_peptides_file.name
            to_delete.append(input_sequence_text_file_path)
            tmp_peptides_file.write(input_sequence_text)
        # must have peptide_length_ranged
        peptide_lists_with_diff_len = split_peptides_with_diff_length_batch(input_sequence_text, peptide_length_range)

        sequence_peptide_index = get_peptide_list_with_start_and_stop_and_sequencenumber(input_sequence_text, alleles, peptide_length_range)

    elif peptide_list:
        peptide_lists_with_diff_len = get_peptides_with_diff_length(peptide_list)
        sequence_peptide_index = get_peptide_list_with_start_and_stop_and_sequencenumber('\n'.join(peptide_list), alleles, peptide_length_range='asis')
    elif input_data.get('peptide_file_path', None):
        with open(input_data.get('peptide_file_path', 'r')) as r_f:
            peptide_list = r_f.read().split()
            peptide_lists_with_diff_len = get_peptides_with_diff_length(peptide_list)
            sequence_peptide_index = get_peptide_list_with_start_and_stop_and_sequencenumber('\n'.join(peptide_list), alleles, peptide_length_range='asis')

    for peptide_list in peptide_lists_with_diff_len:
        with tempfile.NamedTemporaryFile(mode='w', dir=split_inputs_dir, delete=False) as tmp_peptides_file:
            fname = tmp_peptides_file.name
            to_delete.append(fname)
            tmp_peptides_file.write('\n'.join(peptide_list))
        peptide_lengths = set(map(len,peptide_list))
        if len(peptide_lengths) > 1:
            raise ValueError('peptides should be splitted into peptide list with same length')
        if len(peptide_lengths) == 0:
            # TODO: add warnings?
            #warnings.add('one of the peptide lengths is not working with input sequences')
            continue
        else:
            peptide_length = peptide_lengths.pop()
        peptide_file_path_length_pairs.append((fname,peptide_length))

    # for consensus
    if any([predictor['type']=='binding' and predictor['method']=='consensus' for predictor in predictors]):
        has_consensus = True
    else:
        has_consensus = False
    # for basic processing, remove the binding predictor it require
    basic_predictors = [predictor for predictor in predictors if predictor['type']=='processing' and predictor['method']=='basic_processing']
    if basic_predictors:
        binding_method = basic_predictors[0]['mhc_binding_method']
        predictors = [p for p in predictors if p['type']!='binding' or p['method']!=binding_method]
    for allele in alleles:
        for predictor in predictors:
            # for immunogenicity: skip invalid alleles if mask_choice == by_allele
            predictor_type = predictor.get('type','')
            if predictor_type == 'immunogenicity' and predictor.get('mask_choice','') == 'by_allele' and not allele_validator.validate_alleles(allele, 'immunogenicity'):
                # TODO: add warnings for this
                print("Warning: allele {} is not available for immunogenicity.".format(allele))
                continue

            method = predictor.get('method','')
            if not method:
                method = predictor_type
            # recommended_epitope and recommended_binding are netmhcpan method alias
            method = method.replace('recommended_epitope','netmhcpan_el').replace('recommended_binding','netmhcpan_ba')
            # for consensus
            if method == 'consensus':
                methods = MHCI_CONSENSUS_METHOD_SET
            else:
                if has_consensus and method in MHCI_CONSENSUS_METHOD_SET:
                    continue
                else:
                    methods = (method,)
            for method in methods:
                if method and method not in [ 'netchop', 'immunogenicity']:
                    if method == 'basic_processing':
                        # TODO: run validation for its mhc_binding_method for basic_processing
                        binding_method = predictor['mhc_binding_method']
                        if not allele_validator.validate_alleles(allele, binding_method):
                            print('warning: invalid allele %s for binding_method %s' % (allele, binding_method))
                            continue
                    elif not allele_validator.validate_alleles(allele, method):
                        # TODO: add warninigs e.g. allele Mamu-A2*05:10 is not available for method smmpmbec
                        print('warning: invalid allele %s for method %s' % (allele, method))
                        continue
                data_unit = input_data.copy()
                unit_predictor = predictor.copy()
                unit_predictor['method'] = method
                data_unit['alleles'] = allele
                data_unit['predictors'] = [unit_predictor]
                output_data.append(data_unit)
                # for basic processing, add the binding predictor it require
                if method == 'basic_processing':
                    binding_method = unit_predictor['mhc_binding_method'].replace('recommended_epitope','netmhcpan_el').replace('recommended_binding','netmhcpan_ba')
                    data_unit = input_data.copy()
                    data_unit['alleles'] = allele
                    data_unit['predictors'] = [{"type": "binding", "method": binding_method}]
                    output_data.insert(-1, data_unit)

    if peptide_file_path_length_pairs:
        new_output_data =[]
        for peptide_file_path, peptide_length in peptide_file_path_length_pairs:
            for data_unit in output_data:
                # do nothing for netchop/netctl/netctlpan here
                if 'method' in data_unit['predictors'][0] and data_unit['predictors'][0]['method'] in ['netchop', 'netctl', 'netctlpan', 'basic_processing']:
                    continue
                # do nothing for mhcnp too
                if 'method' in data_unit['predictors'][0] and data_unit['predictors'][0]['type'] == 'mhcnp':
                    continue
                # TODO: update this after immunogenicity is added to allele_validator
                # TODO: discuss how to deal with basic_processing if some lengths work but some do not
                if data_unit['predictors'][0].get('type','') != 'immunogenicity':
                    valid_dict, invalid_dict = allele_validator.validate_allele_lengths(data_unit['alleles'], [str(peptide_length)], data_unit['predictors'][0]['method'])
                    if not valid_dict:
                        # TODO: add warninigs e.g. length 12 is not available for allele H2-Db and method smmpmbec
                        print('warning: invalid length %s for allele %s and method %s' % (peptide_length ,data_unit['alleles'], data_unit['predictors'][0]['method']))
                        continue
                new_data_unit = data_unit.copy()
                new_data_unit['peptide_file_path'] = peptide_file_path
                new_data_unit['peptide_length_range'] = [peptide_length, peptide_length]
                new_output_data.append(new_data_unit)
        # to add predictors for netchop/netctl/netctlpan here
        # only peptide_list will not work for netchop/netctl/netctlpan
        if input_sequence_text:
            netchop_job_number = 0
            for data_unit in output_data:
                # not need to split for netchop so only 1 job for netchop should be generated
                if 'method' in data_unit['predictors'][0] and data_unit['predictors'][0]['method'] == 'netchop':
                    netchop_job_number += 1
                    if netchop_job_number > 1:
                        continue
                # skip basic_processing/netchop/netctl/netctlpan if peptide_length_range is None (length asis means the input sequences are peptide_list)
                if (data_unit['predictors'][0]['type'] == 'mhcnp' or 'method' in data_unit['predictors'][0] and data_unit['predictors'][0]['method'] in ['netchop', 'netctl', 'netctlpan', 'basic_processing']) and data_unit['peptide_length_range']:
                    # split with single length for basic_processing
                    if data_unit['predictors'][0]['method'] == 'basic_processing':
                        for length in range(peptide_length_range[0], peptide_length_range[1]+1):
                            new_data_unit = data_unit.copy()
                            new_data_unit['peptide_length_range'] = [length, length]
                            new_data_unit['input_sequence_text_file_path'] = input_sequence_text_file_path
                            new_output_data.append(new_data_unit)
                    else:
                        new_data_unit = data_unit.copy()
                        new_data_unit['input_sequence_text_file_path'] = input_sequence_text_file_path
                        new_output_data.append(new_data_unit)
        output_data = new_output_data
    return output_data, sequence_peptide_index, has_consensus


def split_parameters_file(json_filename, parameters_output_dir=None, split_inputs_dir=None, assume_valid=False, keep_empty_row=False):
    
    with open(json_filename, 'r') as r_file:
        input_data = json.load(r_file)

    # recreate the directory
    if os.path.exists(parameters_output_dir):
        shutil.rmtree(parameters_output_dir)

    # if not given, create a folder for it
    if not parameters_output_dir:
        parameters_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'job_%s/splitted_parameters' % generate_random_str(6)))
    # if not given, use parameters_output_dir for generated sequence files as well
    if not split_inputs_dir:
        split_inputs_dir = parameters_output_dir
    # create dir if not exist:
    os.makedirs(parameters_output_dir, exist_ok=True)
    os.makedirs(split_inputs_dir, exist_ok=True)

    # get splitted parameters, sequence_peptide_index table, and has_consensus flag
    output_data, sequence_peptide_index, has_consensus = split_parameters(input_data, split_inputs_dir)

    parameters_output_dir = os.path.abspath(parameters_output_dir) 
    result_output_dir = os.path.abspath(os.path.join(parameters_output_dir, os.pardir, 'results'))
    aggregate_dir = os.path.abspath(os.path.join(parameters_output_dir, os.pardir, 'aggregate'))
    base_dir = os.path.abspath(os.path.join(parameters_output_dir, os.pardir))
    job_descriptions_path = os.path.abspath(os.path.join(base_dir, 'job_descriptions.json'))
    sequence_peptide_index_path = os.path.abspath(os.path.join(result_output_dir, 'sequence_peptide_index.json'))
    mhci_predict_executable_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'tcell_mhci.py'))

    os.makedirs(result_output_dir, exist_ok=True)
    job_descriptions = []
    aggregate_depends_on_job_ids = []
    job_id = -1
    for i, data_unit in enumerate(output_data):
        job_id = i
        data_unit_file_path = os.path.join(parameters_output_dir, '%d.json' % i)
        save_json(data_unit, data_unit_file_path)
        shell_cmd='%s -j %s/%s.json -o %s/%s -f json' % (mhci_predict_executable_path, parameters_output_dir, job_id, result_output_dir, job_id)
        if assume_valid:
            shell_cmd += ' --assume-valid'
        job_description = dict(
            shell_cmd=shell_cmd,
            job_id=job_id,
            job_type="prediction",
            depends_on_job_ids=[],
            expected_outputs=['%s/%s.json' % (result_output_dir, job_id)]
        )
        aggregate_depends_on_job_ids.append(job_id)
        job_descriptions.append(job_description)
    print('parameters_output_dir: %s' % os.path.abspath(parameters_output_dir))

    # add aggreate job
    # job_id == -1 means no job is required to run
    if job_id > -1:
        job_id +=1
        shell_cmd='%s --aggregate --job-desc-file=%s --aggregate-input-dir=%s --aggregate-result-dir=%s' % (mhci_predict_executable_path, job_descriptions_path, result_output_dir, aggregate_dir)
        if has_consensus:
            shell_cmd += " --has-consensus"
        if keep_empty_row:
            shell_cmd += " --keep-empty-row"
        aggreate_job_description = dict(
            shell_cmd=shell_cmd,
            job_id=job_id,
            job_type="aggregate",
            depends_on_job_ids=aggregate_depends_on_job_ids,
            expected_outputs=['%s/aggregated_result.json' % aggregate_dir], 
        )
        job_descriptions.append(aggreate_job_description)
    save_json(job_descriptions, job_descriptions_path)
    save_json(sequence_peptide_index, sequence_peptide_index_path)
    print('job_descriptions_path: %s' % os.path.abspath(job_descriptions_path))

    return

if __name__ == '__main__':
    
    length_range = [9,10]
    batch_size = 5
    max_batches = 4
    
    # get input sequence text from the test data
    with open('test_data/input_sequence_text.1fasta', 'r') as file:
        ist = file.read()

    # split as before, just on length
    #sp = split_peptides_with_diff_length(ist, length_range)
    
    # summarize lists here
    #print(sp)

    # split into batches, by length as well as by batch size
    sp = split_peptides_with_diff_length_batch(ist, length_range, batch_size, max_batches)
    
    # summarize the list here
    #print(sp)
    
