diff --git a/llama-index-core/llama_index/core/storage/chat_store/sql.py b/llama-index-core/llama_index/core/storage/chat_store/sql.py index c692373b93..28875824fc 100644 --- a/llama-index-core/llama_index/core/storage/chat_store/sql.py +++ b/llama-index-core/llama_index/core/storage/chat_store/sql.py @@ -1,3 +1,4 @@ +import re import time from typing import Any, Dict, List, Optional, Tuple @@ -13,7 +14,6 @@ select, insert, update, - text, ) from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -21,6 +21,7 @@ create_async_engine, ) from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.schema import CreateSchema from llama_index.core.async_utils import asyncio_run from llama_index.core.bridge.pydantic import Field, PrivateAttr, model_serializer @@ -70,6 +71,8 @@ def __init__( db_schema: Optional[str] = None, ): """Initialize the SQLAlchemy chat store.""" + if db_schema is not None: + self._validate_schema_name(db_schema) super().__init__( table_name=table_name, async_database_uri=async_database_uri or DEFAULT_ASYNC_DATABASE_URI, @@ -78,6 +81,15 @@ def __init__( self._async_engine = async_engine self._db_data = db_data + @staticmethod + def _validate_schema_name(schema_name: str) -> None: + """Validate schema name to prevent SQL injection.""" + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]{0,62}$", schema_name): + raise ValueError( + f"Invalid schema name: {schema_name}. Schema names must start with a letter or underscore, " + "contain only alphanumeric characters and underscores, and be at most 63 characters long." + ) + @staticmethod def _is_in_memory_uri(uri: Optional[str]) -> bool: """Check if the URI points to an in-memory SQLite database.""" @@ -140,9 +152,7 @@ async def _setup_tables(self, async_engine: AsyncEngine) -> Table: # Create schema if it doesn't exist (PostgreSQL, SQL Server, etc.) async with async_engine.begin() as conn: - await conn.execute( - text(f'CREATE SCHEMA IF NOT EXISTS "{self.db_schema}"') - ) + await conn.execute(CreateSchema(self.db_schema, if_not_exists=True)) # Create messages table with status column self._table = Table( diff --git a/llama-index-core/tests/storage/chat_store/test_sql_schema.py b/llama-index-core/tests/storage/chat_store/test_sql_schema.py index 54e9adc1e7..7b1985a958 100644 --- a/llama-index-core/tests/storage/chat_store/test_sql_schema.py +++ b/llama-index-core/tests/storage/chat_store/test_sql_schema.py @@ -5,6 +5,7 @@ from llama_index.core.base.llms.types import ChatMessage from llama_index.core.storage.chat_store.sql import SQLAlchemyChatStore +from sqlalchemy.schema import CreateSchema class TestSQLAlchemyChatStoreSchema: @@ -47,7 +48,7 @@ def test_schema_serialization(self): @pytest.mark.asyncio async def test_postgresql_schema_creation(self): - """Test that CREATE SCHEMA SQL is called for PostgreSQL.""" + """Test that CREATE SCHEMA SQL is called for PostgreSQL using CreateSchema.""" store = SQLAlchemyChatStore( table_name="test_messages", async_database_uri="postgresql+asyncpg://user:pass@host/db", @@ -69,10 +70,12 @@ async def test_postgresql_schema_creation(self): # Call _setup_tables await store._setup_tables(async_engine) - # Verify schema creation was called + # Verify schema creation was called with CreateSchema mock_conn.execute.assert_called() call_args = mock_conn.execute.call_args_list[0][0][0] - assert 'CREATE SCHEMA IF NOT EXISTS "test_schema"' in str(call_args) + assert isinstance(call_args, CreateSchema) + assert call_args.element == "test_schema" + assert call_args.if_not_exists is True # Verify MetaData has schema assert store._metadata.schema == "test_schema" @@ -120,3 +123,63 @@ async def test_basic_operations_with_schema(self): # Verify schema is preserved assert store.db_schema == "test_schema" + + def test_schema_name_validation_valid(self): + """Test that valid schema names are accepted.""" + valid_names = [ + "test_schema", + "TestSchema", + "_private_schema", + "schema123", + "a" * 63, + ] + for name in valid_names: + store = SQLAlchemyChatStore( + table_name="test_messages", + async_database_uri="postgresql+asyncpg://user:pass@host/db", + db_schema=name, + ) + assert store.db_schema == name + + def test_schema_name_validation_invalid(self): + """Test that invalid schema names are rejected to prevent SQL injection.""" + invalid_names = [ + "test-schema", + "test schema", + "test;DROP TABLE users;--", + "test' OR '1'='1", + '"; DROP SCHEMA public; --', + "123invalid", + "a" * 64, + "test\nschema", + "test\tschema", + "test'schema", + 'test"schema', + "test;schema", + "test--schema", + "test/*schema", + ] + for name in invalid_names: + with pytest.raises(ValueError, match="Invalid schema name"): + SQLAlchemyChatStore( + table_name="test_messages", + async_database_uri="postgresql+asyncpg://user:pass@host/db", + db_schema=name, + ) + + def test_schema_name_validation_sql_injection_attempts(self): + """Test that SQL injection attempts in schema names are blocked.""" + injection_attempts = [ + "test'; DROP TABLE messages; --", + 'test"; DROP SCHEMA public CASCADE; --', + "test' UNION SELECT * FROM users --", + "test'; CREATE USER hacker WITH PASSWORD 'pass'; --", + 'test\\"; GRANT ALL PRIVILEGES ON DATABASE db TO hacker; --', + ] + for attempt in injection_attempts: + with pytest.raises(ValueError, match="Invalid schema name"): + SQLAlchemyChatStore( + table_name="test_messages", + async_database_uri="postgresql+asyncpg://user:pass@host/db", + db_schema=attempt, + )