Skip to content

Commit 916d6cb

Browse files
authored
refactor: improve type hints (#1779)
1 parent 9758bd9 commit 916d6cb

File tree

5 files changed

+61
-71
lines changed

5 files changed

+61
-71
lines changed

tortoise/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -488,25 +488,26 @@ async def init(
488488

489489
if config_file:
490490
config = cls._get_config_from_config_file(config_file)
491-
492-
if db_url:
491+
elif db_url:
493492
if not modules:
494493
raise ConfigurationError('You must specify "db_url" and "modules" together')
495494
config = generate_config(db_url, modules)
495+
else:
496+
assert config is not None # To improve type hints
496497

497498
try:
498-
connections_config = config["connections"] # type: ignore
499+
connections_config = config["connections"]
499500
except KeyError:
500501
raise ConfigurationError('Config must define "connections" section')
501502

502503
try:
503-
apps_config = config["apps"] # type: ignore
504+
apps_config = config["apps"]
504505
except KeyError:
505506
raise ConfigurationError('Config must define "apps" section')
506507

507-
use_tz = config.get("use_tz", use_tz) # type: ignore
508-
timezone = config.get("timezone", timezone) # type: ignore
509-
routers = config.get("routers", routers) # type: ignore
508+
use_tz = config.get("use_tz", use_tz)
509+
timezone = config.get("timezone", timezone)
510+
routers = config.get("routers", routers)
510511

511512
cls.table_name_generator = table_name_generator
512513

tortoise/contrib/mysql/indexes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
fields: Optional[Tuple[str, ...]] = None,
1515
name: Optional[str] = None,
1616
parser_name: Optional[str] = None,
17-
):
17+
) -> None:
1818
super().__init__(*expressions, fields=fields, name=name)
1919
if parser_name:
2020
self.extra = f" WITH PARSER {parser_name}"

tortoise/contrib/postgres/indexes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(
1616
fields: Optional[Tuple[str, ...]] = None,
1717
name: Optional[str] = None,
1818
condition: Optional[dict] = None,
19-
):
19+
) -> None:
2020
super().__init__(*expressions, fields=fields, name=name)
2121
if condition:
2222
cond = " WHERE "

tortoise/expressions.py

Lines changed: 46 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,7 @@
44
from dataclasses import dataclass
55
from dataclasses import field as dataclass_field
66
from enum import Enum, auto
7-
from typing import (
8-
TYPE_CHECKING,
9-
Any,
10-
Dict,
11-
Iterator,
12-
List,
13-
Optional,
14-
Tuple,
15-
Type,
16-
Union,
17-
cast,
18-
)
7+
from typing import TYPE_CHECKING, Any, Iterator, Type, cast
198

209
from pypika import Case as PypikaCase
2110
from pypika import Field as PypikaField
@@ -48,15 +37,15 @@
4837
class ResolveContext:
4938
model: Type["Model"]
5039
table: Table
51-
annotations: Dict[str, Any]
52-
custom_filters: Dict[str, FilterInfoDict]
40+
annotations: dict[str, Any]
41+
custom_filters: dict[str, FilterInfoDict]
5342

5443

5544
@dataclass
5645
class ResolveResult:
5746
term: Term
58-
joins: List[TableCriterionTuple] = dataclass_field(default_factory=list)
59-
output_field: Optional[Field] = None
47+
joins: list[TableCriterionTuple] = dataclass_field(default_factory=list)
48+
output_field: Field | None = None
6049

6150

6251
class Expression:
@@ -93,25 +82,25 @@ class CombinedExpression(Expression):
9382
def __init__(self, left: Expression, connector: Connector, right: Any) -> None:
9483
self.left = left
9584
self.connector = connector
96-
self.right: Expression
97-
if isinstance(right, Expression):
98-
self.right = right
99-
else:
100-
self.right = Value(right)
85+
self.right = right if isinstance(right, Expression) else Value(right)
10186

10287
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
10388
left = self.left.resolve(resolve_context)
10489
right = self.right.resolve(resolve_context)
90+
left_output_field, right_output_field = left.output_field, right.output_field # type: ignore
10591

106-
if left.output_field and right.output_field: # type: ignore
107-
if type(left.output_field) is not type(right.output_field): # type: ignore
108-
raise FieldError("Cannot use arithmetic expression between different field type")
92+
if (
93+
left_output_field
94+
and right_output_field
95+
and type(left_output_field) is not type(right_output_field)
96+
):
97+
raise FieldError("Cannot use arithmetic expression between different field type")
10998

11099
operator_func = getattr(operator, self.connector.name)
111100
return ResolveResult(
112101
term=operator_func(left.term, right.term),
113102
joins=list(set(left.joins + right.joins)), # dedup joins
114-
output_field=right.output_field or left.output_field, # type: ignore
103+
output_field=right_output_field or left_output_field,
115104
)
116105

117106

@@ -129,7 +118,7 @@ def __init__(self, name: str) -> None:
129118

130119
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
131120
term: Term = PypikaField(self.name)
132-
joins: List[TableCriterionTuple] = []
121+
joins: list[TableCriterionTuple] = []
133122
output_field = None
134123
if self.name.split("__")[0] in resolve_context.model._meta.fetch_fields:
135124
# field in the format of "related_field__field" or "related_field__another_rel_field__field"
@@ -158,7 +147,7 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
158147
except KeyError:
159148
raise FieldError(
160149
f"There is no non-virtual field {self.name} on Model {resolve_context.model.__name__}"
161-
)
150+
) from None
162151
return ResolveResult(term=term, output_field=output_field, joins=joins)
163152

164153
def _combine(self, other: Any, connector: Connector, right_hand: bool) -> CombinedExpression:
@@ -260,9 +249,9 @@ def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None:
260249
if not all(isinstance(node, Q) for node in args):
261250
raise OperationalError("All ordered arguments must be Q nodes")
262251
#: Contains the sub-Q's that this Q is made up of
263-
self.children: Tuple[Q, ...] = args
252+
self.children: tuple[Q, ...] = args
264253
#: Contains the filters applied to this Q
265-
self.filters: Dict[str, FilterInfoDict] = kwargs
254+
self.filters: dict[str, FilterInfoDict] = kwargs
266255
if join_type not in {self.AND, self.OR}:
267256
raise OperationalError("join_type must be AND or OR")
268257
#: Specifies if this Q does an AND or OR on its children
@@ -357,7 +346,7 @@ def _resolve_custom_kwarg(
357346

358347
def _process_filter_kwarg(
359348
self, model: "Type[Model]", key: str, value: Any, table: Table
360-
) -> Tuple[Criterion, Optional[Tuple[Table, Criterion]]]:
349+
) -> tuple[Criterion, tuple[Table, Criterion] | None]:
361350
join = None
362351

363352
if value is None and f"{key}__isnull" in model._meta.filters:
@@ -408,7 +397,7 @@ def _resolve_regular_kwarg(
408397

409398
def _get_actual_filter_params(
410399
self, resolve_context: ResolveContext, key: str, value: Table | FilterInfoDict
411-
) -> Tuple[str, Any]:
400+
) -> tuple[str, Any]:
412401
filter_key = key
413402
if (
414403
key in resolve_context.model._meta.fk_fields
@@ -513,13 +502,13 @@ class Function(Expression):
513502
populate_field_object = False
514503

515504
def __init__(
516-
self, field: Union[str, F, CombinedExpression, "Function"], *default_values: Any
505+
self, field: str | F | CombinedExpression | "Function", *default_values: Any
517506
) -> None:
518507
self.field = field
519-
self.field_object: "Optional[Field]" = None
508+
self.field_object: "Field | None" = None
520509
self.default_values = default_values
521510

522-
def _get_function_field(self, field: Union[Term, str], *default_values) -> PypikaFunction:
511+
def _get_function_field(self, field: Term | str, *default_values) -> PypikaFunction:
523512
return self.database_func(field, *default_values) # type:ignore[arg-type]
524513

525514
def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult:
@@ -549,26 +538,22 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
549538

550539
default_values = self._resolve_default_values(resolve_context)
551540

552-
res = None
553-
if isinstance(self.field, str):
554-
function_arg = self._resolve_nested_field(resolve_context, self.field)
555-
term = self._get_function_field(function_arg.term, *default_values)
556-
res = ResolveResult(
557-
term=term,
558-
joins=function_arg.joins,
559-
output_field=function_arg.output_field, # type: ignore
560-
)
561-
else:
562-
function_arg = self.field.resolve(resolve_context)
563-
term = self._get_function_field(function_arg.term, *default_values)
564-
res = ResolveResult(
565-
term=term,
566-
joins=function_arg.joins,
567-
output_field=function_arg.output_field, # type: ignore
568-
)
541+
function_arg = (
542+
self._resolve_nested_field(resolve_context, self.field)
543+
if isinstance(self.field, str)
544+
else self.field.resolve(resolve_context)
545+
)
546+
term = self._get_function_field(function_arg.term, *default_values)
547+
res = ResolveResult(
548+
term=term,
549+
joins=function_arg.joins,
550+
output_field=function_arg.output_field, # type:ignore[call-overload]
551+
)
569552

570-
if self.populate_field_object and res.output_field: # type: ignore
571-
self.field_object = res.output_field # type: ignore
553+
if self.populate_field_object and (
554+
res_output_field := res.output_field # type:ignore[call-overload]
555+
):
556+
self.field_object = res_output_field
572557

573558
return res
574559

@@ -586,17 +571,17 @@ class Aggregate(Function):
586571

587572
def __init__(
588573
self,
589-
field: Union[str, F, CombinedExpression],
574+
field: str | F | CombinedExpression,
590575
*default_values: Any,
591576
distinct: bool = False,
592-
_filter: Optional[Q] = None,
577+
_filter: Q | None = None,
593578
) -> None:
594579
super().__init__(field, *default_values)
595580
self.distinct = distinct
596581
self.filter = _filter
597582

598583
def _get_function_field( # type:ignore[override]
599-
self, field: Union[ArithmeticExpression, PypikaField, str], *default_values
584+
self, field: ArithmeticExpression | PypikaField | str, *default_values
600585
) -> DistinctOptionFunction:
601586
function = cast(DistinctOptionFunction, self.database_func(field, *default_values))
602587
if self.distinct:
@@ -634,7 +619,7 @@ class When(Expression):
634619
def __init__(
635620
self,
636621
*args: Q,
637-
then: Union[str, F, CombinedExpression, Function],
622+
then: str | F | CombinedExpression | Function,
638623
negate: bool = False,
639624
**kwargs: Any,
640625
) -> None:
@@ -643,7 +628,7 @@ def __init__(
643628
self.negate = negate
644629
self.kwargs = kwargs
645630

646-
def _resolve_q_objects(self) -> List[Q]:
631+
def _resolve_q_objects(self) -> list[Q]:
647632
q_objects = []
648633
for arg in self.args:
649634
if not isinstance(arg, Q):
@@ -684,7 +669,9 @@ class Case(Expression):
684669
"""
685670

686671
def __init__(
687-
self, *args: When, default: Union[str, F, CombinedExpression, Function, None] = None
672+
self,
673+
*args: When,
674+
default: str | F | CombinedExpression | Function | None = None,
688675
) -> None:
689676
self.args = args
690677
self.default = default

tortoise/indexes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
*expressions: Term,
1919
fields: Optional[Tuple[str, ...]] = None,
2020
name: Optional[str] = None,
21-
):
21+
) -> None:
2222
"""
2323
All kinds of index parent class, default is BTreeIndex.
2424
@@ -38,7 +38,9 @@ def __init__(
3838
self.expressions = expressions
3939
self.extra = ""
4040

41-
def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool) -> str:
41+
def get_sql(
42+
self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool
43+
) -> str:
4244
if self.fields:
4345
fields = ", ".join(schema_generator.quote(f) for f in self.fields)
4446
else:
@@ -65,7 +67,7 @@ def __init__(
6567
fields: Optional[Tuple[str, ...]] = None,
6668
name: Optional[str] = None,
6769
condition: Optional[dict] = None,
68-
):
70+
) -> None:
6971
super().__init__(*expressions, fields=fields, name=name)
7072
if condition:
7173
cond = " WHERE "

0 commit comments

Comments
 (0)