Skip to content

Commit 23cd151

Browse files
authored
Improve engine.base (#90)
* Improve `engine.base` * Updates based on PR feedback
1 parent 602cdf3 commit 23cd151

File tree

1 file changed

+173
-110
lines changed

1 file changed

+173
-110
lines changed

sqlalchemy-stubs/engine/base.pyi

Lines changed: 173 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,185 +1,248 @@
11
from typing import Any
2+
from typing import Dict
3+
from typing import List
24
from typing import Optional
5+
from typing import TypeVar
6+
from typing import Union
37

8+
from typing_extensions import Literal
9+
from typing_extensions import Protocol
10+
11+
from .cursor import CursorResult
12+
from .interfaces import _DBAPIConnection
13+
from .interfaces import _DBAPICursor
414
from .interfaces import Connectable as Connectable
5-
from .interfaces import ExceptionContext as ExceptionContext
6-
from .. import exc as exc
7-
from .. import inspection as inspection
8-
from .. import log as log
9-
from .. import util as util
10-
from ..sql import compiler as compiler
15+
from .interfaces import Dialect
16+
from .interfaces import ExceptionContext
17+
from .interfaces import ExecutionContext
18+
from .url import URL
19+
from .. import log
20+
from .. import util
21+
from ..exc import StatementError
22+
from ..pool import Pool
23+
24+
_T = TypeVar("_T")
25+
_T_co = TypeVar("_T_co", covariant=True)
26+
_T_contra = TypeVar("_T_contra", contravariant=True)
27+
_TConnection = TypeVar("_TConnection", bound=Connection)
28+
_TTransaction = TypeVar("_TTransaction", bound=Transaction)
29+
_TEngine = TypeVar("_TEngine", bound=Engine)
30+
31+
_ExecutionOptions: util.immutabledict[Any, Any]
1132

12-
_ExecutionOptions = Any
33+
class _ConnectionCallable(Protocol[_T_contra, _T_co]):
34+
def __call__(
35+
self, __connection: _T_contra, *args: Any, **kwargs: Any
36+
) -> _T_co: ...
1337

1438
class Connection(Connectable):
15-
engine: Any = ...
16-
dialect: Any = ...
39+
engine: Engine = ...
40+
dialect: Dialect = ...
1741
should_close_with_result: bool = ...
1842
dispatch: Any = ...
1943
def __init__(
2044
self,
21-
engine: Any,
22-
connection: Optional[Any] = ...,
45+
engine: Engine,
46+
connection: Optional[_DBAPIConnection] = ...,
2347
close_with_result: bool = ...,
2448
_branch_from: Optional[Any] = ...,
25-
_execution_options: Optional[Any] = ...,
49+
_execution_options: Optional[Dict[str, Any]] = ...,
2650
_dispatch: Optional[Any] = ...,
2751
_has_events: Optional[Any] = ...,
2852
) -> None: ...
29-
def schema_for_object(self, obj: Any): ...
30-
def __enter__(self): ...
53+
def schema_for_object(self, obj: Any) -> str: ...
54+
def __enter__(self: _TConnection) -> _TConnection: ...
3155
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: ...
32-
def execution_options(self, **opt: Any): ...
33-
def get_execution_options(self): ...
56+
def execution_options(self: _TConnection, **opt: Any) -> _TConnection: ...
57+
def get_execution_options(self) -> Dict[str, Any]: ...
3458
@property
35-
def closed(self): ...
59+
def closed(self) -> bool: ...
3660
@property
37-
def invalidated(self): ...
61+
def invalidated(self) -> bool: ...
3862
@property
39-
def connection(self): ...
40-
def get_isolation_level(self): ...
63+
def connection(self) -> _DBAPIConnection: ...
64+
def get_isolation_level(self) -> Any: ...
4165
@property
42-
def default_isolation_level(self): ...
66+
def default_isolation_level(self) -> Any: ...
4367
@property
44-
def info(self): ...
45-
def connect(self, close_with_result: bool = ...): ...
46-
def invalidate(self, exception: Optional[Any] = ...): ...
68+
def info(self) -> Dict[Any, Any]: ...
69+
def connect(self: _TConnection, close_with_result: bool = ...) -> _TConnection: ... # type: ignore[override]
70+
def invalidate(self, exception: Optional[Any] = ...) -> None: ...
4771
def detach(self) -> None: ...
48-
def begin(self): ...
49-
def begin_nested(self): ...
50-
def begin_twophase(self, xid: Optional[Any] = ...): ...
51-
def recover_twophase(self): ...
52-
def rollback_prepared(self, xid: Any, recover: bool = ...) -> None: ...
53-
def commit_prepared(self, xid: Any, recover: bool = ...) -> None: ...
54-
def in_transaction(self): ...
55-
def in_nested_transaction(self): ...
56-
def get_transaction(self): ...
57-
def get_nested_transaction(self): ...
72+
def begin(self) -> Optional[Transaction]: ...
73+
def begin_nested(self) -> NestedTransaction: ...
74+
def begin_twophase(
75+
self, xid: Optional[str] = ...
76+
) -> TwoPhaseTransaction: ...
77+
def recover_twophase(self) -> None: ...
78+
def rollback_prepared(self, xid: str, recover: bool = ...) -> None: ...
79+
def commit_prepared(self, xid: str, recover: bool = ...) -> None: ...
80+
def in_transaction(self) -> bool: ...
81+
def in_nested_transaction(self) -> bool: ...
82+
def get_transaction(self) -> Optional[Transaction]: ...
83+
def get_nested_transaction(self) -> Optional[NestedTransaction]: ...
5884
def close(self) -> None: ...
59-
def scalar(self, object_: Any, *multiparams: Any, **params: Any): ...
60-
def execute(self, statement: Any, *multiparams: Any, **params: Any): ...
85+
def scalar(
86+
self, object_: Any, *multiparams: Any, **params: Any
87+
) -> Any: ...
88+
def execute(self, statement: Any, *multiparams: Any, **params: Any) -> CursorResult: ... # type: ignore[override]
6189
def exec_driver_sql(
6290
self,
63-
statement: Any,
91+
statement: str,
6492
parameters: Optional[Any] = ...,
6593
execution_options: Optional[Any] = ...,
66-
): ...
67-
def transaction(self, callable_: Any, *args: Any, **kwargs: Any): ...
68-
def run_callable(self, callable_: Any, *args: Any, **kwargs: Any): ...
94+
) -> CursorResult: ...
95+
def transaction(
96+
self: _TConnection,
97+
callable_: _ConnectionCallable[_TConnection, _T],
98+
*args: Any,
99+
**kwargs: Any,
100+
) -> _T: ...
101+
def run_callable(
102+
self: _TConnection,
103+
callable_: _ConnectionCallable[_TConnection, _T],
104+
*args: Any,
105+
**kwargs: Any,
106+
) -> _T: ...
69107

70108
class ExceptionContextImpl(ExceptionContext):
71-
engine: Any = ...
72-
connection: Any = ...
73-
sqlalchemy_exception: Any = ...
74-
original_exception: Any = ...
75-
execution_context: Any = ...
76-
statement: Any = ...
77-
parameters: Any = ...
78-
is_disconnect: Any = ...
79-
invalidate_pool_on_disconnect: Any = ...
109+
engine: Engine = ...
110+
connection: Connection = ...
111+
sqlalchemy_exception: Optional[StatementError] = ...
112+
original_exception: BaseException = ...
113+
execution_context: Optional[ExecutionContext] = ...
114+
statement: Optional[str] = ...
115+
parameters: Optional[Any] = ...
116+
is_disconnect: bool = ...
117+
invalidate_pool_on_disconnect: bool = ...
80118
def __init__(
81119
self,
82-
exception: Any,
83-
sqlalchemy_exception: Any,
84-
engine: Any,
85-
connection: Any,
86-
cursor: Any,
87-
statement: Any,
88-
parameters: Any,
89-
context: Any,
90-
is_disconnect: Any,
91-
invalidate_pool_on_disconnect: Any,
120+
exception: BaseException,
121+
sqlalchemy_exception: Optional[StatementError],
122+
engine: Engine,
123+
connection: Optional[Connection],
124+
cursor: Optional[_DBAPICursor],
125+
statement: Optional[str],
126+
parameters: Optional[Any],
127+
context: Optional[ExecutionContext],
128+
is_disconnect: bool,
129+
invalidate_pool_on_disconnect: bool,
92130
) -> None: ...
93131

94132
class Transaction:
95-
def __init__(self, connection: Any) -> None: ...
133+
def __init__(self, connection: Connection) -> None: ...
96134
@property
97-
def is_valid(self): ...
135+
def is_valid(self) -> bool: ...
98136
def close(self) -> None: ...
99137
def rollback(self) -> None: ...
100138
def commit(self) -> None: ...
101-
def __enter__(self): ...
139+
def __enter__(self: _TTransaction) -> _TTransaction: ...
102140
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: ...
103141

104142
class MarkerTransaction(Transaction):
105-
connection: Any = ...
106-
def __init__(self, connection: Any) -> None: ...
143+
connection: Connection = ...
144+
def __init__(self, connection: Connection) -> None: ...
107145
@property
108-
def is_active(self): ...
146+
def is_active(self) -> bool: ...
109147

110148
class RootTransaction(Transaction):
111-
connection: Any = ...
149+
connection: Connection = ...
112150
is_active: bool = ...
113-
def __init__(self, connection: Any) -> None: ...
151+
def __init__(self, connection: Connection) -> None: ...
114152

115153
class NestedTransaction(Transaction):
116-
connection: Any = ...
154+
connection: Connection = ...
117155
is_active: bool = ...
118-
def __init__(self, connection: Any) -> None: ...
156+
def __init__(self, connection: Connection) -> None: ...
119157

120158
class TwoPhaseTransaction(RootTransaction):
121-
xid: Any = ...
122-
def __init__(self, connection: Any, xid: Any) -> None: ...
159+
xid: str = ...
160+
def __init__(self, connection: Connection, xid: str) -> None: ...
123161
def prepare(self) -> None: ...
124162

125163
class Engine(Connectable, log.Identified):
126-
pool: Any = ...
127-
url: Any = ...
128-
dialect: Any = ...
129-
logging_name: Any = ...
130-
echo: Any = ...
131-
hide_parameters: Any = ...
164+
pool: Pool = ...
165+
url: URL = ...
166+
dialect: Dialect = ...
167+
logging_name: Optional[str] = ...
168+
echo: Optional[Union[bool, Literal["debug"]]] = ...
169+
hide_parameters: bool = ...
132170
def __init__(
133171
self,
134-
pool: Any,
135-
dialect: Any,
136-
url: Any,
137-
logging_name: Optional[Any] = ...,
138-
echo: Optional[Any] = ...,
172+
pool: Pool,
173+
dialect: Dialect,
174+
url: URL,
175+
logging_name: Optional[str] = ...,
176+
echo: Optional[Union[bool, Literal["debug"]]] = ...,
139177
query_cache_size: int = ...,
140-
execution_options: Optional[Any] = ...,
178+
execution_options: Optional[Dict[str, Any]] = ...,
141179
hide_parameters: bool = ...,
142180
) -> None: ...
143181
@property
144-
def engine(self): ...
182+
def engine(self: _TEngine) -> _TEngine: ... # type: ignore[override]
145183
def clear_compiled_cache(self) -> None: ...
146184
def update_execution_options(self, **opt: Any) -> None: ...
147-
def execution_options(self, **opt: Any): ...
148-
def get_execution_options(self): ...
185+
def execution_options(self, **opt: Any) -> OptionEngine: ...
186+
def get_execution_options(self) -> Dict[str, Any]: ...
149187
@property
150-
def name(self): ...
188+
def name(self) -> str: ...
151189
@property
152-
def driver(self): ...
190+
def driver(self) -> str: ...
153191
def dispose(self) -> None: ...
154192
class _trans_ctx:
155-
conn: Any = ...
156-
transaction: Any = ...
157-
close_with_result: Any = ...
193+
conn: Connection = ...
194+
transaction: Transaction = ...
195+
close_with_result: bool = ...
158196
def __init__(
159-
self, conn: Any, transaction: Any, close_with_result: Any
197+
self,
198+
conn: Connection,
199+
transaction: Transaction,
200+
close_with_result: bool,
160201
) -> None: ...
161-
def __enter__(self): ...
202+
def __enter__(self) -> Connection: ...
162203
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: ...
163-
def begin(self, close_with_result: bool = ...): ...
164-
def transaction(self, callable_: Any, *args: Any, **kwargs: Any): ...
165-
def run_callable(self, callable_: Any, *args: Any, **kwargs: Any): ...
166-
def execute(self, statement: Any, *multiparams: Any, **params: Any): ...
167-
def scalar(self, statement: Any, *multiparams: Any, **params: Any): ...
168-
def connect(self, close_with_result: bool = ...): ...
204+
def begin(self, close_with_result: bool = ...) -> _trans_ctx: ...
205+
def transaction(
206+
self,
207+
callable_: _ConnectionCallable[_TConnection, _T],
208+
*args: Any,
209+
**kwargs: Any,
210+
) -> _T: ...
211+
def run_callable(
212+
self,
213+
callable_: _ConnectionCallable[_TConnection, _T],
214+
*args: Any,
215+
**kwargs: Any,
216+
) -> _T: ...
217+
def execute( # type: ignore[override]
218+
self, statement: Any, *multiparams: Any, **params: Any
219+
) -> CursorResult: ...
220+
def scalar( # type: ignore[override]
221+
self, statement: Any, *multiparams: Any, **params: Any
222+
) -> Any: ...
223+
def connect(self, close_with_result: bool = ...) -> Connection: ... # type: ignore[override]
169224
def table_names(
170-
self, schema: Optional[Any] = ..., connection: Optional[Any] = ...
171-
): ...
172-
def has_table(self, table_name: Any, schema: Optional[Any] = ...): ...
173-
def raw_connection(self, _connection: Optional[Any] = ...): ...
225+
self,
226+
schema: Optional[str] = ...,
227+
connection: Optional[Connection] = ...,
228+
) -> List[str]: ...
229+
def has_table(
230+
self, table_name: str, schema: Optional[str] = ...
231+
) -> bool: ...
232+
def raw_connection(
233+
self, _connection: Optional[Connection] = ...
234+
) -> _DBAPIConnection: ...
174235

175236
class OptionEngineMixin:
176-
url: Any = ...
177-
dialect: Any = ...
178-
logging_name: Any = ...
179-
echo: Any = ...
180-
hide_parameters: Any = ...
237+
url: URL = ...
238+
dialect: Dialect = ...
239+
logging_name: Optional[str] = ...
240+
echo: Optional[Union[bool, Literal["debug"]]] = ...
241+
hide_parameters: bool = ...
181242
dispatch: Any = ...
182-
def __init__(self, proxied: Any, execution_options: Any) -> None: ...
183-
pool: Any = ...
243+
pool: Pool = ...
244+
def __init__(
245+
self, proxied: Engine, execution_options: Dict[str, Any]
246+
) -> None: ...
184247

185248
class OptionEngine(OptionEngineMixin, Engine): ...

0 commit comments

Comments
 (0)