Skip to content

Commit f5442e1

Browse files
committed
Support async sessions in strawberry_sqlalchemy_mapper and deprecate sync.
I left a note in the loader's __init__, but to duplicate that here: Making blocking database calls from within an async function (the resolver) has catastrophic performance implications. Not only will all resolvers be effectively serialized, any other coroutines waiting on the event loop (e.g. concurrent requests in a web server), will be blocked as well, grinding your entire service to a halt. There's no reason for us to support a foot bazooka, hence the deprecation.
1 parent 070ec9a commit f5442e1

File tree

6 files changed

+171
-15
lines changed

6 files changed

+171
-15
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Release type: minor
2+
3+
Adds support for async sessions and deprecates sync sessions due to performance reasons.

poetry.lock

Lines changed: 55 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ version_scheme = "no-guess-dev"
4040

4141
[tool.poetry.dependencies]
4242
python = "^3.8"
43-
sqlalchemy = ">=1.4"
43+
sqlalchemy = {extras = ["asyncio"], version = ">=1.4"}
4444
strawberry-graphql = ">=0.95"
4545
sentinel = ">=0.3,<1.1"
4646
greenlet = {version = ">=3.0.0rc1", python = ">=3.12"}
4747

4848
[tool.poetry.group.dev.dependencies]
49+
asyncpg = "^0.28.0"
4950
black = ">=22,<24"
5051
importlib-metadata = ">=4.11.1,<7.0.0"
5152
mypy = "1.5.1"
@@ -106,6 +107,8 @@ ignore = [
106107
# we'd want to have consistent docstrings in future
107108
"D",
108109
"ANN001", # missing annotation for function argument self.
110+
"ANN002", # missing annotation for *args.
111+
"ANN003", # missing annotatino for **kwargs.
109112
"ANN101", # missing annotation for self?
110113
# definitely enable these, maybe not in tests
111114
"ANN102",

src/strawberry_sqlalchemy_mapper/loader.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import logging
12
from collections import defaultdict
2-
from typing import Any, Dict, List, Mapping, Tuple, Union
3+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
34

45
from sqlalchemy import select, tuple_
56
from sqlalchemy.engine.base import Connection
7+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
68
from sqlalchemy.orm import RelationshipProperty, Session
79
from strawberry.dataloader import DataLoader
810

@@ -14,9 +16,37 @@ class StrawberrySQLAlchemyLoader:
1416

1517
_loaders: Dict[RelationshipProperty, DataLoader]
1618

17-
def __init__(self, bind: Union[Session, Connection]) -> None:
19+
def __init__(
20+
self,
21+
bind: Union[Session, Connection, None] = None,
22+
async_bind_factory: Optional[
23+
Callable[[], Union[AsyncSession, AsyncConnection]]
24+
] = None,
25+
) -> None:
1826
self._loaders = {}
19-
self.bind = bind
27+
self._bind = bind
28+
self._async_bind_factory = async_bind_factory
29+
self._logger = logging.getLogger("strawberry_sqlalchemy_mapper")
30+
if bind is None and async_bind_factory is None:
31+
self._logger.warning(
32+
"One of bind or async_bind_factory must be set for loader to function properly."
33+
)
34+
if bind is not None:
35+
# For anyone coming here because of this warning:
36+
# Making blocking database calls from within an async function (the resolver) has
37+
# catastrophic performance implications. Not only will all resolvers be effectively
38+
# serialized, any other coroutines waiting on the event loop (e.g. concurrent requests
39+
# in a web server), will be blocked as well, grinding your entire service to a halt.
40+
self._logger.warning(
41+
"`bind` parameter is deprecated due to performance issues. Use `async_bind_factory` instead."
42+
)
43+
44+
async def _scalars(self, *args, **kwargs):
45+
if self._async_bind_factory:
46+
return await self._async_bind_factory().scalars(*args, **kwargs)
47+
else:
48+
# Deprecated, but supported for now.
49+
return self._bind.scalars(*args, **kwargs)
2050

2151
def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
2252
"""
@@ -35,7 +65,7 @@ async def load_fn(keys: List[Tuple]) -> List[Any]:
3565
)
3666
if relationship.order_by:
3767
query = query.order_by(*relationship.order_by)
38-
rows = self.bind.scalars(query).all()
68+
rows = (await self._scalars(query)).all()
3969

4070
def group_by_remote_key(row: Any) -> Tuple:
4171
return tuple(

tests/conftest.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from packaging import version
1818
from sqlalchemy import orm
1919
from sqlalchemy.engine import Engine
20+
from sqlalchemy.ext import asyncio
21+
from sqlalchemy.ext.asyncio import AsyncAttrs, create_async_engine
22+
from sqlalchemy.ext.asyncio.engine import AsyncEngine
2023
from testing.postgresql import Postgresql, PostgresqlFactory
2124

2225
SQLA_VERSION = version.parse(sqlalchemy.__version__)
@@ -57,7 +60,11 @@ def postgresql(postgresql_factory) -> Postgresql:
5760
@pytest.fixture(params=SUPPORTED_DBS)
5861
def engine(request) -> Engine:
5962
if request.param == "postgresql":
60-
url = request.getfixturevalue("postgresql").url()
63+
url = (
64+
request.getfixturevalue("postgresql")
65+
.url()
66+
.replace("postgresql://", "postgresql+psycopg2://")
67+
)
6168
else:
6269
raise ValueError("Unsupported database: %s", request.param)
6370
kwargs = {}
@@ -72,6 +79,31 @@ def sessionmaker(engine) -> orm.sessionmaker:
7279
return orm.sessionmaker(autocommit=False, autoflush=False, bind=engine)
7380

7481

82+
@pytest.fixture(params=SUPPORTED_DBS)
83+
def async_engine(request) -> AsyncEngine:
84+
if request.param == "postgresql":
85+
url = (
86+
request.getfixturevalue("postgresql")
87+
.url()
88+
.replace("postgresql://", "postgresql+asyncpg://")
89+
)
90+
else:
91+
raise ValueError("Unsupported database: %s", request.param)
92+
kwargs = {}
93+
if not SQLA2:
94+
kwargs["future"] = True
95+
engine = create_async_engine(url, **kwargs)
96+
return engine
97+
98+
99+
@pytest.fixture
100+
def async_sessionmaker(async_engine) -> asyncio.async_sessionmaker:
101+
return asyncio.async_sessionmaker(async_engine)
102+
103+
75104
@pytest.fixture
76105
def base():
77-
return orm.declarative_base()
106+
class Base(AsyncAttrs, orm.DeclarativeBase):
107+
pass
108+
109+
return Base

tests/test_loader.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Department(base):
5858

5959
def test_loader_init():
6060
loader = StrawberrySQLAlchemyLoader(bind=None)
61-
assert loader.bind is None
61+
assert loader._bind is None
6262
assert loader._loaders == {}
6363

6464

@@ -99,14 +99,49 @@ async def test_loader_for(engine, base, sessionmaker, many_to_one_tables):
9999
assert department.name == "d2"
100100

101101
loader = base_loader.loader_for(Department.employees.property)
102-
key = tuple(
102+
103+
employees = await loader.load((d2.id,))
104+
assert {e.name for e in employees} == {"e1"}
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_loader_with_async_session(
109+
async_engine, base, async_sessionmaker, many_to_one_tables
110+
):
111+
Employee, Department = many_to_one_tables
112+
async with async_engine.begin() as conn:
113+
await conn.run_sync(base.metadata.create_all)
114+
115+
async with async_sessionmaker() as session:
116+
e1 = Employee(name="e1")
117+
e2 = Employee(name="e2")
118+
d1 = Department(name="d1")
119+
d2 = Department(name="d2")
120+
session.add(e1)
121+
session.add(e2)
122+
session.add(d1)
123+
session.add(d2)
124+
await session.flush()
125+
126+
e1.department = d2
127+
e2.department = d1
128+
await session.commit()
129+
d2_id = await d2.awaitable_attrs.id
130+
department_loader_key = tuple(
103131
[
104-
getattr(d2, local.key)
105-
for local, _ in Department.employees.property.local_remote_pairs
132+
await getattr(e1.awaitable_attrs, local.key)
133+
for local, _ in Employee.department.property.local_remote_pairs
106134
]
107135
)
108-
employees = await loader.load((d2.id,))
109-
assert {e.name for e in employees} == {"e1"}
136+
base_loader = StrawberrySQLAlchemyLoader(async_bind_factory=async_sessionmaker)
137+
loader = base_loader.loader_for(Employee.department.property)
138+
139+
department = await loader.load(department_loader_key)
140+
assert department.name == "d2"
141+
142+
loader = base_loader.loader_for(Department.employees.property)
143+
employees = await loader.load((d2_id,))
144+
assert {e.name for e in employees} == {"e1"}
110145

111146

112147
@pytest.mark.xfail

0 commit comments

Comments
 (0)