|
2 | 2 | from apotheosis import Apotheosis |
3 | 3 | from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm |
4 | 4 | from datalayer.node.hash_node import HashNode |
| 5 | +from datalayer.node.winpage_hash_node import WinPageHashNode |
| 6 | +from unittest.mock import patch |
| 7 | +from datalayer.db_manager import DBManager |
| 8 | + |
| 9 | + |
| 10 | + |
5 | 11 |
|
6 | 12 | APOTHEOSIS_HOST = "localhost:5000" |
7 | 13 | HASHES = [ |
@@ -72,5 +78,39 @@ def test_deletion(self): |
72 | 78 |
|
73 | 79 | self.assertEqual(actual_founds, expected_founds) |
74 | 80 |
|
| 81 | + def mock_get_winmodule_data_by_hash(self, algorithm, hash_value): |
| 82 | + """Mock function for get_winmodule_data_by_hash""" |
| 83 | + return WinPageHashNode(hash_value, algorithm) |
| 84 | + |
| 85 | + def test_dump_load(self): |
| 86 | + self.apo_model = Apotheosis(M=4, ef=4, Mmax=8, Mmax0=16, |
| 87 | + heuristic=False, extend_candidates=False, |
| 88 | + keep_pruned_conns=False, beer_factor=False, |
| 89 | + distance_algorithm=TLSHHashAlgorithm) |
| 90 | + |
| 91 | + hash1 = "T10381E956C26225F2DAD9D5C2C5C1A337FAF3708A25012B8A1EACDAC00B37D557E0E714" |
| 92 | + hash2 = "T1458197A3C292D1EC8566C6A2C6516377FA743E0F8120BA49CFD1CF812B66B60D75E316" |
| 93 | + |
| 94 | + node1 = WinPageHashNode(hash1, TLSHHashAlgorithm) |
| 95 | + node2 = WinPageHashNode(hash2, TLSHHashAlgorithm) |
| 96 | + |
| 97 | + with patch.object(DBManager, 'get_winmodule_data_by_hash', side_effect=self.mock_get_winmodule_data_by_hash) as mock_method: |
| 98 | + self.apo_model.insert(node1) |
| 99 | + self.apo_model.insert(node2) |
| 100 | + |
| 101 | + self.apo_model.dump("TestApo", compress=False) |
| 102 | + self.apo_model.load('TestApo', TLSHHashAlgorithm, WinPageHashNode) |
| 103 | + |
| 104 | + _, exact1, _ = self.apo_model.knn_search(HashNode(hash1, TLSHHashAlgorithm), k=1, ef=4) |
| 105 | + _, exact2, _ = self.apo_model.knn_search(HashNode(hash2, TLSHHashAlgorithm), k=1, ef=4) |
| 106 | + |
| 107 | + self.assertEqual(exact1.get_id(), hash1) |
| 108 | + self.assertEqual(exact2.get_id(), hash2) |
| 109 | + |
| 110 | + mock_method.assert_any_call(TLSHHashAlgorithm, hash1) |
| 111 | + mock_method.assert_any_call(TLSHHashAlgorithm, hash2) |
| 112 | + |
| 113 | + |
| 114 | + |
75 | 115 | if __name__ == '__main__': |
76 | 116 | unittest.main() |
0 commit comments