Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions llama-index-core/llama_index/core/storage/chat_store/sql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import time
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -13,14 +14,14 @@
select,
insert,
update,
text,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.schema import CreateSchema

from llama_index.core.async_utils import asyncio_run
from llama_index.core.bridge.pydantic import Field, PrivateAttr, model_serializer
Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(
db_schema: Optional[str] = None,
):
"""Initialize the SQLAlchemy chat store."""
if db_schema is not None:
self._validate_schema_name(db_schema)
super().__init__(
table_name=table_name,
async_database_uri=async_database_uri or DEFAULT_ASYNC_DATABASE_URI,
Expand All @@ -78,6 +81,15 @@ def __init__(
self._async_engine = async_engine
self._db_data = db_data

@staticmethod
def _validate_schema_name(schema_name: str) -> None:
"""Validate schema name to prevent SQL injection."""
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]{0,62}$", schema_name):
raise ValueError(
f"Invalid schema name: {schema_name}. Schema names must start with a letter or underscore, "
"contain only alphanumeric characters and underscores, and be at most 63 characters long."
)

@staticmethod
def _is_in_memory_uri(uri: Optional[str]) -> bool:
"""Check if the URI points to an in-memory SQLite database."""
Expand Down Expand Up @@ -140,9 +152,7 @@ async def _setup_tables(self, async_engine: AsyncEngine) -> Table:

# Create schema if it doesn't exist (PostgreSQL, SQL Server, etc.)
async with async_engine.begin() as conn:
await conn.execute(
text(f'CREATE SCHEMA IF NOT EXISTS "{self.db_schema}"')
)
await conn.execute(CreateSchema(self.db_schema, if_not_exists=True))

# Create messages table with status column
self._table = Table(
Expand Down
69 changes: 66 additions & 3 deletions llama-index-core/tests/storage/chat_store/test_sql_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.storage.chat_store.sql import SQLAlchemyChatStore
from sqlalchemy.schema import CreateSchema


class TestSQLAlchemyChatStoreSchema:
Expand Down Expand Up @@ -47,7 +48,7 @@ def test_schema_serialization(self):

@pytest.mark.asyncio
async def test_postgresql_schema_creation(self):
"""Test that CREATE SCHEMA SQL is called for PostgreSQL."""
"""Test that CREATE SCHEMA SQL is called for PostgreSQL using CreateSchema."""
store = SQLAlchemyChatStore(
table_name="test_messages",
async_database_uri="postgresql+asyncpg://user:pass@host/db",
Expand All @@ -69,10 +70,12 @@ async def test_postgresql_schema_creation(self):
# Call _setup_tables
await store._setup_tables(async_engine)

# Verify schema creation was called
# Verify schema creation was called with CreateSchema
mock_conn.execute.assert_called()
call_args = mock_conn.execute.call_args_list[0][0][0]
assert 'CREATE SCHEMA IF NOT EXISTS "test_schema"' in str(call_args)
assert isinstance(call_args, CreateSchema)
assert call_args.element == "test_schema"
assert call_args.if_not_exists is True

# Verify MetaData has schema
assert store._metadata.schema == "test_schema"
Expand Down Expand Up @@ -120,3 +123,63 @@ async def test_basic_operations_with_schema(self):

# Verify schema is preserved
assert store.db_schema == "test_schema"

def test_schema_name_validation_valid(self):
"""Test that valid schema names are accepted."""
valid_names = [
"test_schema",
"TestSchema",
"_private_schema",
"schema123",
"a" * 63,
]
for name in valid_names:
store = SQLAlchemyChatStore(
table_name="test_messages",
async_database_uri="postgresql+asyncpg://user:pass@host/db",
db_schema=name,
)
assert store.db_schema == name

def test_schema_name_validation_invalid(self):
"""Test that invalid schema names are rejected to prevent SQL injection."""
invalid_names = [
"test-schema",
"test schema",
"test;DROP TABLE users;--",
"test' OR '1'='1",
'"; DROP SCHEMA public; --',
"123invalid",
"a" * 64,
"test\nschema",
"test\tschema",
"test'schema",
'test"schema',
"test;schema",
"test--schema",
"test/*schema",
]
for name in invalid_names:
with pytest.raises(ValueError, match="Invalid schema name"):
SQLAlchemyChatStore(
table_name="test_messages",
async_database_uri="postgresql+asyncpg://user:pass@host/db",
db_schema=name,
)

def test_schema_name_validation_sql_injection_attempts(self):
"""Test that SQL injection attempts in schema names are blocked."""
injection_attempts = [
"test'; DROP TABLE messages; --",
'test"; DROP SCHEMA public CASCADE; --',
"test' UNION SELECT * FROM users --",
"test'; CREATE USER hacker WITH PASSWORD 'pass'; --",
'test\\"; GRANT ALL PRIVILEGES ON DATABASE db TO hacker; --',
]
for attempt in injection_attempts:
with pytest.raises(ValueError, match="Invalid schema name"):
SQLAlchemyChatStore(
table_name="test_messages",
async_database_uri="postgresql+asyncpg://user:pass@host/db",
db_schema=attempt,
)