|
2 | 2 | from apotheosis import Apotheosis |
3 | 3 | from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm |
4 | 4 | from datalayer.node.hash_node import HashNode |
| 5 | +from datalayer.node.winpage_hash_node import WinPageHashNode |
| 6 | +from unittest.mock import patch |
| 7 | +from datalayer.db_manager import DBManager |
| 8 | + |
| 9 | + |
| 10 | + |
5 | 11 |
|
6 | 12 | APOTHEOSIS_HOST = "localhost:5000" |
7 | 13 | HASHES = [ |
@@ -72,5 +78,33 @@ def test_deletion(self): |
72 | 78 |
|
73 | 79 | self.assertEqual(actual_founds, expected_founds) |
74 | 80 |
|
| 81 | + def test_dump_load(self): |
| 82 | + self.apo_model = Apotheosis(M=4, ef=4, Mmax=8, Mmax0=16, |
| 83 | + heuristic=False, extend_candidates=False, keep_pruned_conns=False, |
| 84 | + beer_factor=False, |
| 85 | + distance_algorithm=TLSHHashAlgorithm) |
| 86 | + |
| 87 | + # Mock return value for get_winmodule_data_by_hash |
| 88 | + mock_winmodule = WinPageHashNode("T10381E956C26225F2DAD9D5C2C5C1A337FAF3708A25012B8A1EACDAC00B37D557E0E714", TLSHHashAlgorithm) |
| 89 | + |
| 90 | + with patch.object(DBManager, 'get_winmodule_data_by_hash', return_value=mock_winmodule): |
| 91 | + # Create the nodes based on TLSH Fuzzy Hashes |
| 92 | + hash1 = "T10381E956C26225F2DAD9D5C2C5C1A337FAF3708A25012B8A1EACDAC00B37D557E0E714" |
| 93 | + node1 = WinPageHashNode(hash1, TLSHHashAlgorithm) |
| 94 | + hash2 = "T1458197A3C292D1EC8566C6A2C6516377FA743E0F8120BA49CFD1CF812B66B60D75E316" |
| 95 | + node2 = WinPageHashNode(hash2, TLSHHashAlgorithm) |
| 96 | + |
| 97 | + self.apo_model.insert(node1) |
| 98 | + self.apo_model.insert(node2) |
| 99 | + |
| 100 | + self.apo_model.dump("TestApo", compress=False) |
| 101 | + self.apo_model.load('TestApo', TLSHHashAlgorithm, WinPageHashNode) |
| 102 | + _, exact1, _ = self.apo_model.knn_search(HashNode(hash1, TLSHHashAlgorithm), k=1, ef=4) |
| 103 | + _, exact2, _ = self.apo_model.knn_search(HashNode(hash2, TLSHHashAlgorithm), k=1, ef=4) |
| 104 | + self.assertEqual(exact1.get_id(), hash1) |
| 105 | + self.assertEqual(exact2.get_id(), hash2) |
| 106 | + |
| 107 | + |
| 108 | + |
75 | 109 | if __name__ == '__main__': |
76 | 110 | unittest.main() |
0 commit comments