Skip to content

Commit 326c37e

Browse files
authored
Improve proxies. (#149)
* improve proxies. * improve names of type only classes.
1 parent fe8a8a1 commit 326c37e

File tree

11 files changed

+485
-310
lines changed

11 files changed

+485
-310
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ incremental = True
3737
strict = True
3838
warn_unused_ignores = False
3939

40+
plugins = sqlalchemy.ext.mypy.plugin
4041

4142

4243
[flake8]

sqlalchemy-stubs/engine/base.pyi

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,17 @@ class _ConnectionCallable(Protocol[_T_contra, _T_co]):
3636
self, __connection: _T_contra, *args: Any, **kwargs: Any
3737
) -> _T_co: ...
3838

39-
class Connection(Connectable):
40-
engine: Engine = ...
39+
class _ConnectionTypingCommon:
40+
@property
41+
def closed(self) -> bool: ...
42+
@property
43+
def invalidated(self) -> bool: ...
4144
dialect: Dialect = ...
45+
@property
46+
def default_isolation_level(self) -> Any: ...
47+
48+
class Connection(_ConnectionTypingCommon, Connectable):
49+
engine: Engine = ...
4250
should_close_with_result: bool = ...
4351
dispatch: Any = ...
4452
def __init__(
@@ -57,15 +65,9 @@ class Connection(Connectable):
5765
def execution_options(self: _TConnection, **opt: Any) -> _TConnection: ...
5866
def get_execution_options(self) -> Dict[str, Any]: ...
5967
@property
60-
def closed(self) -> bool: ...
61-
@property
62-
def invalidated(self) -> bool: ...
63-
@property
6468
def connection(self) -> _DBAPIConnection: ...
6569
def get_isolation_level(self) -> Any: ...
6670
@property
67-
def default_isolation_level(self) -> Any: ...
68-
@property
6971
def info(self) -> MutableMapping[Any, Any]: ...
7072
def connect(self: _TConnection, close_with_result: bool = ...) -> _TConnection: ... # type: ignore[override]
7173
def invalidate(self, exception: Optional[Any] = ...) -> None: ...
@@ -161,13 +163,25 @@ class TwoPhaseTransaction(RootTransaction):
161163
def __init__(self, connection: Connection, xid: str) -> None: ...
162164
def prepare(self) -> None: ...
163165

164-
class Engine(Connectable, log.Identified):
166+
class _EngineTypingCommon:
165167
pool: Pool = ...
166168
url: URL = ...
167169
dialect: Dialect = ...
168170
logging_name: Optional[str] = ...
169171
echo: Optional[Union[bool, Literal["debug"]]] = ...
170172
hide_parameters: bool = ...
173+
@property
174+
def name(self) -> str: ...
175+
@property
176+
def driver(self) -> str: ...
177+
def clear_compiled_cache(self) -> None: ...
178+
def update_execution_options(self, **opt: Any) -> None: ...
179+
def get_execution_options(self) -> Dict[str, Any]: ...
180+
181+
class Engine(_EngineTypingCommon, Connectable, log.Identified):
182+
@property
183+
def engine(self: _TEngine) -> _TEngine: ...
184+
hide_parameters: bool = ...
171185
def __init__(
172186
self,
173187
pool: Pool,
@@ -179,16 +193,7 @@ class Engine(Connectable, log.Identified):
179193
execution_options: Optional[Dict[str, Any]] = ...,
180194
hide_parameters: bool = ...,
181195
) -> None: ...
182-
@property
183-
def engine(self: _TEngine) -> _TEngine: ... # type: ignore[override]
184-
def clear_compiled_cache(self) -> None: ...
185-
def update_execution_options(self, **opt: Any) -> None: ...
186196
def execution_options(self, **opt: Any) -> OptionEngine: ...
187-
def get_execution_options(self) -> Dict[str, Any]: ...
188-
@property
189-
def name(self) -> str: ...
190-
@property
191-
def driver(self) -> str: ...
192197
def dispose(self) -> None: ...
193198
class _trans_ctx:
194199
conn: Connection = ...

sqlalchemy-stubs/ext/asyncio/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,8 @@ from .events import AsyncSessionEvents as AsyncSessionEvents
77
from .result import AsyncMappingResult as AsyncMappingResult
88
from .result import AsyncResult as AsyncResult
99
from .result import AsyncScalarResult as AsyncScalarResult
10+
from .scoping import async_scoped_session as async_scoped_session
11+
from .session import async_object_session as async_object_session
12+
from .session import async_session as async_session
1013
from .session import AsyncSession as AsyncSession
1114
from .session import AsyncSessionTransaction as AsyncSessionTransaction

sqlalchemy-stubs/ext/asyncio/engine.pyi

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,35 @@ from typing import TypeVar
1010
from .base import ProxyComparable
1111
from .base import StartableContext
1212
from .result import AsyncResult
13+
from ...engine import Dialect
1314
from ...engine import Result
1415
from ...engine import Transaction
16+
from ...engine.base import _ConnectionTypingCommon
17+
from ...engine.base import _EngineTypingCommon
1518
from ...future import Connection
1619
from ...future import Engine
1720
from ...sql import Executable
1821

1922
_TAsyncConnection = TypeVar("_TAsyncConnection", bound=AsyncConnection)
2023
_TAsyncTransaction = TypeVar("_TAsyncTransaction", bound=AsyncTransaction)
24+
_TEngine = TypeVar("_TEngine", bound=AsyncEngine)
2125

2226
def create_async_engine(*arg: Any, **kw: Any) -> AsyncEngine: ...
2327

2428
class AsyncConnectable: ...
2529

2630
class AsyncConnection(
27-
ProxyComparable, StartableContext["AsyncConnection"], AsyncConnectable
31+
_ConnectionTypingCommon,
32+
ProxyComparable,
33+
StartableContext["AsyncConnection"],
34+
AsyncConnectable,
2835
):
2936
# copied from future.Connection via create_proxy_methods
3037
@property
3138
def closed(self) -> bool: ...
3239
@property
3340
def invalidated(self) -> bool: ...
34-
dialect: Any
41+
dialect: Dialect
3542
@property
3643
def default_isolation_level(self) -> Any: ...
3744
# end copied
@@ -99,12 +106,9 @@ class AsyncConnection(
99106
self, type_: Any, value: Any, traceback: Any
100107
) -> None: ...
101108

102-
class AsyncEngine(ProxyComparable, AsyncConnectable):
103-
# copied from future.Engine by create_proxy_methods
104-
def clear_compiled_cache(self) -> None: ...
105-
def update_execution_options(self, **opt: Any) -> None: ...
106-
def get_execution_options(self) -> Mapping[Any, Any]: ...
107-
# end copied
109+
class AsyncEngine(_EngineTypingCommon, ProxyComparable, AsyncConnectable):
110+
@property
111+
def engine(self: _TEngine) -> _TEngine: ...
108112
class _trans_ctx(StartableContext[AsyncConnection]):
109113
conn: AsyncConnection = ...
110114
def __init__(self, conn: AsyncConnection) -> None: ...
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Any
2+
from typing import Callable
3+
4+
from .session import _AsyncSessionTypingCommon
5+
from .session import AsyncSession
6+
from ...orm.scoping import ScopedSessionMixin
7+
from ...util import ScopedRegistry
8+
9+
class async_scoped_session(
10+
_AsyncSessionTypingCommon, ScopedSessionMixin[AsyncSession]
11+
):
12+
session_factory: Callable[..., AsyncSession] = ...
13+
registry: ScopedRegistry = ...
14+
def __init__(
15+
self,
16+
session_factory: Callable[..., AsyncSession],
17+
scopefunc: Callable[..., Any] = ...,
18+
) -> None: ...
19+
async def remove(self) -> None: ...

0 commit comments

Comments
 (0)