Skip to content

Commit 14ecd6d

Browse files
authored
Merge pull request #6 from bryanforbes/improve-sql-type-api
Improve typing for sql.type_api classes
2 parents b0afe35 + 31c0630 commit 14ecd6d

File tree

2 files changed

+121
-72
lines changed

2 files changed

+121
-72
lines changed

sqlalchemy-stubs/sql/sqltypes.pyi

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ from typing import Any
77
from typing import List
88
from typing import Mapping
99
from typing import Optional
10+
from typing import Type
11+
from typing import TypeVar
1012
from typing import Union
1113

1214
from . import coercions as coercions
@@ -35,6 +37,8 @@ from ..util import compat as compat
3537
from ..util import langhelpers as langhelpers
3638
from ..util import pickle as pickle
3739

40+
_T = TypeVar("_T")
41+
3842
class _LookupExpressionAdapter:
3943
class Comparator(TypeEngine.Comparator): ...
4044
comparator_factory: Any = ...
@@ -160,7 +164,6 @@ class _Binary(TypeEngine[bytes]):
160164
def python_type(self): ...
161165
def bind_processor(self, dialect: Any): ...
162166
def result_processor(self, dialect: Any, coltype: Any): ...
163-
def result_processor(self, dialect: Any, coltype: Any): ...
164167
def coerce_compared_value(self, op: Any, value: Any): ...
165168
def get_dbapi_type(self, dbapi: Any): ...
166169

@@ -183,7 +186,7 @@ class SchemaType(SchemaEventTarget):
183186
_create_events: bool = ...,
184187
) -> None: ...
185188
def copy(self, **kw: Any): ...
186-
def adapt(self, impltype: Any, **kw: Any): ...
189+
def adapt(self, __impltype: Type[_T], **kw: Any) -> _T: ...
187190
@property
188191
def bind(self): ...
189192
def create(
@@ -204,7 +207,7 @@ class Enum(Emulated, String, SchemaType):
204207
comparator_factory: Any = ...
205208
def as_generic(self, allow_nulltype: bool = ...): ...
206209
def adapt_to_emulated(self, impltype: Any, **kw: Any): ...
207-
def adapt(self, impltype: Any, **kw: Any): ...
210+
def adapt(self, __impltype: Type[_T], **kw: Any) -> _T: ...
208211
def literal_processor(self, dialect: Any): ...
209212
def bind_processor(self, dialect: Any): ...
210213
def result_processor(self, dialect: Any, coltype: Any): ...
@@ -248,20 +251,21 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
248251
class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[Any]):
249252
def coerce_compared_value(self, op: Any, value: Any): ...
250253

251-
class Interval(Emulated, _AbstractInterval, TypeDecorator[timedelta]):
254+
# "comparator_factory" of "_LookupExpressionAdapter" and "TypeDecorator" are incompatible
255+
class Interval(Emulated, _AbstractInterval, TypeDecorator[timedelta]): # type: ignore[misc]
252256
impl: Any = ...
253-
epoch: Any = ...
254-
native: Any = ...
255-
second_precision: Any = ...
256-
day_precision: Any = ...
257+
epoch: datetime = ...
258+
native: bool = ...
259+
second_precision: Optional[float] = ...
260+
day_precision: Optional[float] = ...
257261
def __init__(
258262
self,
259263
native: bool = ...,
260-
second_precision: Optional[Any] = ...,
261-
day_precision: Optional[Any] = ...,
264+
second_precision: Optional[float] = ...,
265+
day_precision: Optional[float] = ...,
262266
) -> None: ...
263267
@property
264-
def python_type(self): ...
268+
def python_type(self) -> Type[timedelta]: ...
265269
def adapt_to_emulated(self, impltype: Any, **kw: Any): ...
266270
def bind_processor(self, dialect: Any): ...
267271
def result_processor(self, dialect: Any, coltype: Any): ...
@@ -270,7 +274,8 @@ class JSON(Indexable, TypeEngine[Union[str, Mapping, List]]):
270274
__visit_name__: str = ...
271275
hashable: bool = ...
272276
NULL: Any = ...
273-
none_as_null: Any = ...
277+
none_as_null: bool = ...
278+
should_evaluate_none: bool = ...
274279
def __init__(self, none_as_null: bool = ...) -> None: ...
275280
class JSONElementType(TypeEngine):
276281
def string_bind_processor(self, dialect: Any): ...
@@ -293,10 +298,6 @@ class JSON(Indexable, TypeEngine[Union[str, Mapping, List]]):
293298
comparator_factory: Any = ...
294299
@property
295300
def python_type(self): ...
296-
@property
297-
def should_evaluate_none(self): ...
298-
@should_evaluate_none.setter
299-
def should_evaluate_none(self, value: Any) -> None: ...
300301
def bind_processor(self, dialect: Any): ...
301302
def result_processor(self, dialect: Any, coltype: Any): ...
302303

@@ -327,7 +328,6 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine[List]):
327328
class TupleType(TypeEngine[TupleType]):
328329
types: Any = ...
329330
def __init__(self, *types: Any) -> None: ...
330-
def result_processor(self, dialect: Any, coltype: Any) -> None: ...
331331

332332
class REAL(Float):
333333
__visit_name__: str = ...

sqlalchemy-stubs/sql/type_api.pyi

Lines changed: 104 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import sys
12
from typing import Any
3+
from typing import Callable
24
from typing import Generic
5+
from typing import Mapping
36
from typing import Optional
7+
from typing import Tuple
48
from typing import Type
59
from typing import TypeVar
10+
from typing import Union
11+
12+
from typing_extensions import Protocol
613

714
from . import operators as operators
815
from .base import SchemaEventTarget as SchemaEventTarget
@@ -20,6 +27,30 @@ INDEXABLE: Any
2027
TABLEVALUE: Any
2128

2229
_T = TypeVar("_T")
30+
_T_co = TypeVar("_T_co", covariant=True)
31+
_T_contra = TypeVar("_T_contra", contravariant=True)
32+
_U = TypeVar("_U")
33+
34+
_TE = TypeVar("_TE", bound=TypeEngine[Any])
35+
_NFE = TypeVar("_NFE", bound=NativeForEmulated)
36+
_TD = TypeVar("_TD", bound=TypeDecorator[Any])
37+
_VT = TypeVar("_VT", bound=Variant[Any])
38+
39+
if sys.version_info[0] < 3:
40+
from _typeshed import SupportsLessThan
41+
42+
_SortKeyFunction = Callable[[Any], SupportsLessThan]
43+
else:
44+
_SortKeyFunction = Callable[[Any], Any]
45+
46+
class _LiteralProcessor(Protocol[_T_contra]):
47+
def __call__(self, __value: Optional[_T_contra]) -> str: ...
48+
49+
class _BindProcessor(Protocol[_T_contra]):
50+
def __call__(self, __value: Optional[_T_contra]) -> Optional[Any]: ...
51+
52+
class _ResultProcessor(Protocol[_T_co]):
53+
def __call__(self, __value: Optional[Any]) -> Optional[_T_co]: ...
2354

2455
class TypeEngine(Traversible, Generic[_T]):
2556
class Comparator(operators.ColumnOperators):
@@ -32,87 +63,105 @@ class TypeEngine(Traversible, Generic[_T]):
3263
def reverse_operate(self, op: Any, other: Any, **kwargs: Any): ...
3364
def __reduce__(self): ...
3465
hashable: bool = ...
35-
comparator_factory: Any = ...
36-
sort_key_function: Any = ...
66+
comparator_factory: Type[Any] = ...
67+
sort_key_function: Optional[_SortKeyFunction] = ...
3768
should_evaluate_none: bool = ...
38-
def evaluates_none(self): ...
39-
def copy(self, **kw: Any): ...
40-
def compare_against_backend(
41-
self, dialect: Any, conn_type: Any
42-
) -> None: ...
43-
def copy_value(self, value: Any): ...
44-
def literal_processor(self, dialect: Any) -> None: ...
45-
def bind_processor(self, dialect: Any) -> None: ...
46-
def result_processor(self, dialect: Any, coltype: Any) -> None: ...
47-
def column_expression(self, colexpr: Any) -> None: ...
48-
def bind_expression(self, bindvalue: Any) -> None: ...
49-
def compare_values(self, x: Any, y: Any): ...
50-
def get_dbapi_type(self, dbapi: Any) -> None: ...
69+
def evaluates_none(self: _TE) -> _TE: ...
70+
def copy(self: _TE, **kw: Any) -> _TE: ...
71+
def compare_against_backend(self, dialect: Any, conn_type: Any) -> Any: ...
72+
def copy_value(self, value: _T) -> _T: ...
73+
def literal_processor(
74+
self, dialect: Any
75+
) -> Optional[_LiteralProcessor[_T]]: ...
76+
def bind_processor(self, dialect: Any) -> Optional[_BindProcessor[_T]]: ...
77+
def result_processor(
78+
self, dialect: Any, coltype: Any
79+
) -> Optional[_ResultProcessor[_T]]: ...
80+
def column_expression(self, colexpr: Any) -> Any: ...
81+
def bind_expression(self, bindvalue: Any) -> Any: ...
82+
def compare_values(self, x: Any, y: Any) -> bool: ...
83+
def get_dbapi_type(self, dbapi: Any) -> Any: ...
5184
@property
5285
def python_type(self) -> Type[_T]: ...
53-
def with_variant(self, type_: Any, dialect_name: Any): ...
54-
def as_generic(self, allow_nulltype: bool = ...): ...
55-
def dialect_impl(self, dialect: Any): ...
56-
def adapt(self, cls: Any, **kw: Any): ...
57-
def coerce_compared_value(self, op: Any, value: Any): ...
58-
def compile(self, dialect: Optional[Any] = ...): ...
86+
def with_variant(
87+
self, type_: Type[TypeEngine[_U]], dialect_name: str
88+
) -> Variant[_U]: ...
89+
def as_generic(self, allow_nulltype: bool = ...) -> TypeEngine[Any]: ...
90+
def dialect_impl(self, dialect: Any) -> Type[Any]: ...
91+
def adapt(self, __cls: Type[_U], **kw: Any) -> _U: ...
92+
def coerce_compared_value(
93+
self, op: Any, value: Any
94+
) -> TypeEngine[Any]: ...
95+
def compile(self, dialect: Optional[Any] = ...) -> Any: ...
5996

6097
class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): ...
6198

6299
class UserDefinedType:
63100
__visit_name__: str = ...
64101
ensure_kwarg: str = ...
65-
def coerce_compared_value(self, op: Any, value: Any): ...
102+
def coerce_compared_value(self, op: Any, value: Any) -> Any: ...
66103

67104
class Emulated:
68-
def adapt_to_emulated(self, impltype: Any, **kw: Any): ...
69-
def adapt(self, impltype: Any, **kw: Any): ...
105+
def adapt_to_emulated(self, impltype: Any, **kw: Any) -> Any: ...
106+
def adapt(self, __impltype: Any, **kw: Any) -> Any: ...
70107

71108
class NativeForEmulated:
72109
@classmethod
73-
def adapt_native_to_emulated(cls, impl: Any, **kw: Any): ...
110+
def adapt_native_to_emulated(cls, impl: Any, **kw: Any) -> Any: ...
74111
@classmethod
75-
def adapt_emulated_to_native(cls, impl: Any, **kw: Any): ...
76-
77-
_TD = TypeVar("_TD")
112+
def adapt_emulated_to_native(
113+
cls: Type[_NFE], impl: Any, **kw: Any
114+
) -> _NFE: ...
78115

79-
class TypeDecorator(SchemaEventTarget, TypeEngine[Any], Generic[_TD]):
116+
class TypeDecorator(SchemaEventTarget, TypeEngine[_T]):
80117
__visit_name__: str = ...
81118
impl: Any = ...
82119
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
83-
coerce_to_is_types: Any = ...
120+
coerce_to_is_types: Tuple[Type[Any], ...] = ...
84121
class Comparator(TypeEngine.Comparator):
85122
def operate(self, op: Any, *other: Any, **kwargs: Any): ...
86123
def reverse_operate(self, op: Any, other: Any, **kwargs: Any): ...
87124
@property
88-
def comparator_factory(self): ...
89-
def type_engine(self, dialect: Any): ...
90-
def load_dialect_impl(self, dialect: Any): ...
91-
def __getattr__(self, key: Any): ...
92-
def process_literal_param(self, value: Any, dialect: Any) -> None: ...
93-
def process_bind_param(self, value: Any, dialect: Any) -> None: ...
94-
def process_result_value(self, value: Any, dialect: Any) -> None: ...
95-
def literal_processor(self, dialect: Any): ...
96-
def bind_processor(self, dialect: Any): ...
97-
def result_processor(self, dialect: Any, coltype: Any): ...
98-
def bind_expression(self, bindparam: Any): ...
99-
def column_expression(self, column: Any): ...
100-
def coerce_compared_value(self, op: Any, value: Any): ...
101-
def copy(self, **kw: Any): ...
102-
def get_dbapi_type(self, dbapi: Any): ...
103-
def compare_values(self, x: Any, y: Any): ...
125+
def comparator_factory(self) -> Type[Any]: ... # type: ignore[override]
126+
def type_engine(self, dialect: Any) -> TypeEngine[Any]: ...
127+
def load_dialect_impl(self, dialect: Any) -> TypeEngine[Any]: ...
128+
def __getattr__(self, key: Any) -> Any: ...
129+
def process_literal_param(
130+
self, value: Optional[_T], dialect: Any
131+
) -> str: ...
132+
def process_bind_param(self, value: Optional[_T], dialect: Any) -> Any: ...
133+
def process_result_value(
134+
self, value: Any, dialect: Any
135+
) -> Optional[_T]: ...
136+
def literal_processor(self, dialect: Any) -> _LiteralProcessor[_T]: ...
137+
def bind_processor(self, dialect: Any) -> _BindProcessor[_T]: ...
138+
def result_processor(
139+
self, dialect: Any, coltype: Any
140+
) -> _ResultProcessor[_T]: ...
141+
def bind_expression(self, bindparam: Any) -> Any: ...
142+
def column_expression(self, column: Any) -> Any: ...
143+
def coerce_compared_value(self, op: Any, value: Any) -> Any: ...
144+
def copy(self: _TD, **kw: Any) -> _TD: ...
145+
def get_dbapi_type(self, dbapi: Any) -> Any: ...
146+
def compare_values(self, x: Any, y: Any) -> bool: ...
104147
@property
105-
def sort_key_function(self): ...
148+
def sort_key_function(self) -> Optional[_SortKeyFunction]: ... # type: ignore[override]
106149

107-
class Variant(TypeDecorator[Any]):
108-
impl: Any = ...
109-
mapping: Any = ...
110-
def __init__(self, base: Any, mapping: Any) -> None: ...
111-
def coerce_compared_value(self, operator: Any, value: Any): ...
112-
def load_dialect_impl(self, dialect: Any): ...
113-
def with_variant(self, type_: Any, dialect_name: Any): ...
150+
class Variant(TypeDecorator[_T]):
151+
impl: TypeEngine[Any] = ...
152+
mapping: Mapping[str, TypeEngine[Any]] = ...
153+
def __init__(
154+
self, base: Any, mapping: Mapping[str, TypeEngine[Any]]
155+
) -> None: ...
156+
def coerce_compared_value(
157+
self: _VT, operator: Any, value: Any
158+
) -> Union[_VT, TypeEngine[Any]]: ...
159+
def load_dialect_impl(self, dialect: Any) -> TypeEngine[Any]: ...
160+
def with_variant(
161+
self, type_: Type[TypeEngine[_U]], dialect_name: str
162+
) -> Variant[_U]: ...
114163
@property
115-
def comparator_factory(self): ...
164+
def comparator_factory(self) -> Type[Any]: ... # type: ignore[override]
116165

117166
def to_instance(typeobj: Any, *arg: Any, **kw: Any): ...
118167
def adapt_type(typeobj: Any, colspecs: Any): ...

0 commit comments

Comments
 (0)