@@ -58,8 +58,25 @@ def get_transform(self, name):
58
58
return KeyTransformFactory (name , self )
59
59
60
60
61
+ class EMFArrayRHSMixin :
62
+ def process_rhs (self , compiler , connection ):
63
+ values = self .rhs
64
+ if not self .get_db_prep_lookup_value_is_iterable :
65
+ values = [values ]
66
+ # Compute how to serialize each value based on the query target.
67
+ # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
68
+ # field of the subfield. Otherwise, use the base field of the array itself.
69
+ if isinstance (self .lhs , KeyTransform ):
70
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
71
+ else :
72
+ get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
73
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
74
+
75
+
61
76
@EmbeddedModelArrayField .register_lookup
62
- class EMFArrayExact (lookups .Exact ):
77
+ class EMFArrayExact (EMFArrayRHSMixin , lookups .Exact ):
78
+ get_db_prep_lookup_value_is_iterable = False
79
+
63
80
def as_mql (self , compiler , connection ):
64
81
if not isinstance (self .lhs , KeyTransform ):
65
82
raise ValueError ("error" )
@@ -82,22 +99,9 @@ def as_mql(self, compiler, connection):
82
99
83
100
84
101
@EmbeddedModelArrayField .register_lookup
85
- class ArrayOverlap (Lookup ):
102
+ class ArrayOverlap (EMFArrayRHSMixin , Lookup ):
86
103
lookup_name = "overlap"
87
- get_db_prep_lookup_value_is_iterable = True
88
-
89
- def process_rhs (self , compiler , connection ):
90
- values = self .rhs
91
- if self .get_db_prep_lookup_value_is_iterable :
92
- values = [values ]
93
- # Compute how to serialize each value based on the query target.
94
- # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
95
- # field of the subfield. Otherwise, use the base field of the array itself.
96
- if isinstance (self .lhs , KeyTransform ):
97
- get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
98
- else :
99
- get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
100
- return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
104
+ get_db_prep_lookup_value_is_iterable = False
101
105
102
106
def as_mql (self , compiler , connection ):
103
107
# Querying a subfield within the array elements (via nested KeyTransform).
0 commit comments