|
22 | 22 | import logging |
23 | 23 | import pkgutil |
24 | 24 | import sys |
| 25 | +import traceback |
25 | 26 | import types |
26 | 27 | from typing import Any, Iterator, List, Optional, Tuple, Type, Union |
27 | 28 |
|
@@ -105,6 +106,36 @@ def __str__(self) -> str: |
105 | 106 | ) |
106 | 107 |
|
107 | 108 |
|
| 109 | +class DirectVolatilityImportUsage(CodeViolation): |
| 110 | + |
| 111 | + def __init__( |
| 112 | + self, |
| 113 | + module: types.ModuleType, |
| 114 | + node: ast.AST, |
| 115 | + importing_module: str, |
| 116 | + imported_item: object, |
| 117 | + imported_name: str, |
| 118 | + ) -> None: |
| 119 | + self.imported_item = imported_item |
| 120 | + self.imported_name = imported_name |
| 121 | + self.importing_module = importing_module |
| 122 | + super().__init__(module, node) |
| 123 | + |
| 124 | + def __str__(self) -> str: |
| 125 | + components = self.importing_module.split(".") |
| 126 | + return ( |
| 127 | + super().__str__() |
| 128 | + + ": " |
| 129 | + + ( |
| 130 | + f"Direct import of {self.imported_name} " |
| 131 | + f"({type(self.imported_item)}) " |
| 132 | + f"from module {self.importing_module} - " |
| 133 | + "change to " |
| 134 | + f"'from {'.'.join(components[:-1])} import {components[-1]} and using {components[-1]}.{self.imported_name}" |
| 135 | + ) |
| 136 | + ) |
| 137 | + |
| 138 | + |
108 | 139 | def is_versionable(var): |
109 | 140 | try: |
110 | 141 | return ( |
@@ -134,6 +165,46 @@ def __init__(self, module: types.ModuleType) -> None: |
134 | 165 | def violations(self): |
135 | 166 | return self._violations |
136 | 167 |
|
| 168 | + def _check_vol3_import_from(self, node: ast.ImportFrom): |
| 169 | + """ |
| 170 | + Ensure that the only thing imported from a volatility3 module (apart |
| 171 | + from the root volatility3 module) are functions and modules. This |
| 172 | + prevents re-exporting of classes and variables from modules that use |
| 173 | + them. |
| 174 | + """ |
| 175 | + if ( |
| 176 | + node.module |
| 177 | + and node.module.startswith("volatility3.") # Give a pass to volatility3 module |
| 178 | + and node.module != "volatility3.framework.constants._version" # make an exception for this |
| 179 | + ): |
| 180 | + for name in node.names: |
| 181 | + try: |
| 182 | + item = vars(self._module)[ |
| 183 | + name.asname if name.asname is not None else name.name |
| 184 | + ] |
| 185 | + except KeyError: |
| 186 | + logger.debug( |
| 187 | + "Couldn't find imported name %s in module %s", |
| 188 | + name.asname or name.name, |
| 189 | + self._module.__name__, |
| 190 | + ) |
| 191 | + continue |
| 192 | + |
| 193 | + if not (isinstance(item, types.ModuleType) or inspect.isfunction(item)): |
| 194 | + self._violations.append( |
| 195 | + DirectVolatilityImportUsage( |
| 196 | + self._module, |
| 197 | + node, |
| 198 | + node.module, |
| 199 | + item, |
| 200 | + name.asname or name.name, |
| 201 | + ) |
| 202 | + ) |
| 203 | + |
| 204 | + def enter_ImportFrom(self, node: ast.ImportFrom): |
| 205 | + self._check_vol3_import_from(node) |
| 206 | + |
| 207 | + |
137 | 208 | def enter_ClassDef(self, node: ast.ClassDef) -> Any: |
138 | 209 | logger.debug("Entering class %s", node.name) |
139 | 210 | clazz = None |
@@ -304,6 +375,7 @@ def report_missing_requirements() -> Iterator[Tuple[str, UnrequiredVersionableUs |
304 | 375 | modname, |
305 | 376 | str(exc), |
306 | 377 | ) |
| 378 | + traceback.print_exc() |
307 | 379 | continue |
308 | 380 |
|
309 | 381 | logger.info("Checking module %s", plugin_module.__name__) |
@@ -344,9 +416,7 @@ def perform_review(): |
344 | 416 | print(str(usage)) |
345 | 417 |
|
346 | 418 | if found: |
347 | | - print( |
348 | | - f"Found {found} issues" |
349 | | - ) |
| 419 | + print(f"Found {found} issues") |
350 | 420 | sys.exit(1) |
351 | 421 |
|
352 | 422 | print("All configurable classes passed validation!") |
|
0 commit comments