Skip to content

Commit 9a231ee

Browse files
committed
Add dump/load unit test to CI
1 parent c9709ce commit 9a231ee

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

tests/unit.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
from apotheosis import Apotheosis
33
from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm
44
from datalayer.node.hash_node import HashNode
5+
from datalayer.node.winpage_hash_node import WinPageHashNode
6+
from unittest.mock import patch
7+
from datalayer.db_manager import DBManager
8+
9+
10+
511

612
APOTHEOSIS_HOST = "localhost:5000"
713
HASHES = [
@@ -72,5 +78,33 @@ def test_deletion(self):
7278

7379
self.assertEqual(actual_founds, expected_founds)
7480

81+
def test_dump_load(self):
82+
self.apo_model = Apotheosis(M=4, ef=4, Mmax=8, Mmax0=16,
83+
heuristic=False, extend_candidates=False, keep_pruned_conns=False,
84+
beer_factor=False,
85+
distance_algorithm=TLSHHashAlgorithm)
86+
87+
# Mock return value for get_winmodule_data_by_hash
88+
mock_winmodule = WinPageHashNode("T10381E956C26225F2DAD9D5C2C5C1A337FAF3708A25012B8A1EACDAC00B37D557E0E714", TLSHHashAlgorithm)
89+
90+
with patch.object(DBManager, 'get_winmodule_data_by_hash', return_value=mock_winmodule):
91+
# Create the nodes based on TLSH Fuzzy Hashes
92+
hash1 = "T10381E956C26225F2DAD9D5C2C5C1A337FAF3708A25012B8A1EACDAC00B37D557E0E714"
93+
node1 = WinPageHashNode(hash1, TLSHHashAlgorithm)
94+
hash2 = "T1458197A3C292D1EC8566C6A2C6516377FA743E0F8120BA49CFD1CF812B66B60D75E316"
95+
node2 = WinPageHashNode(hash2, TLSHHashAlgorithm)
96+
97+
self.apo_model.insert(node1)
98+
self.apo_model.insert(node2)
99+
100+
self.apo_model.dump("TestApo", compress=False)
101+
self.apo_model.load('TestApo', TLSHHashAlgorithm, WinPageHashNode)
102+
_, exact1, _ = self.apo_model.knn_search(HashNode(hash1, TLSHHashAlgorithm), k=1, ef=4)
103+
_, exact2, _ = self.apo_model.knn_search(HashNode(hash2, TLSHHashAlgorithm), k=1, ef=4)
104+
self.assertEqual(exact1.get_id(), hash1)
105+
self.assertEqual(exact2.get_id(), hash2)
106+
107+
108+
75109
if __name__ == '__main__':
76110
unittest.main()

0 commit comments

Comments
 (0)