|
1 | | -from typing import Any, Mapping, Optional, Sequence, TypeVar, Union |
| 1 | +from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload |
2 | 2 |
|
3 | 3 | from sqlalchemy import util |
4 | 4 | from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession |
5 | | -from sqlalchemy.ext.asyncio import engine |
6 | 5 | from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine |
7 | | -from sqlalchemy.util.concurrency import greenlet_spawn |
8 | 6 | from sqlmodel.sql.base import Executable |
9 | 7 |
|
10 | | -from ...engine.result import ScalarResult |
| 8 | +from ...engine.result import Result, ScalarResult |
11 | 9 | from ...orm.session import Session |
12 | | -from ...sql.expression import Select |
| 10 | +from ...sql.expression import Select, SelectOfScalar |
13 | 11 |
|
14 | 12 | _T = TypeVar("_T") |
15 | 13 |
|
16 | 14 |
|
17 | 15 | class AsyncSession(_AsyncSession): |
18 | | - sync_session: Session |
19 | | - |
20 | 16 | def __init__( |
21 | 17 | self, |
22 | 18 | bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, |
23 | 19 | binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, |
24 | 20 | **kw: Any, |
25 | 21 | ): |
26 | | - # All the same code of the original AsyncSession |
27 | | - kw["future"] = True |
28 | | - if bind: |
29 | | - self.bind = bind |
30 | | - bind = engine._get_sync_engine_or_connection(bind) # type: ignore |
31 | | - |
32 | | - if binds: |
33 | | - self.binds = binds |
34 | | - binds = { |
35 | | - key: engine._get_sync_engine_or_connection(b) # type: ignore |
36 | | - for key, b in binds.items() |
37 | | - } |
38 | | - |
39 | | - self.sync_session = self._proxied = self._assign_proxied( # type: ignore |
40 | | - Session(bind=bind, binds=binds, **kw) # type: ignore |
41 | | - ) |
| 22 | + opts = dict(expire_on_commit=False) |
| 23 | + super().__init__(bind, binds, sync_session_class=Session, **{**opts, **kw}) |
42 | 24 |
|
| 25 | + @overload |
43 | 26 | async def exec( |
44 | 27 | self, |
45 | | - statement: Union[Select[_T], Executable[_T]], |
| 28 | + statement: Select[_T], |
| 29 | + *, |
46 | 30 | params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, |
47 | | - execution_options: Mapping[Any, Any] = util.EMPTY_DICT, |
| 31 | + execution_options: Mapping[str, Any] = util.EMPTY_DICT, |
48 | 32 | bind_arguments: Optional[Mapping[str, Any]] = None, |
| 33 | + _parent_execute_state: Optional[Any] = None, |
| 34 | + _add_event: Optional[Any] = None, |
| 35 | + **kw: Any, |
| 36 | + ) -> Result[_T]: |
| 37 | + ... |
| 38 | + |
| 39 | + @overload |
| 40 | + async def exec( |
| 41 | + self, |
| 42 | + statement: SelectOfScalar[_T], |
| 43 | + *, |
| 44 | + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, |
| 45 | + execution_options: Mapping[str, Any] = util.EMPTY_DICT, |
| 46 | + bind_arguments: Optional[Mapping[str, Any]] = None, |
| 47 | + _parent_execute_state: Optional[Any] = None, |
| 48 | + _add_event: Optional[Any] = None, |
49 | 49 | **kw: Any, |
50 | 50 | ) -> ScalarResult[_T]: |
51 | | - # TODO: the documentation says execution_options accepts a dict, but only |
52 | | - # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? |
53 | | - execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore |
| 51 | + ... |
54 | 52 |
|
55 | | - return await greenlet_spawn( |
56 | | - self.sync_session.exec, |
| 53 | + async def exec( |
| 54 | + self, |
| 55 | + statement: Union[ |
| 56 | + Select[_T], |
| 57 | + SelectOfScalar[_T], |
| 58 | + Executable[_T], |
| 59 | + ], |
| 60 | + *, |
| 61 | + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, |
| 62 | + execution_options: Mapping[str, Any] = util.EMPTY_DICT, |
| 63 | + bind_arguments: Optional[Mapping[str, Any]] = None, |
| 64 | + _parent_execute_state: Optional[Any] = None, |
| 65 | + _add_event: Optional[Any] = None, |
| 66 | + **kw: Any, |
| 67 | + ) -> Union[Result[_T], ScalarResult[_T]]: |
| 68 | + results = await super().execute( |
57 | 69 | statement, |
58 | 70 | params=params, |
59 | 71 | execution_options=execution_options, |
60 | 72 | bind_arguments=bind_arguments, |
| 73 | + _parent_execute_state=_parent_execute_state, |
| 74 | + _add_event=_add_event, |
61 | 75 | **kw, |
62 | 76 | ) |
| 77 | + if isinstance(statement, SelectOfScalar): |
| 78 | + return results.scalars() # type: ignore |
| 79 | + return results # type: ignore |
0 commit comments