From da128db85a0cae4d3f64baf84a235aa11eb705a6 Mon Sep 17 00:00:00 2001 From: ziad hany Date: Wed, 22 Apr 2026 03:41:41 +0200 Subject: [PATCH 1/4] feat: create pipeline for symbol reachability and add a test Signed-off-by: ziad hany --- pyproject.toml | 1 + .../pipelines/collect_symbols_reachability.py | 35 + scanpipe/pipes/reachability.py | 730 ++++++++++++++++++ scanpipe/pipes/symbols.py | 155 ++++ scanpipe/tests/data/reachability/app.py | 35 + .../tests/data/reachability/diff-app.patch | 39 + scanpipe/tests/data/reachability/fixed-app.py | 41 + scanpipe/tests/data/reachability/vuln-app.py | 35 + .../tests/pipes/test_symbols_reachability.py | 293 +++++++ 9 files changed, 1364 insertions(+) create mode 100644 scanpipe/pipelines/collect_symbols_reachability.py create mode 100644 scanpipe/pipes/reachability.py create mode 100644 scanpipe/tests/data/reachability/app.py create mode 100644 scanpipe/tests/data/reachability/diff-app.patch create mode 100644 scanpipe/tests/data/reachability/fixed-app.py create mode 100644 scanpipe/tests/data/reachability/vuln-app.py create mode 100644 scanpipe/tests/pipes/test_symbols_reachability.py diff --git a/pyproject.toml b/pyproject.toml index 2a11faf48d..2436a4ae43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,6 +136,7 @@ run = "scancodeio:combined_run" analyze_docker_image = "scanpipe.pipelines.analyze_docker:Docker" analyze_root_filesystem_or_vm_image = "scanpipe.pipelines.analyze_root_filesystem:RootFS" analyze_windows_docker_image = "scanpipe.pipelines.analyze_docker_windows:DockerWindows" +analyze_symbols_reachability = "scanpipe.pipelines.collect_symbols_reachability:SymbolReachability" benchmark_purls = "scanpipe.pipelines.benchmark_purls:BenchmarkPurls" collect_strings_gettext = "scanpipe.pipelines.collect_strings_gettext:CollectStringsGettext" collect_symbols_ctags = "scanpipe.pipelines.collect_symbols_ctags:CollectSymbolsCtags" diff --git a/scanpipe/pipelines/collect_symbols_reachability.py b/scanpipe/pipelines/collect_symbols_reachability.py new file mode 100644 index 0000000000..15519fc661 --- /dev/null +++ b/scanpipe/pipelines/collect_symbols_reachability.py @@ -0,0 +1,35 @@ +# +# Copyright (c) nexB Inc. and others. All rights reserved. +# VulnerableCode is a trademark of nexB Inc. +# SPDX-License-Identifier: Apache-2.0 +# See http://www.apache.org/licenses/LICENSE-2.0 for the license text. +# See https://github.com/aboutcode-org/vulnerablecode for support or download. +# See https://aboutcode.org for more information about nexB OSS projects. +# + +from scanpipe.pipelines import Pipeline +from scanpipe.pipes import reachability + + +class SymbolReachability(Pipeline): + """ + Patch reachability analysis, for given a vulnerability patches + """ + + download_inputs = False + is_addon = True + results_url = "/project/{slug}/resources/?extra_data=symbol_reachability" + + @classmethod + def steps(cls): + return (cls.analyze_and_store_symbol_reachability,) + + def analyze_and_store_symbol_reachability(self): + """ + Perform symbol-level reachability analysis for each patch. + This step compares the AST of patched/vulnerable files against the codebase resources. + Results are stored directly in the 'extra_data' of each CodebaseResource. + """ + reachability.collect_and_store_symbol_reachability_results( + project=self.project, logger=self.log + ) diff --git a/scanpipe/pipes/reachability.py b/scanpipe/pipes/reachability.py new file mode 100644 index 0000000000..5531a1f495 --- /dev/null +++ b/scanpipe/pipes/reachability.py @@ -0,0 +1,730 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# http://nexb.com and https://github.com/aboutcode-org/scancode.io +# The ScanCode.io software is licensed under the Apache License version 2.0. +# Data generated with ScanCode.io is provided as-is without warranties. +# ScanCode is a trademark of nexB Inc. +# +# You may not use this software except in compliance with the License. +# You may obtain a copy of the License at: http://apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# +# Data Generated with ScanCode.io is provided on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. No content created from +# ScanCode.io should be considered or used as legal advice. Consult an Attorney +# for any legal advice. +# +# ScanCode.io is a free software code scanning tool from nexB Inc. and others. +# Visit https://github.com/aboutcode-org/scancode.io for support and download. + +import os +import shutil +import tempfile +from enum import Enum +from pathlib import Path + +from git import Repo +from git.diff import NULL_TREE +from git.exc import BadName +from matchcode_toolkit.fingerprinting import create_file_fingerprints +from scancode.api import get_file_info +from unidiff import PatchSet + +from scanpipe.pipes.symbols import TS_QUERIES +from scanpipe.pipes.symbols import _root_of +from scanpipe.pipes.symbols import collect_definitions +from scanpipe.pipes.symbols import extract_calls_in_node +from scanpipe.pipes.symbols import extract_definitions +from scanpipe.pipes.symbols import extract_symbols +from scanpipe.pipes.symbols import parse_code_to_ast +from scanpipe.pipes.symbols import qualified_name_from_index + +EMPTY_TREE_SHA = "4b825dc642cb6eb9a060e54bf8b8e6f9b79b4d2b" + + +class ReachabilityStatus(str, Enum): + REACHABLE = "REACHABLE" + POTENTIALLY_REACHABLE = "POTENTIALLY_REACHABLE" + NOT_REACHABLE = "NOT_REACHABLE" + + +def api_mocker(): + """ + TODO: Remove this once the API patch url is done + """ + return [ + { + "vcs_url": "https://github.com/pallets/flask", + "commit_hash": "089cb86dd22bff589a4eafb7ab8e42dc357623b4", + }, + ] + + +def clone_repo(vcs_url, commit_hash=None): + repo_path = tempfile.mkdtemp(prefix="symbol-reachability-") + + try: + repo = Repo.clone_from(vcs_url, repo_path) + + if commit_hash: + repo.git.checkout(commit_hash) + + return repo_path + + except BadName as exc: + cleanup_repo(repo_path) + raise ValueError(f"Commit {commit_hash} not found") from exc + + except Exception: + cleanup_repo(repo_path) + raise + + +def cleanup_repo(repo_path): + if repo_path and os.path.exists(repo_path): + shutil.rmtree(repo_path, ignore_errors=True) + + +def normalize_text(content): + if content is None: + return "" + + if isinstance(content, bytes): + return content.decode("utf-8", errors="replace") + + return str(content) + + +def is_supported_language(language): + """A language is supported if we have tree-sitter queries for it.""" + return bool(language) and language in TS_QUERIES + + +def detect_language_with_scancode(file_path, content): + """ + Write `content` to a temp file preserving `file_path`'s basename + so the extension is meaningful, then ask ScanCode's `get_file_info` + to return the programming language. + """ + content = normalize_text(content) + + if not content: + return None + + tmp_dir = tempfile.mkdtemp(prefix="patch-lang-") + + try: + target = Path(tmp_dir) / Path(file_path).name + target.write_text(content, encoding="utf-8", errors="replace") + + info = get_file_info(location=str(target)) or {} + return info.get("programming_language") or None + + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +def get_commit_and_parent(repo, commit_hash): + commit = repo.commit(commit_hash) + parent = commit.parents[0] if commit.parents else None + return commit, parent + + +def get_commit_diff_text(repo, parent_commit, commit): + """Whole-commit unified diff (used to extract changed line numbers).""" + base = parent_commit.hexsha if parent_commit else EMPTY_TREE_SHA + return repo.git.diff(base, commit.hexsha, unified=0) + + +def get_changed_files(parent_commit, commit): + """ + Return: + { + file_path: { + "vulnerable_text": "...", + "fixed_text": "...", + } + } + + """ + diffs = ( + parent_commit.diff(commit, create_patch=False) + if parent_commit + else commit.diff(NULL_TREE, create_patch=False) + ) + + files = {} + for diff in diffs: + change_type = diff.change_type + old_path = diff.a_path if change_type in ("D", "M", "R") else None + new_path = diff.b_path if change_type in ("A", "M", "R") else None + path_key = new_path or old_path + + if not path_key: + continue + + entry = files.setdefault( + path_key, + { + "vulnerable_text": "", + "fixed_text": "", + }, + ) + + if old_path and parent_commit: + entry["vulnerable_text"] = ( + (parent_commit.tree / old_path) + .data_stream.read() + .decode("utf-8", errors="replace") + ) + + if new_path: + entry["fixed_text"] = ( + (commit.tree / new_path) + .data_stream.read() + .decode("utf-8", errors="replace") + ) + + return files + + +def get_changed_lines(diff_text, file_path): + """Return `(removed_lines, added_lines)` for one file from a unified diff.""" + removed = [] + added = [] + + if not diff_text: + return removed, added + + for patched_file in PatchSet.from_string(diff_text): + candidates = { + patched_file.path, + (patched_file.source_file or "").removeprefix("a/"), + (patched_file.target_file or "").removeprefix("b/"), + } + + if file_path not in candidates: + continue + + for hunk in patched_file: + for line in hunk: + if line.is_removed and line.source_line_no: + removed.append(line.source_line_no) + elif line.is_added and line.target_line_no: + added.append(line.target_line_no) + + return removed, added + + +def query_captures(language, kind, node): + """ + Re-run a definition query on the root of `node`'s tree so ancestors can + be compared. Query caching is handled by `scanpipe.pipes.symbols`. + """ + from scanpipe.pipes.symbols import get_query + from scanpipe.pipes.symbols import run_query + + root = node + + while root.parent is not None: + root = root.parent + + query = get_query(language, kind) + return list(run_query(query, root)) + + +def is_nested_function(node, language): + function_nodes = { + captured_node + for captured_node, _ in query_captures(language, "functions", node) + } + class_nodes = { + captured_node for captured_node, _ in query_captures(language, "classes", node) + } + + if node not in function_nodes: + return False + + function_types = {captured_node.type for captured_node in function_nodes} + class_types = {captured_node.type for captured_node in class_nodes} + + parent = node.parent + + while parent is not None: + if parent.type in function_types: + return True + + if parent.type in class_types: + return False + + parent = parent.parent + + return False + + +def diff_changed_symbols(vuln_meta, fixed_meta): + """ + Keep only symbols whose body actually differs between vulnerable and fixed + versions. Pair by qualified name first. + """ + fixed_by_qn = { + metadata["qualified_name"]: metadata for metadata in fixed_meta.values() + } + + vuln_by_qn = { + metadata["qualified_name"]: metadata for metadata in vuln_meta.values() + } + + vuln_only = { + key: metadata + for key, metadata in vuln_meta.items() + if fixed_by_qn.get(metadata["qualified_name"], {}).get("text") + != metadata["text"] + } + + fixed_only = { + key: metadata + for key, metadata in fixed_meta.items() + if vuln_by_qn.get(metadata["qualified_name"], {}).get("text") + != metadata["text"] + } + + return vuln_only, fixed_only + + +def analyze_patched_file(vulnerable_text, fixed_text, diff_text, file_path): + """ + Return `(vuln_metadata, fixed_metadata, language)` for one changed file, + restricted to symbols actually touched by the patch. + """ + vulnerable_text = normalize_text(vulnerable_text) + fixed_text = normalize_text(fixed_text) + + language = detect_language_with_scancode( + file_path, fixed_text + ) or detect_language_with_scancode(file_path, vulnerable_text) + + if not is_supported_language(language): + return {}, {}, language + + vuln_tree, _ = ( + parse_code_to_ast(vulnerable_text, language) + if vulnerable_text + else (None, None) + ) + + fixed_tree, _ = ( + parse_code_to_ast(fixed_text, language) if fixed_text else (None, None) + ) + + if vuln_tree is None and fixed_tree is None: + return {}, {}, language + + removed_lines, added_lines = get_changed_lines(diff_text, file_path) + + vuln_nodes = ( + extract_symbols(vuln_tree, removed_lines, language) if vuln_tree else [] + ) + + fixed_nodes = ( + extract_symbols(fixed_tree, added_lines, language) if fixed_tree else [] + ) + + vuln_meta, fixed_meta = diff_changed_symbols( + build_symbol_metadata(vuln_nodes, language), + build_symbol_metadata(fixed_nodes, language), + ) + + return vuln_meta, fixed_meta, language + + +def collect_patch_symbols(repo, commit_hash): + """ + Return: + { + language: { + "vulnerable": { + "file_path::symbol_key": metadata, + ... + }, + "fixed": { + "file_path::symbol_key": metadata, + ... + }, + }, + ... + } + + Symbols are bucketed by language so resources are only matched against + patch symbols extracted from the same language. + + """ + commit, parent = get_commit_and_parent(repo, commit_hash) + diff_text = get_commit_diff_text(repo, parent, commit) + changed = get_changed_files(parent, commit) + + by_language = {} + for file_path, texts in changed.items(): + vulnerable_text = texts["vulnerable_text"] + fixed_text = texts["fixed_text"] + vuln_meta, fixed_meta, language = analyze_patched_file( + vulnerable_text=vulnerable_text, + fixed_text=fixed_text, + diff_text=diff_text, + file_path=file_path, + ) + + if not language or not (vuln_meta or fixed_meta): + continue + + language_bucket = by_language.setdefault( + language, + { + "vulnerable": {}, + "fixed": {}, + }, + ) + + language_bucket["vulnerable"].update( + {f"{file_path}::{key}": metadata for key, metadata in vuln_meta.items()} + ) + + language_bucket["fixed"].update( + {f"{file_path}::{key}": metadata for key, metadata in fixed_meta.items()} + ) + + return by_language + + +def append_symbol_reachability_result(resource, result): + """ + Append one symbol reachability result to the resource extra_data without + overwriting previous results. + """ + extra_data = resource.extra_data or {} + existing_results = extra_data.get("symbols_reachability", []) + + if not isinstance(existing_results, list): + existing_results = [existing_results] + + existing_results.append(result) + + resource.update_extra_data( + { + "symbols_reachability": existing_results, + } + ) + + +def collect_and_store_symbol_reachability_results(project, logger=None): + """ + For each known patch commit, determine whether each project codebase + resource is reachable to the vulnerable code by comparing tree-sitter ASTs + of the patch versus the resource. + + Result classification: + - REACHABLE + - POTENTIALLY_REACHABLE + - NOT_REACHABLE + """ + candidate_resources = project.codebaseresources.files().filter( + is_binary=False, + is_archive=False, + is_media=False, + ) + + for patch in api_mocker(): + vcs_url = patch["vcs_url"] + commit_hash = patch["commit_hash"] + repo_path = None + try: + repo_path = clone_repo(vcs_url, commit_hash) + repo = Repo(repo_path) + + patch_symbols_by_language = collect_patch_symbols(repo, commit_hash) + + if not patch_symbols_by_language: + continue + + for resource in candidate_resources: + resource_language = resource.programming_language + + if resource_language not in patch_symbols_by_language: + continue + + resource_text = normalize_text(resource.file_content) + + if not resource_text: + continue + + patch_symbols = patch_symbols_by_language[resource_language] + vuln_metadata = patch_symbols["vulnerable"] + fixed_metadata = patch_symbols["fixed"] + + resource_index = build_resource_index( + resource_text, + resource_language, + ) + + if not resource_index: + continue + + vuln_match_symbols = match_symbols_against_resource( + vuln_metadata, + resource_index, + ) + + fixed_match_symbols = match_symbols_against_resource( + fixed_metadata, + resource_index, + ) + + if not vuln_match_symbols and not fixed_match_symbols: + continue + + result = { + "reachability_status": classify_reachability(vuln_match_symbols), + "summary": { + "vulnerable_symbols": sorted(vuln_match_symbols), + "fixed_symbols": sorted(fixed_match_symbols), + "call_paths": { + qn: ev.get("reachable_from", []) + for qn, ev in vuln_match_symbols.items() + if ev.get("called") + }, + }, + "evidence": vuln_match_symbols, + "patch": { + "vcs_url": vcs_url, + "commit_hash": commit_hash, + }, + } + + append_symbol_reachability_result(resource, result) + + except Exception as e: + logger.exception( + "Failed to collect symbol reachability for " + f"{vcs_url}@{commit_hash}: {e}" + ) + finally: + cleanup_repo(repo_path) + + +def compute_reachable_symbols(call_graph, target_simple_names): + if not call_graph or not target_simple_names: + return set(), False + + edges = call_graph["edges"] + targets = set(target_simple_names) + + callers_of = {} + for caller_qn, callees in edges.items(): + for callee_simple in callees: + callers_of.setdefault(callee_simple, set()).add(caller_qn) + + direct_callers = set() + for target in targets: + direct_callers |= callers_of.get(target, set()) + + has_direct_call = bool(direct_callers) + + by_simple = call_graph["by_simple_name"] + qn_to_simple = {qn: meta["simple_name"] for qn, meta in call_graph["nodes"].items()} + + reachable = set(direct_callers) + frontier = list(direct_callers) + + while frontier: + current_qn = frontier.pop() + current_simple = qn_to_simple.get(current_qn) + if not current_simple: + continue + for parent_qn in callers_of.get(current_simple, ()): + if parent_qn not in reachable: + reachable.add(parent_qn) + frontier.append(parent_qn) + + return reachable, has_direct_call + + +def build_resource_index(resource_text, language): + resource_text = normalize_text(resource_text) + + if not is_supported_language(language) or not resource_text: + return None + + tree, _ = parse_code_to_ast(resource_text, language) + + if tree is None: + return None + + call_graph = build_call_graph(tree, language) + + meta = ( + call_graph["nodes"] + if call_graph + else build_symbol_metadata( + extract_definitions(tree, language), + language, + ) + ) + + return { + "definitions": {metadata["qualified_name"] for metadata in meta.values()}, + "fingerprints": { + metadata["fingerprint"] + for metadata in meta.values() + if metadata["fingerprint"] + }, + "call_graph": call_graph, + } + + +def match_symbols_against_resource(symbols, resource_index): + if not symbols or not resource_index: + return {} + + call_graph = resource_index.get("call_graph") + target_simple_names = {metadata["simple_name"] for metadata in symbols.values()} + + reachable_callers, _has_direct = compute_reachable_symbols( + call_graph, + target_simple_names, + ) + + called_simple_names = set() + if call_graph: + for callees in call_graph["edges"].values(): + called_simple_names |= callees + + matched = {} + for metadata in symbols.values(): + qualified_name = metadata["qualified_name"] + simple_name = metadata["simple_name"] + fingerprint = metadata["fingerprint"] + + defined = qualified_name in resource_index["definitions"] + fingerprint_hit = bool( + fingerprint and fingerprint in resource_index["fingerprints"] + ) + called = simple_name in called_simple_names + + if not (defined or fingerprint_hit or called): + continue + + entry = matched.setdefault( + qualified_name, + { + "defined": False, + "called": False, + "reachable_from": [], + }, + ) + + entry["defined"] = entry["defined"] or defined + entry["called"] = entry["called"] or called + + if fingerprint_hit: + entry["exact_match_fingerprint"] = fingerprint + + if called: + entry["reachable_from"] = sorted(reachable_callers) + + return matched + + +def classify_reachability(evidence): + if not evidence: + return ReachabilityStatus.NOT_REACHABLE + + SEVERITY_RANK = { + ReachabilityStatus.NOT_REACHABLE: 0, + ReachabilityStatus.POTENTIALLY_REACHABLE: 1, + ReachabilityStatus.REACHABLE: 2, + } + + highest_status = ReachabilityStatus.NOT_REACHABLE + + for item in evidence.values(): + is_called = bool(item.get("called")) + has_path = bool(item.get("reachable_from")) + is_exact = "exact_match_fingerprint" in item + is_defined = bool(item.get("defined")) + + if is_called or (has_path and is_exact): + return ReachabilityStatus.REACHABLE + + elif has_path or is_exact or is_defined: + current_item_status = ReachabilityStatus.POTENTIALLY_REACHABLE + + else: + current_item_status = ReachabilityStatus.NOT_REACHABLE + + if SEVERITY_RANK[current_item_status] > SEVERITY_RANK[highest_status]: + highest_status = current_item_status + return highest_status + + +def build_symbol_metadata(nodes, language, index=None): + if index is None and nodes: + index = collect_definitions(_root_of(nodes[0]), language) + + metadata = {} + for node in nodes: + if is_nested_function(node, language): + continue + + qualified_name = qualified_name_from_index(node, index) + if not qualified_name: + continue + + body_text = node.text.decode("utf-8", errors="replace") + fingerprints = create_file_fingerprints(content=body_text) or {} + + key = qualified_name + suffix = 1 + while key in metadata: + suffix += 1 + key = f"{qualified_name}#{suffix}" + + metadata[key] = { + "qualified_name": qualified_name, + "simple_name": qualified_name.rsplit(".", 1)[-1], + "text": body_text, + "fingerprint": fingerprints.get("halo1"), + "start_line": node.start_point[0] + 1, + "end_line": node.end_point[0] + 1, + "node_type": node.type, + } + return metadata + + +def build_call_graph(tree, language): + if tree is None or not is_supported_language(language): + return None + + index = collect_definitions(tree.root_node, language) + definition_nodes = [d["node"] for d in index.values()] + metadata = build_symbol_metadata(definition_nodes, language, index=index) + + qualified_name_to_node = {} + for node in definition_nodes: + qualified_name = qualified_name_from_index(node, index) + if qualified_name: + qualified_name_to_node.setdefault(qualified_name, node) + + edges = {} + by_simple_name = {} + for qualified_name, meta in metadata.items(): + canonical = meta["qualified_name"] + node = qualified_name_to_node.get(canonical) + if node is None: + continue + edges.setdefault(canonical, set()).update(extract_calls_in_node(node, language)) + by_simple_name.setdefault(meta["simple_name"], set()).add(canonical) + + return {"nodes": metadata, "edges": edges, "by_simple_name": by_simple_name} diff --git a/scanpipe/pipes/symbols.py b/scanpipe/pipes/symbols.py index 76493d8dac..c6247dbc44 100644 --- a/scanpipe/pipes/symbols.py +++ b/scanpipe/pipes/symbols.py @@ -20,8 +20,20 @@ # ScanCode.io is a free software code scanning tool from nexB Inc. and others. # Visit https://github.com/aboutcode-org/scancode.io for support and download. +import importlib +from functools import cache + from django.db.models import Q +from source_inspector import symbols_ctags +from source_inspector import symbols_pygments +from source_inspector import symbols_tree_sitter +from source_inspector.symbols_tree_sitter import TS_LANGUAGE_WHEELS +from source_inspector.symbols_tree_sitter import TreeSitterWheelNotInstalled +from tree_sitter import Language +from tree_sitter import Parser +from tree_sitter import Query + from aboutcode.pipeline import LoopProgress @@ -171,3 +183,146 @@ def _collect_and_store_tree_sitter_symbols_and_strings(resource): "source_strings": result.get("source_strings"), } ) + + +SYMBOLS_TYPE_SUPPORTED = { + "ctags": symbols_ctags.get_symbols, + "tree_sitter": symbols_tree_sitter.get_treesitter_symbols, + "pygments": symbols_pygments.get_pygments_symbols, +} + +# https://github.com/Aider-AI/aider/tree/5dc9490bb35f9729ef2c95d00a19ccd30c26339c/aider/queries/tree-sitter-language-pack +TS_QUERIES = { + "Python": { + "functions": """ + (function_definition name: (identifier) @name) @function + """, + "classes": """ + (class_definition name: (identifier) @name) @class + """, + "calls": """ + (call function: (identifier) @callee) + (call function: (attribute attribute: (identifier) @callee)) + """, + }, +} + +@cache +def load_language(language: str) -> Language: + if language not in TS_LANGUAGE_WHEELS: + raise ValueError(f"Unsupported language: {language}") + + wheel = TS_LANGUAGE_WHEELS[language]["wheel"] + try: + grammar = importlib.import_module(wheel) + except ModuleNotFoundError as exc: + raise TreeSitterWheelNotInstalled( + f"Grammar wheel '{wheel}' is not installed." + ) from exc + return Language(grammar.language()) + + +@cache +def get_query(language: str, kind: str) -> Query | None: + source = TS_QUERIES.get(language, {}).get(kind, "").strip() + if not source: + return None + return Query(load_language(language), source) + + +def parse_code_to_ast(code_text: str, language: str): + if not code_text or not language or language not in TS_LANGUAGE_WHEELS: + return None, None + + ts_language = load_language(language) + parser = Parser(language=ts_language) + return parser.parse(code_text.encode("utf-8")), TS_LANGUAGE_WHEELS[language] + + +def run_query(query: Query, root_node): + """Yield ``(definition_node, name)`` pairs for function/class queries.""" + if query is None: + return + + for _pattern_index, captures in query.matches(root_node): + def_nodes = captures.get("function") or captures.get("class") or [] + if not def_nodes: + continue + + name_nodes = captures.get("name") or [] + name = ( + name_nodes[0].text.decode("utf-8", errors="replace") if name_nodes else None + ) + yield def_nodes[0], name + + +def extract_calls_in_node(node, language: str) -> set[str]: + query = get_query(language, "calls") + if query is None or node is None: + return set() + + names = set() + for _pattern_index, captures in query.matches(node): + for callee_node in captures.get("callee", []): + name = callee_node.text.decode("utf-8", errors="replace") + if name: + names.add(name) + return names + + +def collect_definitions(root_node, language: str) -> dict[int, dict]: + index: dict[int, dict] = {} + for kind in ("functions", "classes"): + query = get_query(language, kind) + for node, name in run_query(query, root_node): + index[node.id] = {"node": node, "name": name, "kind": kind} + return index + + +def extract_definitions(tree, language: str, kinds=("functions", "classes")): + if tree is None: + return [] + index = collect_definitions(tree.root_node, language) + return [d["node"] for d in index.values() if d["kind"] in kinds] + + +def extract_symbols(tree, changed_lines: list[int], language: str): + if tree is None or not changed_lines: + return [] + + definition_ids = set(collect_definitions(tree.root_node, language).keys()) + if not definition_ids: + return [] + + seen = set() + enclosing = [] + + for line in changed_lines: + row = max(0, line - 1) + node = tree.root_node.descendant_for_point_range((row, 0), (row, 0)) + + while node is not None: + if node.id in definition_ids and node.id not in seen: + seen.add(node.id) + enclosing.append(node) + break + node = node.parent + + return enclosing + + +def _root_of(node): + while node.parent is not None: + node = node.parent + return node + + +def qualified_name_from_index(node, index: dict[int, dict]) -> str: + parts = [] + curr = node + while curr is not None: + definition = index.get(curr.id) + if definition is not None and definition["name"]: + parts.append(definition["name"]) + curr = curr.parent + return ".".join(reversed(parts)) diff --git a/scanpipe/tests/data/reachability/app.py b/scanpipe/tests/data/reachability/app.py new file mode 100644 index 0000000000..c64ae7d9d1 --- /dev/null +++ b/scanpipe/tests/data/reachability/app.py @@ -0,0 +1,35 @@ +import os + + +class ReportGenerator: + """A dummy class to test AST class method parsing.""" + + def __init__(self, base_dir): + self.base_dir = base_dir + + +def serve_report(request_payload): + """Top-level function handling a request.""" + generator = ReportGenerator("/var/reports") + requested_file = request_payload.get("file") + + # Helper function nested inside serve_report + def build_file_path(filename): + # VULNERABLE: Direct concatenation allows Path Traversal + # An attacker passing "../../etc/passwd" could read system files. + return os.path.join(generator.base_dir, filename) + + if not requested_file: + return "Error: No file specified" + + target_path = build_file_path(requested_file) + + if os.path.exists(target_path): + return f"Serving content of {target_path}" + + return "Error: File not found" + + +def unrelated_top_level_function(): + """An extra function to test AST node boundaries.""" + return "I am just here to add AST complexity." diff --git a/scanpipe/tests/data/reachability/diff-app.patch b/scanpipe/tests/data/reachability/diff-app.patch new file mode 100644 index 0000000000..ccb86953a8 --- /dev/null +++ b/scanpipe/tests/data/reachability/diff-app.patch @@ -0,0 +1,39 @@ +From 8f7b1c3d9a4e2b6f5d8c1a2e3f4b5c6d7e8f9a0b Mon Sep 17 00:00:00 2001 +From: Security Team +Date: Tue, 2 Jun 2026 10:00:00 +0000 +Subject: [PATCH] Fix path traversal vulnerability in report generator + +- Validates that target paths stay within the designated base_dir. +- Catches ValueError on invalid path resolution. +--- + app.py | 12 +++++++++--- + 1 file changed, 9 insertions(+), 3 deletions(-) + +diff --git a/app.py b/app.py +index a1b2c3d..e4f5g6h 100644 +--- a/app.py ++++ b/app.py +@@ -15,13 +15,19 @@ def serve_report(request_payload): + # Helper function nested inside serve_report + def build_file_path(filename): +- # VULNERABLE: Direct concatenation allows Path Traversal +- # An attacker passing "../../etc/passwd" could read system files. +- return os.path.join(generator.base_dir, filename) ++ # FIXED: Validate that the resolved path stays within the base_dir ++ base = os.path.abspath(generator.base_dir) ++ target = os.path.abspath(os.path.join(base, filename)) ++ if not target.startswith(base): ++ raise ValueError("Path Traversal Detected") ++ return target + + if not requested_file: + return "Error: No file specified" + +- target_path = build_file_path(requested_file) ++ try: ++ target_path = build_file_path(requested_file) ++ except ValueError: ++ return "Error: Invalid path" + + if os.path.exists(target_path): + return f"Serving content of {target_path}" \ No newline at end of file diff --git a/scanpipe/tests/data/reachability/fixed-app.py b/scanpipe/tests/data/reachability/fixed-app.py new file mode 100644 index 0000000000..3296bb843e --- /dev/null +++ b/scanpipe/tests/data/reachability/fixed-app.py @@ -0,0 +1,41 @@ +import os + + +class ReportGenerator: + """A dummy class to test AST class method parsing.""" + + def __init__(self, base_dir): + self.base_dir = base_dir + + +def serve_report(request_payload): + """Top-level function handling a request.""" + generator = ReportGenerator("/var/reports") + requested_file = request_payload.get("file") + + # Helper function nested inside serve_report + def build_file_path(filename): + # FIXED: Validate that the resolved path stays within the base_dir + base = os.path.abspath(generator.base_dir) + target = os.path.abspath(os.path.join(base, filename)) + if not target.startswith(base): + raise ValueError("Path Traversal Detected") + return target + + if not requested_file: + return "Error: No file specified" + + try: + target_path = build_file_path(requested_file) + except ValueError: + return "Error: Invalid path" + + if os.path.exists(target_path): + return f"Serving content of {target_path}" + + return "Error: File not found" + + +def unrelated_top_level_function(): + """An extra function to test AST node boundaries.""" + return "I am just here to add AST complexity." diff --git a/scanpipe/tests/data/reachability/vuln-app.py b/scanpipe/tests/data/reachability/vuln-app.py new file mode 100644 index 0000000000..c64ae7d9d1 --- /dev/null +++ b/scanpipe/tests/data/reachability/vuln-app.py @@ -0,0 +1,35 @@ +import os + + +class ReportGenerator: + """A dummy class to test AST class method parsing.""" + + def __init__(self, base_dir): + self.base_dir = base_dir + + +def serve_report(request_payload): + """Top-level function handling a request.""" + generator = ReportGenerator("/var/reports") + requested_file = request_payload.get("file") + + # Helper function nested inside serve_report + def build_file_path(filename): + # VULNERABLE: Direct concatenation allows Path Traversal + # An attacker passing "../../etc/passwd" could read system files. + return os.path.join(generator.base_dir, filename) + + if not requested_file: + return "Error: No file specified" + + target_path = build_file_path(requested_file) + + if os.path.exists(target_path): + return f"Serving content of {target_path}" + + return "Error: File not found" + + +def unrelated_top_level_function(): + """An extra function to test AST node boundaries.""" + return "I am just here to add AST complexity." diff --git a/scanpipe/tests/pipes/test_symbols_reachability.py b/scanpipe/tests/pipes/test_symbols_reachability.py new file mode 100644 index 0000000000..30d4c924dd --- /dev/null +++ b/scanpipe/tests/pipes/test_symbols_reachability.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# http://nexb.com and https://github.com/nexB/scancode.io +# The ScanCode.io software is licensed under the Apache License version 2.0. +# Data generated with ScanCode.io is provided as-is without warranties. +# ScanCode is a trademark of nexB Inc. +# +# You may not use this software except in compliance with the License. +# You may obtain a copy of the License at: http://apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# +# Data Generated with ScanCode.io is provided on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. No content created from +# ScanCode.io should be considered or used as legal advice. Consult an Attorney +# for any legal advice. +# +# ScanCode.io is a free software code scanning tool from nexB Inc. and others. +# Visit https://github.com/nexB/scancode.io for support and download. + +from pathlib import Path +from unittest.mock import patch + +from django.test import TestCase + +from scanpipe.models import Project +from scanpipe.pipes import collect_and_create_codebase_resources +from scanpipe.pipes.reachability import ReachabilityStatus +from scanpipe.pipes.reachability import analyze_patched_file +from scanpipe.pipes.reachability import build_call_graph +from scanpipe.pipes.reachability import classify_reachability +from scanpipe.pipes.reachability import collect_and_store_symbol_reachability_results +from scanpipe.pipes.symbols import collect_definitions +from scanpipe.pipes.symbols import extract_definitions +from scanpipe.pipes.symbols import parse_code_to_ast +from scanpipe.pipes.symbols import qualified_name_from_index + + +class SymbolReachabilityPipesTest(TestCase): + data = Path(__file__).parent.parent / "data" / "reachability" + + def setUp(self): + self.project1 = Project.objects.create(name="Analysis") + self.project1.codebase_path.mkdir(parents=True, exist_ok=True) + + @patch("scanpipe.pipes.reachability.Repo") + @patch("scanpipe.pipes.reachability.clone_repo") + @patch("scanpipe.pipes.reachability.api_mocker") + @patch("scanpipe.pipes.reachability.collect_patch_symbols") + def test_collect_and_store_symbol_reachability_results( + self, mock_collect_symbols, mock_api, mock_clone_repo, mock_repo + ): + app_text = (self.data / "app.py").read_text() + vuln_text = (self.data / "vuln-app.py").read_text() + fixed_text = (self.data / "fixed-app.py").read_text() + diff_text = (self.data / "diff-app.patch").read_text() + + vuln_meta, fixed_meta, lang = analyze_patched_file( + vulnerable_text=vuln_text, + fixed_text=fixed_text, + diff_text=diff_text, + file_path="app.py", + ) + + self.assertTrue(lang) + self.assertTrue(vuln_meta or fixed_meta) + mock_api.return_value = [ + { + "vcs_url": "https://github.com/aboutcode-org/test", + "commit_hash": "07ec0de1964b14bf085a1c9a27ece2b61ab6105c", + } + ] + + mock_clone_repo.return_value = str(self.project1.codebase_path) + mock_collect_symbols.return_value = { + lang: { + "vulnerable": { + f"app.py::{key}": metadata for key, metadata in vuln_meta.items() + }, + "fixed": { + f"app.py::{key}": metadata for key, metadata in fixed_meta.items() + }, + } + } + + resource_file = self.project1.codebase_path / "app.py" + resource_file.write_text(app_text) + collect_and_create_codebase_resources(self.project1) + + resource = self.project1.codebaseresources.get(path="app.py") + resource.programming_language = lang + resource.save() + + collect_and_store_symbol_reachability_results(self.project1) + + resource.refresh_from_db() + results = resource.extra_data.get("symbols_reachability") + + assert results == [ + { + "patch": { + "vcs_url": "https://github.com/aboutcode-org/test", + "commit_hash": "07ec0de1964b14bf085a1c9a27ece2b61ab6105c", + }, + "summary": { + "call_paths": {}, + "fixed_symbols": ["serve_report"], + "vulnerable_symbols": ["serve_report"], + }, + "evidence": { + "serve_report": { + "called": False, + "defined": True, + "reachable_from": [], + "exact_match_fingerprint": "000000556d322a47595af353274b000aa324e014", + } + }, + "reachability_status": "POTENTIALLY_REACHABLE", + } + ] + + def test_build_call_graph(self): + source_code = """ +def calculate_total(price, tax): + return price + get_tax_amount(price, tax) + +def get_tax_amount(price, tax): + return price * tax + +def process_order(): + total = calculate_total(100, 0.05) + print("Done") +""" + tree, _ = parse_code_to_ast(source_code, "Python") + result = build_call_graph(tree, "Python") + + assert result == { + "nodes": { + "calculate_total": { + "qualified_name": "calculate_total", + "simple_name": "calculate_total", + "text": "def calculate_total(price, tax):\n return price + get_tax_amount(price, tax)", + "fingerprint": "00000008060105fd3624134884412006ce880936", + "start_line": 2, + "end_line": 3, + "node_type": "function_definition", + }, + "get_tax_amount": { + "qualified_name": "get_tax_amount", + "simple_name": "get_tax_amount", + "text": "def get_tax_amount(price, tax):\n return price * tax", + "fingerprint": "000000058f0ee87d9669f20b1f473137b665bb20", + "start_line": 5, + "end_line": 6, + "node_type": "function_definition", + }, + "process_order": { + "qualified_name": "process_order", + "simple_name": "process_order", + "text": 'def process_order():\n total = calculate_total(100, 0.05)\n print("Done")', + "fingerprint": "000000071c3e6902da5c2b322386eff29068e3e2", + "start_line": 8, + "end_line": 10, + "node_type": "function_definition", + }, + }, + "edges": { + "calculate_total": {"get_tax_amount"}, + "get_tax_amount": set(), + "process_order": {"print", "calculate_total"}, + }, + "by_simple_name": { + "calculate_total": {"calculate_total"}, + "get_tax_amount": {"get_tax_amount"}, + "process_order": {"process_order"}, + }, + } + + def test_extract_definitions(self): + source_code = """ +class OrderManager: + def __init__(self, order_id): + self.order_id = order_id + + def process_payment(self): + print("Processing...") + +def calculate_discount(price): + return price * 0.10 + +class InventoryItem: + pass +""" + tree, _ = parse_code_to_ast(source_code, "Python") + functions = extract_definitions(tree, "Python", kinds=("functions",)) + assert ( + len(functions) == 3 + ) # '__init__', 'process_payment', and 'calculate_discount' + + assert functions[0].type == "function_definition" + first_func_text = functions[0].text.decode("utf-8") + assert "def __init__" in first_func_text + + classes = extract_definitions(tree, "Python", kinds=("classes",)) + assert len(classes) == 2 # OrderManager, InventoryItem + second_class_text = classes[1].text.decode("utf-8") + assert "class InventoryItem" in second_class_text + + def test_extract_definitions_empty(self): + tree, _ = parse_code_to_ast("", "Python") + assert extract_definitions(tree, "Python", kinds=("functions",)) == [] + assert extract_definitions(tree, "Python", kinds=("functions",)) == [] + assert extract_definitions(None, "Python", kinds=("classes",)) == [] + assert extract_definitions(None, "Python", kinds=("classes",)) == [] + + def test_get_qualified_name_functions(self): + source_code = """ +class CoreService: + class Validator: + def validate_payload(self, data): + return True + +def global_utility(): + pass + """ + + tree, _ = parse_code_to_ast(source_code, "Python") + index = collect_definitions(tree.root_node, "Python") + + functions = extract_definitions(tree, "Python", kinds=("functions",)) + assert len(functions) == 2 + + outer_function_name = qualified_name_from_index(functions[0], index) + inner_function_name = qualified_name_from_index(functions[1], index) + + assert outer_function_name == "CoreService.Validator.validate_payload" + assert inner_function_name == "global_utility" + + def test_get_qualified_classes(self): + source_code = """ +class FleetManagement: + class DroneController: + pass + """ + tree, _ = parse_code_to_ast(source_code, "Python") + index = collect_definitions(tree.root_node, "Python") + + classes = extract_definitions(tree, "Python", kinds=("classes",)) + assert len(classes) == 2 + + outer_class_name = qualified_name_from_index(classes[0], index) + inner_class_name = qualified_name_from_index(classes[1], index) + + assert outer_class_name == "FleetManagement" + assert inner_class_name == "FleetManagement.DroneController" + + def test_classify_reachability(self): + assert classify_reachability(None) == ReachabilityStatus.NOT_REACHABLE + assert classify_reachability({}) == ReachabilityStatus.NOT_REACHABLE + assert ( + classify_reachability( + {"sym1": {"exact_match_fingerprint": "hash123", "called": True}} + ) + == ReachabilityStatus.REACHABLE + ) + + assert ( + classify_reachability( + { + "sym1": { + "called": True, + "reachable_from": ["main_function", "api_handler"], + } + } + ) + == ReachabilityStatus.REACHABLE + ) + assert ( + classify_reachability({"sym1": {"defined": True, "called": False}}) + == ReachabilityStatus.POTENTIALLY_REACHABLE + ) + assert ( + classify_reachability( + {"sym1": {"exact_match_fingerprint": "hash123", "called": False}} + ) + == ReachabilityStatus.POTENTIALLY_REACHABLE + ) + assert ( + classify_reachability({"sym1": {"file_path": "src/vulnerable.py"}}) + == ReachabilityStatus.NOT_REACHABLE + ) From 59b8eb9b473eb23b264177193d44cdffe8dbab3b Mon Sep 17 00:00:00 2001 From: ziad hany Date: Fri, 5 Jun 2026 11:03:43 +0300 Subject: [PATCH 2/4] Add more test for reachability and remove redundant code Signed-off-by: ziad hany --- scanpipe/pipes/reachability.py | 67 ++----- scanpipe/pipes/symbols.py | 54 +++++- .../tests/pipes/test_symbols_reachability.py | 175 +++++++++++++++++- 3 files changed, 236 insertions(+), 60 deletions(-) diff --git a/scanpipe/pipes/reachability.py b/scanpipe/pipes/reachability.py index 5531a1f495..5169eba94d 100644 --- a/scanpipe/pipes/reachability.py +++ b/scanpipe/pipes/reachability.py @@ -39,6 +39,7 @@ from scanpipe.pipes.symbols import extract_calls_in_node from scanpipe.pipes.symbols import extract_definitions from scanpipe.pipes.symbols import extract_symbols +from scanpipe.pipes.symbols import is_nested_function from scanpipe.pipes.symbols import parse_code_to_ast from scanpipe.pipes.symbols import qualified_name_from_index @@ -219,52 +220,6 @@ def get_changed_lines(diff_text, file_path): return removed, added -def query_captures(language, kind, node): - """ - Re-run a definition query on the root of `node`'s tree so ancestors can - be compared. Query caching is handled by `scanpipe.pipes.symbols`. - """ - from scanpipe.pipes.symbols import get_query - from scanpipe.pipes.symbols import run_query - - root = node - - while root.parent is not None: - root = root.parent - - query = get_query(language, kind) - return list(run_query(query, root)) - - -def is_nested_function(node, language): - function_nodes = { - captured_node - for captured_node, _ in query_captures(language, "functions", node) - } - class_nodes = { - captured_node for captured_node, _ in query_captures(language, "classes", node) - } - - if node not in function_nodes: - return False - - function_types = {captured_node.type for captured_node in function_nodes} - class_types = {captured_node.type for captured_node in class_nodes} - - parent = node.parent - - while parent is not None: - if parent.type in function_types: - return True - - if parent.type in class_types: - return False - - parent = parent.parent - - return False - - def diff_changed_symbols(vuln_meta, fixed_meta): """ Keep only symbols whose body actually differs between vulnerable and fixed @@ -506,7 +461,7 @@ def collect_and_store_symbol_reachability_results(project, logger=None): append_symbol_reachability_result(resource, result) except Exception as e: - logger.exception( + logger( "Failed to collect symbol reachability for " f"{vcs_url}@{commit_hash}: {e}" ) @@ -515,24 +470,36 @@ def collect_and_store_symbol_reachability_results(project, logger=None): def compute_reachable_symbols(call_graph, target_simple_names): + """ + Find all symbols that can transitively reach any of ``target_simple_names``. + + Reachability is matched on *simple* names (the call graph records callee + tokens, not fully-qualified names), so distinct symbols sharing a name are + treated as equivalent. This can over-approximate reachability. + + Returns: + (reachable_callers, has_direct_call) + reachable_callers: qualified names of all transitive callers + has_direct_call: whether any symbol calls a target directly + + """ if not call_graph or not target_simple_names: return set(), False edges = call_graph["edges"] targets = set(target_simple_names) - callers_of = {} + callers_of: dict[str, set[str]] = {} for caller_qn, callees in edges.items(): for callee_simple in callees: callers_of.setdefault(callee_simple, set()).add(caller_qn) - direct_callers = set() + direct_callers: set[str] = set() for target in targets: direct_callers |= callers_of.get(target, set()) has_direct_call = bool(direct_callers) - by_simple = call_graph["by_simple_name"] qn_to_simple = {qn: meta["simple_name"] for qn, meta in call_graph["nodes"].items()} reachable = set(direct_callers) diff --git a/scanpipe/pipes/symbols.py b/scanpipe/pipes/symbols.py index c6247dbc44..cac76562e3 100644 --- a/scanpipe/pipes/symbols.py +++ b/scanpipe/pipes/symbols.py @@ -207,6 +207,7 @@ def _collect_and_store_tree_sitter_symbols_and_strings(resource): }, } + @cache def load_language(language: str) -> Language: if language not in TS_LANGUAGE_WHEELS: @@ -256,7 +257,48 @@ def run_query(query: Query, root_node): yield def_nodes[0], name -def extract_calls_in_node(node, language: str) -> set[str]: +def query_captures(language, kind, node): + """Re-run a definition query on the root of node's tree.""" + query = get_query(language, kind) + return list(run_query(query, _root_of(node))) + + +def _root_of(node): + while node.parent is not None: + node = node.parent + return node + + +def is_nested_function(node, language): + function_nodes = { + captured_node + for captured_node, _ in query_captures(language, "functions", node) + } + class_nodes = { + captured_node for captured_node, _ in query_captures(language, "classes", node) + } + + if node not in function_nodes: + return False + + function_types = {captured_node.type for captured_node in function_nodes} + class_types = {captured_node.type for captured_node in class_nodes} + + parent = node.parent + + while parent is not None: + if parent.type in function_types: + return True + + if parent.type in class_types: + return False + + parent = parent.parent + + return False + + +def extract_calls_in_node(node, language: str): query = get_query(language, "calls") if query is None or node is None: return set() @@ -270,7 +312,7 @@ def extract_calls_in_node(node, language: str) -> set[str]: return names -def collect_definitions(root_node, language: str) -> dict[int, dict]: +def collect_definitions(root_node, language: str): index: dict[int, dict] = {} for kind in ("functions", "classes"): query = get_query(language, kind) @@ -311,13 +353,7 @@ def extract_symbols(tree, changed_lines: list[int], language: str): return enclosing -def _root_of(node): - while node.parent is not None: - node = node.parent - return node - - -def qualified_name_from_index(node, index: dict[int, dict]) -> str: +def qualified_name_from_index(node, index): parts = [] curr = node while curr is not None: diff --git a/scanpipe/tests/pipes/test_symbols_reachability.py b/scanpipe/tests/pipes/test_symbols_reachability.py index 30d4c924dd..1a5fd7b3ad 100644 --- a/scanpipe/tests/pipes/test_symbols_reachability.py +++ b/scanpipe/tests/pipes/test_symbols_reachability.py @@ -30,9 +30,12 @@ from scanpipe.pipes.reachability import ReachabilityStatus from scanpipe.pipes.reachability import analyze_patched_file from scanpipe.pipes.reachability import build_call_graph +from scanpipe.pipes.reachability import build_symbol_metadata from scanpipe.pipes.reachability import classify_reachability from scanpipe.pipes.reachability import collect_and_store_symbol_reachability_results -from scanpipe.pipes.symbols import collect_definitions +from scanpipe.pipes.reachability import diff_changed_symbols +from scanpipe.pipes.reachability import get_changed_lines +from scanpipe.pipes.symbols import collect_definitions, extract_symbols from scanpipe.pipes.symbols import extract_definitions from scanpipe.pipes.symbols import parse_code_to_ast from scanpipe.pipes.symbols import qualified_name_from_index @@ -291,3 +294,173 @@ def test_classify_reachability(self): classify_reachability({"sym1": {"file_path": "src/vulnerable.py"}}) == ReachabilityStatus.NOT_REACHABLE ) + + def test_get_changed_lines(self): + data = Path(__file__).parent.parent / "data" / "reachability" + diff_text = (data / "diff-app.patch").read_text(encoding="utf-8") + + removed, added = get_changed_lines(diff_text, "app.py") + assert removed == [17, 18, 19, 24] + assert added == [17, 18, 19, 20, 21, 22, 27, 28, 29, 30] + + def test_build_symbol_metadata_processing(self): + source_code = """ +class Controller: + def process_data(payload): + def inner_helper(): + return True + return payload.strip() + +if True: + def process_data(payload): + return payload +""" + tree, _ = parse_code_to_ast(source_code, "Python") + nodes = extract_definitions(tree, "Python", kinds=("functions",)) + + metadata = build_symbol_metadata(nodes, "Python") + assert metadata == { + "Controller.process_data": { + "qualified_name": "Controller.process_data", + "simple_name": "process_data", + "text": "def process_data(payload):\n def inner_helper():\n return True\n return payload.strip()", + "fingerprint": "0000000888014a04b037189a42b238a2c50f218c", + "start_line": 3, + "end_line": 6, + "node_type": "function_definition", + }, + "process_data": { + "qualified_name": "process_data", + "simple_name": "process_data", + "text": "def process_data(payload):\n return payload", + "fingerprint": "000000022020300e882a900807880d0300010000", + "start_line": 9, + "end_line": 10, + "node_type": "function_definition", + }, + } + + def test_diff_changed_symbols(self): + vuln_meta = { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n return os.path.join(base, filename)", + }, + "sanitize_input": { + "qualified_name": "app.sanitize_input", + "text": "def sanitize_input(x):\n return x.strip()", + }, + "deprecated_logger": { + "qualified_name": "app.deprecated_logger", + "text": "def deprecated_logger():\n print('legacy')", + }, + } + + fixed_meta = { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n if not target.startswith(base): raise ValueError\n return target", + }, + "sanitize_input": { + "qualified_name": "app.sanitize_input", + "text": "def sanitize_input(x):\n return x.strip()", + }, + "audit_trail": { + "qualified_name": "app.audit_trail", + "text": "def audit_trail():\n log.info('action')", + }, + } + + vuln_only, fixed_only = diff_changed_symbols(vuln_meta, fixed_meta) + + assert vuln_only == { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n return os.path.join(base, filename)", + }, + "deprecated_logger": { + "qualified_name": "app.deprecated_logger", + "text": "def deprecated_logger():\n print('legacy')", + }, + } + assert fixed_only == { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n if not target.startswith(base): raise ValueError\n return target", + }, + "audit_trail": { + "qualified_name": "app.audit_trail", + "text": "def audit_trail():\n log.info('action')", + }, + } + + def test_analyze_patched_file(self): + vuln_text = (self.data / "vuln-app.py").read_text(encoding="utf-8") + fixed_text = (self.data / "fixed-app.py").read_text(encoding="utf-8") + diff_text = (self.data / "diff-app.patch").read_text(encoding="utf-8") + + vuln_meta, fixed_meta, lang = analyze_patched_file( + vulnerable_text=vuln_text, + fixed_text=fixed_text, + diff_text=diff_text, + file_path="app.py", + ) + + assert vuln_meta == { + "serve_report": { + "qualified_name": "serve_report", + "simple_name": "serve_report", + "text": 'def serve_report(request_payload):\n """Top-level function handling a request."""\n generator = ReportGenerator("/var/reports")\n requested_file = request_payload.get("file")\n\n # Helper function nested inside serve_report\n def build_file_path(filename):\n # VULNERABLE: Direct concatenation allows Path Traversal\n # An attacker passing "../../etc/passwd" could read system files.\n return os.path.join(generator.base_dir, filename)\n\n if not requested_file:\n return "Error: No file specified"\n\n target_path = build_file_path(requested_file)\n\n if os.path.exists(target_path):\n return f"Serving content of {target_path}"\n\n return "Error: File not found"', + "fingerprint": "000000556d322a47595af353274b000aa324e014", + "start_line": 11, + "end_line": 30, + "node_type": "function_definition", + } + } + assert fixed_meta == { + "serve_report": { + "qualified_name": "serve_report", + "simple_name": "serve_report", + "text": 'def serve_report(request_payload):\n """Top-level function handling a request."""\n generator = ReportGenerator("/var/reports")\n requested_file = request_payload.get("file")\n\n # Helper function nested inside serve_report\n def build_file_path(filename):\n # FIXED: Validate that the resolved path stays within the base_dir\n base = os.path.abspath(generator.base_dir)\n target = os.path.abspath(os.path.join(base, filename))\n if not target.startswith(base):\n raise ValueError("Path Traversal Detected")\n return target\n\n if not requested_file:\n return "Error: No file specified"\n\n try:\n target_path = build_file_path(requested_file)\n except ValueError:\n return "Error: Invalid path"\n\n if os.path.exists(target_path):\n return f"Serving content of {target_path}"\n\n return "Error: File not found"', + "fingerprint": "0000006cceea8aedf1da91830f67b64927086d24", + "start_line": 11, + "end_line": 36, + "node_type": "function_definition", + } + } + + def test_extract_symbols(self): + source_code = ( + "def serve_report(request):\n" # Line 1 (Row 0) + " # Some processing here\n" # Line 2 (Row 1) + " def build_path(filename):\n" # Line 3 (Row 2) + " return filename.strip()\n" # Line 4 (Row 3) <- Targeted Change + " return build_path(request)\n" # Line 5 (Row 4) + ) + + tree, _ = parse_code_to_ast(source_code, "Python") + + changed_lines = [4] + enclosing_symbols = extract_symbols(tree, changed_lines, "Python") + + assert len(enclosing_symbols) == 1 + target_node = enclosing_symbols[0] + assert target_node.type == "function_definition" + + node_text = target_node.text.decode("utf-8") + assert "def build_path" in node_text + assert "def serve_report" not in node_text + + def test_extract_symbols_deduplication(self): + source_code = ( + "def calculate_total(price, tax):\n" + " amount = price * tax\n" # Line 2 -> Changed + " return price + amount\n" # Line 3 -> Changed + ) + + tree, _ = parse_code_to_ast(source_code, "Python") + changed_lines = [2, 3] + + enclosing_symbols = extract_symbols(tree, changed_lines, "Python") + assert len(enclosing_symbols) == 1 + assert enclosing_symbols[0].type == "function_definition" \ No newline at end of file From bfe703c9e824816e9fbcccad40741be77662b3e0 Mon Sep 17 00:00:00 2001 From: ziad hany Date: Wed, 10 Jun 2026 15:08:45 +0300 Subject: [PATCH 3/4] Fix the format bugs and refactor the code Signed-off-by: ziad hany --- .../pipelines/collect_symbols_reachability.py | 8 +- scanpipe/pipes/reachability.py | 394 +++++++++++------- scanpipe/pipes/symbols.py | 14 +- scanpipe/tests/data/reachability/app.py | 2 +- scanpipe/tests/data/reachability/fixed-app.py | 2 +- scanpipe/tests/data/reachability/vuln-app.py | 2 +- .../tests/pipes/test_symbols_reachability.py | 381 +++++++++-------- 7 files changed, 473 insertions(+), 330 deletions(-) diff --git a/scanpipe/pipelines/collect_symbols_reachability.py b/scanpipe/pipelines/collect_symbols_reachability.py index 15519fc661..c1d5fb11c4 100644 --- a/scanpipe/pipelines/collect_symbols_reachability.py +++ b/scanpipe/pipelines/collect_symbols_reachability.py @@ -12,9 +12,7 @@ class SymbolReachability(Pipeline): - """ - Patch reachability analysis, for given a vulnerability patches - """ + """Patch reachability analysis for given vulnerability patches.""" download_inputs = False is_addon = True @@ -26,8 +24,8 @@ def steps(cls): def analyze_and_store_symbol_reachability(self): """ - Perform symbol-level reachability analysis for each patch. - This step compares the AST of patched/vulnerable files against the codebase resources. + Perform symbol-level reachability analysis for each patch. This step compares + the AST of patched/vulnerable files against the codebase resources. Results are stored directly in the 'extra_data' of each CodebaseResource. """ reachability.collect_and_store_symbol_reachability_results( diff --git a/scanpipe/pipes/reachability.py b/scanpipe/pipes/reachability.py index 5169eba94d..e9ebd97378 100644 --- a/scanpipe/pipes/reachability.py +++ b/scanpipe/pipes/reachability.py @@ -29,16 +29,16 @@ from git import Repo from git.diff import NULL_TREE from git.exc import BadName -from matchcode_toolkit.fingerprinting import create_file_fingerprints from scancode.api import get_file_info from unidiff import PatchSet from scanpipe.pipes.symbols import TS_QUERIES from scanpipe.pipes.symbols import _root_of from scanpipe.pipes.symbols import collect_definitions -from scanpipe.pipes.symbols import extract_calls_in_node +from scanpipe.pipes.symbols import create_exact_symbol_fingerprint from scanpipe.pipes.symbols import extract_definitions from scanpipe.pipes.symbols import extract_symbols +from scanpipe.pipes.symbols import get_query from scanpipe.pipes.symbols import is_nested_function from scanpipe.pipes.symbols import parse_code_to_ast from scanpipe.pipes.symbols import qualified_name_from_index @@ -53,14 +53,16 @@ class ReachabilityStatus(str, Enum): def api_mocker(): - """ - TODO: Remove this once the API patch url is done - """ + """TODO: Remove this once the API patch url is done""" return [ { "vcs_url": "https://github.com/pallets/flask", "commit_hash": "089cb86dd22bff589a4eafb7ab8e42dc357623b4", }, + # { + # "vcs_url": "https://github.com/aio-libs/aiohttp", + # "commit_hash": "0c2e9da51126238a421568eb7c5b53e5b5d17b36", + # } ] @@ -100,7 +102,7 @@ def normalize_text(content): def is_supported_language(language): - """A language is supported if we have tree-sitter queries for it.""" + """Return True if the language is supported by tree-sitter queries.""" return bool(language) and language in TS_QUERIES @@ -223,28 +225,18 @@ def get_changed_lines(diff_text, file_path): def diff_changed_symbols(vuln_meta, fixed_meta): """ Keep only symbols whose body actually differs between vulnerable and fixed - versions. Pair by qualified name first. + versions. Utilizes the unique suffix keys generated by build_symbol_metadata. """ - fixed_by_qn = { - metadata["qualified_name"]: metadata for metadata in fixed_meta.values() - } - - vuln_by_qn = { - metadata["qualified_name"]: metadata for metadata in vuln_meta.values() - } - vuln_only = { key: metadata for key, metadata in vuln_meta.items() - if fixed_by_qn.get(metadata["qualified_name"], {}).get("text") - != metadata["text"] + if fixed_meta.get(key, {}).get("text") != metadata["text"] } fixed_only = { key: metadata for key, metadata in fixed_meta.items() - if vuln_by_qn.get(metadata["qualified_name"], {}).get("text") - != metadata["text"] + if vuln_meta.get(key, {}).get("text") != metadata["text"] } return vuln_only, fixed_only @@ -313,9 +305,6 @@ def collect_patch_symbols(repo, commit_hash): ... } - Symbols are bucketed by language so resources are only matched against - patch symbols extracted from the same language. - """ commit, parent = get_commit_and_parent(repo, commit_hash) diff_text = get_commit_diff_text(repo, parent, commit) @@ -379,11 +368,6 @@ def collect_and_store_symbol_reachability_results(project, logger=None): For each known patch commit, determine whether each project codebase resource is reachable to the vulnerable code by comparing tree-sitter ASTs of the patch versus the resource. - - Result classification: - - REACHABLE - - POTENTIALLY_REACHABLE - - NOT_REACHABLE """ candidate_resources = project.codebaseresources.files().filter( is_binary=False, @@ -394,10 +378,10 @@ def collect_and_store_symbol_reachability_results(project, logger=None): for patch in api_mocker(): vcs_url = patch["vcs_url"] commit_hash = patch["commit_hash"] - repo_path = None try: - repo_path = clone_repo(vcs_url, commit_hash) - repo = Repo(repo_path) + # repo_path = clone_repo(vcs_url, commit_hash) + # repo = Repo("/home/ziad-hany/PycharmProjects/flask/") + repo = Repo("/home/ziad-hany/PycharmProjects/aiohttp") patch_symbols_by_language = collect_patch_symbols(repo, commit_hash) @@ -406,18 +390,16 @@ def collect_and_store_symbol_reachability_results(project, logger=None): for resource in candidate_resources: resource_language = resource.programming_language - if resource_language not in patch_symbols_by_language: continue - resource_text = normalize_text(resource.file_content) - + resource_text = resource.file_content if not resource_text: continue patch_symbols = patch_symbols_by_language[resource_language] - vuln_metadata = patch_symbols["vulnerable"] - fixed_metadata = patch_symbols["fixed"] + vuln_patch_metadata = patch_symbols["vulnerable"] + fixed_patch_metadata = patch_symbols["fixed"] resource_index = build_resource_index( resource_text, @@ -428,12 +410,12 @@ def collect_and_store_symbol_reachability_results(project, logger=None): continue vuln_match_symbols = match_symbols_against_resource( - vuln_metadata, + vuln_patch_metadata, resource_index, ) fixed_match_symbols = match_symbols_against_resource( - fixed_metadata, + fixed_patch_metadata, resource_index, ) @@ -443,84 +425,36 @@ def collect_and_store_symbol_reachability_results(project, logger=None): result = { "reachability_status": classify_reachability(vuln_match_symbols), "summary": { - "vulnerable_symbols": sorted(vuln_match_symbols), - "fixed_symbols": sorted(fixed_match_symbols), "call_paths": { - qn: ev.get("reachable_from", []) - for qn, ev in vuln_match_symbols.items() + qualified_name: ev.get("reachable_from", []) + for qualified_name, ev in vuln_match_symbols.items() if ev.get("called") }, }, "evidence": vuln_match_symbols, + "vulnerable_symbols": sorted(vuln_match_symbols), + "fixed_symbols": sorted(fixed_match_symbols), "patch": { "vcs_url": vcs_url, "commit_hash": commit_hash, }, } - + print(result) append_symbol_reachability_result(resource, result) except Exception as e: logger( - "Failed to collect symbol reachability for " + f"Failed to collect symbol reachability for " f"{vcs_url}@{commit_hash}: {e}" ) finally: - cleanup_repo(repo_path) - - -def compute_reachable_symbols(call_graph, target_simple_names): - """ - Find all symbols that can transitively reach any of ``target_simple_names``. - - Reachability is matched on *simple* names (the call graph records callee - tokens, not fully-qualified names), so distinct symbols sharing a name are - treated as equivalent. This can over-approximate reachability. - - Returns: - (reachable_callers, has_direct_call) - reachable_callers: qualified names of all transitive callers - has_direct_call: whether any symbol calls a target directly - - """ - if not call_graph or not target_simple_names: - return set(), False - - edges = call_graph["edges"] - targets = set(target_simple_names) - - callers_of: dict[str, set[str]] = {} - for caller_qn, callees in edges.items(): - for callee_simple in callees: - callers_of.setdefault(callee_simple, set()).add(caller_qn) - - direct_callers: set[str] = set() - for target in targets: - direct_callers |= callers_of.get(target, set()) - - has_direct_call = bool(direct_callers) - - qn_to_simple = {qn: meta["simple_name"] for qn, meta in call_graph["nodes"].items()} - - reachable = set(direct_callers) - frontier = list(direct_callers) - - while frontier: - current_qn = frontier.pop() - current_simple = qn_to_simple.get(current_qn) - if not current_simple: - continue - for parent_qn in callers_of.get(current_simple, ()): - if parent_qn not in reachable: - reachable.add(parent_qn) - frontier.append(parent_qn) + if repo: + repo.close() - return reachable, has_direct_call + # cleanup_repo(repo_path) def build_resource_index(resource_text, language): - resource_text = normalize_text(resource_text) - if not is_supported_language(language) or not resource_text: return None @@ -551,34 +485,38 @@ def build_resource_index(resource_text, language): } -def match_symbols_against_resource(symbols, resource_index): - if not symbols or not resource_index: +def match_symbols_against_resource(patch_symbols_metadata, resource_index): + if not patch_symbols_metadata or not resource_index: return {} call_graph = resource_index.get("call_graph") - target_simple_names = {metadata["simple_name"] for metadata in symbols.values()} - reachable_callers, _has_direct = compute_reachable_symbols( + target_qualified_names = { + metadata["qualified_name"] for metadata in patch_symbols_metadata.values() + } + + reachable_callers, _ = compute_reachable_symbols( call_graph, - target_simple_names, + target_qualified_names, ) - called_simple_names = set() + called_qualified_names = set() + if call_graph: - for callees in call_graph["edges"].values(): - called_simple_names |= callees + for callees in call_graph.get("edges_qualified", {}).values(): + called_qualified_names |= set(callees) matched = {} - for metadata in symbols.values(): + + for metadata in patch_symbols_metadata.values(): qualified_name = metadata["qualified_name"] - simple_name = metadata["simple_name"] fingerprint = metadata["fingerprint"] - defined = qualified_name in resource_index["definitions"] + defined = qualified_name in resource_index.get("definitions", {}) fingerprint_hit = bool( - fingerprint and fingerprint in resource_index["fingerprints"] + fingerprint and fingerprint in resource_index.get("fingerprints", {}) ) - called = simple_name in called_simple_names + called = qualified_name in called_qualified_names if not (defined or fingerprint_hit or called): continue @@ -608,12 +546,6 @@ def classify_reachability(evidence): if not evidence: return ReachabilityStatus.NOT_REACHABLE - SEVERITY_RANK = { - ReachabilityStatus.NOT_REACHABLE: 0, - ReachabilityStatus.POTENTIALLY_REACHABLE: 1, - ReachabilityStatus.REACHABLE: 2, - } - highest_status = ReachabilityStatus.NOT_REACHABLE for item in evidence.values(): @@ -622,17 +554,12 @@ def classify_reachability(evidence): is_exact = "exact_match_fingerprint" in item is_defined = bool(item.get("defined")) - if is_called or (has_path and is_exact): + if is_called or has_path: return ReachabilityStatus.REACHABLE - elif has_path or is_exact or is_defined: - current_item_status = ReachabilityStatus.POTENTIALLY_REACHABLE - - else: - current_item_status = ReachabilityStatus.NOT_REACHABLE + if is_exact or is_defined: + highest_status = ReachabilityStatus.POTENTIALLY_REACHABLE - if SEVERITY_RANK[current_item_status] > SEVERITY_RANK[highest_status]: - highest_status = current_item_status return highest_status @@ -650,7 +577,7 @@ def build_symbol_metadata(nodes, language, index=None): continue body_text = node.text.decode("utf-8", errors="replace") - fingerprints = create_file_fingerprints(content=body_text) or {} + fingerprints = create_exact_symbol_fingerprint(body_text) key = qualified_name suffix = 1 @@ -660,9 +587,8 @@ def build_symbol_metadata(nodes, language, index=None): metadata[key] = { "qualified_name": qualified_name, - "simple_name": qualified_name.rsplit(".", 1)[-1], "text": body_text, - "fingerprint": fingerprints.get("halo1"), + "fingerprint": fingerprints, "start_line": node.start_point[0] + 1, "end_line": node.end_point[0] + 1, "node_type": node.type, @@ -675,23 +601,211 @@ def build_call_graph(tree, language): return None index = collect_definitions(tree.root_node, language) - definition_nodes = [d["node"] for d in index.values()] - metadata = build_symbol_metadata(definition_nodes, language, index=index) - qualified_name_to_node = {} - for node in definition_nodes: + graph_meta = {} + for definition in index.values(): + node = definition["node"] qualified_name = qualified_name_from_index(node, index) - if qualified_name: - qualified_name_to_node.setdefault(qualified_name, node) - - edges = {} - by_simple_name = {} - for qualified_name, meta in metadata.items(): - canonical = meta["qualified_name"] - node = qualified_name_to_node.get(canonical) - if node is None: + + if not qualified_name: continue - edges.setdefault(canonical, set()).update(extract_calls_in_node(node, language)) - by_simple_name.setdefault(meta["simple_name"], set()).add(canonical) - return {"nodes": metadata, "edges": edges, "by_simple_name": by_simple_name} + body_text = node.text.decode("utf-8", errors="replace") + fingerprints = create_exact_symbol_fingerprint(body_text) or {} + + graph_meta[qualified_name] = { + "qualified_name": qualified_name, + "node": node, + "node_type": node.type, + "fingerprint": fingerprints, + } + + definitions_by_name = {} + class_methods = set() + + for qualified_name, metadata in graph_meta.items(): + name = qualified_name.rsplit(".", 1)[-1] + definitions_by_name.setdefault(name, set()).add(qualified_name) + + if metadata["node_type"] == "function_definition" and "." in qualified_name: + class_methods.add(qualified_name) + + edges_qualified = {} + for qualified_name, metadata in graph_meta.items(): + direct_calls = extract_direct_calls(metadata["node"], language, index) + + resolved_callees = set() + + for receiver_name, callee_name in direct_calls: + resolved_callees |= resolve_callee( + receiver_name=receiver_name, + callee_name=callee_name, + owner_qn=qualified_name, + definitions_by_name=definitions_by_name, + class_methods=class_methods, + ) + + edges_qualified[qualified_name] = resolved_callees + + return { + "nodes": graph_meta, + "edges_qualified": edges_qualified, + } + + +def extract_direct_calls(node, language, definition_index): + """ + Return direct calls inside `node`, excluding calls inside nested definitions. + + Returns: + list of (receiver_name, callee_name) + + Examples: + foo() -> (None, "foo") + self.foo() -> ("self", "foo") + obj.foo() -> ("obj", "foo") + + """ + query = get_query(language, "calls") + if query is None or node is None: + return [] + + definition_ids = set(definition_index) + calls = [] + + for _, captures in query.matches(node): + for callee_node in captures.get("callee", []): + if is_inside_nested_definition( + node=callee_node, + owner_node=node, + definition_ids=definition_ids, + ): + continue + + receiver_name = get_call_receiver(callee_node) + callee_name = node_text(callee_node) + + if callee_name: + calls.append((receiver_name, callee_name)) + + return calls + + +def is_inside_nested_definition(node, owner_node, definition_ids): + """ + Return True if node is inside a nested function/class within `owner_node`. + + Example: + def outer(): + foo() # belongs to outer + + def inner(): + bar() # nested; should not count as outer's call + + """ + current = node.parent + + while current is not None and current is not owner_node: + if current.id in definition_ids: + return True + + current = current.parent + + return False + + +def node_text(node): + return node.text.decode("utf-8", errors="replace") + + +def get_call_receiver(callee_node): + """ + Return receiver name for attribute calls. + + Examples: + foo() -> None + self.foo() -> "self" + obj.foo() -> "obj" + + """ + parent = callee_node.parent + + if parent is None or parent.type != "attribute": + return None + + object_node = parent.child_by_field_name("object") + if object_node is None: + return None + + return node_text(object_node) + + +def resolve_callee( + receiver_name, callee_name, owner_qn, definitions_by_name, class_methods +): + """ + Resolve a call to candidate qualified names. + + Examples: + self.foo() from class A -> {"A.foo"} if A.foo exists + foo() -> definitions named "foo" + + """ + if receiver_name == "self": + owner_class = get_owner_class_name(owner_qn) + + if owner_class: + method_qn = f"{owner_class}.{callee_name}" + + if method_qn in class_methods: + return {method_qn} + + candidates = definitions_by_name.get(callee_name, set()) + return set(candidates) + + +def get_owner_class_name(owner_qn): + """ + Return enclosing class name from a qualified name. + + Examples: + "User.save" -> "User" + "User.Inner.save" -> "User.Inner" + "save" -> None + + """ + if "." not in owner_qn: + return None + + return owner_qn.rsplit(".", 1)[0] + + +def compute_reachable_symbols(call_graph, target_qualified_names): + """Transitive callers using resolved qualified-name edges.""" + if not call_graph or not target_qualified_names: + return set(), False + + edges = call_graph.get("edges_qualified") + if not edges: + return set(), False + + callers_of = {} + for caller, callees in edges.items(): + for callee in callees: + callers_of.setdefault(callee, set()).add(caller) + + targets = set(target_qualified_names) + direct = set() + for target in targets: + direct |= callers_of.get(target, set()) + + reachable = set(direct) + frontier = list(direct) + while frontier: + cur = frontier.pop() + for parent in callers_of.get(cur, ()): + if parent not in reachable: + reachable.add(parent) + frontier.append(parent) + + return reachable, bool(direct) diff --git a/scanpipe/pipes/symbols.py b/scanpipe/pipes/symbols.py index cac76562e3..da6f359f65 100644 --- a/scanpipe/pipes/symbols.py +++ b/scanpipe/pipes/symbols.py @@ -20,6 +20,7 @@ # ScanCode.io is a free software code scanning tool from nexB Inc. and others. # Visit https://github.com/aboutcode-org/scancode.io for support and download. +import hashlib import importlib from functools import cache @@ -191,7 +192,6 @@ def _collect_and_store_tree_sitter_symbols_and_strings(resource): "pygments": symbols_pygments.get_pygments_symbols, } -# https://github.com/Aider-AI/aider/tree/5dc9490bb35f9729ef2c95d00a19ccd30c26339c/aider/queries/tree-sitter-language-pack TS_QUERIES = { "Python": { "functions": """ @@ -202,7 +202,9 @@ def _collect_and_store_tree_sitter_symbols_and_strings(resource): """, "calls": """ (call function: (identifier) @callee) - (call function: (attribute attribute: (identifier) @callee)) + (call function: (attribute + object: (_) @receiver + attribute: (identifier) @callee)) """, }, } @@ -362,3 +364,11 @@ def qualified_name_from_index(node, index): parts.append(definition["name"]) curr = curr.parent return ".".join(reversed(parts)) + + +def create_exact_symbol_fingerprint(text): + if text is None: + return None + + text = text.encode("utf-8", errors="replace") + return hashlib.sha256(text).hexdigest() diff --git a/scanpipe/tests/data/reachability/app.py b/scanpipe/tests/data/reachability/app.py index c64ae7d9d1..b8c9eff5e0 100644 --- a/scanpipe/tests/data/reachability/app.py +++ b/scanpipe/tests/data/reachability/app.py @@ -31,5 +31,5 @@ def build_file_path(filename): def unrelated_top_level_function(): - """An extra function to test AST node boundaries.""" + """Test AST node boundaries.""" return "I am just here to add AST complexity." diff --git a/scanpipe/tests/data/reachability/fixed-app.py b/scanpipe/tests/data/reachability/fixed-app.py index 3296bb843e..ca5a6f4c8b 100644 --- a/scanpipe/tests/data/reachability/fixed-app.py +++ b/scanpipe/tests/data/reachability/fixed-app.py @@ -37,5 +37,5 @@ def build_file_path(filename): def unrelated_top_level_function(): - """An extra function to test AST node boundaries.""" + """Test AST node boundaries.""" return "I am just here to add AST complexity." diff --git a/scanpipe/tests/data/reachability/vuln-app.py b/scanpipe/tests/data/reachability/vuln-app.py index c64ae7d9d1..b8c9eff5e0 100644 --- a/scanpipe/tests/data/reachability/vuln-app.py +++ b/scanpipe/tests/data/reachability/vuln-app.py @@ -31,5 +31,5 @@ def build_file_path(filename): def unrelated_top_level_function(): - """An extra function to test AST node boundaries.""" + """Test AST node boundaries.""" return "I am just here to add AST complexity." diff --git a/scanpipe/tests/pipes/test_symbols_reachability.py b/scanpipe/tests/pipes/test_symbols_reachability.py index 1a5fd7b3ad..68bbe9b964 100644 --- a/scanpipe/tests/pipes/test_symbols_reachability.py +++ b/scanpipe/tests/pipes/test_symbols_reachability.py @@ -29,14 +29,15 @@ from scanpipe.pipes import collect_and_create_codebase_resources from scanpipe.pipes.reachability import ReachabilityStatus from scanpipe.pipes.reachability import analyze_patched_file -from scanpipe.pipes.reachability import build_call_graph from scanpipe.pipes.reachability import build_symbol_metadata from scanpipe.pipes.reachability import classify_reachability from scanpipe.pipes.reachability import collect_and_store_symbol_reachability_results +from scanpipe.pipes.reachability import compute_reachable_symbols from scanpipe.pipes.reachability import diff_changed_symbols from scanpipe.pipes.reachability import get_changed_lines -from scanpipe.pipes.symbols import collect_definitions, extract_symbols +from scanpipe.pipes.symbols import collect_definitions from scanpipe.pipes.symbols import extract_definitions +from scanpipe.pipes.symbols import extract_symbols from scanpipe.pipes.symbols import parse_code_to_ast from scanpipe.pipes.symbols import qualified_name_from_index @@ -101,85 +102,31 @@ def test_collect_and_store_symbol_reachability_results( resource.refresh_from_db() results = resource.extra_data.get("symbols_reachability") - assert results == [ - { - "patch": { - "vcs_url": "https://github.com/aboutcode-org/test", - "commit_hash": "07ec0de1964b14bf085a1c9a27ece2b61ab6105c", - }, - "summary": { - "call_paths": {}, + self.assertEqual( + results, + [ + { + "patch": { + "vcs_url": "https://github.com/aboutcode-org/test", + "commit_hash": "07ec0de1964b14bf085a1c9a27ece2b61ab6105c", + }, + "summary": {"call_paths": {}}, + "evidence": { + "serve_report": { + "called": False, + "defined": True, + "reachable_from": [], + "exact_match_fingerprint": ( + "e341b914f9823915e0685396a730d421ec9e3635" + ), + } + }, "fixed_symbols": ["serve_report"], "vulnerable_symbols": ["serve_report"], - }, - "evidence": { - "serve_report": { - "called": False, - "defined": True, - "reachable_from": [], - "exact_match_fingerprint": "000000556d322a47595af353274b000aa324e014", - } - }, - "reachability_status": "POTENTIALLY_REACHABLE", - } - ] - - def test_build_call_graph(self): - source_code = """ -def calculate_total(price, tax): - return price + get_tax_amount(price, tax) - -def get_tax_amount(price, tax): - return price * tax - -def process_order(): - total = calculate_total(100, 0.05) - print("Done") -""" - tree, _ = parse_code_to_ast(source_code, "Python") - result = build_call_graph(tree, "Python") - - assert result == { - "nodes": { - "calculate_total": { - "qualified_name": "calculate_total", - "simple_name": "calculate_total", - "text": "def calculate_total(price, tax):\n return price + get_tax_amount(price, tax)", - "fingerprint": "00000008060105fd3624134884412006ce880936", - "start_line": 2, - "end_line": 3, - "node_type": "function_definition", - }, - "get_tax_amount": { - "qualified_name": "get_tax_amount", - "simple_name": "get_tax_amount", - "text": "def get_tax_amount(price, tax):\n return price * tax", - "fingerprint": "000000058f0ee87d9669f20b1f473137b665bb20", - "start_line": 5, - "end_line": 6, - "node_type": "function_definition", - }, - "process_order": { - "qualified_name": "process_order", - "simple_name": "process_order", - "text": 'def process_order():\n total = calculate_total(100, 0.05)\n print("Done")', - "fingerprint": "000000071c3e6902da5c2b322386eff29068e3e2", - "start_line": 8, - "end_line": 10, - "node_type": "function_definition", - }, - }, - "edges": { - "calculate_total": {"get_tax_amount"}, - "get_tax_amount": set(), - "process_order": {"print", "calculate_total"}, - }, - "by_simple_name": { - "calculate_total": {"calculate_total"}, - "get_tax_amount": {"get_tax_amount"}, - "process_order": {"process_order"}, - }, - } + "reachability_status": "POTENTIALLY_REACHABLE", + } + ], + ) def test_extract_definitions(self): source_code = """ @@ -198,25 +145,25 @@ class InventoryItem: """ tree, _ = parse_code_to_ast(source_code, "Python") functions = extract_definitions(tree, "Python", kinds=("functions",)) - assert ( - len(functions) == 3 + self.assertEqual( + len(functions), 3 ) # '__init__', 'process_payment', and 'calculate_discount' - assert functions[0].type == "function_definition" + self.assertEqual(functions[0].type, "function_definition") first_func_text = functions[0].text.decode("utf-8") - assert "def __init__" in first_func_text + self.assertIn("def __init__", first_func_text) classes = extract_definitions(tree, "Python", kinds=("classes",)) - assert len(classes) == 2 # OrderManager, InventoryItem + self.assertEqual(len(classes), 2) second_class_text = classes[1].text.decode("utf-8") - assert "class InventoryItem" in second_class_text + self.assertIn("class InventoryItem", second_class_text) def test_extract_definitions_empty(self): tree, _ = parse_code_to_ast("", "Python") - assert extract_definitions(tree, "Python", kinds=("functions",)) == [] - assert extract_definitions(tree, "Python", kinds=("functions",)) == [] - assert extract_definitions(None, "Python", kinds=("classes",)) == [] - assert extract_definitions(None, "Python", kinds=("classes",)) == [] + self.assertEqual(extract_definitions(tree, "Python", kinds=("functions",)), []) + self.assertEqual(extract_definitions(tree, "Python", kinds=("functions",)), []) + self.assertEqual(extract_definitions(None, "Python", kinds=("classes",)), []) + self.assertEqual(extract_definitions(None, "Python", kinds=("classes",)), []) def test_get_qualified_name_functions(self): source_code = """ @@ -233,13 +180,13 @@ def global_utility(): index = collect_definitions(tree.root_node, "Python") functions = extract_definitions(tree, "Python", kinds=("functions",)) - assert len(functions) == 2 + self.assertEqual(len(functions), 2) outer_function_name = qualified_name_from_index(functions[0], index) inner_function_name = qualified_name_from_index(functions[1], index) - assert outer_function_name == "CoreService.Validator.validate_payload" - assert inner_function_name == "global_utility" + self.assertEqual(outer_function_name, "CoreService.Validator.validate_payload") + self.assertEqual(inner_function_name, "global_utility") def test_get_qualified_classes(self): source_code = """ @@ -251,25 +198,25 @@ class DroneController: index = collect_definitions(tree.root_node, "Python") classes = extract_definitions(tree, "Python", kinds=("classes",)) - assert len(classes) == 2 + self.assertEqual(len(classes), 2) outer_class_name = qualified_name_from_index(classes[0], index) inner_class_name = qualified_name_from_index(classes[1], index) - assert outer_class_name == "FleetManagement" - assert inner_class_name == "FleetManagement.DroneController" + self.assertEqual(outer_class_name, "FleetManagement") + self.assertEqual(inner_class_name, "FleetManagement.DroneController") def test_classify_reachability(self): - assert classify_reachability(None) == ReachabilityStatus.NOT_REACHABLE - assert classify_reachability({}) == ReachabilityStatus.NOT_REACHABLE - assert ( + self.assertEqual(classify_reachability(None), ReachabilityStatus.NOT_REACHABLE) + self.assertEqual(classify_reachability({}), ReachabilityStatus.NOT_REACHABLE) + self.assertEqual( classify_reachability( {"sym1": {"exact_match_fingerprint": "hash123", "called": True}} - ) - == ReachabilityStatus.REACHABLE + ), + ReachabilityStatus.REACHABLE, ) - assert ( + self.assertEqual( classify_reachability( { "sym1": { @@ -277,22 +224,22 @@ def test_classify_reachability(self): "reachable_from": ["main_function", "api_handler"], } } - ) - == ReachabilityStatus.REACHABLE + ), + ReachabilityStatus.REACHABLE, ) - assert ( - classify_reachability({"sym1": {"defined": True, "called": False}}) - == ReachabilityStatus.POTENTIALLY_REACHABLE + self.assertEqual( + classify_reachability({"sym1": {"defined": True, "called": False}}), + ReachabilityStatus.POTENTIALLY_REACHABLE, ) - assert ( + self.assertEqual( classify_reachability( {"sym1": {"exact_match_fingerprint": "hash123", "called": False}} - ) - == ReachabilityStatus.POTENTIALLY_REACHABLE + ), + ReachabilityStatus.POTENTIALLY_REACHABLE, ) - assert ( - classify_reachability({"sym1": {"file_path": "src/vulnerable.py"}}) - == ReachabilityStatus.NOT_REACHABLE + self.assertEqual( + classify_reachability({"sym1": {"file_path": "src/vulnerable.py"}}), + ReachabilityStatus.NOT_REACHABLE, ) def test_get_changed_lines(self): @@ -300,8 +247,8 @@ def test_get_changed_lines(self): diff_text = (data / "diff-app.patch").read_text(encoding="utf-8") removed, added = get_changed_lines(diff_text, "app.py") - assert removed == [17, 18, 19, 24] - assert added == [17, 18, 19, 20, 21, 22, 27, 28, 29, 30] + self.assertEqual(removed, [17, 18, 19, 24]) + self.assertEqual(added, [17, 18, 19, 20, 21, 22, 27, 28, 29, 30]) def test_build_symbol_metadata_processing(self): source_code = """ @@ -319,26 +266,30 @@ def process_data(payload): nodes = extract_definitions(tree, "Python", kinds=("functions",)) metadata = build_symbol_metadata(nodes, "Python") - assert metadata == { - "Controller.process_data": { - "qualified_name": "Controller.process_data", - "simple_name": "process_data", - "text": "def process_data(payload):\n def inner_helper():\n return True\n return payload.strip()", - "fingerprint": "0000000888014a04b037189a42b238a2c50f218c", - "start_line": 3, - "end_line": 6, - "node_type": "function_definition", - }, - "process_data": { - "qualified_name": "process_data", - "simple_name": "process_data", - "text": "def process_data(payload):\n return payload", - "fingerprint": "000000022020300e882a900807880d0300010000", - "start_line": 9, - "end_line": 10, - "node_type": "function_definition", + self.assertEqual( + metadata, + { + "Controller.process_data": { + "qualified_name": "Controller.process_data", + "text": "def process_data(payload):\n" + " def inner_helper():\n" + " return True\n" + " return payload.strip()", + "fingerprint": "0000000888014a04b037189a42b238a2c50f218c", + "start_line": 3, + "end_line": 6, + "node_type": "function_definition", + }, + "process_data": { + "qualified_name": "process_data", + "text": "def process_data(payload):\n return payload", + "fingerprint": "000000022020300e882a900807880d0300010000", + "start_line": 9, + "end_line": 10, + "node_type": "function_definition", + }, }, - } + ) def test_diff_changed_symbols(self): vuln_meta = { @@ -359,7 +310,10 @@ def test_diff_changed_symbols(self): fixed_meta = { "serve_report": { "qualified_name": "app.serve_report", - "text": "def serve_report():\n if not target.startswith(base): raise ValueError\n return target", + "text": "def serve_report():\n " + " if not target.startswith(base): " + "raise ValueError\n " + " return target", }, "sanitize_input": { "qualified_name": "app.sanitize_input", @@ -373,26 +327,34 @@ def test_diff_changed_symbols(self): vuln_only, fixed_only = diff_changed_symbols(vuln_meta, fixed_meta) - assert vuln_only == { - "serve_report": { - "qualified_name": "app.serve_report", - "text": "def serve_report():\n return os.path.join(base, filename)", - }, - "deprecated_logger": { - "qualified_name": "app.deprecated_logger", - "text": "def deprecated_logger():\n print('legacy')", - }, - } - assert fixed_only == { - "serve_report": { - "qualified_name": "app.serve_report", - "text": "def serve_report():\n if not target.startswith(base): raise ValueError\n return target", + self.assertEqual( + vuln_only, + { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n " + " return os.path.join(base, filename)", + }, + "deprecated_logger": { + "qualified_name": "app.deprecated_logger", + "text": "def deprecated_logger():\n print('legacy')", + }, }, - "audit_trail": { - "qualified_name": "app.audit_trail", - "text": "def audit_trail():\n log.info('action')", + ) + self.assertEqual( + fixed_only, + { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n if not target.startswith(base): " + "raise ValueError\n return target", + }, + "audit_trail": { + "qualified_name": "app.audit_trail", + "text": "def audit_trail():\n log.info('action')", + }, }, - } + ) def test_analyze_patched_file(self): vuln_text = (self.data / "vuln-app.py").read_text(encoding="utf-8") @@ -406,28 +368,70 @@ def test_analyze_patched_file(self): file_path="app.py", ) - assert vuln_meta == { - "serve_report": { - "qualified_name": "serve_report", - "simple_name": "serve_report", - "text": 'def serve_report(request_payload):\n """Top-level function handling a request."""\n generator = ReportGenerator("/var/reports")\n requested_file = request_payload.get("file")\n\n # Helper function nested inside serve_report\n def build_file_path(filename):\n # VULNERABLE: Direct concatenation allows Path Traversal\n # An attacker passing "../../etc/passwd" could read system files.\n return os.path.join(generator.base_dir, filename)\n\n if not requested_file:\n return "Error: No file specified"\n\n target_path = build_file_path(requested_file)\n\n if os.path.exists(target_path):\n return f"Serving content of {target_path}"\n\n return "Error: File not found"', - "fingerprint": "000000556d322a47595af353274b000aa324e014", - "start_line": 11, - "end_line": 30, - "node_type": "function_definition", - } - } - assert fixed_meta == { - "serve_report": { - "qualified_name": "serve_report", - "simple_name": "serve_report", - "text": 'def serve_report(request_payload):\n """Top-level function handling a request."""\n generator = ReportGenerator("/var/reports")\n requested_file = request_payload.get("file")\n\n # Helper function nested inside serve_report\n def build_file_path(filename):\n # FIXED: Validate that the resolved path stays within the base_dir\n base = os.path.abspath(generator.base_dir)\n target = os.path.abspath(os.path.join(base, filename))\n if not target.startswith(base):\n raise ValueError("Path Traversal Detected")\n return target\n\n if not requested_file:\n return "Error: No file specified"\n\n try:\n target_path = build_file_path(requested_file)\n except ValueError:\n return "Error: Invalid path"\n\n if os.path.exists(target_path):\n return f"Serving content of {target_path}"\n\n return "Error: File not found"', - "fingerprint": "0000006cceea8aedf1da91830f67b64927086d24", - "start_line": 11, - "end_line": 36, - "node_type": "function_definition", - } - } + self.assertEqual( + vuln_meta, + { + "serve_report": { + "qualified_name": "serve_report", + "text": "def serve_report(request_payload):\n " + ' """Top-level function handling a request."""\n ' + ' generator = ReportGenerator("/var/reports")\n ' + ' requested_file = request_payload.get("file")\n\n ' + "# Helper function nested inside serve_report\n " + "def build_file_path(filename):\n " + " # VULNERABLE: Direct concatenation allows Path Traversal\n " + ' # An attacker passing "../../etc/passwd" ' + "could read system files.\n " + " return os.path.join(generator.base_dir, filename)\n\n " + " if not requested_file:\n " + ' return "Error: No file specified"\n\n ' + " target_path = build_file_path(requested_file)\n\n" + " " + " " + "if os.path.exists(target_path):\n" + " " + ' return f"Serving content of {target_path}"\n\n ' + ' return "Error: File not found"', + "fingerprint": "000000556d322a47595af353274b000aa324e014", + "start_line": 11, + "end_line": 30, + "node_type": "function_definition", + } + }, + ) + + self.assertEqual( + fixed_meta, + { + "serve_report": { + "qualified_name": "serve_report", + "text": "def serve_report(request_payload):\n " + ' """Top-level function handling a request."""\n ' + ' generator = ReportGenerator("/var/reports")\n ' + ' requested_file = request_payload.get("file")\n\n ' + " # Helper function nested inside serve_report\n " + " def build_file_path(filename):\n " + " # FIXED: Validate that the resolved " + "path stays within the base_dir\n " + " base = os.path.abspath(generator.base_dir)\n " + " target = os.path.abspath(os.path.join(base, filename))\n " + " if not target.startswith(base):\n " + ' raise ValueError("Path Traversal Detected")\n ' + " return target\n\n if not requested_file:\n " + ' return "Error: No file specified"\n\n try:\n ' + " target_path = build_file_path(requested_file)\n " + " except ValueError:\n " + ' return "Error: Invalid path"\n\n' + " if os.path.exists(target_path):\n " + ' return f"Serving content of {target_path}"\n\n ' + ' return "Error: File not found"', + "fingerprint": "0000006cceea8aedf1da91830f67b64927086d24", + "start_line": 11, + "end_line": 36, + "node_type": "function_definition", + } + }, + ) def test_extract_symbols(self): source_code = ( @@ -443,13 +447,13 @@ def test_extract_symbols(self): changed_lines = [4] enclosing_symbols = extract_symbols(tree, changed_lines, "Python") - assert len(enclosing_symbols) == 1 + self.assertEqual(len(enclosing_symbols), 1) target_node = enclosing_symbols[0] - assert target_node.type == "function_definition" + self.assertEqual(target_node.type, "function_definition") node_text = target_node.text.decode("utf-8") - assert "def build_path" in node_text - assert "def serve_report" not in node_text + self.assertIn("def build_path", node_text) + self.assertNotIn("def serve_report", node_text) def test_extract_symbols_deduplication(self): source_code = ( @@ -462,5 +466,22 @@ def test_extract_symbols_deduplication(self): changed_lines = [2, 3] enclosing_symbols = extract_symbols(tree, changed_lines, "Python") - assert len(enclosing_symbols) == 1 - assert enclosing_symbols[0].type == "function_definition" \ No newline at end of file + self.assertEqual(len(enclosing_symbols), 1) + self.assertEqual(enclosing_symbols[0].type, "function_definition") + + def test_compute_reachable_symbols(self): + call_graph = { + "edges_qualified": { + "app.main": {"app.helper", "app.safe_func"}, + "app.helper": {"app.vuln_func"}, + "app.direct_caller": {"app.vuln_func"}, + "app.unrelated": {"app.safe_func"}, + } + } + + target_qns = ["app.vuln_func"] + reachable, has_direct = compute_reachable_symbols(call_graph, target_qns) + self.assertTrue(has_direct) + + expected_reachable = {"app.main", "app.helper", "app.direct_caller"} + self.assertEqual(reachable, expected_reachable) From 983afd1d29161838e691d0156371f378485df4ed Mon Sep 17 00:00:00 2001 From: ziad hany Date: Thu, 11 Jun 2026 18:47:37 +0300 Subject: [PATCH 4/4] Fix a bug in the import-catching logic and add a test. Signed-off-by: ziad hany --- scanpipe/pipes/reachability.py | 255 +++++++++++------- scanpipe/pipes/symbols.py | 15 ++ .../tests/pipes/test_symbols_reachability.py | 39 ++- 3 files changed, 215 insertions(+), 94 deletions(-) diff --git a/scanpipe/pipes/reachability.py b/scanpipe/pipes/reachability.py index e9ebd97378..ad8a347e8f 100644 --- a/scanpipe/pipes/reachability.py +++ b/scanpipe/pipes/reachability.py @@ -55,14 +55,14 @@ class ReachabilityStatus(str, Enum): def api_mocker(): """TODO: Remove this once the API patch url is done""" return [ - { - "vcs_url": "https://github.com/pallets/flask", - "commit_hash": "089cb86dd22bff589a4eafb7ab8e42dc357623b4", - }, # { - # "vcs_url": "https://github.com/aio-libs/aiohttp", - # "commit_hash": "0c2e9da51126238a421568eb7c5b53e5b5d17b36", - # } + # "vcs_url": "https://github.com/pallets/flask", + # "commit_hash": "089cb86dd22bff589a4eafb7ab8e42dc357623b4", + # }, + { + "vcs_url": "https://github.com/aio-libs/aiohttp", + "commit_hash": "0c2e9da51126238a421568eb7c5b53e5b5d17b36", + } ] @@ -139,7 +139,7 @@ def get_commit_and_parent(repo, commit_hash): def get_commit_diff_text(repo, parent_commit, commit): """Whole-commit unified diff (used to extract changed line numbers).""" base = parent_commit.hexsha if parent_commit else EMPTY_TREE_SHA - return repo.git.diff(base, commit.hexsha, unified=0) + return repo.git.diff(base, commit.hexsha, unified=3) def get_changed_files(parent_commit, commit): @@ -195,7 +195,13 @@ def get_changed_files(parent_commit, commit): def get_changed_lines(diff_text, file_path): - """Return `(removed_lines, added_lines)` for one file from a unified diff.""" + """ + Return `(removed_lines, added_lines)` for one file. + + For pure-insertion hunks (no removed lines) we anchor the vulnerable side + to the hunk's source location so the enclosing old symbol is still found. + For pure-deletion hunks we do the mirror image on the added side. + """ removed = [] added = [] @@ -208,20 +214,37 @@ def get_changed_lines(diff_text, file_path): (patched_file.source_file or "").removeprefix("a/"), (patched_file.target_file or "").removeprefix("b/"), } - if file_path not in candidates: continue for hunk in patched_file: - for line in hunk: - if line.is_removed and line.source_line_no: - removed.append(line.source_line_no) - elif line.is_added and line.target_line_no: - added.append(line.target_line_no) + hunk_removed = [ + line.source_line_no + for line in hunk + if line.is_removed and line.source_line_no + ] + hunk_added = [ + line.target_line_no + for line in hunk + if line.is_added and line.target_line_no + ] + + # Pure insertion: nothing removed -> anchor old side to the + # line just before the insertion point in the source file. + if hunk_added and not hunk_removed: + anchor = max(hunk.source_start, 1) + hunk_removed = [anchor] + + # Pure deletion: nothing added -> anchor new side similarly. + if hunk_removed and not hunk_added: + anchor = max(hunk.target_start, 1) + hunk_added = [anchor] + + removed.extend(hunk_removed) + added.extend(hunk_added) return removed, added - def diff_changed_symbols(vuln_meta, fixed_meta): """ Keep only symbols whose body actually differs between vulnerable and fixed @@ -343,26 +366,6 @@ def collect_patch_symbols(repo, commit_hash): return by_language -def append_symbol_reachability_result(resource, result): - """ - Append one symbol reachability result to the resource extra_data without - overwriting previous results. - """ - extra_data = resource.extra_data or {} - existing_results = extra_data.get("symbols_reachability", []) - - if not isinstance(existing_results, list): - existing_results = [existing_results] - - existing_results.append(result) - - resource.update_extra_data( - { - "symbols_reachability": existing_results, - } - ) - - def collect_and_store_symbol_reachability_results(project, logger=None): """ For each known patch commit, determine whether each project codebase @@ -398,9 +401,6 @@ def collect_and_store_symbol_reachability_results(project, logger=None): continue patch_symbols = patch_symbols_by_language[resource_language] - vuln_patch_metadata = patch_symbols["vulnerable"] - fixed_patch_metadata = patch_symbols["fixed"] - resource_index = build_resource_index( resource_text, resource_language, @@ -409,38 +409,34 @@ def collect_and_store_symbol_reachability_results(project, logger=None): if not resource_index: continue - vuln_match_symbols = match_symbols_against_resource( - vuln_patch_metadata, + vuln_evidence = match_symbols_against_resource( + patch_symbols["vulnerable"], resource_index, ) - - fixed_match_symbols = match_symbols_against_resource( - fixed_patch_metadata, + fixed_evidence = match_symbols_against_resource( + patch_symbols["fixed"], resource_index, ) - if not vuln_match_symbols and not fixed_match_symbols: + if not vuln_evidence and not fixed_evidence: continue result = { - "reachability_status": classify_reachability(vuln_match_symbols), - "summary": { - "call_paths": { - qualified_name: ev.get("reachable_from", []) - for qualified_name, ev in vuln_match_symbols.items() - if ev.get("called") - }, - }, - "evidence": vuln_match_symbols, - "vulnerable_symbols": sorted(vuln_match_symbols), - "fixed_symbols": sorted(fixed_match_symbols), + "reachability_status": classify_reachability(vuln_evidence).value, + "vulnerable_symbols": sorted(vuln_evidence), + "fixed_symbols": sorted(fixed_evidence), + "evidence": vuln_evidence, "patch": { "vcs_url": vcs_url, "commit_hash": commit_hash, }, } - print(result) - append_symbol_reachability_result(resource, result) + + resource.update_extra_data( + { + "symbols_reachability": result, + } + ) except Exception as e: logger( @@ -448,11 +444,8 @@ def collect_and_store_symbol_reachability_results(project, logger=None): f"{vcs_url}@{commit_hash}: {e}" ) finally: - if repo: - repo.close() - # cleanup_repo(repo_path) - + pass def build_resource_index(resource_text, language): if not is_supported_language(language) or not resource_text: @@ -489,7 +482,11 @@ def match_symbols_against_resource(patch_symbols_metadata, resource_index): if not patch_symbols_metadata or not resource_index: return {} - call_graph = resource_index.get("call_graph") + call_graph = resource_index.get("call_graph") or {} + imports = call_graph.get("imports", {}) + + # Set of fully-qualified names the resource imports, e.g. "aiohttp.ClientSession" + imported_fq_names = set(imports.values()) target_qualified_names = { metadata["qualified_name"] for metadata in patch_symbols_metadata.values() @@ -501,24 +498,37 @@ def match_symbols_against_resource(patch_symbols_metadata, resource_index): ) called_qualified_names = set() - - if call_graph: - for callees in call_graph.get("edges_qualified", {}).values(): - called_qualified_names |= set(callees) + for callees in call_graph.get("edges_qualified", {}).values(): + called_qualified_names |= set(callees) matched = {} - for metadata in patch_symbols_metadata.values(): qualified_name = metadata["qualified_name"] fingerprint = metadata["fingerprint"] - defined = qualified_name in resource_index.get("definitions", {}) + defined = qualified_name in resource_index.get("definitions", set()) fingerprint_hit = bool( - fingerprint and fingerprint in resource_index.get("fingerprints", {}) + fingerprint and fingerprint in resource_index.get("fingerprints", set()) ) - called = qualified_name in called_qualified_names - if not (defined or fingerprint_hit or called): + # Does the resource *import* this symbol? + # Match either the bare name (import key) or any fq import target + # that ends with ".". + imported = ( + qualified_name in imports + or qualified_name in imported_fq_names + or any( + fq == qualified_name or fq.endswith("." + qualified_name) + for fq in imported_fq_names + ) + ) + + called = any( + fq_name == qualified_name or fq_name.endswith("." + qualified_name) + for fq_name in called_qualified_names + ) + + if not (defined or fingerprint_hit or called or imported): continue entry = matched.setdefault( @@ -526,18 +536,29 @@ def match_symbols_against_resource(patch_symbols_metadata, resource_index): { "defined": False, "called": False, + "imported": False, + "fingerprint": None, "reachable_from": [], + "external": False, }, ) - entry["defined"] = entry["defined"] or defined - entry["called"] = entry["called"] or called + if defined: + entry["defined"] = True - if fingerprint_hit: - entry["exact_match_fingerprint"] = fingerprint + if imported: + entry["imported"] = True + if not defined: + entry["external"] = True if called: + entry["called"] = True entry["reachable_from"] = sorted(reachable_callers) + if not defined: + entry["external"] = True + + if fingerprint_hit: + entry["fingerprint"] = fingerprint return matched @@ -553,8 +574,9 @@ def classify_reachability(evidence): has_path = bool(item.get("reachable_from")) is_exact = "exact_match_fingerprint" in item is_defined = bool(item.get("defined")) + is_imported = bool(item.get("imported")) - if is_called or has_path: + if is_called or has_path or is_imported: return ReachabilityStatus.REACHABLE if is_exact or is_defined: @@ -601,6 +623,7 @@ def build_call_graph(tree, language): return None index = collect_definitions(tree.root_node, language) + import_map = collect_imports(tree.root_node, language) graph_meta = {} for definition in index.values(): @@ -611,13 +634,12 @@ def build_call_graph(tree, language): continue body_text = node.text.decode("utf-8", errors="replace") - fingerprints = create_exact_symbol_fingerprint(body_text) or {} - + fingerprint = create_exact_symbol_fingerprint(body_text) graph_meta[qualified_name] = { "qualified_name": qualified_name, "node": node, "node_type": node.type, - "fingerprint": fingerprints, + "fingerprint": fingerprint, } definitions_by_name = {} @@ -635,7 +657,6 @@ def build_call_graph(tree, language): direct_calls = extract_direct_calls(metadata["node"], language, index) resolved_callees = set() - for receiver_name, callee_name in direct_calls: resolved_callees |= resolve_callee( receiver_name=receiver_name, @@ -643,6 +664,7 @@ def build_call_graph(tree, language): owner_qn=qualified_name, definitions_by_name=definitions_by_name, class_methods=class_methods, + import_map=import_map, ) edges_qualified[qualified_name] = resolved_callees @@ -650,6 +672,7 @@ def build_call_graph(tree, language): return { "nodes": graph_meta, "edges_qualified": edges_qualified, + "imports": import_map, } @@ -741,27 +764,30 @@ def get_call_receiver(callee_node): def resolve_callee( - receiver_name, callee_name, owner_qn, definitions_by_name, class_methods + receiver_name, + callee_name, + owner_qn, + definitions_by_name, + class_methods, + import_map=None, ): - """ - Resolve a call to candidate qualified names. + import_map = import_map or {} - Examples: - self.foo() from class A -> {"A.foo"} if A.foo exists - foo() -> definitions named "foo" - - """ if receiver_name == "self": owner_class = get_owner_class_name(owner_qn) - if owner_class: method_qn = f"{owner_class}.{callee_name}" - if method_qn in class_methods: return {method_qn} - candidates = definitions_by_name.get(callee_name, set()) - return set(candidates) + if callee_name in import_map: + return {import_map[callee_name]} + + if receiver_name is not None and receiver_name in import_map: + base = import_map[receiver_name] + return {f"{base}.{callee_name}"} + + return set(definitions_by_name.get(callee_name, set())) def get_owner_class_name(owner_qn): @@ -809,3 +835,46 @@ def compute_reachable_symbols(call_graph, target_qualified_names): frontier.append(parent) return reachable, bool(direct) + + +def collect_imports(root_node, language: str): + """ + Returns a dict mapping local names/aliases to their absolute import path. + Examples: + 'from django.db import models' -> {'models': 'django.db.models'} + 'import os.path' -> {'os.path': 'os.path'} + 'import numpy as np' -> {'np': 'numpy'} + 'from a.b import c as d' -> {'d': 'a.b.c'} + """ + import_map = {} + query = get_query(language, "imports") + if not query or not root_node: + return import_map + + for _pattern_index, captures in query.matches(root_node): + module_name = None + import_name = None + alias = None + + for node_name, nodes in captures.items(): + if not nodes: + continue + + text = nodes[0].text.decode("utf-8", errors="replace") + if node_name == "module_name": + module_name = text + elif node_name == "import_name": + import_name = text + elif node_name == "alias": + alias = text + + if not import_name: + continue + + local_name = alias or import_name + if module_name: + import_map[local_name] = f"{module_name}.{import_name}" + else: + import_map[local_name] = import_name + + return import_map diff --git a/scanpipe/pipes/symbols.py b/scanpipe/pipes/symbols.py index da6f359f65..6637bbf1c2 100644 --- a/scanpipe/pipes/symbols.py +++ b/scanpipe/pipes/symbols.py @@ -206,6 +206,21 @@ def _collect_and_store_tree_sitter_symbols_and_strings(resource): object: (_) @receiver attribute: (identifier) @callee)) """, + "imports": """ + (import_statement name: (dotted_name) @import_name) + (import_statement + name: (aliased_import + name: (dotted_name) @import_name + alias: (identifier) @alias)) + (import_from_statement + module_name: (dotted_name) @module_name + name: (dotted_name) @import_name) + (import_from_statement + module_name: (dotted_name) @module_name + name: (aliased_import + name: (dotted_name) @import_name + alias: (identifier) @alias)) + """, }, } diff --git a/scanpipe/tests/pipes/test_symbols_reachability.py b/scanpipe/tests/pipes/test_symbols_reachability.py index 68bbe9b964..3d5f27ab28 100644 --- a/scanpipe/tests/pipes/test_symbols_reachability.py +++ b/scanpipe/tests/pipes/test_symbols_reachability.py @@ -27,7 +27,7 @@ from scanpipe.models import Project from scanpipe.pipes import collect_and_create_codebase_resources -from scanpipe.pipes.reachability import ReachabilityStatus +from scanpipe.pipes.reachability import ReachabilityStatus, collect_imports, extract_direct_calls from scanpipe.pipes.reachability import analyze_patched_file from scanpipe.pipes.reachability import build_symbol_metadata from scanpipe.pipes.reachability import classify_reachability @@ -485,3 +485,40 @@ def test_compute_reachable_symbols(self): expected_reachable = {"app.main", "app.helper", "app.direct_caller"} self.assertEqual(reachable, expected_reachable) + + def test_collect_imports(self): + source_code = """ +from django.db import models +import os.path +import numpy as np +from a.b import c as d + """.strip() + + tree, _ = parse_code_to_ast(source_code, "Python") + real_root_node = tree.root_node + result = collect_imports(real_root_node, language="Python") + + expected_map = { + "models": "django.db.models", + "os.path": "os.path", + "np": "numpy", + "d": "a.b.c", + } + self.assertEqual(result, expected_map) + + def test_extract_direct(self): + source_code = """ +def hello(): + return 10 + +def clean_function(): + x = 10 + y = 20 + return hello() + x + y + """.strip() + + tree, _ = parse_code_to_ast(source_code, "Python") + functions = extract_definitions(tree, "Python", kinds=("functions",)) + + result = extract_direct_calls(functions[1], "Python", []) + self.assertEqual(result, [(None, 'hello')]) \ No newline at end of file