Skip to content

Commit ae5cb8e

Browse files
authored
Merge pull request #25 from bryanforbes/improve-sql-base
Improve `sql.base` typings
2 parents e931c09 + b7f469c commit ae5cb8e

File tree

1 file changed

+122
-90
lines changed

1 file changed

+122
-90
lines changed

sqlalchemy-stubs/sql/base.pyi

Lines changed: 122 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,120 @@
11
from types import ModuleType
22
from typing import Any
3+
from typing import Callable
4+
from typing import Dict
5+
from typing import Generic
6+
from typing import Iterable
7+
from typing import Iterator
8+
from typing import List
9+
from typing import Mapping
10+
from typing import NoReturn
311
from typing import Optional
4-
5-
from . import roles as roles
6-
from .traversals import HasCacheKey as HasCacheKey
7-
from .traversals import HasCopyInternals as HasCopyInternals
8-
from .traversals import MemoizedHasCacheKey as MemoizedHasCacheKey
9-
from .visitors import ClauseVisitor as ClauseVisitor
10-
from .visitors import ExtendedInternalTraversal as ExtendedInternalTraversal
11-
from .visitors import InternalTraversal as InternalTraversal
12-
from .. import exc as exc
13-
from .. import util as util
12+
from typing import overload
13+
from typing import Tuple
14+
from typing import Type
15+
from typing import TypeVar
16+
from typing import Union
17+
18+
from . import roles
19+
from .schema import Column
20+
from .traversals import HasCacheKey
21+
from .traversals import HasCopyInternals
22+
from .visitors import ClauseVisitor
23+
from .. import util
24+
from ..engine import Connection
25+
from ..engine import Engine
1426
from ..util import HasMemoized as HasMemoized
15-
from ..util import hybridmethod as hybridmethod
27+
from ..util import langhelpers
28+
29+
_T = TypeVar("_T")
30+
_SC = TypeVar("_SC", bound=SingletonConstant)
31+
_O = TypeVar("_O", bound=Options)
32+
_E = TypeVar("_E", bound=Executable)
33+
_C = TypeVar("_C", bound=Column[Any])
34+
_OC = TypeVar("_OC", bound=Column[Any])
1635

1736
coercions: ModuleType
1837
elements: ModuleType
1938
type_api: ModuleType
20-
PARSE_AUTOCOMMIT: Any
21-
NO_ARG: Any
39+
PARSE_AUTOCOMMIT: langhelpers._symbol
40+
NO_ARG: langhelpers._symbol
2241

2342
class Immutable:
24-
def unique_params(self, *optionaldict: Any, **kwargs: Any) -> None: ...
25-
def params(self, *optionaldict: Any, **kwargs: Any) -> None: ...
43+
def unique_params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn: ...
44+
def params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn: ...
2645

2746
class SingletonConstant(Immutable):
28-
def __new__(cls, *arg: Any, **kw: Any): ...
47+
def __new__(cls: Type[_SC], *arg: Any, **kw: Any) -> _SC: ...
2948

30-
class _DialectArgView(util.collections_abc.MutableMapping):
49+
class _DialectArgView(util.collections_abc.MutableMapping[str, Any]):
3150
obj: Any = ...
3251
def __init__(self, obj: Any) -> None: ...
33-
def __getitem__(self, key: Any): ...
34-
def __setitem__(self, key: Any, value: Any) -> None: ...
35-
def __delitem__(self, key: Any) -> None: ...
36-
def __len__(self): ...
37-
def __iter__(self) -> Any: ...
52+
def __getitem__(self, key: str) -> Any: ...
53+
def __setitem__(self, key: str, value: Any) -> None: ...
54+
def __delitem__(self, key: str) -> None: ...
55+
def __len__(self) -> int: ...
56+
def __iter__(self) -> Iterator[str]: ...
3857

39-
class _DialectArgDict(util.collections_abc.MutableMapping):
58+
class _DialectArgDict(util.collections_abc.MutableMapping[str, Any]):
4059
def __init__(self) -> None: ...
41-
def __len__(self): ...
42-
def __iter__(self) -> Any: ...
43-
def __getitem__(self, key: Any): ...
44-
def __setitem__(self, key: Any, value: Any) -> None: ...
45-
def __delitem__(self, key: Any) -> None: ...
60+
def __len__(self) -> int: ...
61+
def __iter__(self) -> Iterator[str]: ...
62+
def __getitem__(self, key: str) -> Any: ...
63+
def __setitem__(self, key: str, value: Any) -> None: ...
64+
def __delitem__(self, key: str) -> None: ...
4665

4766
class DialectKWArgs:
4867
@classmethod
4968
def argument_for(
5069
cls, dialect_name: Any, argument_name: Any, default: Any
5170
) -> None: ...
52-
def dialect_kwargs(self): ...
71+
@util.memoized_property
72+
def dialect_kwargs(self) -> _DialectArgView: ...
5373
@property
54-
def kwargs(self): ...
55-
def dialect_options(self): ...
74+
def kwargs(self) -> _DialectArgView: ...
75+
@util.memoized_property
76+
def dialect_options(self) -> util.PopulateDict[str, _DialectArgDict]: ...
5677

5778
class CompileState:
5879
plugins: Any = ...
5980
@classmethod
6081
def create_for_statement(
6182
cls, statement: Any, compiler: Any, **kw: Any
62-
): ...
83+
) -> Any: ...
6384
statement: Any = ...
6485
def __init__(self, statement: Any, compiler: Any, **kw: Any) -> None: ...
6586
@classmethod
66-
def get_plugin_class(cls, statement: Any): ...
87+
def get_plugin_class(cls, statement: Any) -> Optional[Any]: ...
6788
@classmethod
68-
def plugin_for(cls, plugin_name: Any, visit_name: Any): ...
89+
def plugin_for(
90+
cls, plugin_name: Any, visit_name: Any
91+
) -> Callable[[_T], _T]: ...
6992

7093
class Generative(HasMemoized): ...
7194
class InPlaceGenerative(HasMemoized): ...
7295
class HasCompileState(Generative): ...
7396

7497
class _MetaOptions(type):
7598
def __init__(cls, classname: Any, bases: Any, dict_: Any) -> None: ...
76-
def __add__(self, other: Any): ...
99+
def __add__(self: Type[_T], other: Any) -> _T: ...
77100

78-
class Options:
101+
class Options(metaclass=_MetaOptions):
79102
def __init__(self, **kw: Any) -> None: ...
80-
def __add__(self, other: Any): ...
81-
def __eq__(self, other: Any) -> Any: ...
103+
def __add__(self: _O, other: Any) -> _O: ...
104+
def __eq__(self, other: Any) -> bool: ...
82105
@classmethod
83-
def isinstance(cls, klass: Any): ...
84-
def add_to_element(self, name: Any, value: Any): ...
106+
def isinstance(cls, klass: Any) -> bool: ...
107+
def add_to_element(self: _O, name: Any, value: Any) -> _O: ...
85108
@classmethod
86-
def safe_merge(cls, other: Any): ...
109+
def safe_merge(cls: Type[_O], other: Any) -> _O: ...
87110
@classmethod
88111
def from_execution_options(
89112
cls,
90113
key: Any,
91114
attrs: Any,
92115
exec_options: Any,
93116
statement_exec_options: Any,
94-
): ...
117+
) -> Tuple[Any, Any]: ...
95118

96119
class CacheableOptions(Options, HasCacheKey): ...
97120

@@ -106,63 +129,72 @@ class Executable(roles.CoerceTextStatementRole, Generative):
106129
is_text: bool = ...
107130
is_delete: bool = ...
108131
is_dml: bool = ...
109-
def options(self, *options: Any) -> None: ...
110-
def execution_options(self, **kw: Any) -> None: ...
111-
def get_execution_options(self): ...
112-
def execute(self, *multiparams: Any, **params: Any): ...
113-
def scalar(self, *multiparams: Any, **params: Any): ...
132+
def options(self: _E, *options: Any) -> _E: ...
133+
def execution_options(self: _E, **kw: Any) -> _E: ...
134+
def get_execution_options(self) -> Any: ...
135+
def execute(self, *multiparams: Any, **params: Any) -> Any: ...
136+
def scalar(self, *multiparams: Any, **params: Any) -> Any: ...
114137
@property
115-
def bind(self): ...
138+
def bind(self) -> Optional[Union[Engine, Connection]]: ...
116139

117-
class prefix_anon_map(dict):
118-
def __missing__(self, key: Any): ...
140+
class prefix_anon_map(Dict[str, str]):
141+
def __missing__(self, key: str) -> str: ...
119142

120143
class SchemaEventTarget: ...
121144

122145
class SchemaVisitor(ClauseVisitor):
123146
__traverse_options__: Any = ...
124147

125-
class ColumnCollection:
126-
def __init__(self, columns: Optional[Any] = ...) -> None: ...
127-
def keys(self): ...
128-
def __bool__(self): ...
129-
def __len__(self): ...
130-
def __iter__(self) -> Any: ...
131-
def __getitem__(self, key: Any): ...
132-
def __getattr__(self, key: Any): ...
133-
def __contains__(self, key: Any): ...
134-
def compare(self, other: Any): ...
135-
def __eq__(self, other: Any) -> Any: ...
136-
def get(self, key: Any, default: Optional[Any] = ...): ...
137-
def __setitem__(self, key: Any, value: Any) -> None: ...
138-
def __delitem__(self, key: Any) -> None: ...
139-
def __setattr__(self, key: Any, obj: Any) -> None: ...
148+
class ColumnCollection(Generic[_C]):
149+
def __init__(
150+
self, columns: Optional[Iterable[Tuple[str, _C]]] = ...
151+
) -> None: ...
152+
def keys(self) -> List[_C]: ...
153+
def values(self) -> List[str]: ...
154+
def items(self) -> List[Tuple[str, _C]]: ...
155+
def __bool__(self) -> bool: ...
156+
def __len__(self) -> int: ...
157+
def __iter__(self) -> Iterator[_C]: ...
158+
def __getitem__(self, key: str) -> _C: ...
159+
def __getattr__(self, key: str) -> _C: ...
160+
def __contains__(self, key: str) -> bool: ...
161+
def compare(self, other: Any) -> bool: ...
162+
def __eq__(self, other: Any) -> bool: ...
163+
@overload
164+
def get(self, key: str) -> Optional[_C]: ...
165+
@overload
166+
def get(self, key: str, default: _T) -> Union[_C, _T]: ...
167+
def __setitem__(self, key: str, value: _C) -> None: ...
168+
def __delitem__(self, key: str) -> None: ...
169+
def __setattr__(self, key: str, obj: _C) -> None: ...
140170
def clear(self) -> None: ...
141-
def remove(self, column: Any) -> None: ...
142-
def update(self, iter_: Any) -> None: ...
171+
def remove(self, column: _C) -> None: ...
172+
def update(
173+
self, iter_: Union[Mapping[str, _C], Iterable[Tuple[str, _C]]]
174+
) -> None: ...
143175
__hash__: Any = ...
144-
def add(self, column: Any, key: Optional[Any] = ...) -> None: ...
145-
def contains_column(self, col: Any): ...
146-
def as_immutable(self): ...
176+
def add(self, column: _C, key: Optional[str] = ...) -> None: ...
177+
def contains_column(self, col: Column[Any]) -> bool: ...
178+
def as_immutable(self) -> ImmutableColumnCollection[_C]: ...
147179
def corresponding_column(
148-
self, column: Any, require_embedded: bool = ...
149-
): ...
150-
151-
class DedupeColumnCollection(ColumnCollection):
152-
def add(self, column: Any, key: Optional[Any] = ...) -> None: ...
153-
def extend(self, iter_: Any) -> None: ...
154-
def remove(self, column: Any) -> None: ...
155-
def replace(self, column: Any) -> None: ...
156-
157-
class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
158-
def __init__(self, collection: Any) -> None: ...
159-
add: Any = ...
160-
extend: Any = ...
161-
remove: Any = ...
162-
163-
class ColumnSet(util.ordered_column_set):
164-
def contains_column(self, col: Any): ...
165-
def extend(self, cols: Any) -> None: ...
166-
def __add__(self, other: Any): ...
180+
self, column: Column[Any], require_embedded: bool = ...
181+
) -> Optional[_C]: ...
182+
183+
class DedupeColumnCollection(ColumnCollection[_C]):
184+
def add(self, column: _C, key: Optional[str] = ...) -> None: ...
185+
def extend(self, iter_: Iterable[_C]) -> None: ...
186+
def remove(self, column: _C) -> None: ...
187+
def replace(self, column: _C) -> None: ...
188+
189+
class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection[_C]):
190+
def __init__(self, collection: ColumnCollection[_C]) -> None: ...
191+
def add(self, column: _C, key: Optional[str] = ...) -> NoReturn: ...
192+
def extend(self, iter_: Iterable[_C]) -> NoReturn: ...
193+
def remove(self, column: _C) -> NoReturn: ...
194+
195+
class ColumnSet(util.ordered_column_set[_C]):
196+
def contains_column(self, col: Column[Any]) -> bool: ...
197+
def extend(self, cols: Iterable[_C]) -> None: ...
198+
def __add__(self, other: Iterable[_OC]) -> List[Union[_C, _OC]]: ... # type: ignore[override]
167199
def __eq__(self, other: Any) -> Any: ...
168-
def __hash__(self) -> Any: ...
200+
def __hash__(self) -> int: ... # type: ignore[override]

0 commit comments

Comments
 (0)