55import mysql .connector
66from mysql .connector import errorcode
77
8- from datalayer .node .winmodule_hash_node import WinModuleHashNode
8+ from datalayer .node .winpage_hash_node import WinPageHashNode
99from datalayer .hash_algorithm .hash_algorithm import HashAlgorithm
1010from datalayer .hash_algorithm .tlsh_algorithm import TLSHHashAlgorithm
1111from datalayer .hash_algorithm .ssdeep_algorithm import SSDEEPHashAlgorithm
1414
1515from common .errors import HashValueNotInDBError , PageIdValueNotInDBError
1616
17+ SQL_GET_ALL_OS = """
18+ SELECT *
19+ FROM os
20+ """
21+ SQL_GET_MODULES_BY_OS = """
22+ SELECT m.id AS module_id, m.file_version, m.original_filename,
23+ m.internal_filename, m.product_filename, m.company_name,
24+ m.legal_copyright, m.classification, m.size, m.base_address
25+ FROM modules m
26+ WHERE m.os_id = %s
27+ """
28+
1729SQL_GET_ALL_PAGES = """
1830 SELECT p.{}, m.id AS module_id, m.file_version, m.original_filename,
1931 m.internal_filename, m.product_filename, m.company_name,
@@ -78,9 +90,11 @@ def _clean_dict_keys(self, _dict: dict, keys: list):
7890 for key in keys :
7991 _dict .pop (key , None )
8092
81- def _row_to_module (self , row ):
93+ def _row_to_module (self , row , os = None ):
94+ if not os :
95+ os = OS (row ["id" ], row ["name" ], row ["version" ])
8296 return Module (
83- os = OS ( row [ "id" ], row [ "name" ], row [ "version" ]) ,
97+ os = os ,
8498 id = row ["module_id" ],
8599 file_version = row ["file_version" ],
86100 original_filename = row ["original_filename" ],
@@ -105,7 +119,7 @@ def get_winmodule_data_by_pageid(self, page_id=0, algorithm=HashAlgorithm):
105119 raise PageIdValueNotInDBError
106120
107121 module = self ._row_to_module (row )
108- return WinModuleHashNode (id = row [hash_column ], hash_algorithm = algorithm , module = module )
122+ return WinPageHashNode (id = row [hash_column ], hash_algorithm = algorithm , module = module )
109123
110124
111125 def get_winmodule_data_by_hash (self , algorithm : str = "" , hash_value : str = "" ):
@@ -119,20 +133,63 @@ def get_winmodule_data_by_hash(self, algorithm: str = "", hash_value: str = ""):
119133 raise HashValueNotInDBError
120134
121135 module = self ._row_to_module (row )
122- return WinModuleHashNode (id = hash_value , hash_algorithm = algorithm , module = module )
123-
136+ return WinPageHashNode (id = hash_value , hash_algorithm = algorithm , module = module )
124137
125- def get_winmodules (self , algorithm : HashAlgorithm = TLSHHashAlgorithm , limit : int = None , modules_of_interest : set = None ) -> set :
138+
139+ def get_organized_modules (self , algorithm : HashAlgorithm = TLSHHashAlgorithm ) -> dict :
140+ result = {}
141+
142+ self .cursor .execute (SQL_GET_ALL_OS )
143+ db_operating_systems = self .cursor .fetchall ()
144+ for db_os in db_operating_systems :
145+ os_name = db_os ['name' ]
146+ result [os_name ] = {}
147+
148+ self .cursor .execute (SQL_GET_MODULES_BY_OS , (db_os ['id' ],))
149+ modules = self .cursor .fetchall ()
150+ for module in modules :
151+ module_name = module ['internal_filename' ]
152+ result [os_name ][module_name ] = set ()
153+
154+ pages = self .get_winmodules (algorithm , modules_of_interest = {module_name }, os_id = db_os ['id' ])
155+ for page in pages :
156+ result [os_name ][module_name ].add (page )
157+
158+ return result
159+
160+ def get_winmodules (self , algorithm : HashAlgorithm = TLSHHashAlgorithm , limit : int = None , modules_of_interest : set = None , os_id : int = None ) -> set :
126161 try :
127162 winmodules = set ()
128163 operating_systems = set ()
129164 modules = set ()
130165
131166 hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP"
132- query = SQL_GET_ALL_PAGES .format (hash_column )
167+ query = SQL_GET_ALL_PAGES .format (hash_column ) # Inject hash column
168+
169+ conditions = []
170+ params = []
171+
172+ if modules_of_interest :
173+ placeholders = ', ' .join (['%s' ] * len (modules_of_interest ))
174+ conditions .append (f"m.internal_filename IN ({ placeholders } )" )
175+ params .extend (modules_of_interest )
176+
177+ if os_id is not None :
178+ conditions .append ("o.id = %s" )
179+ params .append (os_id )
180+
181+ if algorithm == TLSHHashAlgorithm :
182+ conditions .append ("p.hashTLSH != '*'" )
183+ conditions .append ("p.hashTLSH != '-'" )
184+
185+ if conditions :
186+ query += " WHERE " + " AND " .join (conditions )
187+
133188 if limit :
134- query = query + f" LIMIT { limit } "
135- self .cursor .execute (query )
189+ query += " LIMIT %s"
190+ params .append (limit )
191+
192+ self .cursor .execute (query , params )
136193 results = self .cursor .fetchall ()
137194
138195 for row in results :
@@ -148,26 +205,15 @@ def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: in
148205 operating_systems .add (current_os )
149206
150207 current_module = self ._row_to_module (row )
151-
152- # Supposedly more memory-efficient, but it slows down retrieval
153- '''
154- if current_module in modules:
155- current_module = next(module for module in modules if module == current_module)
156- else:
157- modules.add(current_module)
158- '''
159-
160- if modules_of_interest and current_module .internal_filename not in modules_of_interest :
161- continue
162-
163- winmodules .add (WinModuleHashNode (hash_value , algorithm , current_module ))
208+ winmodules .add (WinPageHashNode (hash_value , algorithm , current_module ))
164209
165210 return winmodules
166211 except mysql .connector .Error as err :
167212 logger .error (f"Database query error: { err } " )
168213 raise
169214 finally :
170- self .cursor .close ()
215+ pass
216+ #self.cursor.close()
171217
172218 def close (self ):
173219 self .cursor .close ()
0 commit comments