Skip to content

Commit cadf25e

Browse files
committed
Lazy load implementation
1 parent a64c539 commit cadf25e

File tree

5 files changed

+70
-96
lines changed

5 files changed

+70
-96
lines changed

apotheosis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _load_node_from_fp(cls, f, data_to_node: dict,
224224
data_neighs = {}
225225
if db_manager:
226226
if data_to_node.get(data) is None:
227-
new_node = hash_node_class.create_node_from_DB(db_manager, data, algorithm)
227+
new_node = hash_node_class.create_node_from_DB(db_manager, data, algorithm, lazy=True)
228228
if with_layer:
229229
new_node.set_max_layer(max_layer)
230230
# store it for next iterations

datalayer/db_manager.py

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import mysql.connector
66
from mysql.connector import errorcode
77

8-
from datalayer.node.winpage_hash_node import WinPageHashNode
98
from datalayer.hash_algorithm.hash_algorithm import HashAlgorithm
109
from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm
1110
from datalayer.hash_algorithm.ssdeep_algorithm import SSDEEPHashAlgorithm
@@ -35,7 +34,14 @@
3534
JOIN os o ON m.os_id = o.id
3635
"""
3736

38-
SQL_GET_WINMODULE_BY_HASH = """
37+
SQL_GET_ALL_PAGES_LAZY = """
38+
SELECT p.{}
39+
FROM pages p
40+
JOIN modules m ON p.module_id = m.id
41+
JOIN os o ON m.os_id = o.id
42+
"""
43+
44+
SQL_GET_MODULE_BY_HASH = """
3945
SELECT p.{}, m.id AS module_id, m.file_version, m.original_filename,
4046
m.internal_filename, m.product_filename, m.company_name,
4147
m.legal_copyright, m.classification, m.size, m.base_address, o.*
@@ -107,61 +113,51 @@ def _row_to_module(self, row, os=None):
107113
base_address=row["base_address"]
108114
)
109115

110-
111-
def get_winmodule_data_by_pageid(self, page_id=0, algorithm=HashAlgorithm):
112-
logger.info(f"Getting results for \"{page_id}\" from DB ({algorithm.__name__}) ...")
113-
hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP"
114-
query = SQL_GET_WINMODULE_BY_PAGEID.format(hash_column)
115-
self.cursor.execute(query, (page_id,))
116-
row = self.cursor.fetchone()
117-
if not row:
118-
logger.debug(f"Error! Page ID {page_id} not in DB (algorithm: {algorithm})")
119-
raise PageIdValueNotInDBError
120-
121-
module = self._row_to_module(row)
122-
return WinPageHashNode(id=row[hash_column], hash_algorithm=algorithm, module=module)
123-
124-
125-
def get_winmodule_data_by_hash(self, algorithm, hash_value: str = ""):
116+
def get_winpage_module_by_hash(self, algorithm : HashAlgorithm, hash_value: str = ""):
126117
logger.info(f"Getting results for \"{hash_value}\" from DB ({algorithm})")
127118
hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP"
128-
query = SQL_GET_WINMODULE_BY_HASH.format(hash_column, hash_column)
119+
120+
query = SQL_GET_MODULE_BY_HASH.format(hash_column, hash_column)
129121
self.cursor.execute(query, (hash_value,))
130122
row = self.cursor.fetchone()
123+
131124
if not row:
132125
logger.debug(f"Error! Hash value {hash_value} not in DB (algorithm: {algorithm})")
133126
raise HashValueNotInDBError
134127

135128
module = self._row_to_module(row)
136-
return WinPageHashNode(id=hash_value, hash_algorithm=algorithm, module=module)
137-
129+
return module
138130

139131
def get_organized_modules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm) -> dict:
140132
result = {}
141133

142134
self.cursor.execute(SQL_GET_ALL_OS)
143135
db_operating_systems = self.cursor.fetchall()
144136
for db_os in db_operating_systems:
145-
os_name = db_os['name']
137+
os_name = db_os['version']
146138
result[os_name] = {}
147139

148140
self.cursor.execute(SQL_GET_MODULES_BY_OS, (db_os['id'],))
149141
modules = self.cursor.fetchall()
150142
for module in modules:
151143
module_name = module['internal_filename']
152-
result[os_name][module_name] = set()
144+
original_name = module['original_filename']
145+
result[os_name][original_name] = set()
153146

154147
pages = self.get_winmodules(algorithm, modules_of_interest={module_name}, os_id=db_os['id'])
155148
for page in pages:
156-
result[os_name][module_name].add(page)
149+
result[os_name][original_name].add(page)
157150

158151
return result
159152

160-
def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None, os_id: int = None) -> set:
153+
def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None, os_id: int = None,
154+
lazy: bool = True) -> set:
155+
from datalayer.node.winpage_hash_node import WinPageHashNode # Avoid circular deps
161156
try:
157+
query = SQL_GET_ALL_PAGES_LAZY if lazy else SQL_GET_ALL_PAGES
158+
162159
winmodules = set()
163-
operating_systems = set()
164-
modules = set()
160+
165161

166162
hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP"
167163
query = SQL_GET_ALL_PAGES.format(hash_column) # Inject hash column
@@ -192,20 +188,27 @@ def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: in
192188
self.cursor.execute(query, params)
193189
results = self.cursor.fetchall()
194190

195-
for row in results:
196-
os_id = row["id"]
197-
os_version = row["version"]
198-
os_name = row["name"]
199-
hash_value = row[hash_column]
200-
201-
current_os = OS(os_id, os_name, os_version)
202-
if current_os in operating_systems:
203-
current_os = next(os for os in operating_systems if os == current_os)
204-
else:
205-
operating_systems.add(current_os)
206-
207-
current_module = self._row_to_module(row)
208-
winmodules.add(WinPageHashNode(hash_value, algorithm, current_module))
191+
if lazy:
192+
for row in results:
193+
hash_value = row[hash_column]
194+
winmodules.add(WinPageHashNode(hash_value, algorithm, None))
195+
else:
196+
operating_systems = set()
197+
modules = set()
198+
for row in results:
199+
os_id = row["id"]
200+
os_version = row["version"]
201+
os_name = row["name"]
202+
hash_value = row[hash_column]
203+
204+
current_os = OS(os_id, os_name, os_version)
205+
if current_os in operating_systems:
206+
current_os = next(os for os in operating_systems if os == current_os)
207+
else:
208+
operating_systems.add(current_os)
209+
210+
current_module = self._row_to_module(row)
211+
winmodules.add(WinPageHashNode(hash_value, algorithm, current_module))
209212

210213
return winmodules
211214
except mysql.connector.Error as err:

datalayer/node/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def internal_load(cls, f):
5353
raise NotImplementedError
5454
# to be implemented in final classes
5555
@classmethod
56-
def create_node_from_DB(cls, db_manager, _id, hash_algoritmh):
56+
def create_node_from_DB(cls, db_manager, _id, hash_algoritmh, lazy):
5757
raise NotImplementedError
5858
# to be implemented in final classes
5959
@classmethod
Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
11
#TODO docstring
22
from datalayer.node.hash_node import HashNode
33
from datalayer.hash_algorithm.hash_algorithm import HashAlgorithm
4-
from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm
5-
from datalayer.hash_algorithm.ssdeep_algorithm import SSDEEPHashAlgorithm
64
from datalayer.database.module import Module
5+
from datalayer.db_manager import DBManager
76
from common.constants import *
87
from common.errors import NodeUnsupportedAlgorithm
98

109
class WinPageHashNode(HashNode):
11-
def __init__(self, id, hash_algorithm: HashAlgorithm, module: Module=None):
10+
def __init__(self, id, hash_algorithm: HashAlgorithm, module: Module=None, db_manager: DBManager=None):
1211
super().__init__(id, hash_algorithm)
13-
self._module = module
12+
self._real_module = module
13+
self._db_manager = db_manager
14+
15+
@property
16+
def _module(self):
17+
if not self._real_module and self._db_manager is not None:
18+
self._real_module = self._db_manager.get_winpage_module_by_id(self._hash_algorithm, self._id)
19+
return self._real_module
1420

1521
def __lt__(self, other): # Hack for priority queue. TODO: not needed here?
1622
return False
1723

1824
def get_module(self):
1925
return self._module
2026

21-
def get_page(self):
22-
return self._page
23-
24-
def get_internal_page_id(self):
25-
return self._page.id if self._page else 0
26-
2727
def get_draw_features(self):
2828
return {"module_names": { self._id: self._module.original_filename + " " + self._module.file_version},
2929
"module_version": {self._id: self._module.file_version},
3030
"os_version": {self._id: self._module.os.version}
3131
}
3232

33-
3433
def as_dict(self):
3534
node_dict = super().as_dict()
3635
if self._module:
@@ -64,8 +63,10 @@ def internal_load(cls, f):
6463
return bpage_id, bpage_id.decode('utf-8').rstrip('\x00')
6564

6665
@classmethod
67-
def create_node_from_DB(cls, db_manager, hash_id, hash_algorithm):
68-
new_node = db_manager.get_winmodule_data_by_hash(hash_value=hash_id, algorithm=hash_algorithm)
66+
def create_node_from_DB(cls, db_manager, hash_id, hash_algorithm, lazy=True):
67+
new_node = WinPageHashNode(hash_id, hash_algorithm, None, db_manager)
68+
if not lazy:
69+
new_node._module # Force load module from database
6970
return new_node
7071

7172
@classmethod
@@ -74,28 +75,5 @@ def internal_data_needs_DB(cls) -> bool:
7475
# to load a WinPageHashNode from an Apotheosis file
7576

7677
def is_equal(self, other):
77-
if type(self) != type(other):
78-
return False
79-
try:
80-
same_module = self._module == other._module
81-
same_page = self._page == other._page
82-
if not same_module or not same_page:
83-
return False
84-
if type(self._hash_algorithm) != type(other._hash_algorithm):
85-
return False
86-
# check now the id and the hash, both modules and pages are the same
87-
equal = self._id == other._id and self._max_layer == other._max_layer and\
88-
len(self._neighbors) == len(other._neighbors)
89-
if not equal:
90-
return False
91-
# now, check the neighbors
92-
for idx, neighs in enumerate(self._neighbors):
93-
other_pageid = set([node._page.id for node in other._neighbors[idx]])
94-
self_pageid = set([node._page.id for node in self._neighbors[idx]])
95-
if other_pageid != self_pageid:
96-
return False
97-
98-
return True
99-
except:
100-
return False
78+
return other._id == self._id
10179

tests/unit.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,6 @@ def test_deletion(self):
7878

7979
self.assertEqual(actual_founds, expected_founds)
8080

81-
82-
def mock_get_winmodule_data_by_hash(self, algorithm, hash_value):
83-
"""Mock function for get_winmodule_data_by_hash"""
84-
return WinPageHashNode(hash_value, algorithm)
85-
8681
def test_dump_load(self):
8782
self.apo_model = Apotheosis(
8883
M=4, ef=4, Mmax=8, Mmax0=16,
@@ -97,20 +92,18 @@ def test_dump_load(self):
9792
node1 = WinPageHashNode(hash1, TLSHHashAlgorithm)
9893
node2 = WinPageHashNode(hash2, TLSHHashAlgorithm)
9994

100-
# Using self to reference the mock function
101-
with patch.object(DBManager, 'connect', return_value=None), \
102-
patch.object(DBManager, 'get_winmodule_data_by_hash', side_effect=self.mock_get_winmodule_data_by_hash) as mock_method:
103-
self.apo_model.insert(node1)
104-
self.apo_model.insert(node2)
95+
96+
self.apo_model.insert(node1)
97+
self.apo_model.insert(node2)
10598

106-
self.apo_model.dump("TestApo", compress=False)
107-
self.apo_model.load('TestApo', TLSHHashAlgorithm, WinPageHashNode)
99+
self.apo_model.dump("TestApo", compress=False)
100+
self.apo_model.load('TestApo', TLSHHashAlgorithm, WinPageHashNode)
108101

109-
_, exact1, _ = self.apo_model.knn_search(HashNode(hash1, TLSHHashAlgorithm), k=1, ef=4)
110-
_, exact2, _ = self.apo_model.knn_search(HashNode(hash2, TLSHHashAlgorithm), k=1, ef=4)
102+
_, exact1, _ = self.apo_model.knn_search(HashNode(hash1, TLSHHashAlgorithm), k=1, ef=4)
103+
_, exact2, _ = self.apo_model.knn_search(HashNode(hash2, TLSHHashAlgorithm), k=1, ef=4)
111104

112-
self.assertEqual(exact1.get_id(), hash1)
113-
self.assertEqual(exact2.get_id(), hash2)
105+
self.assertEqual(exact1.get_id(), hash1)
106+
self.assertEqual(exact2.get_id(), hash2)
114107

115108

116109
if __name__ == '__main__':

0 commit comments

Comments
 (0)