Skip to content

Commit 317b0f3

Browse files
authored
Support async sessions in strawberry_sqlalchemy_mapper (#53)
I left a note in the loader's init, but to duplicate that here: Making blocking database calls from within an async function (the resolver) has potentially catastrophic performance implications. Not only will all field resolution be effectively serialized, any other coroutines waiting on the event loop (e.g. concurrent requests in a web server), will be blocked as well, grinding your entire service to a halt. More discussion needs to happen about what we do with the sync bind.
1 parent 9746996 commit 317b0f3

File tree

6 files changed

+187
-15
lines changed

6 files changed

+187
-15
lines changed

RELEASE.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Release type: minor
2+
3+
Adds support for async sessions. To use:
4+
5+
```python
6+
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
7+
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyLoader
8+
9+
url = "postgresql://..."
10+
engine = create_async_engine(url)
11+
sessionmaker = async_sessionmaker(engine)
12+
13+
loader = StrawberrySQLAlchemyLoader(async_bind_factory=sessionmaker)
14+
```

poetry.lock

Lines changed: 55 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ version_scheme = "no-guess-dev"
4040

4141
[tool.poetry.dependencies]
4242
python = "^3.8"
43-
sqlalchemy = ">=1.4"
43+
sqlalchemy = {extras = ["asyncio"], version = ">=1.4"}
4444
strawberry-graphql = ">=0.95"
4545
sentinel = ">=0.3,<1.1"
4646
greenlet = {version = ">=3.0.0rc1", python = ">=3.12"}
4747

4848
[tool.poetry.group.dev.dependencies]
49+
asyncpg = "^0.28.0"
4950
black = ">=22,<24"
5051
importlib-metadata = ">=4.11.1,<7.0.0"
5152
mypy = "1.5.1"
@@ -106,6 +107,8 @@ ignore = [
106107
# we'd want to have consistent docstrings in future
107108
"D",
108109
"ANN001", # missing annotation for function argument self.
110+
"ANN002", # missing annotation for *args.
111+
"ANN003", # missing annotatino for **kwargs.
109112
"ANN101", # missing annotation for self?
110113
# definitely enable these, maybe not in tests
111114
"ANN102",

src/strawberry_sqlalchemy_mapper/loader.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1+
import logging
12
from collections import defaultdict
2-
from typing import Any, Dict, List, Mapping, Tuple, Union
3+
from typing import (
4+
Any,
5+
AsyncContextManager,
6+
Callable,
7+
Dict,
8+
List,
9+
Mapping,
10+
Optional,
11+
Tuple,
12+
Union,
13+
)
314

415
from sqlalchemy import select, tuple_
516
from sqlalchemy.engine.base import Connection
17+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
618
from sqlalchemy.orm import RelationshipProperty, Session
719
from strawberry.dataloader import DataLoader
820

@@ -14,9 +26,32 @@ class StrawberrySQLAlchemyLoader:
1426

1527
_loaders: Dict[RelationshipProperty, DataLoader]
1628

17-
def __init__(self, bind: Union[Session, Connection]) -> None:
29+
def __init__(
30+
self,
31+
bind: Union[Session, Connection, None] = None,
32+
async_bind_factory: Optional[
33+
Union[
34+
Callable[[], AsyncContextManager[AsyncSession]],
35+
Callable[[], AsyncContextManager[AsyncConnection]],
36+
]
37+
] = None,
38+
) -> None:
1839
self._loaders = {}
19-
self.bind = bind
40+
self._bind = bind
41+
self._async_bind_factory = async_bind_factory
42+
self._logger = logging.getLogger("strawberry_sqlalchemy_mapper")
43+
if bind is None and async_bind_factory is None:
44+
self._logger.warning(
45+
"One of bind or async_bind_factory must be set for loader to function properly."
46+
)
47+
48+
async def _scalars(self, *args, **kwargs):
49+
if self._async_bind_factory:
50+
async with self._async_bind_factory() as bind:
51+
return await bind.scalars(*args, **kwargs)
52+
else:
53+
assert self._bind is not None
54+
return self._bind.scalars(*args, **kwargs)
2055

2156
def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
2257
"""
@@ -35,7 +70,7 @@ async def load_fn(keys: List[Tuple]) -> List[Any]:
3570
)
3671
if relationship.order_by:
3772
query = query.order_by(*relationship.order_by)
38-
rows = self.bind.scalars(query).all()
73+
rows = (await self._scalars(query)).all()
3974

4075
def group_by_remote_key(row: Any) -> Tuple:
4176
return tuple(

tests/conftest.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from packaging import version
1818
from sqlalchemy import orm
1919
from sqlalchemy.engine import Engine
20+
from sqlalchemy.ext import asyncio
21+
from sqlalchemy.ext.asyncio import AsyncAttrs, create_async_engine
22+
from sqlalchemy.ext.asyncio.engine import AsyncEngine
2023
from testing.postgresql import Postgresql, PostgresqlFactory
2124

2225
SQLA_VERSION = version.parse(sqlalchemy.__version__)
@@ -57,7 +60,11 @@ def postgresql(postgresql_factory) -> Postgresql:
5760
@pytest.fixture(params=SUPPORTED_DBS)
5861
def engine(request) -> Engine:
5962
if request.param == "postgresql":
60-
url = request.getfixturevalue("postgresql").url()
63+
url = (
64+
request.getfixturevalue("postgresql")
65+
.url()
66+
.replace("postgresql://", "postgresql+psycopg2://")
67+
)
6168
else:
6269
raise ValueError("Unsupported database: %s", request.param)
6370
kwargs = {}
@@ -72,6 +79,31 @@ def sessionmaker(engine) -> orm.sessionmaker:
7279
return orm.sessionmaker(autocommit=False, autoflush=False, bind=engine)
7380

7481

82+
@pytest.fixture(params=SUPPORTED_DBS)
83+
def async_engine(request) -> AsyncEngine:
84+
if request.param == "postgresql":
85+
url = (
86+
request.getfixturevalue("postgresql")
87+
.url()
88+
.replace("postgresql://", "postgresql+asyncpg://")
89+
)
90+
else:
91+
raise ValueError("Unsupported database: %s", request.param)
92+
kwargs = {}
93+
if not SQLA2:
94+
kwargs["future"] = True
95+
engine = create_async_engine(url, **kwargs)
96+
return engine
97+
98+
99+
@pytest.fixture
100+
def async_sessionmaker(async_engine) -> asyncio.async_sessionmaker:
101+
return asyncio.async_sessionmaker(async_engine)
102+
103+
75104
@pytest.fixture
76105
def base():
77-
return orm.declarative_base()
106+
class Base(AsyncAttrs, orm.DeclarativeBase):
107+
pass
108+
109+
return Base

tests/test_loader.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Department(base):
5858

5959
def test_loader_init():
6060
loader = StrawberrySQLAlchemyLoader(bind=None)
61-
assert loader.bind is None
61+
assert loader._bind is None
6262
assert loader._loaders == {}
6363

6464

@@ -99,14 +99,49 @@ async def test_loader_for(engine, base, sessionmaker, many_to_one_tables):
9999
assert department.name == "d2"
100100

101101
loader = base_loader.loader_for(Department.employees.property)
102-
key = tuple(
102+
103+
employees = await loader.load((d2.id,))
104+
assert {e.name for e in employees} == {"e1"}
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_loader_with_async_session(
109+
async_engine, base, async_sessionmaker, many_to_one_tables
110+
):
111+
Employee, Department = many_to_one_tables
112+
async with async_engine.begin() as conn:
113+
await conn.run_sync(base.metadata.create_all)
114+
115+
async with async_sessionmaker() as session:
116+
e1 = Employee(name="e1")
117+
e2 = Employee(name="e2")
118+
d1 = Department(name="d1")
119+
d2 = Department(name="d2")
120+
session.add(e1)
121+
session.add(e2)
122+
session.add(d1)
123+
session.add(d2)
124+
await session.flush()
125+
126+
e1.department = d2
127+
e2.department = d1
128+
await session.commit()
129+
d2_id = await d2.awaitable_attrs.id
130+
department_loader_key = tuple(
103131
[
104-
getattr(d2, local.key)
105-
for local, _ in Department.employees.property.local_remote_pairs
132+
await getattr(e1.awaitable_attrs, local.key)
133+
for local, _ in Employee.department.property.local_remote_pairs
106134
]
107135
)
108-
employees = await loader.load((d2.id,))
109-
assert {e.name for e in employees} == {"e1"}
136+
base_loader = StrawberrySQLAlchemyLoader(async_bind_factory=async_sessionmaker)
137+
loader = base_loader.loader_for(Employee.department.property)
138+
139+
department = await loader.load(department_loader_key)
140+
assert department.name == "d2"
141+
142+
loader = base_loader.loader_for(Department.employees.property)
143+
employees = await loader.load((d2_id,))
144+
assert {e.name for e in employees} == {"e1"}
110145

111146

112147
@pytest.mark.xfail

0 commit comments

Comments
 (0)