4
4
from django .db import models
5
5
from django .db .models import Field
6
6
from django .db .models .expressions import Col
7
- from django .db .models .lookups import Transform
7
+ from django .db .models .lookups import Lookup , Transform
8
8
9
9
from .. import forms
10
10
from ..query_utils import process_lhs , process_rhs
11
11
from . import EmbeddedModelField
12
12
from .array import ArrayField
13
- from .embedded_model import EMFExact
13
+ from .embedded_model import EMFExact , EMFMixin
14
14
15
15
16
16
class EmbeddedModelArrayField (ArrayField ):
@@ -60,17 +60,8 @@ def get_transform(self, name):
60
60
return KeyTransformFactory (name , self )
61
61
62
62
63
- class ProcessRHSMixin :
64
- def process_rhs (self , compiler , connection ):
65
- if isinstance (self .lhs , KeyTransform ):
66
- get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
67
- else :
68
- get_db_prep_value = self .lhs .output_field .get_db_prep_value
69
- return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
70
-
71
-
72
63
@EmbeddedModelArrayField .register_lookup
73
- class EMFArrayExact (EMFExact , ProcessRHSMixin ):
64
+ class EMFArrayExact (EMFExact ):
74
65
def as_mql (self , compiler , connection ):
75
66
lhs_mql = process_lhs (self , compiler , connection )
76
67
value = process_rhs (self , compiler , connection )
@@ -113,12 +104,29 @@ def as_mql(self, compiler, connection):
113
104
114
105
115
106
@EmbeddedModelArrayField .register_lookup
116
- class ArrayOverlap (EMFExact , ProcessRHSMixin ):
107
+ class ArrayOverlap (EMFMixin , Lookup ):
117
108
lookup_name = "overlap"
109
+ get_db_prep_lookup_value_is_iterable = True
110
+
111
+ def process_rhs (self , compiler , connection ):
112
+ values = self .rhs
113
+ if self .get_db_prep_lookup_value_is_iterable :
114
+ values = [values ]
115
+ # Compute how to serialize each value based on the query target.
116
+ # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
117
+ # field of the subfield. Otherwise, use the base field of the array itself.
118
+ if isinstance (self .lhs , KeyTransform ):
119
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
120
+ else :
121
+ get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
122
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
118
123
119
124
def as_mql (self , compiler , connection ):
120
125
lhs_mql = process_lhs (self , compiler , connection )
121
126
values = process_rhs (self , compiler , connection )
127
+ # Querying a subfield within the array elements (via nested KeyTransform).
128
+ # Replicates MongoDB's implicit ANY-match by mapping over the array and applying
129
+ # `$in` on the subfield.
122
130
if isinstance (self .lhs , KeyTransform ):
123
131
lhs_mql , inner_lhs_mql = lhs_mql
124
132
return {
@@ -137,11 +145,12 @@ def as_mql(self, compiler, connection):
137
145
}
138
146
conditions = []
139
147
inner_lhs_mql = "$$item"
148
+ # Querying full embedded documents in the array.
149
+ # Builds `$or` conditions and maps them over the array to match any full document.
140
150
for value in values :
141
- if isinstance (value , models .Model ):
142
- value , emf_data = self .model_to_dict (value )
143
- # Get conditions for any nested EmbeddedModelFields.
144
- conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
151
+ value , emf_data = self .model_to_dict (value )
152
+ # Get conditions for any nested EmbeddedModelFields.
153
+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
145
154
return {
146
155
"$anyElementTrue" : {
147
156
"$ifNull" : [
0 commit comments