Skip to content

Commit 31f4ab2

Browse files
authored
Merge pull request #5 from smaximoff/async
Use `sqlalchemy.ext.asyncio.AsyncSession` directly
2 parents 5c9b9cc + c730058 commit 31f4ab2

File tree

2 files changed

+52
-30
lines changed

2 files changed

+52
-30
lines changed

sqlmodel/ext/asyncio/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .session import AsyncSession
2+
3+
__all__ = [
4+
"AsyncSession",
5+
]

sqlmodel/ext/asyncio/session.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,79 @@
1-
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
1+
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
22

33
from sqlalchemy import util
44
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
5-
from sqlalchemy.ext.asyncio import engine
65
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
7-
from sqlalchemy.util.concurrency import greenlet_spawn
86
from sqlmodel.sql.base import Executable
97

10-
from ...engine.result import ScalarResult
8+
from ...engine.result import Result, ScalarResult
119
from ...orm.session import Session
12-
from ...sql.expression import Select
10+
from ...sql.expression import Select, SelectOfScalar
1311

1412
_T = TypeVar("_T")
1513

1614

1715
class AsyncSession(_AsyncSession):
18-
sync_session: Session
19-
2016
def __init__(
2117
self,
2218
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
2319
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
2420
**kw: Any,
2521
):
26-
# All the same code of the original AsyncSession
27-
kw["future"] = True
28-
if bind:
29-
self.bind = bind
30-
bind = engine._get_sync_engine_or_connection(bind) # type: ignore
31-
32-
if binds:
33-
self.binds = binds
34-
binds = {
35-
key: engine._get_sync_engine_or_connection(b) # type: ignore
36-
for key, b in binds.items()
37-
}
38-
39-
self.sync_session = self._proxied = self._assign_proxied( # type: ignore
40-
Session(bind=bind, binds=binds, **kw) # type: ignore
41-
)
22+
opts = dict(expire_on_commit=False)
23+
super().__init__(bind, binds, sync_session_class=Session, **{**opts, **kw})
4224

25+
@overload
4326
async def exec(
4427
self,
45-
statement: Union[Select[_T], Executable[_T]],
28+
statement: Select[_T],
29+
*,
4630
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
47-
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
31+
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
4832
bind_arguments: Optional[Mapping[str, Any]] = None,
33+
_parent_execute_state: Optional[Any] = None,
34+
_add_event: Optional[Any] = None,
35+
**kw: Any,
36+
) -> Result[_T]:
37+
...
38+
39+
@overload
40+
async def exec(
41+
self,
42+
statement: SelectOfScalar[_T],
43+
*,
44+
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
45+
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
46+
bind_arguments: Optional[Mapping[str, Any]] = None,
47+
_parent_execute_state: Optional[Any] = None,
48+
_add_event: Optional[Any] = None,
4949
**kw: Any,
5050
) -> ScalarResult[_T]:
51-
# TODO: the documentation says execution_options accepts a dict, but only
52-
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
53-
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
51+
...
5452

55-
return await greenlet_spawn(
56-
self.sync_session.exec,
53+
async def exec(
54+
self,
55+
statement: Union[
56+
Select[_T],
57+
SelectOfScalar[_T],
58+
Executable[_T],
59+
],
60+
*,
61+
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
62+
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
63+
bind_arguments: Optional[Mapping[str, Any]] = None,
64+
_parent_execute_state: Optional[Any] = None,
65+
_add_event: Optional[Any] = None,
66+
**kw: Any,
67+
) -> Union[Result[_T], ScalarResult[_T]]:
68+
results = await super().execute(
5769
statement,
5870
params=params,
5971
execution_options=execution_options,
6072
bind_arguments=bind_arguments,
73+
_parent_execute_state=_parent_execute_state,
74+
_add_event=_add_event,
6175
**kw,
6276
)
77+
if isinstance(statement, SelectOfScalar):
78+
return results.scalars() # type: ignore
79+
return results # type: ignore

0 commit comments

Comments
 (0)