44from dataclasses import dataclass
55from dataclasses import field as dataclass_field
66from enum import Enum , auto
7- from typing import (
8- TYPE_CHECKING ,
9- Any ,
10- Dict ,
11- Iterator ,
12- List ,
13- Optional ,
14- Tuple ,
15- Type ,
16- Union ,
17- cast ,
18- )
7+ from typing import TYPE_CHECKING , Any , Iterator , Type , cast
198
209from pypika import Case as PypikaCase
2110from pypika import Field as PypikaField
4837class ResolveContext :
4938 model : Type ["Model" ]
5039 table : Table
51- annotations : Dict [str , Any ]
52- custom_filters : Dict [str , FilterInfoDict ]
40+ annotations : dict [str , Any ]
41+ custom_filters : dict [str , FilterInfoDict ]
5342
5443
5544@dataclass
5645class ResolveResult :
5746 term : Term
58- joins : List [TableCriterionTuple ] = dataclass_field (default_factory = list )
59- output_field : Optional [ Field ] = None
47+ joins : list [TableCriterionTuple ] = dataclass_field (default_factory = list )
48+ output_field : Field | None = None
6049
6150
6251class Expression :
@@ -93,25 +82,25 @@ class CombinedExpression(Expression):
9382 def __init__ (self , left : Expression , connector : Connector , right : Any ) -> None :
9483 self .left = left
9584 self .connector = connector
96- self .right : Expression
97- if isinstance (right , Expression ):
98- self .right = right
99- else :
100- self .right = Value (right )
85+ self .right = right if isinstance (right , Expression ) else Value (right )
10186
10287 def resolve (self , resolve_context : ResolveContext ) -> ResolveResult :
10388 left = self .left .resolve (resolve_context )
10489 right = self .right .resolve (resolve_context )
90+ left_output_field , right_output_field = left .output_field , right .output_field # type: ignore
10591
106- if left .output_field and right .output_field : # type: ignore
107- if type (left .output_field ) is not type (right .output_field ): # type: ignore
108- raise FieldError ("Cannot use arithmetic expression between different field type" )
92+ if (
93+ left_output_field
94+ and right_output_field
95+ and type (left_output_field ) is not type (right_output_field )
96+ ):
97+ raise FieldError ("Cannot use arithmetic expression between different field type" )
10998
11099 operator_func = getattr (operator , self .connector .name )
111100 return ResolveResult (
112101 term = operator_func (left .term , right .term ),
113102 joins = list (set (left .joins + right .joins )), # dedup joins
114- output_field = right . output_field or left . output_field , # type: ignore
103+ output_field = right_output_field or left_output_field ,
115104 )
116105
117106
@@ -129,7 +118,7 @@ def __init__(self, name: str) -> None:
129118
130119 def resolve (self , resolve_context : ResolveContext ) -> ResolveResult :
131120 term : Term = PypikaField (self .name )
132- joins : List [TableCriterionTuple ] = []
121+ joins : list [TableCriterionTuple ] = []
133122 output_field = None
134123 if self .name .split ("__" )[0 ] in resolve_context .model ._meta .fetch_fields :
135124 # field in the format of "related_field__field" or "related_field__another_rel_field__field"
@@ -158,7 +147,7 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
158147 except KeyError :
159148 raise FieldError (
160149 f"There is no non-virtual field { self .name } on Model { resolve_context .model .__name__ } "
161- )
150+ ) from None
162151 return ResolveResult (term = term , output_field = output_field , joins = joins )
163152
164153 def _combine (self , other : Any , connector : Connector , right_hand : bool ) -> CombinedExpression :
@@ -260,9 +249,9 @@ def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None:
260249 if not all (isinstance (node , Q ) for node in args ):
261250 raise OperationalError ("All ordered arguments must be Q nodes" )
262251 #: Contains the sub-Q's that this Q is made up of
263- self .children : Tuple [Q , ...] = args
252+ self .children : tuple [Q , ...] = args
264253 #: Contains the filters applied to this Q
265- self .filters : Dict [str , FilterInfoDict ] = kwargs
254+ self .filters : dict [str , FilterInfoDict ] = kwargs
266255 if join_type not in {self .AND , self .OR }:
267256 raise OperationalError ("join_type must be AND or OR" )
268257 #: Specifies if this Q does an AND or OR on its children
@@ -357,7 +346,7 @@ def _resolve_custom_kwarg(
357346
358347 def _process_filter_kwarg (
359348 self , model : "Type[Model]" , key : str , value : Any , table : Table
360- ) -> Tuple [Criterion , Optional [ Tuple [ Table , Criterion ]] ]:
349+ ) -> tuple [Criterion , tuple [ Table , Criterion ] | None ]:
361350 join = None
362351
363352 if value is None and f"{ key } __isnull" in model ._meta .filters :
@@ -408,7 +397,7 @@ def _resolve_regular_kwarg(
408397
409398 def _get_actual_filter_params (
410399 self , resolve_context : ResolveContext , key : str , value : Table | FilterInfoDict
411- ) -> Tuple [str , Any ]:
400+ ) -> tuple [str , Any ]:
412401 filter_key = key
413402 if (
414403 key in resolve_context .model ._meta .fk_fields
@@ -513,13 +502,13 @@ class Function(Expression):
513502 populate_field_object = False
514503
515504 def __init__ (
516- self , field : Union [ str , F , CombinedExpression , "Function" ] , * default_values : Any
505+ self , field : str | F | CombinedExpression | "Function" , * default_values : Any
517506 ) -> None :
518507 self .field = field
519- self .field_object : "Optional[ Field] " = None
508+ self .field_object : "Field | None " = None
520509 self .default_values = default_values
521510
522- def _get_function_field (self , field : Union [ Term , str ] , * default_values ) -> PypikaFunction :
511+ def _get_function_field (self , field : Term | str , * default_values ) -> PypikaFunction :
523512 return self .database_func (field , * default_values ) # type:ignore[arg-type]
524513
525514 def _resolve_nested_field (self , resolve_context : ResolveContext , field : str ) -> ResolveResult :
@@ -549,26 +538,22 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
549538
550539 default_values = self ._resolve_default_values (resolve_context )
551540
552- res = None
553- if isinstance (self .field , str ):
554- function_arg = self ._resolve_nested_field (resolve_context , self .field )
555- term = self ._get_function_field (function_arg .term , * default_values )
556- res = ResolveResult (
557- term = term ,
558- joins = function_arg .joins ,
559- output_field = function_arg .output_field , # type: ignore
560- )
561- else :
562- function_arg = self .field .resolve (resolve_context )
563- term = self ._get_function_field (function_arg .term , * default_values )
564- res = ResolveResult (
565- term = term ,
566- joins = function_arg .joins ,
567- output_field = function_arg .output_field , # type: ignore
568- )
541+ function_arg = (
542+ self ._resolve_nested_field (resolve_context , self .field )
543+ if isinstance (self .field , str )
544+ else self .field .resolve (resolve_context )
545+ )
546+ term = self ._get_function_field (function_arg .term , * default_values )
547+ res = ResolveResult (
548+ term = term ,
549+ joins = function_arg .joins ,
550+ output_field = function_arg .output_field , # type:ignore[call-overload]
551+ )
569552
570- if self .populate_field_object and res .output_field : # type: ignore
571- self .field_object = res .output_field # type: ignore
553+ if self .populate_field_object and (
554+ res_output_field := res .output_field # type:ignore[call-overload]
555+ ):
556+ self .field_object = res_output_field
572557
573558 return res
574559
@@ -586,17 +571,17 @@ class Aggregate(Function):
586571
587572 def __init__ (
588573 self ,
589- field : Union [ str , F , CombinedExpression ] ,
574+ field : str | F | CombinedExpression ,
590575 * default_values : Any ,
591576 distinct : bool = False ,
592- _filter : Optional [ Q ] = None ,
577+ _filter : Q | None = None ,
593578 ) -> None :
594579 super ().__init__ (field , * default_values )
595580 self .distinct = distinct
596581 self .filter = _filter
597582
598583 def _get_function_field ( # type:ignore[override]
599- self , field : Union [ ArithmeticExpression , PypikaField , str ] , * default_values
584+ self , field : ArithmeticExpression | PypikaField | str , * default_values
600585 ) -> DistinctOptionFunction :
601586 function = cast (DistinctOptionFunction , self .database_func (field , * default_values ))
602587 if self .distinct :
@@ -634,7 +619,7 @@ class When(Expression):
634619 def __init__ (
635620 self ,
636621 * args : Q ,
637- then : Union [ str , F , CombinedExpression , Function ] ,
622+ then : str | F | CombinedExpression | Function ,
638623 negate : bool = False ,
639624 ** kwargs : Any ,
640625 ) -> None :
@@ -643,7 +628,7 @@ def __init__(
643628 self .negate = negate
644629 self .kwargs = kwargs
645630
646- def _resolve_q_objects (self ) -> List [Q ]:
631+ def _resolve_q_objects (self ) -> list [Q ]:
647632 q_objects = []
648633 for arg in self .args :
649634 if not isinstance (arg , Q ):
@@ -684,7 +669,9 @@ class Case(Expression):
684669 """
685670
686671 def __init__ (
687- self , * args : When , default : Union [str , F , CombinedExpression , Function , None ] = None
672+ self ,
673+ * args : When ,
674+ default : str | F | CombinedExpression | Function | None = None ,
688675 ) -> None :
689676 self .args = args
690677 self .default = default
0 commit comments