Skip to content

Commit b78155f

Browse files
authored
Improve ext.asyncio.session (#52)
* Improve `ext.asyncio.session` * Add `AsyncSession.delete()` to stubs * Fix bind and get_bind types
1 parent 6ee174c commit b78155f

File tree

1 file changed

+109
-45
lines changed

1 file changed

+109
-45
lines changed
Lines changed: 109 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,72 @@
11
from typing import Any
22
from typing import Callable
3+
from typing import ContextManager
4+
from typing import Iterable
5+
from typing import Iterator
36
from typing import Mapping
47
from typing import Optional
8+
from typing import Sequence
59
from typing import TypeVar
10+
from typing import Union
611

7-
from . import engine as engine
8-
from .base import StartableContext as StartableContext
9-
from .engine import AsyncEngine as AsyncEngine
10-
from ... import util as util
11-
from ...engine import Result as Result
12-
from ...orm import Session as Session
13-
from ...sql import Executable as Executable
14-
from ...util.concurrency import greenlet_spawn as greenlet_spawn
12+
from .base import StartableContext
13+
from .engine import AsyncConnection
14+
from .engine import AsyncEngine
15+
from .result import AsyncResult
16+
from ... import util
17+
from ...engine import Result
18+
from ...engine.base import _ExecutionOptions
19+
from ...orm import Session
20+
from ...orm.session import _IdentityMap
21+
from ...sql import ClauseElement
22+
from ...sql import Executable
1523

1624
_T = TypeVar("_T")
25+
_TAsyncSession = TypeVar("_TAsyncSession", bound=AsyncSession)
26+
_TAsyncSessionTransaction = TypeVar(
27+
"_TAsyncSessionTransaction", bound=AsyncSessionTransaction
28+
)
1729

1830
class AsyncSession:
1931
dispatch: Any = ...
2032
bind: Any = ...
2133
binds: Any = ...
22-
sync_session: Any = ...
34+
sync_session: Session = ...
2335
def __init__(
2436
self,
25-
bind: AsyncEngine = ...,
26-
binds: Mapping[object, AsyncEngine] = ...,
37+
bind: Optional[Union[AsyncConnection, AsyncEngine]] = ...,
38+
binds: Optional[
39+
Mapping[object, Union[AsyncConnection, AsyncEngine]]
40+
] = ...,
2741
**kw: Any,
2842
) -> None: ...
2943
async def refresh(
3044
self,
3145
instance: Any,
3246
attribute_names: Optional[Any] = ...,
3347
with_for_update: Optional[Any] = ...,
34-
): ...
48+
) -> None: ...
3549
async def run_sync(
3650
self, fn: Callable[..., _T], *arg: Any, **kw: Any
3751
) -> _T: ...
3852
async def execute(
3953
self,
4054
statement: Executable,
41-
params: Optional[Mapping] = ...,
42-
execution_options: Mapping = ...,
43-
bind_arguments: Optional[Mapping] = ...,
55+
params: Optional[
56+
Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]
57+
] = ...,
58+
execution_options: Optional[_ExecutionOptions] = ...,
59+
bind_arguments: Optional[Mapping[str, Any]] = ...,
4460
**kw: Any,
4561
) -> Result: ...
4662
async def scalar(
4763
self,
4864
statement: Executable,
49-
params: Optional[Mapping] = ...,
50-
execution_options: Mapping = ...,
51-
bind_arguments: Optional[Mapping] = ...,
65+
params: Optional[
66+
Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]
67+
] = ...,
68+
execution_options: _ExecutionOptions = ...,
69+
bind_arguments: Optional[Mapping[str, Any]] = ...,
5270
**kw: Any,
5371
) -> Any: ...
5472
async def get(
@@ -59,47 +77,93 @@ class AsyncSession:
5977
populate_existing: bool = ...,
6078
with_for_update: Optional[Any] = ...,
6179
identity_token: Optional[Any] = ...,
62-
): ...
80+
) -> Any: ...
6381
async def stream(
6482
self,
6583
statement: Any,
66-
params: Optional[Any] = ...,
67-
execution_options: Any = ...,
68-
bind_arguments: Optional[Any] = ...,
84+
params: Optional[
85+
Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]
86+
] = ...,
87+
execution_options: _ExecutionOptions = ...,
88+
bind_arguments: Optional[Mapping[str, Any]] = ...,
6989
**kw: Any,
70-
): ...
71-
async def merge(self, instance: Any, load: bool = ...): ...
90+
) -> AsyncResult: ...
91+
async def delete(self, instance: Any) -> None: ...
92+
async def merge(self, instance: _T, load: bool = ...) -> _T: ...
7293
async def flush(self, objects: Optional[Any] = ...) -> None: ...
73-
async def connection(self): ...
74-
def begin(self, **kw: Any): ...
75-
def begin_nested(self, **kw: Any): ...
76-
async def rollback(self): ...
77-
async def commit(self): ...
78-
async def close(self): ...
94+
async def connection(self) -> AsyncConnection: ...
95+
def begin(self, **kw: Any) -> AsyncSessionTransaction: ...
96+
def begin_nested(self, **kw: Any) -> AsyncSessionTransaction: ...
97+
async def rollback(self) -> None: ...
98+
async def commit(self) -> None: ...
99+
async def close(self) -> None: ...
79100
@classmethod
80-
async def close_all(self): ...
81-
async def __aenter__(self): ...
101+
async def close_all(self) -> None: ...
102+
async def __aenter__(self: _TAsyncSession) -> _TAsyncSession: ...
82103
async def __aexit__(
83104
self, type_: Any, value: Any, traceback: Any
84105
) -> None: ...
106+
# copied via create_proxy_methods
107+
def __contains__(self, instance: Any) -> bool: ...
108+
def __iter__(self) -> Iterator[Any]: ...
109+
def add(self, instance: Any, _warn: bool = ...) -> None: ...
110+
def add_all(self, instances: Any) -> None: ...
111+
def expire(
112+
self, instance: Any, attribute_names: Optional[Iterable[str]] = ...
113+
) -> None: ...
114+
def expire_all(self) -> None: ...
115+
def expunge(self, instance: Any) -> None: ...
116+
def expunge_all(self) -> None: ...
117+
def get_bind(
118+
self,
119+
mapper: Optional[Any] = ...,
120+
clause: Optional[ClauseElement] = ...,
121+
bind: Optional[Union[AsyncConnection, AsyncEngine]] = ...,
122+
_sa_skip_events: Optional[Any] = ...,
123+
_sa_skip_for_implicit_returning: bool = ...,
124+
) -> Union[AsyncConnection, AsyncEngine]: ...
125+
def is_modified(
126+
self, instance: Any, include_collections: bool = ...
127+
) -> bool: ...
128+
def in_transaction(self) -> bool: ...
129+
@property
130+
def dirty(self) -> util.IdentitySet[Any]: ...
131+
@property
132+
def deleted(self) -> util.IdentitySet[Any]: ...
133+
@property
134+
def new(self) -> util.IdentitySet[Any]: ...
135+
identity_map: _IdentityMap
136+
@property
137+
def is_active(self) -> bool: ...
138+
autoflush: bool
139+
@property
140+
def no_autoflush(
141+
self: _TAsyncSession,
142+
) -> ContextManager[_TAsyncSession]: ...
143+
@util.memoized_property
144+
def info(self) -> Mapping[Any, Any]: ...
85145

86146
class _AsyncSessionContextManager:
87-
async_session: Any = ...
88-
def __init__(self, async_session: Any) -> None: ...
89-
trans: Any = ...
90-
async def __aenter__(self): ...
147+
async_session: AsyncSession = ...
148+
trans: AsyncSessionTransaction = ...
149+
def __init__(self, async_session: AsyncSession) -> None: ...
150+
async def __aenter__(self) -> AsyncSession: ...
91151
async def __aexit__(
92152
self, type_: Any, value: Any, traceback: Any
93153
) -> None: ...
94154

95155
class AsyncSessionTransaction(StartableContext):
96-
session: Any = ...
97-
nested: Any = ...
98-
sync_transaction: Any = ...
99-
def __init__(self, session: Any, nested: bool = ...) -> None: ...
156+
session: AsyncSession = ...
157+
nested: bool = ...
158+
sync_transaction: Optional[Any] = ...
159+
def __init__(self, session: AsyncSession, nested: bool = ...) -> None: ...
100160
@property
101-
def is_active(self): ...
102-
async def rollback(self) -> None: ...
103-
async def commit(self) -> None: ...
104-
async def start(self): ...
105-
async def __aexit__(self, type_: Any, value: Any, traceback: Any): ...
161+
def is_active(self) -> bool: ...
162+
async def rollback(self) -> Optional[AsyncSessionTransaction]: ...
163+
async def commit(self) -> Optional[AsyncSessionTransaction]: ...
164+
async def start(
165+
self: _TAsyncSessionTransaction,
166+
) -> _TAsyncSessionTransaction: ...
167+
async def __aexit__(
168+
self, type_: Any, value: Any, traceback: Any
169+
) -> None: ...

0 commit comments

Comments
 (0)