Skip to content

Commit 397b488

Browse files
committed
Add functionality to DbManager + WinModuleHashNode refactor to WinPageHashNode
1 parent 722dc85 commit 397b488

File tree

5 files changed

+76
-30
lines changed

5 files changed

+76
-30
lines changed

common/utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def create_model(npages, nsearch_pages,\
2222
import time
2323
import datetime
2424
from common.errors import NodeAlreadyExistsError
25-
def load_DB_in_model(npages=0, nsearch_pages=None, algorithm=None, current_model=None, printlog=True):
25+
def load_DB_in_model(npages=0, nsearch_pages=None, algorithm=None, current_model=None, printlog=True, modules_of_interest=None, os_id=None):
2626
BATCH_PRINT=1e2
2727

2828
db_manager = DBManager()
2929

3030
print(f"[*] Getting modules from DB (with {algorithm.__name__}) ...")
3131
start = time.time_ns()
32-
all_pages = db_manager.get_winmodules(algorithm, npages + nsearch_pages if nsearch_pages else npages)
32+
all_pages = db_manager.get_winmodules(algorithm, npages + nsearch_pages if nsearch_pages else npages, modules_of_interest, os_id)
3333
end = time.time_ns() # in nanoseconds
3434
db_time = (end - start)/1e6 # ms
3535
print(f"[*] {len(all_pages)} pages recovered from DB in {db_time} ms.")

datalayer/db_manager.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import mysql.connector
66
from mysql.connector import errorcode
77

8-
from datalayer.node.winmodule_hash_node import WinModuleHashNode
8+
from datalayer.node.winpage_hash_node import WinPageHashNode
99
from datalayer.hash_algorithm.hash_algorithm import HashAlgorithm
1010
from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm
1111
from datalayer.hash_algorithm.ssdeep_algorithm import SSDEEPHashAlgorithm
@@ -14,6 +14,18 @@
1414

1515
from common.errors import HashValueNotInDBError, PageIdValueNotInDBError
1616

17+
SQL_GET_ALL_OS = """
18+
SELECT *
19+
FROM os
20+
"""
21+
SQL_GET_MODULES_BY_OS = """
22+
SELECT m.id AS module_id, m.file_version, m.original_filename,
23+
m.internal_filename, m.product_filename, m.company_name,
24+
m.legal_copyright, m.classification, m.size, m.base_address
25+
FROM modules m
26+
WHERE m.os_id = %s
27+
"""
28+
1729
SQL_GET_ALL_PAGES = """
1830
SELECT p.{}, m.id AS module_id, m.file_version, m.original_filename,
1931
m.internal_filename, m.product_filename, m.company_name,
@@ -78,9 +90,11 @@ def _clean_dict_keys(self, _dict: dict, keys: list):
7890
for key in keys:
7991
_dict.pop(key, None)
8092

81-
def _row_to_module(self, row):
93+
def _row_to_module(self, row, os=None):
94+
if not os:
95+
os = OS(row["id"], row["name"], row["version"])
8296
return Module(
83-
os=OS(row["id"], row["name"], row["version"]),
97+
os=os,
8498
id=row["module_id"],
8599
file_version=row["file_version"],
86100
original_filename=row["original_filename"],
@@ -105,7 +119,7 @@ def get_winmodule_data_by_pageid(self, page_id=0, algorithm=HashAlgorithm):
105119
raise PageIdValueNotInDBError
106120

107121
module = self._row_to_module(row)
108-
return WinModuleHashNode(id=row[hash_column], hash_algorithm=algorithm, module=module)
122+
return WinPageHashNode(id=row[hash_column], hash_algorithm=algorithm, module=module)
109123

110124

111125
def get_winmodule_data_by_hash(self, algorithm: str = "", hash_value: str = ""):
@@ -119,20 +133,63 @@ def get_winmodule_data_by_hash(self, algorithm: str = "", hash_value: str = ""):
119133
raise HashValueNotInDBError
120134

121135
module = self._row_to_module(row)
122-
return WinModuleHashNode(id=hash_value, hash_algorithm=algorithm, module=module)
123-
136+
return WinPageHashNode(id=hash_value, hash_algorithm=algorithm, module=module)
124137

125-
def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None) -> set:
138+
139+
def get_organized_modules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm) -> dict:
140+
result = {}
141+
142+
self.cursor.execute(SQL_GET_ALL_OS)
143+
db_operating_systems = self.cursor.fetchall()
144+
for db_os in db_operating_systems:
145+
os_name = db_os['name']
146+
result[os_name] = {}
147+
148+
self.cursor.execute(SQL_GET_MODULES_BY_OS, (db_os['id'],))
149+
modules = self.cursor.fetchall()
150+
for module in modules:
151+
module_name = module['internal_filename']
152+
result[os_name][module_name] = set()
153+
154+
pages = self.get_winmodules(algorithm, modules_of_interest={module_name}, os_id=db_os['id'])
155+
for page in pages:
156+
result[os_name][module_name].add(page)
157+
158+
return result
159+
160+
def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None, os_id: int = None) -> set:
126161
try:
127162
winmodules = set()
128163
operating_systems = set()
129164
modules = set()
130165

131166
hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP"
132-
query = SQL_GET_ALL_PAGES.format(hash_column)
167+
query = SQL_GET_ALL_PAGES.format(hash_column) # Inject hash column
168+
169+
conditions = []
170+
params = []
171+
172+
if modules_of_interest:
173+
placeholders = ', '.join(['%s'] * len(modules_of_interest))
174+
conditions.append(f"m.internal_filename IN ({placeholders})")
175+
params.extend(modules_of_interest)
176+
177+
if os_id is not None:
178+
conditions.append("o.id = %s")
179+
params.append(os_id)
180+
181+
if algorithm == TLSHHashAlgorithm:
182+
conditions.append("p.hashTLSH != '*'")
183+
conditions.append("p.hashTLSH != '-'")
184+
185+
if conditions:
186+
query += " WHERE " + " AND ".join(conditions)
187+
133188
if limit:
134-
query = query + f" LIMIT {limit}"
135-
self.cursor.execute(query)
189+
query += " LIMIT %s"
190+
params.append(limit)
191+
192+
self.cursor.execute(query, params)
136193
results = self.cursor.fetchall()
137194

138195
for row in results:
@@ -148,26 +205,15 @@ def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: in
148205
operating_systems.add(current_os)
149206

150207
current_module = self._row_to_module(row)
151-
152-
# Supposedly more memory-efficient, but it slows down retrieval
153-
'''
154-
if current_module in modules:
155-
current_module = next(module for module in modules if module == current_module)
156-
else:
157-
modules.add(current_module)
158-
'''
159-
160-
if modules_of_interest and current_module.internal_filename not in modules_of_interest:
161-
continue
162-
163-
winmodules.add(WinModuleHashNode(hash_value, algorithm, current_module))
208+
winmodules.add(WinPageHashNode(hash_value, algorithm, current_module))
164209

165210
return winmodules
166211
except mysql.connector.Error as err:
167212
logger.error(f"Database query error: {err}")
168213
raise
169214
finally:
170-
self.cursor.close()
215+
pass
216+
#self.cursor.close()
171217

172218
def close(self):
173219
self.cursor.close()

datalayer/node/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def set_neighbors_at_layer(self, layer: int, neighbors: set):
4040
# only in HashNode
4141
def calculate_similarity(self, other_node):
4242
raise NotImplementedError
43-
# only in WinModuleHashNode
43+
# only in WinPageHashNode
4444
def get_pageids(self):
4545
raise NotImplementedError
4646

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from common.constants import *
88
from common.errors import NodeUnsupportedAlgorithm
99

10-
class WinModuleHashNode(HashNode):
10+
class WinPageHashNode(HashNode):
1111
def __init__(self, id, hash_algorithm: HashAlgorithm, module: Module=None):
1212
super().__init__(id, hash_algorithm)
1313
self._module = module
@@ -78,7 +78,7 @@ def create_node_from_DB(cls, db_manager, hash_id, hash_algorithm):
7878
@classmethod
7979
def internal_data_needs_DB(cls) -> bool:
8080
return True # we have some data necessary to retrieve from the DB
81-
# to load a WinModuleHashNode from an Apotheosis file
81+
# to load a WinPageHashNode from an Apotheosis file
8282

8383
def is_equal(self, other):
8484
if type(self) != type(other):

rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _extend_results_winmodule_data(hash_algorithm: str, results: dict) -> dict:
143143
"""Extends the results dict with Winmodule information (from the database).
144144
145145
Arguments:
146-
results -- dict of WinModuleHashNode
146+
results -- dict of WinPageHashNode
147147
"""
148148

149149
new_results = {}

0 commit comments

Comments
 (0)