Skip to content

Commit 1e6089d

Browse files
authored
chore: improve type hints (#1784)
1 parent 7f077c1 commit 1e6089d

File tree

1 file changed

+47
-64
lines changed

1 file changed

+47
-64
lines changed

tortoise/__init__.py

Lines changed: 47 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import importlib
35
import importlib.metadata as importlib_metadata
@@ -7,20 +9,9 @@
79
from copy import deepcopy
810
from inspect import isclass
911
from types import ModuleType
10-
from typing import (
11-
Callable,
12-
Coroutine,
13-
Dict,
14-
Iterable,
15-
List,
16-
Optional,
17-
Tuple,
18-
Type,
19-
Union,
20-
cast,
21-
)
12+
from typing import Any, Callable, Coroutine, Iterable, Type, cast
2213

23-
from pypika import Table
14+
from pypika import Query, Table
2415

2516
from tortoise.backends.base.client import BaseDBAsyncClient
2617
from tortoise.backends.base.config_generator import expand_db_url, generate_config
@@ -40,8 +31,8 @@
4031

4132

4233
class Tortoise:
43-
apps: Dict[str, Dict[str, Type["Model"]]] = {}
44-
table_name_generator: Optional[Callable[[Type["Model"]], str]] = None
34+
apps: dict[str, dict[str, Type["Model"]]] = {}
35+
table_name_generator: Callable[[Type["Model"]], str] | None = None
4536
_inited: bool = False
4637

4738
@classmethod
@@ -60,7 +51,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient:
6051
@classmethod
6152
def describe_model(
6253
cls, model: Type["Model"], serializable: bool = True
63-
) -> dict: # pragma: nocoverage
54+
) -> dict[str, Any]: # pragma: nocoverage
6455
"""
6556
Describes the given list of models or ALL registered models.
6657
@@ -85,8 +76,8 @@ def describe_model(
8576

8677
@classmethod
8778
def describe_models(
88-
cls, models: Optional[List[Type["Model"]]] = None, serializable: bool = True
89-
) -> Dict[str, dict]:
79+
cls, models: list[Type["Model"]] | None = None, serializable: bool = True
80+
) -> dict[str, dict[str, Any]]:
9081
"""
9182
Describes the given list of models or ALL registered models.
9283
@@ -142,7 +133,7 @@ def get_related_model(related_app_name: str, related_model_name: str) -> Type["M
142133
f" app '{related_app_name}'."
143134
)
144135

145-
def split_reference(reference: str) -> Tuple[str, str]:
136+
def split_reference(reference: str) -> tuple[str, str]:
146137
"""
147138
Validate, if reference follow the official naming conventions. Throws a
148139
ConfigurationError with a hopefully helpful message. If successful,
@@ -158,12 +149,9 @@ def split_reference(reference: str) -> Tuple[str, str]:
158149
return items[0], items[1]
159150

160151
def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
161-
if is_o2o:
162-
fk_object: Union[OneToOneFieldInstance, ForeignKeyFieldInstance] = cast(
163-
OneToOneFieldInstance, model._meta.fields_map[field]
164-
)
165-
else:
166-
fk_object = cast(ForeignKeyFieldInstance, model._meta.fields_map[field])
152+
fk_object = cast(
153+
"OneToOneFieldInstance | ForeignKeyFieldInstance", model._meta.fields_map[field]
154+
)
167155
related_app_name, related_model_name = split_reference(fk_object.model_name)
168156
related_model = get_related_model(related_app_name, related_model_name)
169157

@@ -206,24 +194,24 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
206194
f'backward relation "{backward_relation_name}" duplicates in'
207195
f" model {related_model_name}"
208196
)
209-
if is_o2o:
210-
fk_relation: Union[BackwardOneToOneRelation, BackwardFKRelation] = (
211-
BackwardOneToOneRelation(
212-
model,
213-
key_field,
214-
key_fk_object.source_field,
215-
null=True,
216-
description=fk_object.description,
217-
)
197+
198+
fk_relation = (
199+
BackwardOneToOneRelation(
200+
model,
201+
key_field,
202+
key_fk_object.source_field,
203+
null=True,
204+
description=fk_object.description,
218205
)
219-
else:
220-
fk_relation = BackwardFKRelation(
206+
if is_o2o
207+
else BackwardFKRelation(
221208
model,
222209
key_field,
223210
key_fk_object.source_field,
224211
null=fk_object.null,
225212
description=fk_object.description,
226213
)
214+
)
227215
fk_relation.to_field_instance = fk_object.to_field_instance # type:ignore
228216
related_model._meta.add_field(backward_relation_name, fk_relation)
229217
if is_o2o and fk_object.pk:
@@ -251,8 +239,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
251239
m2m_object = cast(ManyToManyFieldInstance, model._meta.fields_map[field])
252240
if m2m_object._generated:
253241
continue
254-
backward_key = m2m_object.backward_key
255-
if not backward_key:
242+
if not (backward_key := m2m_object.backward_key):
256243
backward_key = f"{model._meta.db_table}_id"
257244
if backward_key == m2m_object.forward_key:
258245
backward_key = f"{model._meta.db_table}_rel_id"
@@ -264,8 +251,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
264251

265252
m2m_object.related_model = related_model
266253

267-
backward_relation_name = m2m_object.related_name
268-
if not backward_relation_name:
254+
if not (backward_relation_name := m2m_object.related_name):
269255
backward_relation_name = m2m_object.related_name = (
270256
f"{model._meta.db_table}s"
271257
)
@@ -295,9 +281,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
295281
related_model._meta.add_field(backward_relation_name, m2m_relation)
296282

297283
@classmethod
298-
def _discover_models(
299-
cls, models_path: Union[ModuleType, str], app_label: str
300-
) -> List[Type["Model"]]:
284+
def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[Type["Model"]]:
301285
if isinstance(models_path, ModuleType):
302286
module = models_path
303287
else:
@@ -306,11 +290,11 @@ def _discover_models(
306290
except ImportError:
307291
raise ConfigurationError(f'Module "{models_path}" not found')
308292
discovered_models = []
309-
possible_models = getattr(module, "__models__", None)
310-
try:
311-
possible_models = [*possible_models] # type:ignore
312-
except TypeError:
313-
possible_models = None
293+
if possible_models := getattr(module, "__models__", None):
294+
try:
295+
possible_models = [*possible_models]
296+
except TypeError:
297+
possible_models = None
314298
if not possible_models:
315299
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
316300
for attr in possible_models:
@@ -326,7 +310,7 @@ def _discover_models(
326310
@classmethod
327311
def init_models(
328312
cls,
329-
models_paths: Iterable[Union[ModuleType, str]],
313+
models_paths: Iterable[ModuleType | str],
330314
app_label: str,
331315
_init_relations: bool = True,
332316
) -> None:
@@ -342,7 +326,7 @@ def init_models(
342326
343327
:raises ConfigurationError: If models are invalid.
344328
"""
345-
app_models: List[Type[Model]] = []
329+
app_models: list[Type[Model]] = []
346330
for models_path in models_paths:
347331
app_models += cls._discover_models(models_path, app_label)
348332

@@ -352,7 +336,7 @@ def init_models(
352336
cls._init_relations()
353337

354338
@classmethod
355-
def _init_apps(cls, apps_config: dict) -> None:
339+
def _init_apps(cls, apps_config: dict[str, dict[str, Any]]) -> None:
356340
for name, info in apps_config.items():
357341
try:
358342
connections.get(info.get("default_connection", "default"))
@@ -396,23 +380,23 @@ def _build_initial_querysets(cls) -> None:
396380
model._meta.finalise_model()
397381
model._meta.basetable = Table(name=model._meta.db_table, schema=model._meta.schema)
398382
basequery = model._meta.db.query_class.from_(model._meta.basetable)
399-
model._meta.basequery = basequery # type:ignore[assignment]
400-
model._meta.basequery_all_fields = basequery.select(
401-
*model._meta.db_fields
402-
) # type:ignore[assignment]
383+
model._meta.basequery = cast(Query, basequery)
384+
model._meta.basequery_all_fields = cast(
385+
Query, basequery.select(*model._meta.db_fields)
386+
)
403387

404388
@classmethod
405389
async def init(
406390
cls,
407-
config: Optional[dict] = None,
408-
config_file: Optional[str] = None,
391+
config: dict[str, Any] | None = None,
392+
config_file: str | None = None,
409393
_create_db: bool = False,
410-
db_url: Optional[str] = None,
411-
modules: Optional[Dict[str, Iterable[Union[str, ModuleType]]]] = None,
394+
db_url: str | None = None,
395+
modules: dict[str, Iterable[str | ModuleType]] | None = None,
412396
use_tz: bool = False,
413397
timezone: str = "UTC",
414-
routers: Optional[List[Union[str, Type]]] = None,
415-
table_name_generator: Optional[Callable[[Type["Model"]], str]] = None,
398+
routers: list[str | type] | None = None,
399+
table_name_generator: Callable[[Type["Model"]], str] | None = None,
416400
) -> None:
417401
"""
418402
Sets up Tortoise-ORM.
@@ -516,8 +500,7 @@ async def init(
516500
for name, info in connections_config.items():
517501
if isinstance(info, str):
518502
info = expand_db_url(info)
519-
password = info.get("credentials", {}).get("password")
520-
if password:
503+
if password := info.get("credentials", {}).get("password"):
521504
passwords.append(password)
522505

523506
str_connection_config = str(connections_config)
@@ -542,7 +525,7 @@ async def init(
542525
cls._inited = True
543526

544527
@classmethod
545-
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None:
528+
def _init_routers(cls, routers: list[str | type] | None = None) -> None:
546529
from tortoise.router import router
547530

548531
routers = routers or []

0 commit comments

Comments
 (0)