From ef6f99bf68b577486811be56e8be3286c898c9de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez?= Date: Wed, 17 Sep 2025 16:21:39 +0200 Subject: [PATCH] Fix CTE support in SQLWrapper by implementing intelligent schema prefixing - Add _extract_cte_names() method to identify CTE names in SQL queries - Add _add_schema_prefix_smart() method to intelligently prefix table names while preserving CTE names - Update run_sql() method to use smart schema prefixing instead of naive string replacement - Add comprehensive tests for CTE functionality including: - CTE name extraction - Schema prefixing with CTEs - Multiple CTEs handling - Already schema-qualified table handling - End-to-end integration tests Fixes #19889 --- .../llama_index/core/utilities/sql_wrapper.py | 90 ++++++++- .../tests/utilities/test_sql_wrapper.py | 172 ++++++++++++++++++ 2 files changed, 258 insertions(+), 4 deletions(-) diff --git a/llama-index-core/llama_index/core/utilities/sql_wrapper.py b/llama-index-core/llama_index/core/utilities/sql_wrapper.py index 09712ff73b..0caccaac01 100644 --- a/llama-index-core/llama_index/core/utilities/sql_wrapper.py +++ b/llama-index-core/llama_index/core/utilities/sql_wrapper.py @@ -1,5 +1,6 @@ """SQL wrapper around SQLDatabase in langchain.""" +import re from typing import Any, Dict, Iterable, List, Optional, Tuple from sqlalchemy import MetaData, create_engine, insert, inspect, text @@ -212,6 +213,88 @@ def truncate_word(self, content: Any, *, length: int, suffix: str = "...") -> st return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix + def _extract_cte_names(self, command: str) -> set[str]: + """ + Extract CTE names from a SQL command. + + Returns a set of CTE names that should not be schema-prefixed. + """ + cte_names = set() + + # Pattern to match CTE definitions: WITH name AS (...), name2 AS (...) + # This handles both single and multiple CTEs + cte_pattern = r"\bWITH\s+(.+?)\s+AS\s*\(" + + # Find the WITH clause + with_match = re.search(cte_pattern, command, re.IGNORECASE | re.DOTALL) + if not with_match: + return cte_names + + # Extract everything between WITH and the first AS ( + with_clause_start = with_match.start() + + # Find all CTE definitions by looking for patterns like "name AS (" + # This is more complex because we need to handle nested parentheses + remaining_sql = command[with_clause_start:] + + # Simple approach: find all patterns like "identifier AS (" at the beginning or after commas + cte_def_pattern = r"(?:WITH\s+|,\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s+AS\s*\(" + + for match in re.finditer(cte_def_pattern, remaining_sql, re.IGNORECASE): + cte_name = match.group(1).strip() + cte_names.add(cte_name) + + return cte_names + + def _add_schema_prefix_smart(self, command: str) -> str: + """ + Add schema prefix to table names while preserving CTE names. + + This method: + 1. Identifies CTE names in the query + 2. Only adds schema prefixes to actual table names, not CTE names + 3. Handles both FROM and JOIN clauses properly + """ + if not self._schema: + return command + + # Extract CTE names that should not be prefixed + cte_names = self._extract_cte_names(command) + + # Create a modified command + modified_command = command + + # Pattern to find FROM/JOIN clauses followed by table names + # This pattern captures: FROM/JOIN + whitespace + table_name + from_join_pattern = ( + r"\b(FROM|JOIN)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)\b" + ) + + def replace_table_reference(match): + clause_type = match.group(1) # FROM or JOIN + table_ref = match.group(2) # table name or schema.table + + # Skip if this is a CTE name + if table_ref in cte_names: + return match.group(0) # Return original match unchanged + + # Skip if already schema-qualified (contains a dot) + if "." in table_ref: + return match.group(0) # Return original match unchanged + + # Add schema prefix + return f"{clause_type} {self._schema}.{table_ref}" + + # Apply the replacement + modified_command = re.sub( + from_join_pattern, + replace_table_reference, + modified_command, + flags=re.IGNORECASE, + ) + + return modified_command + def run_sql(self, command: str) -> Tuple[str, Dict]: """ Execute a SQL statement and return a string representing the results. @@ -221,10 +304,9 @@ def run_sql(self, command: str) -> Tuple[str, Dict]: """ with self._engine.begin() as connection: try: - if self._schema: - command = command.replace("FROM ", f"FROM {self._schema}.") - command = command.replace("JOIN ", f"JOIN {self._schema}.") - cursor = connection.execute(text(command)) + # Use smart schema prefixing that handles CTEs properly + modified_command = self._add_schema_prefix_smart(command) + cursor = connection.execute(text(modified_command)) except (ProgrammingError, OperationalError) as exc: raise NotImplementedError( f"Statement {command!r} is invalid SQL.\nError: {exc.orig}" diff --git a/llama-index-core/tests/utilities/test_sql_wrapper.py b/llama-index-core/tests/utilities/test_sql_wrapper.py index 2b06004eea..f96ffac705 100644 --- a/llama-index-core/tests/utilities/test_sql_wrapper.py +++ b/llama-index-core/tests/utilities/test_sql_wrapper.py @@ -96,3 +96,175 @@ def test_long_string_no_truncation(sql_database: SQLDatabase) -> None: result_str, _ = sql_database.run_sql("SELECT * FROM test_table;") assert result_str == f"[(1, '{long_string}')]" + + +# Test CTE functionality +def test_cte_extraction(sql_database: SQLDatabase) -> None: + """Test that CTE names are correctly extracted from SQL queries.""" + # Test single CTE + query1 = "WITH my_cte AS (SELECT * FROM test_table) SELECT * FROM my_cte" + cte_names = sql_database._extract_cte_names(query1) + assert cte_names == {"my_cte"} + + # Test multiple CTEs + query2 = "WITH cte1 AS (SELECT * FROM test_table), cte2 AS (SELECT * FROM test_table) SELECT * FROM cte1 JOIN cte2" + cte_names = sql_database._extract_cte_names(query2) + assert cte_names == {"cte1", "cte2"} + + # Test no CTE + query3 = "SELECT * FROM test_table" + cte_names = sql_database._extract_cte_names(query3) + assert cte_names == set() + + # Test case insensitive + query4 = "with my_cte as (SELECT * FROM test_table) SELECT * FROM my_cte" + cte_names = sql_database._extract_cte_names(query4) + assert cte_names == {"my_cte"} + + +def test_schema_prefixing_without_cte(sql_database: SQLDatabase) -> None: + """Test that schema prefixing works correctly for regular queries without CTEs.""" + # Create a database with schema + engine = create_engine("sqlite:///:memory:") + metadata = MetaData() + table_name = "test_table" + Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + metadata.create_all(engine) + + # Create SQLDatabase with schema + db_with_schema = SQLDatabase( + engine=engine, + schema="test_schema", + metadata=metadata, + sample_rows_in_table_info=1, + ) + + # Test simple query + query = "SELECT * FROM test_table" + modified = db_with_schema._add_schema_prefix_smart(query) + assert modified == "SELECT * FROM test_schema.test_table" + + # Test JOIN query + query2 = "SELECT * FROM test_table t1 JOIN test_table t2 ON t1.id = t2.id" + modified2 = db_with_schema._add_schema_prefix_smart(query2) + assert "test_schema.test_table" in modified2 + assert modified2.count("test_schema.test_table") == 2 + + +def test_schema_prefixing_with_cte(sql_database: SQLDatabase) -> None: + """Test that schema prefixing works correctly with CTEs.""" + # Create a database with schema + engine = create_engine("sqlite:///:memory:") + metadata = MetaData() + table_name = "test_table" + Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + metadata.create_all(engine) + + # Create SQLDatabase with schema + db_with_schema = SQLDatabase( + engine=engine, + schema="test_schema", + metadata=metadata, + sample_rows_in_table_info=1, + ) + + # Test CTE query - CTE name should not be prefixed, but table should be + query = "WITH my_cte AS (SELECT * FROM test_table) SELECT * FROM my_cte" + modified = db_with_schema._add_schema_prefix_smart(query) + assert "FROM test_schema.test_table" in modified + assert "FROM my_cte" in modified # CTE name should not be prefixed + + # Test multiple CTEs + query2 = "WITH cte1 AS (SELECT * FROM test_table), cte2 AS (SELECT * FROM test_table) SELECT * FROM cte1 JOIN cte2" + modified2 = db_with_schema._add_schema_prefix_smart(query2) + assert "FROM test_schema.test_table" in modified2 + assert "FROM cte1" in modified2 + assert "JOIN cte2" in modified2 + + # Test CTE with JOIN + query3 = "WITH my_cte AS (SELECT * FROM test_table) SELECT * FROM my_cte JOIN test_table ON my_cte.id = test_table.id" + modified3 = db_with_schema._add_schema_prefix_smart(query3) + assert "FROM my_cte" in modified3 # CTE should not be prefixed + assert "JOIN test_schema.test_table" in modified3 # Table should be prefixed + + +def test_schema_prefixing_already_qualified(sql_database: SQLDatabase) -> None: + """Test that already schema-qualified table names are not modified.""" + # Create a database with schema + engine = create_engine("sqlite:///:memory:") + metadata = MetaData() + table_name = "test_table" + Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + metadata.create_all(engine) + + # Create SQLDatabase with schema + db_with_schema = SQLDatabase( + engine=engine, + schema="test_schema", + metadata=metadata, + sample_rows_in_table_info=1, + ) + + # Test already qualified table name + query = "SELECT * FROM other_schema.test_table" + modified = db_with_schema._add_schema_prefix_smart(query) + assert modified == query # Should remain unchanged + + # Test mixed qualified and unqualified + query2 = ( + "SELECT * FROM test_table t1 JOIN other_schema.test_table t2 ON t1.id = t2.id" + ) + modified2 = db_with_schema._add_schema_prefix_smart(query2) + assert "FROM test_schema.test_table" in modified2 + assert "JOIN other_schema.test_table" in modified2 + + +def test_run_sql_with_cte_and_schema(sql_database: SQLDatabase) -> None: + """Test that run_sql works correctly with CTEs and schema prefixing.""" + # Create a database with schema + engine = create_engine("sqlite:///:memory:") + metadata = MetaData() + table_name = "test_table" + Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + metadata.create_all(engine) + + # Create SQLDatabase with schema + db_with_schema = SQLDatabase( + engine=engine, + schema="test_schema", + metadata=metadata, + sample_rows_in_table_info=1, + ) + + # Insert test data + db_with_schema.insert_into_table("test_table", {"id": 1, "name": "Alice"}) + db_with_schema.insert_into_table("test_table", {"id": 2, "name": "Bob"}) + + # Test CTE query + query = "WITH filtered_users AS (SELECT * FROM test_table WHERE id > 1) SELECT * FROM filtered_users" + result_str, result_dict = db_with_schema.run_sql(query) + + # Should return Bob's record + assert "Bob" in result_str + assert "Alice" not in result_str + assert len(result_dict["result"]) == 1