@@ -57,11 +57,20 @@ def get_transform(self, name):
57
57
transform = super ().get_transform (name )
58
58
if transform :
59
59
return transform
60
- return KeyTransformFactory (name , self .base_field )
60
+ return KeyTransformFactory (name , self )
61
+
62
+
63
+ class ProcessRHSMixin :
64
+ def process_rhs (self , compiler , connection ):
65
+ if isinstance (self .lhs , KeyTransform ):
66
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
67
+ else :
68
+ get_db_prep_value = self .lhs .output_field .get_db_prep_value
69
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
61
70
62
71
63
72
@EmbeddedModelArrayField .register_lookup
64
- class EMFArrayExact (EMFExact ):
73
+ class EMFArrayExact (EMFExact , ProcessRHSMixin ):
65
74
def as_mql (self , compiler , connection ):
66
75
lhs_mql = process_lhs (self , compiler , connection )
67
76
value = process_rhs (self , compiler , connection )
@@ -103,15 +112,61 @@ def as_mql(self, compiler, connection):
103
112
}
104
113
105
114
115
+ @EmbeddedModelArrayField .register_lookup
116
+ class ArrayOverlap (EMFExact , ProcessRHSMixin ):
117
+ lookup_name = "overlap"
118
+
119
+ def as_mql (self , compiler , connection ):
120
+ lhs_mql = process_lhs (self , compiler , connection )
121
+ values = process_rhs (self , compiler , connection )
122
+ if isinstance (self .lhs , KeyTransform ):
123
+ lhs_mql , inner_lhs_mql = lhs_mql
124
+ return {
125
+ "$anyElementTrue" : {
126
+ "$ifNull" : [
127
+ {
128
+ "$map" : {
129
+ "input" : lhs_mql ,
130
+ "as" : "item" ,
131
+ "in" : {"$in" : [inner_lhs_mql , values ]},
132
+ }
133
+ },
134
+ [],
135
+ ]
136
+ }
137
+ }
138
+ conditions = []
139
+ inner_lhs_mql = "$$item"
140
+ for value in values :
141
+ if isinstance (value , models .Model ):
142
+ value , emf_data = self .model_to_dict (value )
143
+ # Get conditions for any nested EmbeddedModelFields.
144
+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
145
+ return {
146
+ "$anyElementTrue" : {
147
+ "$ifNull" : [
148
+ {
149
+ "$map" : {
150
+ "input" : lhs_mql ,
151
+ "as" : "item" ,
152
+ "in" : {"$or" : conditions },
153
+ }
154
+ },
155
+ [],
156
+ ]
157
+ }
158
+ }
159
+
160
+
106
161
class KeyTransform (Transform ):
107
162
# it should be different class than EMF keytransform even most of the methods are equal.
108
- def __init__ (self , key_name , base_field , * args , ** kwargs ):
163
+ def __init__ (self , key_name , array_field , * args , ** kwargs ):
109
164
super ().__init__ (* args , ** kwargs )
110
- self .base_field = base_field
165
+ self .array_field = array_field
111
166
self .key_name = key_name
112
167
# The iteration items begins from the base_field, a virtual column with
113
168
# base field output type is created.
114
- column_target = base_field .clone ()
169
+ column_target = array_field . base_field . embedded_model . _meta . get_field ( key_name ) .clone ()
115
170
column_name = f"$item.{ key_name } "
116
171
column_target .db_column = column_name
117
172
column_target .set_attributes_from_name (column_name )
@@ -134,7 +189,7 @@ def _get_missing_field_or_lookup_exception(self, lhs, name):
134
189
suggestion = "."
135
190
raise FieldDoesNotExist (
136
191
f"Unsupported lookup '{ name } ' for "
137
- f"{ self .base_field .__class__ .__name__ } '{ self .base_field .name } '"
192
+ f"{ self .array_field . base_field .__class__ .__name__ } '{ self . array_field .base_field .name } '"
138
193
f"{ suggestion } "
139
194
)
140
195
@@ -147,7 +202,9 @@ def get_transform(self, name):
147
202
transform = (
148
203
self ._lhs .get_transform (name )
149
204
if isinstance (self ._lhs , Transform )
150
- else self .base_field .embedded_model ._meta .get_field (self .key_name ).get_transform (name )
205
+ else self .array_field .base_field .embedded_model ._meta .get_field (
206
+ self .key_name
207
+ ).get_transform (name )
151
208
)
152
209
if transform :
153
210
self ._sub_transform = transform
@@ -163,7 +220,7 @@ def as_mql(self, compiler, connection):
163
220
164
221
@property
165
222
def output_field (self ):
166
- return EmbeddedModelArrayField ( self .base_field )
223
+ return self .array_field
167
224
168
225
169
226
class KeyTransformFactory :
0 commit comments