'''
Unit tests for smm_predictor
'''
import unittest
from smm_predictor import SMMMatrix, single_prediction_smm, single_prediction_smmpmbec

class SMMTests(unittest.TestCase):
        
    def test_basic(self):
        allele_name = 'HLA-A*01:01'
        length = '9'
        peptide_list = ['CSANNSHHY','LTDLGLLYT', 'FSDQIEQEA', 'QSSINISGY','LRDLMGVPY']

        result = single_prediction_smm(allele_name, length, peptide_list)
        expected_result = (173.60011158813074, 242.9685153759396, 360.20521816589303, 380.6711814678423, 421.25978328719043)

        self.assertEqual(result, expected_result)

    def test_matrix(self):
        allele_name = 'HLA-A*01:01'
        length = '9'
        peptide_list = ['CSANNSHHY','LTDLGLLYT', 'FSDQIEQEA', 'QSSINISGY','LRDLMGVPY']

        predictor = SMMMatrix(method_name='smm')
        predictor.initialize(allele_name, length)
        result = predictor.predict_peptide_list(peptide_list)
        expected_result = (173.60011158813074, 242.9685153759396, 360.20521816589303, 380.6711814678423, 421.25978328719043)

        self.assertEqual(result, expected_result)

class SMMPMBECTests(unittest.TestCase):
        
    def test_basic(self):
        allele_name = 'HLA-A*01:01'
        length = '9'
        peptide_list = ['CSANNSHHY','LTDLGLLYT', 'FSDQIEQEA', 'QSSINISGY','LRDLMGVPY']

        result = single_prediction_smmpmbec(allele_name, length, peptide_list)
        expected_result = (197.888246404065, 237.91400068041455, 380.5572499081934, 481.3045862525536, 428.9631640256637)

        self.assertEqual(result, expected_result)

if __name__ == '__main__':
    unittest.main()
