diff --git a/reflex/config.py b/reflex/config.py index 6977d74563..a800895154 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -184,6 +184,9 @@ class BaseConfig: # The async database url used by rx.Model. async_db_url: str | None = None + # The arguments to pass to the sqlalchemy engine. + connect_args: dict[str, Any] = dataclasses.field(default_factory=dict) + # The redis url redis_url: str | None = None @@ -312,7 +315,7 @@ class Config(BaseConfig): - **App Settings**: `app_name`, `loglevel`, `telemetry_enabled` - **Server**: `frontend_port`, `backend_port`, `api_url`, `cors_allowed_origins` - - **Database**: `db_url`, `async_db_url`, `redis_url` + - **Database**: `db_url`, `async_db_url`, `redis_url`, `connect_args` - **Frontend**: `frontend_packages`, `react_strict_mode` - **State Management**: `state_manager_mode`, `state_auto_setters` - **Plugins**: `plugins`, `disable_plugins` diff --git a/reflex/model.py b/reflex/model.py index acfa877988..b7e6bf620b 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -82,9 +82,12 @@ def get_engine_args(url: str | None = None) -> dict[str, Any]: } conf = get_config() url = url or conf.db_url + connect_args = conf.connect_args.copy() if conf.connect_args else {} if url is not None and url.startswith("sqlite"): # Needed for the admin dash on sqlite. - kwargs["connect_args"] = {"check_same_thread": False} + connect_args["check_same_thread"] = False + if connect_args: + kwargs["connect_args"] = connect_args return kwargs def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: diff --git a/tests/units/test_connect_args.py b/tests/units/test_connect_args.py new file mode 100644 index 0000000000..90755f3da3 --- /dev/null +++ b/tests/units/test_connect_args.py @@ -0,0 +1,34 @@ +"""Test connecting to the database with custom arguments.""" +from unittest import mock + +from reflex.model import get_engine_args + + +def test_get_engine_args_connect_args(): + """Test that connect_args are correctly retrieved from config.""" + with mock.patch("reflex.model.get_config") as mock_get_config: + # Case 1: Postgres with connect_args + mock_conf = mock.Mock() + mock_conf.db_url = "postgresql://user:pass@localhost/db" + mock_conf.connect_args = {"application_name": "test_app"} + mock_get_config.return_value = mock_conf + + args = get_engine_args() + assert "connect_args" in args + assert args["connect_args"]["application_name"] == "test_app" + + # Case 2: Sqlite with connect_args + mock_conf.db_url = "sqlite:///test.db" + mock_conf.connect_args = {"timeout": 10} + + args = get_engine_args() + assert "connect_args" in args + assert args["connect_args"]["timeout"] == 10 + # Ensure sqlite specific check_same_thread is set + assert args["connect_args"]["check_same_thread"] is False + + # Case 3: Sqlite without connect_args + mock_conf.connect_args = {} + args = get_engine_args() + assert "connect_args" in args + assert args["connect_args"]["check_same_thread"] is False