@@ -70,31 +70,53 @@ def as_mql(self, compiler, connection):
70
70
if isinstance (value , models .Model ):
71
71
value , emf_data = self .model_to_dict (value )
72
72
# Get conditions for any nested EmbeddedModelFields.
73
- conditions = self .get_conditions ({"$$item" : (value , emf_data )})
73
+ conditions = self .get_conditions ({lhs_mql [ 1 ] : (value , emf_data )})
74
74
return {
75
75
"$anyElementTrue" : {
76
- "$map" : {"input" : lhs_mql , "as" : "item" , "in" : {"$and" : conditions }}
76
+ "$ifNull" : [
77
+ {
78
+ "$map" : {
79
+ "input" : lhs_mql [0 ],
80
+ "as" : "item" ,
81
+ "in" : {"$and" : conditions },
82
+ }
83
+ },
84
+ [],
85
+ ]
77
86
}
78
87
}
79
- lhs_mql = process_lhs (self .lhs , compiler , connection )
80
88
return {
81
89
"$anyElementTrue" : {
82
- "$map" : {
83
- "input" : lhs_mql ,
84
- "as" : "item" ,
85
- "in" : {"$eq" : [f"$$item.{ self .lhs .key_name } " , value ]},
86
- }
90
+ "$ifNull" : [
91
+ {
92
+ "$map" : {
93
+ "input" : lhs_mql [0 ],
94
+ "as" : "item" ,
95
+ "in" : {"$eq" : [lhs_mql [1 ], value ]},
96
+ }
97
+ },
98
+ [],
99
+ ]
87
100
}
88
101
}
89
102
return connection .mongo_operators [self .lookup_name ](lhs_mql , value )
90
103
91
104
92
105
class KeyTransform (Transform ):
93
106
# it should be different class than EMF keytransform even most of the methods are equal.
94
- def __init__ (self , key_name , ref_field , * args , ** kwargs ):
107
+ def __init__ (self , key_name , base_field , * args , ** kwargs ):
95
108
super ().__init__ (* args , ** kwargs )
96
- self .key_name = str (key_name )
97
- self .ref_field = ref_field
109
+ self .base_field = base_field
110
+ # TODO: Need to create a column, will refactor this thing.
111
+ column_target = base_field .clone ()
112
+ column_target .db_column = f"$item.{ key_name } "
113
+ column_target .set_attributes_from_name (f"$item.{ key_name } " )
114
+ self ._lhs = Col (None , column_target )
115
+ self ._sub_transform = None
116
+
117
+ def __call__ (self , this , * args , ** kwargs ):
118
+ self ._lhs = self ._sub_transform (self ._lhs , * args , ** kwargs )
119
+ return self
98
120
99
121
def get_lookup (self , name ):
100
122
return self .output_field .get_lookup (name )
@@ -104,33 +126,42 @@ def get_transform(self, name):
104
126
Validate that `name` is either a field of an embedded model or a
105
127
lookup on an embedded model's field.
106
128
"""
107
- if transform := self .ref_field .get_transform (name ):
108
- return transform
109
- suggested_lookups = difflib .get_close_matches (name , self .ref_field .get_lookups ())
129
+ if isinstance (self ._lhs , Transform ):
130
+ transform = self ._lhs .get_transform (name )
131
+ else :
132
+ transform = self .base_field .get_transform (name )
133
+ if transform :
134
+ self ._sub_transform = transform
135
+ return self
136
+ suggested_lookups = difflib .get_close_matches (name , self .base_field .get_lookups ())
110
137
if suggested_lookups :
111
138
suggested_lookups = " or " .join (suggested_lookups )
112
139
suggestion = f", perhaps you meant { suggested_lookups } ?"
113
140
else :
114
141
suggestion = "."
115
142
raise FieldDoesNotExist (
116
143
f"Unsupported lookup '{ name } ' for "
117
- f"{ self .ref_field .__class__ .__name__ } '{ self .ref_field .name } '"
144
+ f"{ self .base_field .__class__ .__name__ } '{ self .base_field .name } '"
118
145
f"{ suggestion } "
119
146
)
120
147
121
148
def as_mql (self , compiler , connection ):
149
+ if isinstance (self ._lhs , Transform ):
150
+ inner_lhs_mql = self ._lhs .as_mql (compiler , connection )
151
+ else :
152
+ inner_lhs_mql = None
122
153
lhs_mql = process_lhs (self , compiler , connection )
123
- return f" { lhs_mql } . { self . key_name } "
154
+ return lhs_mql , inner_lhs_mql
124
155
125
156
@property
126
157
def output_field (self ):
127
- return EmbeddedModelArrayField (self .ref_field )
158
+ return EmbeddedModelArrayField (self .base_field )
128
159
129
160
130
161
class KeyTransformFactory :
131
- def __init__ (self , key_name , ref_field ):
162
+ def __init__ (self , key_name , base_field ):
132
163
self .key_name = key_name
133
- self .ref_field = ref_field
164
+ self .base_field = base_field
134
165
135
166
def __call__ (self , * args , ** kwargs ):
136
- return KeyTransform (self .key_name , self .ref_field , * args , ** kwargs )
167
+ return KeyTransform (self .key_name , self .base_field , * args , ** kwargs )
0 commit comments