Skip to content

Commit b515784

Browse files
committed
Improve typing for sql.type_api classes
1 parent b0afe35 commit b515784

File tree

2 files changed

+83
-69
lines changed

2 files changed

+83
-69
lines changed

sqlalchemy-stubs/sql/sqltypes.pyi

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ from typing import Any
77
from typing import List
88
from typing import Mapping
99
from typing import Optional
10+
from typing import Type
11+
from typing import TypeVar
1012
from typing import Union
1113

1214
from . import coercions as coercions
@@ -35,6 +37,8 @@ from ..util import compat as compat
3537
from ..util import langhelpers as langhelpers
3638
from ..util import pickle as pickle
3739

40+
_U = TypeVar("_U")
41+
3842
class _LookupExpressionAdapter:
3943
class Comparator(TypeEngine.Comparator): ...
4044
comparator_factory: Any = ...
@@ -160,7 +164,6 @@ class _Binary(TypeEngine[bytes]):
160164
def python_type(self): ...
161165
def bind_processor(self, dialect: Any): ...
162166
def result_processor(self, dialect: Any, coltype: Any): ...
163-
def result_processor(self, dialect: Any, coltype: Any): ...
164167
def coerce_compared_value(self, op: Any, value: Any): ...
165168
def get_dbapi_type(self, dbapi: Any): ...
166169

@@ -183,7 +186,7 @@ class SchemaType(SchemaEventTarget):
183186
_create_events: bool = ...,
184187
) -> None: ...
185188
def copy(self, **kw: Any): ...
186-
def adapt(self, impltype: Any, **kw: Any): ...
189+
def adapt(self, __impltype: Type[_U], **kw: Any) -> _U: ...
187190
@property
188191
def bind(self): ...
189192
def create(
@@ -204,7 +207,7 @@ class Enum(Emulated, String, SchemaType):
204207
comparator_factory: Any = ...
205208
def as_generic(self, allow_nulltype: bool = ...): ...
206209
def adapt_to_emulated(self, impltype: Any, **kw: Any): ...
207-
def adapt(self, impltype: Any, **kw: Any): ...
210+
def adapt(self, __impltype: Type[_U], **kw: Any) -> _U: ...
208211
def literal_processor(self, dialect: Any): ...
209212
def bind_processor(self, dialect: Any): ...
210213
def result_processor(self, dialect: Any, coltype: Any): ...
@@ -248,20 +251,21 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
248251
class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[Any]):
249252
def coerce_compared_value(self, op: Any, value: Any): ...
250253

251-
class Interval(Emulated, _AbstractInterval, TypeDecorator[timedelta]):
254+
# "comparator_factory" of "_LookupExpressionAdapter" and "TypeDecorator" are incompatible
255+
class Interval(Emulated, _AbstractInterval, TypeDecorator[timedelta]): # type: ignore[misc]
252256
impl: Any = ...
253-
epoch: Any = ...
254-
native: Any = ...
255-
second_precision: Any = ...
256-
day_precision: Any = ...
257+
epoch: datetime = ...
258+
native: bool = ...
259+
second_precision: Optional[float] = ...
260+
day_precision: Optional[float] = ...
257261
def __init__(
258262
self,
259263
native: bool = ...,
260-
second_precision: Optional[Any] = ...,
261-
day_precision: Optional[Any] = ...,
264+
second_precision: Optional[float] = ...,
265+
day_precision: Optional[float] = ...,
262266
) -> None: ...
263267
@property
264-
def python_type(self): ...
268+
def python_type(self) -> Type[timedelta]: ...
265269
def adapt_to_emulated(self, impltype: Any, **kw: Any): ...
266270
def bind_processor(self, dialect: Any): ...
267271
def result_processor(self, dialect: Any, coltype: Any): ...
@@ -270,7 +274,8 @@ class JSON(Indexable, TypeEngine[Union[str, Mapping, List]]):
270274
__visit_name__: str = ...
271275
hashable: bool = ...
272276
NULL: Any = ...
273-
none_as_null: Any = ...
277+
none_as_null: bool = ...
278+
should_evaluate_none: bool = ...
274279
def __init__(self, none_as_null: bool = ...) -> None: ...
275280
class JSONElementType(TypeEngine):
276281
def string_bind_processor(self, dialect: Any): ...
@@ -293,10 +298,6 @@ class JSON(Indexable, TypeEngine[Union[str, Mapping, List]]):
293298
comparator_factory: Any = ...
294299
@property
295300
def python_type(self): ...
296-
@property
297-
def should_evaluate_none(self): ...
298-
@should_evaluate_none.setter
299-
def should_evaluate_none(self, value: Any) -> None: ...
300301
def bind_processor(self, dialect: Any): ...
301302
def result_processor(self, dialect: Any, coltype: Any): ...
302303

sqlalchemy-stubs/sql/type_api.pyi

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
from _typeshed import SupportsLessThan
12
from typing import Any
3+
from typing import Callable
24
from typing import Generic
5+
from typing import Mapping
36
from typing import Optional
7+
from typing import Text
8+
from typing import Tuple
49
from typing import Type
510
from typing import TypeVar
11+
from typing import Union
12+
from typing_extensions import Protocol
613

714
from . import operators as operators
815
from .base import SchemaEventTarget as SchemaEventTarget
@@ -20,6 +27,14 @@ INDEXABLE: Any
2027
TABLEVALUE: Any
2128

2229
_T = TypeVar("_T")
30+
_U = TypeVar("_U")
31+
32+
_TE = TypeVar('_TE', bound=TypeEngine[Any])
33+
_NFE = TypeVar('_NFE', bound=NativeForEmulated)
34+
_TD = TypeVar('_TD', bound=TypeDecorator[Any])
35+
_VT = TypeVar('_VT', bound=Variant[Any])
36+
37+
_SortKeyFunction = Callable[[Any], SupportsLessThan]
2338

2439
class TypeEngine(Traversible, Generic[_T]):
2540
class Comparator(operators.ColumnOperators):
@@ -32,87 +47,85 @@ class TypeEngine(Traversible, Generic[_T]):
3247
def reverse_operate(self, op: Any, other: Any, **kwargs: Any): ...
3348
def __reduce__(self): ...
3449
hashable: bool = ...
35-
comparator_factory: Any = ...
36-
sort_key_function: Any = ...
50+
comparator_factory: Type[Any] = ...
51+
sort_key_function: Optional[_SortKeyFunction] = ...
3752
should_evaluate_none: bool = ...
38-
def evaluates_none(self): ...
39-
def copy(self, **kw: Any): ...
53+
def evaluates_none(self: _TE) -> _TE: ...
54+
def copy(self: _TE, **kw: Any) -> _TE: ...
4055
def compare_against_backend(
4156
self, dialect: Any, conn_type: Any
42-
) -> None: ...
43-
def copy_value(self, value: Any): ...
44-
def literal_processor(self, dialect: Any) -> None: ...
45-
def bind_processor(self, dialect: Any) -> None: ...
46-
def result_processor(self, dialect: Any, coltype: Any) -> None: ...
47-
def column_expression(self, colexpr: Any) -> None: ...
48-
def bind_expression(self, bindvalue: Any) -> None: ...
49-
def compare_values(self, x: Any, y: Any): ...
50-
def get_dbapi_type(self, dbapi: Any) -> None: ...
57+
) -> Any: ...
58+
def copy_value(self, value: _T) -> _T: ...
59+
def literal_processor(self, dialect: Any) -> Optional[Callable[..., Any]]: ...
60+
def bind_processor(self, dialect: Any) -> Optional[Callable[..., Any]]: ...
61+
def result_processor(self, dialect: Any, coltype: Any) -> Optional[Callable[..., Any]]: ...
62+
def column_expression(self, colexpr: Any) -> Any: ...
63+
def bind_expression(self, bindvalue: Any) -> Any: ...
64+
def compare_values(self, x: Any, y: Any) -> bool: ...
65+
def get_dbapi_type(self, dbapi: Any) -> Any: ...
5166
@property
5267
def python_type(self) -> Type[_T]: ...
53-
def with_variant(self, type_: Any, dialect_name: Any): ...
54-
def as_generic(self, allow_nulltype: bool = ...): ...
55-
def dialect_impl(self, dialect: Any): ...
56-
def adapt(self, cls: Any, **kw: Any): ...
57-
def coerce_compared_value(self, op: Any, value: Any): ...
58-
def compile(self, dialect: Optional[Any] = ...): ...
68+
def with_variant(self, type_: Type[TypeEngine[_U]], dialect_name: str) -> Variant[_U]: ...
69+
def as_generic(self, allow_nulltype: bool = ...) -> TypeEngine[Any]: ...
70+
def dialect_impl(self, dialect: Any) -> Type[Any]: ...
71+
def adapt(self, __cls: Type[_U], **kw: Any) -> _U: ...
72+
def coerce_compared_value(self, op: Any, value: Any) -> TypeEngine[Any]: ...
73+
def compile(self, dialect: Optional[Any] = ...) -> Any: ...
5974

6075
class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): ...
6176

6277
class UserDefinedType:
6378
__visit_name__: str = ...
6479
ensure_kwarg: str = ...
65-
def coerce_compared_value(self, op: Any, value: Any): ...
80+
def coerce_compared_value(self, op: Any, value: Any) -> Any: ...
6681

6782
class Emulated:
68-
def adapt_to_emulated(self, impltype: Any, **kw: Any): ...
69-
def adapt(self, impltype: Any, **kw: Any): ...
83+
def adapt_to_emulated(self, impltype: Any, **kw: Any) -> Any: ...
84+
def adapt(self, __impltype: Any, **kw: Any) -> Any: ...
7085

7186
class NativeForEmulated:
7287
@classmethod
73-
def adapt_native_to_emulated(cls, impl: Any, **kw: Any): ...
88+
def adapt_native_to_emulated(cls, impl: Any, **kw: Any) -> Any: ...
7489
@classmethod
75-
def adapt_emulated_to_native(cls, impl: Any, **kw: Any): ...
90+
def adapt_emulated_to_native(cls: Type[_NFE], impl: Any, **kw: Any) -> _NFE: ...
7691

77-
_TD = TypeVar("_TD")
78-
79-
class TypeDecorator(SchemaEventTarget, TypeEngine[Any], Generic[_TD]):
92+
class TypeDecorator(SchemaEventTarget, TypeEngine[_T]):
8093
__visit_name__: str = ...
8194
impl: Any = ...
8295
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
83-
coerce_to_is_types: Any = ...
96+
coerce_to_is_types: Tuple[Type[Any], ...] = ...
8497
class Comparator(TypeEngine.Comparator):
8598
def operate(self, op: Any, *other: Any, **kwargs: Any): ...
8699
def reverse_operate(self, op: Any, other: Any, **kwargs: Any): ...
87100
@property
88-
def comparator_factory(self): ...
89-
def type_engine(self, dialect: Any): ...
90-
def load_dialect_impl(self, dialect: Any): ...
91-
def __getattr__(self, key: Any): ...
92-
def process_literal_param(self, value: Any, dialect: Any) -> None: ...
93-
def process_bind_param(self, value: Any, dialect: Any) -> None: ...
94-
def process_result_value(self, value: Any, dialect: Any) -> None: ...
95-
def literal_processor(self, dialect: Any): ...
96-
def bind_processor(self, dialect: Any): ...
97-
def result_processor(self, dialect: Any, coltype: Any): ...
98-
def bind_expression(self, bindparam: Any): ...
99-
def column_expression(self, column: Any): ...
100-
def coerce_compared_value(self, op: Any, value: Any): ...
101-
def copy(self, **kw: Any): ...
102-
def get_dbapi_type(self, dbapi: Any): ...
103-
def compare_values(self, x: Any, y: Any): ...
101+
def comparator_factory(self) -> Type[Any]: ... # type: ignore[override]
102+
def type_engine(self, dialect: Any) -> TypeEngine[Any]: ...
103+
def load_dialect_impl(self, dialect: Any) -> TypeEngine[Any]: ...
104+
def __getattr__(self, key: Any) -> Any: ...
105+
def process_literal_param(self, value: Any, dialect: Any) -> Optional[str]: ...
106+
def process_bind_param(self, value: Any, dialect: Any) -> Optional[Text]: ...
107+
def process_result_value(self, value: Any, dialect: Any) -> Optional[_T]: ...
108+
def literal_processor(self, dialect: Any) -> Callable[[Optional[_T]], Optional[str]]: ...
109+
def bind_processor(self, dialect: Any) -> Callable[[Optional[_T]], Optional[str]]: ...
110+
def result_processor(self, dialect: Any, coltype: Any) -> Callable[[Optional[Any]], Optional[_T]]: ...
111+
def bind_expression(self, bindparam: Any) -> Any: ...
112+
def column_expression(self, column: Any) -> Any: ...
113+
def coerce_compared_value(self, op: Any, value: Any) -> Any: ...
114+
def copy(self: _TD, **kw: Any) -> _TD: ...
115+
def get_dbapi_type(self, dbapi: Any) -> Any: ...
116+
def compare_values(self, x: Any, y: Any) -> bool: ...
104117
@property
105-
def sort_key_function(self): ...
118+
def sort_key_function(self) -> Optional[_SortKeyFunction]: ... # type: ignore[override]
106119

107-
class Variant(TypeDecorator[Any]):
108-
impl: Any = ...
109-
mapping: Any = ...
110-
def __init__(self, base: Any, mapping: Any) -> None: ...
111-
def coerce_compared_value(self, operator: Any, value: Any): ...
112-
def load_dialect_impl(self, dialect: Any): ...
113-
def with_variant(self, type_: Any, dialect_name: Any): ...
120+
class Variant(TypeDecorator[_T]):
121+
impl: Type[TypeEngine[Any]] = ...
122+
mapping: Mapping[str, TypeEngine[Any]] = ...
123+
def __init__(self, base: Any, mapping: Mapping[str, TypeEngine[Any]]) -> None: ...
124+
def coerce_compared_value(self: _VT, operator: Any, value: Any) -> Union[_VT, TypeEngine[Any]]: ...
125+
def load_dialect_impl(self, dialect: Any) -> TypeEngine[Any]: ...
126+
def with_variant(self, type_: Type[TypeEngine[_U]], dialect_name: str) -> Variant[_U]: ...
114127
@property
115-
def comparator_factory(self): ...
128+
def comparator_factory(self) -> Type[Any]: ... # type: ignore[override]
116129

117130
def to_instance(typeobj: Any, *arg: Any, **kw: Any): ...
118131
def adapt_type(typeobj: Any, colspecs: Any): ...

0 commit comments

Comments
 (0)