Skip to content

Commit 75d1f62

Browse files
committed
Fix rhs serialization
1 parent caea979 commit 75d1f62

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,25 @@ def get_transform(self, name):
5858
return KeyTransformFactory(name, self)
5959

6060

61+
class EMFArrayRHSMixin:
62+
def process_rhs(self, compiler, connection):
63+
values = self.rhs
64+
if not self.get_db_prep_lookup_value_is_iterable:
65+
values = [values]
66+
# Compute how to serialize each value based on the query target.
67+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
68+
# field of the subfield. Otherwise, use the base field of the array itself.
69+
if isinstance(self.lhs, KeyTransform):
70+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
71+
else:
72+
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
73+
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
74+
75+
6176
@EmbeddedModelArrayField.register_lookup
62-
class EMFArrayExact(lookups.Exact):
77+
class EMFArrayExact(EMFArrayRHSMixin, lookups.Exact):
78+
get_db_prep_lookup_value_is_iterable = False
79+
6380
def as_mql(self, compiler, connection):
6481
if not isinstance(self.lhs, KeyTransform):
6582
raise ValueError("error")
@@ -82,22 +99,9 @@ def as_mql(self, compiler, connection):
8299

83100

84101
@EmbeddedModelArrayField.register_lookup
85-
class ArrayOverlap(Lookup):
102+
class ArrayOverlap(EMFArrayRHSMixin, Lookup):
86103
lookup_name = "overlap"
87-
get_db_prep_lookup_value_is_iterable = True
88-
89-
def process_rhs(self, compiler, connection):
90-
values = self.rhs
91-
if self.get_db_prep_lookup_value_is_iterable:
92-
values = [values]
93-
# Compute how to serialize each value based on the query target.
94-
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
95-
# field of the subfield. Otherwise, use the base field of the array itself.
96-
if isinstance(self.lhs, KeyTransform):
97-
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
98-
else:
99-
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
100-
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
104+
get_db_prep_lookup_value_is_iterable = False
101105

102106
def as_mql(self, compiler, connection):
103107
# Querying a subfield within the array elements (via nested KeyTransform).

0 commit comments

Comments
 (0)