Skip to content

Commit 369b421

Browse files
committed
Add functionality to DbManager + WinModuleHashNode refactor to WinPageHashNode
1 parent 722dc85 commit 369b421

File tree

6 files changed

+86
-40
lines changed

6 files changed

+86
-40
lines changed

apotheosis.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,14 @@ def _create_empty(self, M=0, ef=0, Mmax=0, Mmax0=0,\
147147
heuristic=False, extend_candidates=True, keep_pruned_conns=True,\
148148
beer_factor: float=0):
149149
# check first if algorithm is supported
150-
if not issubclass(distance_algorithm, HashAlgorithm):
150+
if not isinstance(distance_algorithm, type) or not issubclass(distance_algorithm, HashAlgorithm):
151151
raise ApotheosisUnsupportedDistanceAlgorithmError
152-
# construct both data structures (a HNSW and a radix tree for all nodes -- will contain @WinModuleHashNode)
152+
# construct both data structures (a HNSW and a radix tree for all nodes -- will contain @WinPageHashNode)
153153
self._HNSW = HNSW(M=M, ef=ef, Mmax=Mmax, Mmax0=Mmax0, distance_algorithm=distance_algorithm,\
154154
heuristic=heuristic, extend_candidates=extend_candidates, keep_pruned_conns=keep_pruned_conns,\
155155
beer_factor=beer_factor)
156156
self._distance_algorithm = distance_algorithm
157-
# radix hash tree for all nodes (of @WinModuleHashNode)
157+
# radix hash tree for all nodes (of @WinPageHashNode)
158158
self._radix = RadixHash(distance_algorithm)
159159

160160
@classmethod
@@ -538,7 +538,7 @@ def knn_search(self, query=None, k:int=0, ef=0, hashid=0):
538538
"""
539539
if hashid != 0:
540540
# create node and make the search again...
541-
query = WinModuleHashNode(hashid, self.get_distance_algorithm())
541+
query = WinPageHashNode(hashid, self.get_distance_algorithm())
542542

543543
self._sanity_checks(query)
544544

@@ -628,7 +628,7 @@ def __eq__(self, other):
628628
# unit test
629629
import common.utilities as util
630630
from datalayer.node.hash_node import HashNode
631-
from datalayer.node.winmodule_hash_node import WinModuleHashNode
631+
from datalayer.node.winpage_hash_node import WinPageHashNode
632632
from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm
633633
from datalayer.hash_algorithm.ssdeep_algorithm import SSDEEPHashAlgorithm
634634
from random import random
@@ -672,11 +672,11 @@ def search_knns(apo, query_node):
672672
hash6 = "T1DF8174A9C2A506FC122292D644816333FEF1B845C419121A0F91CF5359B5B21FA3A305" #fake
673673
hash7 = "T10381E956C26225F2DAD9D097B381202C62AC793B37082B8A1EACDAC00B37D557E0E714" #fake
674674

675-
node1 = WinModuleHashNode(hash1, TLSHHashAlgorithm)
676-
node2 = WinModuleHashNode(hash2, TLSHHashAlgorithm)
677-
node3 = WinModuleHashNode(hash3, TLSHHashAlgorithm)
678-
node4 = WinModuleHashNode(hash4, TLSHHashAlgorithm)
679-
node5 = WinModuleHashNode(hash5, TLSHHashAlgorithm)
675+
node1 = WinPageHashNode(hash1, TLSHHashAlgorithm)
676+
node2 = WinPageHashNode(hash2, TLSHHashAlgorithm)
677+
node3 = WinPageHashNode(hash3, TLSHHashAlgorithm)
678+
node4 = WinPageHashNode(hash4, TLSHHashAlgorithm)
679+
node5 = WinPageHashNode(hash5, TLSHHashAlgorithm)
680680
nodes = [node1, node2, node3]
681681

682682
print("Testing insert ...")

common/utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def create_model(npages, nsearch_pages,\
2222
import time
2323
import datetime
2424
from common.errors import NodeAlreadyExistsError
25-
def load_DB_in_model(npages=0, nsearch_pages=None, algorithm=None, current_model=None, printlog=True):
25+
def load_DB_in_model(npages=0, nsearch_pages=None, algorithm=None, current_model=None, printlog=True, modules_of_interest=None, os_id=None):
2626
BATCH_PRINT=1e2
2727

2828
db_manager = DBManager()
2929

3030
print(f"[*] Getting modules from DB (with {algorithm.__name__}) ...")
3131
start = time.time_ns()
32-
all_pages = db_manager.get_winmodules(algorithm, npages + nsearch_pages if nsearch_pages else npages)
32+
all_pages = db_manager.get_winmodules(algorithm, npages + nsearch_pages if nsearch_pages else npages, modules_of_interest, os_id)
3333
end = time.time_ns() # in nanoseconds
3434
db_time = (end - start)/1e6 # ms
3535
print(f"[*] {len(all_pages)} pages recovered from DB in {db_time} ms.")

datalayer/db_manager.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import mysql.connector
66
from mysql.connector import errorcode
77

8-
from datalayer.node.winmodule_hash_node import WinModuleHashNode
8+
from datalayer.node.winpage_hash_node import WinPageHashNode
99
from datalayer.hash_algorithm.hash_algorithm import HashAlgorithm
1010
from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm
1111
from datalayer.hash_algorithm.ssdeep_algorithm import SSDEEPHashAlgorithm
@@ -14,6 +14,18 @@
1414

1515
from common.errors import HashValueNotInDBError, PageIdValueNotInDBError
1616

17+
SQL_GET_ALL_OS = """
18+
SELECT *
19+
FROM os
20+
"""
21+
SQL_GET_MODULES_BY_OS = """
22+
SELECT m.id AS module_id, m.file_version, m.original_filename,
23+
m.internal_filename, m.product_filename, m.company_name,
24+
m.legal_copyright, m.classification, m.size, m.base_address
25+
FROM modules m
26+
WHERE m.os_id = %s
27+
"""
28+
1729
SQL_GET_ALL_PAGES = """
1830
SELECT p.{}, m.id AS module_id, m.file_version, m.original_filename,
1931
m.internal_filename, m.product_filename, m.company_name,
@@ -78,9 +90,11 @@ def _clean_dict_keys(self, _dict: dict, keys: list):
7890
for key in keys:
7991
_dict.pop(key, None)
8092

81-
def _row_to_module(self, row):
93+
def _row_to_module(self, row, os=None):
94+
if not os:
95+
os = OS(row["id"], row["name"], row["version"])
8296
return Module(
83-
os=OS(row["id"], row["name"], row["version"]),
97+
os=os,
8498
id=row["module_id"],
8599
file_version=row["file_version"],
86100
original_filename=row["original_filename"],
@@ -105,7 +119,7 @@ def get_winmodule_data_by_pageid(self, page_id=0, algorithm=HashAlgorithm):
105119
raise PageIdValueNotInDBError
106120

107121
module = self._row_to_module(row)
108-
return WinModuleHashNode(id=row[hash_column], hash_algorithm=algorithm, module=module)
122+
return WinPageHashNode(id=row[hash_column], hash_algorithm=algorithm, module=module)
109123

110124

111125
def get_winmodule_data_by_hash(self, algorithm: str = "", hash_value: str = ""):
@@ -119,20 +133,63 @@ def get_winmodule_data_by_hash(self, algorithm: str = "", hash_value: str = ""):
119133
raise HashValueNotInDBError
120134

121135
module = self._row_to_module(row)
122-
return WinModuleHashNode(id=hash_value, hash_algorithm=algorithm, module=module)
123-
136+
return WinPageHashNode(id=hash_value, hash_algorithm=algorithm, module=module)
124137

125-
def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None) -> set:
138+
139+
def get_organized_modules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm) -> dict:
140+
result = {}
141+
142+
self.cursor.execute(SQL_GET_ALL_OS)
143+
db_operating_systems = self.cursor.fetchall()
144+
for db_os in db_operating_systems:
145+
os_name = db_os['name']
146+
result[os_name] = {}
147+
148+
self.cursor.execute(SQL_GET_MODULES_BY_OS, (db_os['id'],))
149+
modules = self.cursor.fetchall()
150+
for module in modules:
151+
module_name = module['internal_filename']
152+
result[os_name][module_name] = set()
153+
154+
pages = self.get_winmodules(algorithm, modules_of_interest={module_name}, os_id=db_os['id'])
155+
for page in pages:
156+
result[os_name][module_name].add(page)
157+
158+
return result
159+
160+
def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None, os_id: int = None) -> set:
126161
try:
127162
winmodules = set()
128163
operating_systems = set()
129164
modules = set()
130165

131166
hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP"
132-
query = SQL_GET_ALL_PAGES.format(hash_column)
167+
query = SQL_GET_ALL_PAGES.format(hash_column) # Inject hash column
168+
169+
conditions = []
170+
params = []
171+
172+
if modules_of_interest:
173+
placeholders = ', '.join(['%s'] * len(modules_of_interest))
174+
conditions.append(f"m.internal_filename IN ({placeholders})")
175+
params.extend(modules_of_interest)
176+
177+
if os_id is not None:
178+
conditions.append("o.id = %s")
179+
params.append(os_id)
180+
181+
if algorithm == TLSHHashAlgorithm:
182+
conditions.append("p.hashTLSH != '*'")
183+
conditions.append("p.hashTLSH != '-'")
184+
185+
if conditions:
186+
query += " WHERE " + " AND ".join(conditions)
187+
133188
if limit:
134-
query = query + f" LIMIT {limit}"
135-
self.cursor.execute(query)
189+
query += " LIMIT %s"
190+
params.append(limit)
191+
192+
self.cursor.execute(query, params)
136193
results = self.cursor.fetchall()
137194

138195
for row in results:
@@ -148,26 +205,15 @@ def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: in
148205
operating_systems.add(current_os)
149206

150207
current_module = self._row_to_module(row)
151-
152-
# Supposedly more memory-efficient, but it slows down retrieval
153-
'''
154-
if current_module in modules:
155-
current_module = next(module for module in modules if module == current_module)
156-
else:
157-
modules.add(current_module)
158-
'''
159-
160-
if modules_of_interest and current_module.internal_filename not in modules_of_interest:
161-
continue
162-
163-
winmodules.add(WinModuleHashNode(hash_value, algorithm, current_module))
208+
winmodules.add(WinPageHashNode(hash_value, algorithm, current_module))
164209

165210
return winmodules
166211
except mysql.connector.Error as err:
167212
logger.error(f"Database query error: {err}")
168213
raise
169214
finally:
170-
self.cursor.close()
215+
pass
216+
#self.cursor.close()
171217

172218
def close(self):
173219
self.cursor.close()

datalayer/node/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def set_neighbors_at_layer(self, layer: int, neighbors: set):
4040
# only in HashNode
4141
def calculate_similarity(self, other_node):
4242
raise NotImplementedError
43-
# only in WinModuleHashNode
43+
# only in WinPageHashNode
4444
def get_pageids(self):
4545
raise NotImplementedError
4646

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from common.constants import *
88
from common.errors import NodeUnsupportedAlgorithm
99

10-
class WinModuleHashNode(HashNode):
10+
class WinPageHashNode(HashNode):
1111
def __init__(self, id, hash_algorithm: HashAlgorithm, module: Module=None):
1212
super().__init__(id, hash_algorithm)
1313
self._module = module
@@ -78,7 +78,7 @@ def create_node_from_DB(cls, db_manager, hash_id, hash_algorithm):
7878
@classmethod
7979
def internal_data_needs_DB(cls) -> bool:
8080
return True # we have some data necessary to retrieve from the DB
81-
# to load a WinModuleHashNode from an Apotheosis file
81+
# to load a WinPageHashNode from an Apotheosis file
8282

8383
def is_equal(self, other):
8484
if type(self) != type(other):

rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _extend_results_winmodule_data(hash_algorithm: str, results: dict) -> dict:
143143
"""Extends the results dict with Winmodule information (from the database).
144144
145145
Arguments:
146-
results -- dict of WinModuleHashNode
146+
results -- dict of WinPageHashNode
147147
"""
148148

149149
new_results = {}

0 commit comments

Comments
 (0)