|
5 | 5 | import mysql.connector |
6 | 6 | from mysql.connector import errorcode |
7 | 7 |
|
8 | | -from datalayer.node.winpage_hash_node import WinPageHashNode |
9 | 8 | from datalayer.hash_algorithm.hash_algorithm import HashAlgorithm |
10 | 9 | from datalayer.hash_algorithm.tlsh_algorithm import TLSHHashAlgorithm |
11 | 10 | from datalayer.hash_algorithm.ssdeep_algorithm import SSDEEPHashAlgorithm |
|
35 | 34 | JOIN os o ON m.os_id = o.id |
36 | 35 | """ |
37 | 36 |
|
38 | | -SQL_GET_WINMODULE_BY_HASH = """ |
| 37 | +SQL_GET_ALL_PAGES_LAZY = """ |
| 38 | + SELECT p.{} |
| 39 | + FROM pages p |
| 40 | + JOIN modules m ON p.module_id = m.id |
| 41 | + JOIN os o ON m.os_id = o.id |
| 42 | + """ |
| 43 | + |
| 44 | +SQL_GET_MODULE_BY_HASH = """ |
39 | 45 | SELECT p.{}, m.id AS module_id, m.file_version, m.original_filename, |
40 | 46 | m.internal_filename, m.product_filename, m.company_name, |
41 | 47 | m.legal_copyright, m.classification, m.size, m.base_address, o.* |
@@ -107,61 +113,51 @@ def _row_to_module(self, row, os=None): |
107 | 113 | base_address=row["base_address"] |
108 | 114 | ) |
109 | 115 |
|
110 | | - |
111 | | - def get_winmodule_data_by_pageid(self, page_id=0, algorithm=HashAlgorithm): |
112 | | - logger.info(f"Getting results for \"{page_id}\" from DB ({algorithm.__name__}) ...") |
113 | | - hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP" |
114 | | - query = SQL_GET_WINMODULE_BY_PAGEID.format(hash_column) |
115 | | - self.cursor.execute(query, (page_id,)) |
116 | | - row = self.cursor.fetchone() |
117 | | - if not row: |
118 | | - logger.debug(f"Error! Page ID {page_id} not in DB (algorithm: {algorithm})") |
119 | | - raise PageIdValueNotInDBError |
120 | | - |
121 | | - module = self._row_to_module(row) |
122 | | - return WinPageHashNode(id=row[hash_column], hash_algorithm=algorithm, module=module) |
123 | | - |
124 | | - |
125 | | - def get_winmodule_data_by_hash(self, algorithm, hash_value: str = ""): |
| 116 | + def get_winpage_module_by_hash(self, algorithm : HashAlgorithm, hash_value: str = ""): |
126 | 117 | logger.info(f"Getting results for \"{hash_value}\" from DB ({algorithm})") |
127 | 118 | hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP" |
128 | | - query = SQL_GET_WINMODULE_BY_HASH.format(hash_column, hash_column) |
| 119 | + |
| 120 | + query = SQL_GET_MODULE_BY_HASH.format(hash_column, hash_column) |
129 | 121 | self.cursor.execute(query, (hash_value,)) |
130 | 122 | row = self.cursor.fetchone() |
| 123 | + |
131 | 124 | if not row: |
132 | 125 | logger.debug(f"Error! Hash value {hash_value} not in DB (algorithm: {algorithm})") |
133 | 126 | raise HashValueNotInDBError |
134 | 127 |
|
135 | 128 | module = self._row_to_module(row) |
136 | | - return WinPageHashNode(id=hash_value, hash_algorithm=algorithm, module=module) |
137 | | - |
| 129 | + return module |
138 | 130 |
|
139 | 131 | def get_organized_modules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm) -> dict: |
140 | 132 | result = {} |
141 | 133 |
|
142 | 134 | self.cursor.execute(SQL_GET_ALL_OS) |
143 | 135 | db_operating_systems = self.cursor.fetchall() |
144 | 136 | for db_os in db_operating_systems: |
145 | | - os_name = db_os['name'] |
| 137 | + os_name = db_os['version'] |
146 | 138 | result[os_name] = {} |
147 | 139 |
|
148 | 140 | self.cursor.execute(SQL_GET_MODULES_BY_OS, (db_os['id'],)) |
149 | 141 | modules = self.cursor.fetchall() |
150 | 142 | for module in modules: |
151 | 143 | module_name = module['internal_filename'] |
152 | | - result[os_name][module_name] = set() |
| 144 | + original_name = module['original_filename'] |
| 145 | + result[os_name][original_name] = set() |
153 | 146 |
|
154 | 147 | pages = self.get_winmodules(algorithm, modules_of_interest={module_name}, os_id=db_os['id']) |
155 | 148 | for page in pages: |
156 | | - result[os_name][module_name].add(page) |
| 149 | + result[os_name][original_name].add(page) |
157 | 150 |
|
158 | 151 | return result |
159 | 152 |
|
160 | | - def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None, os_id: int = None) -> set: |
| 153 | + def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: int = None, modules_of_interest: set = None, os_id: int = None, |
| 154 | + lazy: bool = True) -> set: |
| 155 | + from datalayer.node.winpage_hash_node import WinPageHashNode # Avoid circular deps |
161 | 156 | try: |
| 157 | + query = SQL_GET_ALL_PAGES_LAZY if lazy else SQL_GET_ALL_PAGES |
| 158 | + |
162 | 159 | winmodules = set() |
163 | | - operating_systems = set() |
164 | | - modules = set() |
| 160 | + |
165 | 161 |
|
166 | 162 | hash_column = "hashTLSH" if algorithm == TLSHHashAlgorithm else "hashSSDEEP" |
167 | 163 | query = SQL_GET_ALL_PAGES.format(hash_column) # Inject hash column |
@@ -192,20 +188,27 @@ def get_winmodules(self, algorithm: HashAlgorithm = TLSHHashAlgorithm, limit: in |
192 | 188 | self.cursor.execute(query, params) |
193 | 189 | results = self.cursor.fetchall() |
194 | 190 |
|
195 | | - for row in results: |
196 | | - os_id = row["id"] |
197 | | - os_version = row["version"] |
198 | | - os_name = row["name"] |
199 | | - hash_value = row[hash_column] |
200 | | - |
201 | | - current_os = OS(os_id, os_name, os_version) |
202 | | - if current_os in operating_systems: |
203 | | - current_os = next(os for os in operating_systems if os == current_os) |
204 | | - else: |
205 | | - operating_systems.add(current_os) |
206 | | - |
207 | | - current_module = self._row_to_module(row) |
208 | | - winmodules.add(WinPageHashNode(hash_value, algorithm, current_module)) |
| 191 | + if lazy: |
| 192 | + for row in results: |
| 193 | + hash_value = row[hash_column] |
| 194 | + winmodules.add(WinPageHashNode(hash_value, algorithm, None)) |
| 195 | + else: |
| 196 | + operating_systems = set() |
| 197 | + modules = set() |
| 198 | + for row in results: |
| 199 | + os_id = row["id"] |
| 200 | + os_version = row["version"] |
| 201 | + os_name = row["name"] |
| 202 | + hash_value = row[hash_column] |
| 203 | + |
| 204 | + current_os = OS(os_id, os_name, os_version) |
| 205 | + if current_os in operating_systems: |
| 206 | + current_os = next(os for os in operating_systems if os == current_os) |
| 207 | + else: |
| 208 | + operating_systems.add(current_os) |
| 209 | + |
| 210 | + current_module = self._row_to_module(row) |
| 211 | + winmodules.add(WinPageHashNode(hash_value, algorithm, current_module)) |
209 | 212 |
|
210 | 213 | return winmodules |
211 | 214 | except mysql.connector.Error as err: |
|
0 commit comments