Skip to content

Commit 3b229ff

Browse files
authored
Make the SQLAlchemy type hint work better in Pycharm. (#223)
* Fix `pycharm` unable to import protected variables in `.pyi` file. * Fix `pycharm` unable to import protected variables in `.pyi` file. * fix tests.(I don't know why here use `traversals. MemoizedHasCacheKey` lead to test failure) * Remove duplicate variable declarations.
1 parent 49f55e2 commit 3b229ff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+345
-406
lines changed

sqlalchemy-stubs/_typing.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ class _TypeToInstance(Generic[_T]):
2121

2222
_ExecuteParams = Union[Mapping[Any, Any], Sequence[Mapping[Any, Any]]]
2323
_ExecuteOptions = Mapping[Any, Any]
24+
25+
TypingExecuteOptions = _ExecuteOptions
26+
TypingExecuteParams = _ExecuteParams
27+
TypingTypeToInstance = _TypeToInstance

sqlalchemy-stubs/dialects/mssql/base.pyi

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ from ... import exc as exc
88
from ... import Identity as Identity
99
from ... import Sequence as Sequence
1010
from ... import sql as sql
11-
from ... import types as sqltypes
11+
from ...sql import sqltypes as sqltypes
12+
from ...sql import type_api as type_api
1213
from ... import util as util
1314
from ...engine import default as default
1415
from ...engine import reflection as reflection
@@ -47,11 +48,9 @@ MS_2000_VERSION: Any
4748
RESERVED_WORDS: Any
4849

4950
class REAL(sqltypes.REAL):
50-
__visit_name__: str = ...
5151
def __init__(self, **kw: Any) -> None: ...
5252

53-
class TINYINT(sqltypes.Integer):
54-
__visit_name__: str = ...
53+
class TINYINT(sqltypes.Integer):...
5554

5655
class _MSDate(sqltypes.Date):
5756
def bind_processor(self, dialect: Any): ...
@@ -70,16 +69,13 @@ class _DateTimeBase:
7069

7170
class _MSDateTime(_DateTimeBase, sqltypes.DateTime): ...
7271

73-
class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
74-
__visit_name__: str = ...
72+
class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):...
7573

7674
class DATETIME2(_DateTimeBase, sqltypes.DateTime):
77-
__visit_name__: str = ...
7875
precision: Any = ...
7976
def __init__(self, precision: Optional[Any] = ..., **kw: Any) -> None: ...
8077

8178
class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime):
82-
__visit_name__: str = ...
8379
precision: Any = ...
8480
def __init__(self, precision: Optional[Any] = ..., **kw: Any) -> None: ...
8581

@@ -89,42 +85,31 @@ class _UnicodeLiteral:
8985
class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode): ...
9086
class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText): ...
9187

92-
class TIMESTAMP(sqltypes._Binary):
93-
__visit_name__: str = ...
88+
class TIMESTAMP(sqltypes.TypingBinary):
9489
length: Any = ...
9590
convert_int: Any = ...
9691
def __init__(self, convert_int: bool = ...) -> None: ...
9792
def result_processor(self, dialect: Any, coltype: Any): ...
9893

99-
class ROWVERSION(TIMESTAMP):
100-
__visit_name__: str = ...
94+
class ROWVERSION(TIMESTAMP):...
10195

102-
class NTEXT(sqltypes.UnicodeText):
103-
__visit_name__: str = ...
96+
class NTEXT(sqltypes.UnicodeText):...
10497

105-
class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):
106-
__visit_name__: str = ...
98+
class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):...
10799

108-
class IMAGE(sqltypes.LargeBinary):
109-
__visit_name__: str = ...
100+
class IMAGE(sqltypes.LargeBinary):...
110101

111-
class XML(sqltypes.Text):
112-
__visit_name__: str = ...
102+
class XML(sqltypes.Text):...
113103

114-
class BIT(sqltypes.TypeEngine):
115-
__visit_name__: str = ...
104+
class BIT(type_api.TypeEngine):...
116105

117-
class MONEY(sqltypes.TypeEngine):
118-
__visit_name__: str = ...
106+
class MONEY(type_api.TypeEngine):...
119107

120-
class SMALLMONEY(sqltypes.TypeEngine):
121-
__visit_name__: str = ...
108+
class SMALLMONEY(type_api.TypeEngine):...
122109

123-
class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
124-
__visit_name__: str = ...
110+
class UNIQUEIDENTIFIER(type_api.TypeEngine):...
125111

126-
class SQL_VARIANT(sqltypes.TypeEngine):
127-
__visit_name__: str = ...
112+
class SQL_VARIANT(type_api.TypeEngine):...
128113

129114
class TryCast(sql.elements.Cast):
130115
__visit_name__: str = ...
@@ -418,3 +403,5 @@ class MSDialect(default.DefaultDialect):
418403
schema: Any,
419404
**kw: Any,
420405
): ...
406+
407+
TypingMSDate = _MSDate

sqlalchemy-stubs/dialects/mssql/mxodbc.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Any
22
from typing import Optional
33

4-
from .base import _MSDate
5-
from .base import _MSTime
4+
from .base import TypingMSDate as _MSDate
5+
from .base import MSTime as _MSTime
66
from .base import MSDialect as MSDialect
77
from .base import VARBINARY as VARBINARY
8-
from .pyodbc import _MSNumeric_pyodbc
8+
from .pyodbc import TypingMSNumeric_pyodbc as _MSNumeric_pyodbc
99
from .pyodbc import MSExecutionContext_pyodbc as MSExecutionContext_pyodbc
1010
from ...connectors.mxodbc import MxODBCConnector as MxODBCConnector
1111

sqlalchemy-stubs/dialects/mssql/pyodbc.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
5454
def is_disconnect(self, e: Any, connection: Any, cursor: Any): ...
5555

5656
dialect = MSDialect_pyodbc
57+
TypingMSNumeric_pyodbc=_MSNumeric_pyodbc

sqlalchemy-stubs/dialects/mysql/enumerated.pyi

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
from typing import Any
22

3-
from .types import _StringType
3+
from .types import TypingStringType as _StringType
44
from ... import exc as exc
55
from ... import sql as sql
66
from ... import util as util
77
from ...sql import sqltypes as sqltypes
88
from ...sql.base import NO_ARG as NO_ARG
9+
from ...sql import type_api as type_api
910

10-
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
11-
__visit_name__: str = ...
11+
class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
1212
native_enum: bool = ...
1313
def __init__(self, *enums: Any, **kw: Any) -> None: ...
1414
@classmethod
1515
def adapt_emulated_to_native(cls, impl: Any, **kw: Any): ...
1616

1717
class SET(_StringType):
18-
__visit_name__: str = ...
1918
retrieve_as_bitwise: Any = ...
2019
values: Any = ...
2120
def __init__(self, *values: Any, **kw: Any) -> None: ...

sqlalchemy-stubs/dialects/mysql/types.pyi

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ from typing import Any
22
from typing import Optional
33

44
from ... import exc as exc
5-
from ... import types as sqltypes
5+
from ...sql import sqltypes as sqltypes
6+
from ...sql import type_api as type_api
67
from ... import util as util
78

89
class _NumericType:
@@ -49,7 +50,6 @@ class _MatchType(sqltypes.Float, sqltypes.MatchType):
4950
def __init__(self, **kw: Any) -> None: ...
5051

5152
class NUMERIC(_NumericType, sqltypes.NUMERIC):
52-
__visit_name__: str = ...
5353
def __init__(
5454
self,
5555
precision: Optional[Any] = ...,
@@ -59,7 +59,6 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC):
5959
) -> None: ...
6060

6161
class DECIMAL(_NumericType, sqltypes.DECIMAL):
62-
__visit_name__: str = ...
6362
def __init__(
6463
self,
6564
precision: Optional[Any] = ...,
@@ -69,7 +68,6 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL):
6968
) -> None: ...
7069

7170
class DOUBLE(_FloatType):
72-
__visit_name__: str = ...
7371
def __init__(
7472
self,
7573
precision: Optional[Any] = ...,
@@ -79,7 +77,6 @@ class DOUBLE(_FloatType):
7977
) -> None: ...
8078

8179
class REAL(_FloatType, sqltypes.REAL):
82-
__visit_name__: str = ...
8380
def __init__(
8481
self,
8582
precision: Optional[Any] = ...,
@@ -89,7 +86,6 @@ class REAL(_FloatType, sqltypes.REAL):
8986
) -> None: ...
9087

9188
class FLOAT(_FloatType, sqltypes.FLOAT):
92-
__visit_name__: str = ...
9389
def __init__(
9490
self,
9591
precision: Optional[Any] = ...,
@@ -100,105 +96,86 @@ class FLOAT(_FloatType, sqltypes.FLOAT):
10096
def bind_processor(self, dialect: Any) -> None: ...
10197

10298
class INTEGER(_IntegerType, sqltypes.INTEGER):
103-
__visit_name__: str = ...
10499
def __init__(
105100
self, display_width: Optional[Any] = ..., **kw: Any
106101
) -> None: ...
107102

108103
class BIGINT(_IntegerType, sqltypes.BIGINT):
109-
__visit_name__: str = ...
110104
def __init__(
111105
self, display_width: Optional[Any] = ..., **kw: Any
112106
) -> None: ...
113107

114108
class MEDIUMINT(_IntegerType):
115-
__visit_name__: str = ...
116109
def __init__(
117110
self, display_width: Optional[Any] = ..., **kw: Any
118111
) -> None: ...
119112

120113
class TINYINT(_IntegerType):
121-
__visit_name__: str = ...
122114
def __init__(
123115
self, display_width: Optional[Any] = ..., **kw: Any
124116
) -> None: ...
125117

126118
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
127-
__visit_name__: str = ...
128119
def __init__(
129120
self, display_width: Optional[Any] = ..., **kw: Any
130121
) -> None: ...
131122

132-
class BIT(sqltypes.TypeEngine):
133-
__visit_name__: str = ...
123+
class BIT(type_api.TypeEngine):
134124
length: Any = ...
135125
def __init__(self, length: Optional[Any] = ...) -> None: ...
136126
def result_processor(self, dialect: Any, coltype: Any): ...
137127

138128
class TIME(sqltypes.TIME):
139-
__visit_name__: str = ...
140129
fsp: Any = ...
141130
def __init__(
142131
self, timezone: bool = ..., fsp: Optional[Any] = ...
143132
) -> None: ...
144133
def result_processor(self, dialect: Any, coltype: Any): ...
145134

146135
class TIMESTAMP(sqltypes.TIMESTAMP):
147-
__visit_name__: str = ...
148136
fsp: Any = ...
149137
def __init__(
150138
self, timezone: bool = ..., fsp: Optional[Any] = ...
151139
) -> None: ...
152140

153141
class DATETIME(sqltypes.DATETIME):
154-
__visit_name__: str = ...
155142
fsp: Any = ...
156143
def __init__(
157144
self, timezone: bool = ..., fsp: Optional[Any] = ...
158145
) -> None: ...
159146

160-
class YEAR(sqltypes.TypeEngine):
161-
__visit_name__: str = ...
147+
class YEAR(type_api.TypeEngine):
162148
display_width: Any = ...
163149
def __init__(self, display_width: Optional[Any] = ...) -> None: ...
164150

165151
class TEXT(_StringType, sqltypes.TEXT):
166-
__visit_name__: str = ...
167152
def __init__(self, length: Optional[Any] = ..., **kw: Any) -> None: ...
168153

169154
class TINYTEXT(_StringType):
170-
__visit_name__: str = ...
171155
def __init__(self, **kwargs: Any) -> None: ...
172156

173157
class MEDIUMTEXT(_StringType):
174-
__visit_name__: str = ...
175158
def __init__(self, **kwargs: Any) -> None: ...
176159

177160
class LONGTEXT(_StringType):
178-
__visit_name__: str = ...
179161
def __init__(self, **kwargs: Any) -> None: ...
180162

181163
class VARCHAR(_StringType, sqltypes.VARCHAR):
182-
__visit_name__: str = ...
183164
def __init__(self, length: Optional[Any] = ..., **kwargs: Any) -> None: ...
184165

185166
class CHAR(_StringType, sqltypes.CHAR):
186-
__visit_name__: str = ...
187167
def __init__(self, length: Optional[Any] = ..., **kwargs: Any) -> None: ...
188168

189169
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
190-
__visit_name__: str = ...
191170
def __init__(self, length: Optional[Any] = ..., **kwargs: Any) -> None: ...
192171

193172
class NCHAR(_StringType, sqltypes.NCHAR):
194-
__visit_name__: str = ...
195173
def __init__(self, length: Optional[Any] = ..., **kwargs: Any) -> None: ...
196174

197-
class TINYBLOB(sqltypes._Binary):
198-
__visit_name__: str = ...
175+
class TINYBLOB(sqltypes.TypingBinary):...
199176

200-
class MEDIUMBLOB(sqltypes._Binary):
201-
__visit_name__: str = ...
177+
class MEDIUMBLOB(sqltypes.TypingBinary):...
202178

203-
class LONGBLOB(sqltypes._Binary):
204-
__visit_name__: str = ...
179+
class LONGBLOB(sqltypes.TypingBinary):...
180+
181+
TypingStringType=_StringType

sqlalchemy-stubs/dialects/oracle/base.pyi

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from ...engine import reflection as reflection
1010
from ...sql import compiler as compiler
1111
from ...sql import expression as expression
1212
from ...sql import sqltypes as sqltypes
13+
from ...sql import type_api as type_api
1314
from ...sql import visitors as visitors
1415
from ...types import BLOB as BLOB
1516
from ...types import CHAR as CHAR
@@ -26,21 +27,17 @@ from ...util import compat as compat
2627
RESERVED_WORDS: Any
2728
NO_ARG_FNS: Any
2829

29-
class RAW(sqltypes._Binary):
30-
__visit_name__: str = ...
30+
class RAW(sqltypes.TypingBinary):...
3131

3232
OracleRaw = RAW
3333

34-
class NCLOB(sqltypes.Text):
35-
__visit_name__: str = ...
34+
class NCLOB(sqltypes.Text):...
3635

37-
class VARCHAR2(VARCHAR):
38-
__visit_name__: str = ...
36+
class VARCHAR2(VARCHAR):...
3937

4038
NVARCHAR2 = NVARCHAR
4139

4240
class NUMBER(sqltypes.Numeric, sqltypes.Integer):
43-
__visit_name__: str = ...
4441
def __init__(
4542
self,
4643
precision: Optional[Any] = ...,
@@ -49,26 +46,19 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer):
4946
) -> None: ...
5047
def adapt(self, impltype: Any): ...
5148

52-
class DOUBLE_PRECISION(sqltypes.Float):
53-
__visit_name__: str = ...
49+
class DOUBLE_PRECISION(sqltypes.Float):...
5450

55-
class BINARY_DOUBLE(sqltypes.Float):
56-
__visit_name__: str = ...
51+
class BINARY_DOUBLE(sqltypes.Float):...
5752

58-
class BINARY_FLOAT(sqltypes.Float):
59-
__visit_name__: str = ...
53+
class BINARY_FLOAT(sqltypes.Float):...
6054

61-
class BFILE(sqltypes.LargeBinary):
62-
__visit_name__: str = ...
55+
class BFILE(sqltypes.LargeBinary):...
6356

64-
class LONG(sqltypes.Text):
65-
__visit_name__: str = ...
57+
class LONG(sqltypes.Text):...
6658

67-
class DATE(sqltypes.DateTime):
68-
__visit_name__: str = ...
59+
class DATE(sqltypes.DateTime):...
6960

70-
class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
71-
__visit_name__: str = ...
61+
class INTERVAL(type_api.NativeForEmulated, sqltypes.TypingAbstractInterval):
7262
day_precision: Any = ...
7363
second_precision: Any = ...
7464
def __init__(
@@ -78,8 +68,7 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
7868
) -> None: ...
7969
def as_generic(self, allow_nulltype: bool = ...): ...
8070

81-
class ROWID(sqltypes.TypeEngine):
82-
__visit_name__: str = ...
71+
class ROWID(type_api.TypeEngine):...
8372

8473
class _OracleBoolean(sqltypes.Boolean):
8574
def get_dbapi_type(self, dbapi: Any): ...

0 commit comments

Comments
 (0)