Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class BaseConfig:
# The async database url used by rx.Model.
async_db_url: str | None = None

# The arguments to pass to the sqlalchemy engine.
connect_args: dict[str, Any] = dataclasses.field(default_factory=dict)

# The redis url
redis_url: str | None = None

Expand Down Expand Up @@ -312,7 +315,7 @@ class Config(BaseConfig):

- **App Settings**: `app_name`, `loglevel`, `telemetry_enabled`
- **Server**: `frontend_port`, `backend_port`, `api_url`, `cors_allowed_origins`
- **Database**: `db_url`, `async_db_url`, `redis_url`
- **Database**: `db_url`, `async_db_url`, `redis_url`, `connect_args`
- **Frontend**: `frontend_packages`, `react_strict_mode`
- **State Management**: `state_manager_mode`, `state_auto_setters`
- **Plugins**: `plugins`, `disable_plugins`
Expand Down
5 changes: 4 additions & 1 deletion reflex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def get_engine_args(url: str | None = None) -> dict[str, Any]:
}
conf = get_config()
url = url or conf.db_url
connect_args = conf.connect_args.copy() if conf.connect_args else {}
if url is not None and url.startswith("sqlite"):
# Needed for the admin dash on sqlite.
kwargs["connect_args"] = {"check_same_thread": False}
connect_args["check_same_thread"] = False
if connect_args:
kwargs["connect_args"] = connect_args
return kwargs

def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
Expand Down
34 changes: 34 additions & 0 deletions tests/units/test_connect_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Test connecting to the database with custom arguments."""
from unittest import mock

from reflex.model import get_engine_args


def test_get_engine_args_connect_args():
"""Test that connect_args are correctly retrieved from config."""
with mock.patch("reflex.model.get_config") as mock_get_config:
# Case 1: Postgres with connect_args
mock_conf = mock.Mock()
mock_conf.db_url = "postgresql://user:pass@localhost/db"
mock_conf.connect_args = {"application_name": "test_app"}
mock_get_config.return_value = mock_conf

args = get_engine_args()
assert "connect_args" in args
assert args["connect_args"]["application_name"] == "test_app"

# Case 2: Sqlite with connect_args
mock_conf.db_url = "sqlite:///test.db"
mock_conf.connect_args = {"timeout": 10}

args = get_engine_args()
assert "connect_args" in args
assert args["connect_args"]["timeout"] == 10
# Ensure sqlite specific check_same_thread is set
assert args["connect_args"]["check_same_thread"] is False

# Case 3: Sqlite without connect_args
mock_conf.connect_args = {}
args = get_engine_args()
assert "connect_args" in args
assert args["connect_args"]["check_same_thread"] is False
Loading