diff --git a/fickling/analysis.py b/fickling/analysis.py index d24afc8..440b805 100644 --- a/fickling/analysis.py +++ b/fickling/analysis.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import json from abc import ABC, abstractmethod from ast import unparse @@ -325,45 +326,55 @@ class UnsafeImportsML(Analysis): def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: for node in context.pickled.properties.imports: shortened, _ = context.shorten_code(node) - all_modules = [ - node.module.rsplit(".", i)[0] for i in range(0, node.module.count(".") + 1) - ] - for module_name in all_modules: - if module_name in self.UNSAFE_MODULES: - # Special handling for builtins - check specific function names - if module_name in BUILTIN_MODULE_NAMES: - for n in node.names: - if n.name not in SAFE_BUILTINS: - risk_info = self.UNSAFE_MODULES[module_name] - yield AnalysisResult( - Severity.LIKELY_OVERTLY_MALICIOUS, - f"`{shortened}` imports `{n.name}` from `{module_name}` " - f"which can execute arbitrary code. {risk_info}", - "UnsafeImportsML", - trigger=shortened, - ) - else: - # All other unsafe modules are fully blocked - risk_info = self.UNSAFE_MODULES[module_name] - yield AnalysisResult( - Severity.LIKELY_OVERTLY_MALICIOUS, - f"`{shortened}` uses `{module_name}` that is indicative of a malicious pickle file. {risk_info}", - "UnsafeImportsML", - trigger=shortened, - ) - if node.module in self.UNSAFE_IMPORTS: - for n in node.names: - if n.name in self.UNSAFE_IMPORTS[node.module]: - risk_info = self.UNSAFE_IMPORTS[node.module][n.name] - yield AnalysisResult( - Severity.LIKELY_OVERTLY_MALICIOUS, - f"`{shortened}` imports `{n.name}` that is indicative of a malicious pickle file. {risk_info}", - "UnsafeImportsML", - trigger=shortened, - ) + + match node: + case ast.ImportFrom(module=module, names=names) if module: + modules_to_check = [module] + imported_names = names + case ast.Import(names=names): + modules_to_check = [alias.name for alias in names] + imported_names = [] + case _: + continue + + for module in modules_to_check: + all_modules = [module.rsplit(".", i)[0] for i in range(0, module.count(".") + 1)] + for module_name in all_modules: + if module_name in self.UNSAFE_MODULES: + # Special handling for builtins - check specific function names + if module_name in BUILTIN_MODULE_NAMES: + for n in imported_names: + if n.name not in SAFE_BUILTINS: + risk_info = self.UNSAFE_MODULES[module_name] + yield AnalysisResult( + Severity.LIKELY_OVERTLY_MALICIOUS, + f"`{shortened}` imports `{n.name}` from `{module_name}` " + f"which can execute arbitrary code. {risk_info}", + "UnsafeImportsML", + trigger=shortened, + ) + else: + # All other unsafe modules are fully blocked + risk_info = self.UNSAFE_MODULES[module_name] + yield AnalysisResult( + Severity.LIKELY_OVERTLY_MALICIOUS, + f"`{shortened}` uses `{module_name}` that is indicative of a malicious pickle file. {risk_info}", + "UnsafeImportsML", + trigger=shortened, + ) + if module in self.UNSAFE_IMPORTS: + for n in imported_names: + if n.name in self.UNSAFE_IMPORTS[module]: + risk_info = self.UNSAFE_IMPORTS[module][n.name] + yield AnalysisResult( + Severity.LIKELY_OVERTLY_MALICIOUS, + f"`{shortened}` imports `{n.name}` that is indicative of a malicious pickle file. {risk_info}", + "UnsafeImportsML", + trigger=shortened, + ) # NOTE(boyan): Special case with eval? # Copy pasted from pickled.unsafe_imports() original implementation - elif "eval" in (n.name for n in node.names): + if "eval" in (n.name for n in imported_names): yield AnalysisResult( Severity.LIKELY_OVERTLY_MALICIOUS, f"`{shortened}` imports `eval` which can execute arbitrary code", @@ -425,10 +436,11 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: class UnsafeImports(Analysis): def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: for node in context.pickled.unsafe_imports(): - if node.module in BUILTIN_MODULE_NAMES and all( - n.name in SAFE_BUILTINS for n in node.names - ): - continue + if isinstance(node, ast.ImportFrom): + if node.module in BUILTIN_MODULE_NAMES and all( + n.name in SAFE_BUILTINS for n in node.names + ): + continue shortened, _ = context.shorten_code(node) yield AnalysisResult( Severity.LIKELY_OVERTLY_MALICIOUS, diff --git a/fickling/fickle.py b/fickling/fickle.py index 50173a4..0c4c9eb 100644 --- a/fickling/fickle.py +++ b/fickling/fickle.py @@ -256,7 +256,7 @@ def __post_init__(self) -> None: def is_std_module(module_name: str) -> bool: - return module_name in BUILTIN_STDLIB_MODULE_NAMES + return module_name.split(".")[0] in BUILTIN_STDLIB_MODULE_NAMES def extract_identifier_from_ast_node( diff --git a/test/test_analysis.py b/test/test_analysis.py new file mode 100644 index 0000000..2124edd --- /dev/null +++ b/test/test_analysis.py @@ -0,0 +1,76 @@ +from unittest import TestCase + +import fickling.fickle as op +from fickling.analysis import ( + Analyzer, + UnsafeImports, + UnsafeImportsML, + check_safety, +) +from fickling.fickle import Pickled + + +class TestImportMatchingGaps(TestCase): + """Regression tests for gaps in how analysis passes handle imports.""" + + def test_stdlib_submodule_not_flagged_as_nonstandard(self): + """Stdlib submodules like collections.abc should not trigger NonStandardImports.""" + pickled = Pickled( + [ + op.Proto.create(4), + op.Global.create("collections.abc", "Mapping"), + op.EmptyTuple(), + op.Reduce(), + op.Stop(), + ] + ) + result = check_safety(pickled) + non_std_results = [r for r in result.results if r.analysis_name == "NonStandardImports"] + self.assertEqual( + len(non_std_results), + 0, + "collections.abc should not be flagged as non-standard", + ) + + def test_eval_import_from_unsafe_imports_ml_module(self): + """Eval check must not be skipped when module is in UNSAFE_IMPORTS. + + Not a real payload (_io.eval doesn't exist), just a regression trigger. + """ + pickled = Pickled( + [ + op.Proto.create(4), + op.ShortBinUnicode("_io"), + op.ShortBinUnicode("eval"), + op.StackGlobal(), + op.EmptyTuple(), + op.Reduce(), + op.Stop(), + ] + ) + result = check_safety(pickled) + eval_results = [ + r + for r in result.results + if r.analysis_name == "UnsafeImportsML" and "eval" in (r.message or "") + ] + self.assertGreater( + len(eval_results), + 0, + "UnsafeImportsML should flag 'from _io import eval'", + ) + + def test_ext1_ast_import_does_not_crash_analysis(self): + """Ext1 generates ast.Import nodes; both analysis passes must handle them.""" + pickled = Pickled( + [ + op.Proto.create(2), + op.Ext1(1), + op.Stop(), + ] + ) + # Must not raise AttributeError: 'Import' object has no attribute 'module' + for analysis in [UnsafeImportsML(), UnsafeImports()]: + with self.subTest(analysis=type(analysis).__name__): + result = Analyzer([analysis]).analyze(pickled) + self.assertIsNotNone(result)