Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 86 additions & 4 deletions llama-index-core/llama_index/core/utilities/sql_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SQL wrapper around SQLDatabase in langchain."""

import re
from typing import Any, Dict, Iterable, List, Optional, Tuple

from sqlalchemy import MetaData, create_engine, insert, inspect, text
Expand Down Expand Up @@ -212,6 +213,88 @@ def truncate_word(self, content: Any, *, length: int, suffix: str = "...") -> st

return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix

def _extract_cte_names(self, command: str) -> set[str]:
"""
Extract CTE names from a SQL command.

Returns a set of CTE names that should not be schema-prefixed.
"""
cte_names = set()

# Pattern to match CTE definitions: WITH name AS (...), name2 AS (...)
# This handles both single and multiple CTEs
cte_pattern = r"\bWITH\s+(.+?)\s+AS\s*\("

# Find the WITH clause
with_match = re.search(cte_pattern, command, re.IGNORECASE | re.DOTALL)
if not with_match:
return cte_names

# Extract everything between WITH and the first AS (
with_clause_start = with_match.start()

# Find all CTE definitions by looking for patterns like "name AS ("
# This is more complex because we need to handle nested parentheses
remaining_sql = command[with_clause_start:]

# Simple approach: find all patterns like "identifier AS (" at the beginning or after commas
cte_def_pattern = r"(?:WITH\s+|,\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s+AS\s*\("

for match in re.finditer(cte_def_pattern, remaining_sql, re.IGNORECASE):
cte_name = match.group(1).strip()
cte_names.add(cte_name)

return cte_names

def _add_schema_prefix_smart(self, command: str) -> str:
"""
Add schema prefix to table names while preserving CTE names.

This method:
1. Identifies CTE names in the query
2. Only adds schema prefixes to actual table names, not CTE names
3. Handles both FROM and JOIN clauses properly
"""
if not self._schema:
return command

# Extract CTE names that should not be prefixed
cte_names = self._extract_cte_names(command)

# Create a modified command
modified_command = command

# Pattern to find FROM/JOIN clauses followed by table names
# This pattern captures: FROM/JOIN + whitespace + table_name
from_join_pattern = (
r"\b(FROM|JOIN)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)\b"
)

def replace_table_reference(match):
clause_type = match.group(1) # FROM or JOIN
table_ref = match.group(2) # table name or schema.table

# Skip if this is a CTE name
if table_ref in cte_names:
return match.group(0) # Return original match unchanged

# Skip if already schema-qualified (contains a dot)
if "." in table_ref:
return match.group(0) # Return original match unchanged

# Add schema prefix
return f"{clause_type} {self._schema}.{table_ref}"

# Apply the replacement
modified_command = re.sub(
from_join_pattern,
replace_table_reference,
modified_command,
flags=re.IGNORECASE,
)

return modified_command

def run_sql(self, command: str) -> Tuple[str, Dict]:
"""
Execute a SQL statement and return a string representing the results.
Expand All @@ -221,10 +304,9 @@ def run_sql(self, command: str) -> Tuple[str, Dict]:
"""
with self._engine.begin() as connection:
try:
if self._schema:
command = command.replace("FROM ", f"FROM {self._schema}.")
command = command.replace("JOIN ", f"JOIN {self._schema}.")
cursor = connection.execute(text(command))
# Use smart schema prefixing that handles CTEs properly
modified_command = self._add_schema_prefix_smart(command)
cursor = connection.execute(text(modified_command))
except (ProgrammingError, OperationalError) as exc:
raise NotImplementedError(
f"Statement {command!r} is invalid SQL.\nError: {exc.orig}"
Expand Down
172 changes: 172 additions & 0 deletions llama-index-core/tests/utilities/test_sql_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,175 @@ def test_long_string_no_truncation(sql_database: SQLDatabase) -> None:
result_str, _ = sql_database.run_sql("SELECT * FROM test_table;")

assert result_str == f"[(1, '{long_string}')]"


# Test CTE functionality
def test_cte_extraction(sql_database: SQLDatabase) -> None:
"""Test that CTE names are correctly extracted from SQL queries."""
# Test single CTE
query1 = "WITH my_cte AS (SELECT * FROM test_table) SELECT * FROM my_cte"
cte_names = sql_database._extract_cte_names(query1)
assert cte_names == {"my_cte"}

# Test multiple CTEs
query2 = "WITH cte1 AS (SELECT * FROM test_table), cte2 AS (SELECT * FROM test_table) SELECT * FROM cte1 JOIN cte2"
cte_names = sql_database._extract_cte_names(query2)
assert cte_names == {"cte1", "cte2"}

# Test no CTE
query3 = "SELECT * FROM test_table"
cte_names = sql_database._extract_cte_names(query3)
assert cte_names == set()

# Test case insensitive
query4 = "with my_cte as (SELECT * FROM test_table) SELECT * FROM my_cte"
cte_names = sql_database._extract_cte_names(query4)
assert cte_names == {"my_cte"}


def test_schema_prefixing_without_cte(sql_database: SQLDatabase) -> None:
"""Test that schema prefixing works correctly for regular queries without CTEs."""
# Create a database with schema
engine = create_engine("sqlite:///:memory:")
metadata = MetaData()
table_name = "test_table"
Table(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("name", String),
)
metadata.create_all(engine)

# Create SQLDatabase with schema
db_with_schema = SQLDatabase(
engine=engine,
schema="test_schema",
metadata=metadata,
sample_rows_in_table_info=1,
)

# Test simple query
query = "SELECT * FROM test_table"
modified = db_with_schema._add_schema_prefix_smart(query)
assert modified == "SELECT * FROM test_schema.test_table"

# Test JOIN query
query2 = "SELECT * FROM test_table t1 JOIN test_table t2 ON t1.id = t2.id"
modified2 = db_with_schema._add_schema_prefix_smart(query2)
assert "test_schema.test_table" in modified2
assert modified2.count("test_schema.test_table") == 2


def test_schema_prefixing_with_cte(sql_database: SQLDatabase) -> None:
"""Test that schema prefixing works correctly with CTEs."""
# Create a database with schema
engine = create_engine("sqlite:///:memory:")
metadata = MetaData()
table_name = "test_table"
Table(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("name", String),
)
metadata.create_all(engine)

# Create SQLDatabase with schema
db_with_schema = SQLDatabase(
engine=engine,
schema="test_schema",
metadata=metadata,
sample_rows_in_table_info=1,
)

# Test CTE query - CTE name should not be prefixed, but table should be
query = "WITH my_cte AS (SELECT * FROM test_table) SELECT * FROM my_cte"
modified = db_with_schema._add_schema_prefix_smart(query)
assert "FROM test_schema.test_table" in modified
assert "FROM my_cte" in modified # CTE name should not be prefixed

# Test multiple CTEs
query2 = "WITH cte1 AS (SELECT * FROM test_table), cte2 AS (SELECT * FROM test_table) SELECT * FROM cte1 JOIN cte2"
modified2 = db_with_schema._add_schema_prefix_smart(query2)
assert "FROM test_schema.test_table" in modified2
assert "FROM cte1" in modified2
assert "JOIN cte2" in modified2

# Test CTE with JOIN
query3 = "WITH my_cte AS (SELECT * FROM test_table) SELECT * FROM my_cte JOIN test_table ON my_cte.id = test_table.id"
modified3 = db_with_schema._add_schema_prefix_smart(query3)
assert "FROM my_cte" in modified3 # CTE should not be prefixed
assert "JOIN test_schema.test_table" in modified3 # Table should be prefixed


def test_schema_prefixing_already_qualified(sql_database: SQLDatabase) -> None:
"""Test that already schema-qualified table names are not modified."""
# Create a database with schema
engine = create_engine("sqlite:///:memory:")
metadata = MetaData()
table_name = "test_table"
Table(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("name", String),
)
metadata.create_all(engine)

# Create SQLDatabase with schema
db_with_schema = SQLDatabase(
engine=engine,
schema="test_schema",
metadata=metadata,
sample_rows_in_table_info=1,
)

# Test already qualified table name
query = "SELECT * FROM other_schema.test_table"
modified = db_with_schema._add_schema_prefix_smart(query)
assert modified == query # Should remain unchanged

# Test mixed qualified and unqualified
query2 = (
"SELECT * FROM test_table t1 JOIN other_schema.test_table t2 ON t1.id = t2.id"
)
modified2 = db_with_schema._add_schema_prefix_smart(query2)
assert "FROM test_schema.test_table" in modified2
assert "JOIN other_schema.test_table" in modified2


def test_run_sql_with_cte_and_schema(sql_database: SQLDatabase) -> None:
"""Test that run_sql works correctly with CTEs and schema prefixing."""
# Create a database with schema
engine = create_engine("sqlite:///:memory:")
metadata = MetaData()
table_name = "test_table"
Table(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("name", String),
)
metadata.create_all(engine)

# Create SQLDatabase with schema
db_with_schema = SQLDatabase(
engine=engine,
schema="test_schema",
metadata=metadata,
sample_rows_in_table_info=1,
)

# Insert test data
db_with_schema.insert_into_table("test_table", {"id": 1, "name": "Alice"})
db_with_schema.insert_into_table("test_table", {"id": 2, "name": "Bob"})

# Test CTE query
query = "WITH filtered_users AS (SELECT * FROM test_table WHERE id > 1) SELECT * FROM filtered_users"
result_str, result_dict = db_with_schema.run_sql(query)

# Should return Bob's record
assert "Bob" in result_str
assert "Alice" not in result_str
assert len(result_dict["result"]) == 1
Loading