diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 4f9515146..c4b23cb9e 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) def as_mql(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) - return { - "$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}] - } + return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]} @ArrayField.register_lookup @@ -338,7 +336,7 @@ class ArrayLenTransform(Transform): def as_mql(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) - return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}} + return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} @ArrayField.register_lookup diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 57bbd3f50..4a5ff09d9 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -186,8 +186,11 @@ def as_mql(self, compiler, connection): key_transforms.insert(0, previous.key_name) previous = previous.lhs mql = previous.as_mql(compiler, connection) - transforms = ".".join(key_transforms) - return f"{mql}.{transforms}" + # transform = ".".join(key_transforms) + for key in key_transforms: + mql = {"$getField": {"input": mql, "field": key}} + return mql + # return f"{mql}.{transform}" @property def output_field(self): diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index d894e598a..85f29c93b 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -1,8 +1,16 @@ +import difflib + +from django.core.exceptions import FieldDoesNotExist +from django.db import models from django.db.models import Field +from django.db.models.expressions import Col +from django.db.models.lookups import Lookup, Transform from .. import forms +from ..query_utils import process_lhs, process_rhs from . import EmbeddedModelField from .array import ArrayField +from .embedded_model import EMFExact, EMFMixin class EmbeddedModelArrayField(ArrayField): @@ -44,3 +52,190 @@ def formfield(self, **kwargs): **kwargs, }, ) + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + return KeyTransformFactory(name, self) + + +@EmbeddedModelArrayField.register_lookup +class EMFArrayExact(EMFExact): + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + if isinstance(self.lhs, KeyTransform): + lhs_mql, inner_lhs_mql = lhs_mql + else: + inner_lhs_mql = "$$item" + if isinstance(value, models.Model): + value, emf_data = self.model_to_dict(value) + # Get conditions for any nested EmbeddedModelFields. + conditions = self.get_conditions({inner_lhs_mql: (value, emf_data)}) + return { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": lhs_mql, + "as": "item", + "in": {"$and": conditions}, + } + }, + [], + ] + } + } + return { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": lhs_mql, + "as": "item", + "in": {"$eq": [inner_lhs_mql, value]}, + } + }, + [], + ] + } + } + + +@EmbeddedModelArrayField.register_lookup +class ArrayOverlap(EMFMixin, Lookup): + lookup_name = "overlap" + get_db_prep_lookup_value_is_iterable = True + + def process_rhs(self, compiler, connection): + values = self.rhs + if self.get_db_prep_lookup_value_is_iterable: + values = [values] + # Compute how to serialize each value based on the query target. + # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output + # field of the subfield. Otherwise, use the base field of the array itself. + if isinstance(self.lhs, KeyTransform): + get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value + else: + get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value + return None, [get_db_prep_value(v, connection, prepared=True) for v in values] + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + values = process_rhs(self, compiler, connection) + # Querying a subfield within the array elements (via nested KeyTransform). + # Replicates MongoDB's implicit ANY-match by mapping over the array and applying + # `$in` on the subfield. + if isinstance(self.lhs, KeyTransform): + lhs_mql, inner_lhs_mql = lhs_mql + return { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": lhs_mql, + "as": "item", + "in": {"$in": [inner_lhs_mql, values]}, + } + }, + [], + ] + } + } + conditions = [] + inner_lhs_mql = "$$item" + # Querying full embedded documents in the array. + # Builds `$or` conditions and maps them over the array to match any full document. + for value in values: + value, emf_data = self.model_to_dict(value) + # Get conditions for any nested EmbeddedModelFields. + conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})}) + return { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": lhs_mql, + "as": "item", + "in": {"$or": conditions}, + } + }, + [], + ] + } + } + + +class KeyTransform(Transform): + # it should be different class than EMF keytransform even most of the methods are equal. + def __init__(self, key_name, array_field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.array_field = array_field + self.key_name = key_name + # The iteration items begins from the base_field, a virtual column with + # base field output type is created. + column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone() + column_name = f"$item.{key_name}" + column_target.db_column = column_name + column_target.set_attributes_from_name(column_name) + self._lhs = Col(None, column_target) + self._sub_transform = None + + def __call__(self, this, *args, **kwargs): + self._lhs = self._sub_transform(self._lhs, *args, **kwargs) + return self + + def get_lookup(self, name): + return self.output_field.get_lookup(name) + + def _get_missing_field_or_lookup_exception(self, lhs, name): + suggested_lookups = difflib.get_close_matches(name, lhs.get_lookups()) + if suggested_lookups: + suggested_lookups = " or ".join(suggested_lookups) + suggestion = f", perhaps you meant {suggested_lookups}?" + else: + suggestion = "." + raise FieldDoesNotExist( + f"Unsupported lookup '{name}' for " + f"{self.array_field.base_field.__class__.__name__} '{self.array_field.base_field.name}'" + f"{suggestion}" + ) + + def get_transform(self, name): + """ + Validate that `name` is either a field of an embedded model or a + lookup on an embedded model's field. + """ + # Once the sub lhs is a transform, all the filter are applied over it. + transform = ( + self._lhs.get_transform(name) + if isinstance(self._lhs, Transform) + else self.array_field.base_field.embedded_model._meta.get_field( + self.key_name + ).get_transform(name) + ) + if transform: + self._sub_transform = transform + return self + raise self._get_missing_field_or_lookup_exception( + self._lhs if isinstance(self._lhs, Transform) else self.base_field, name + ) + + def as_mql(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection) + lhs_mql = process_lhs(self, compiler, connection) + return lhs_mql, inner_lhs_mql + + @property + def output_field(self): + return self.array_field + + +class KeyTransformFactory: + def __init__(self, key_name, base_field): + self.key_name = key_name + self.base_field = base_field + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, self.base_field, *args, **kwargs) diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 3cf074a23..ed2051792 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -165,3 +165,37 @@ class Movie(models.Model): def __str__(self): return self.title + + +class RestorationRecord(EmbeddedModel): + date = models.DateField() + description = models.TextField() + restored_by = models.CharField(max_length=255) + + +class ArtifactDetail(EmbeddedModel): + """Details about a specific artifact.""" + + name = models.CharField(max_length=255) + description = models.CharField(max_length=255) + metadata = models.JSONField() + restorations = EmbeddedModelArrayField(RestorationRecord, null=True) + last_restoration = EmbeddedModelField(RestorationRecord, null=True) + + +class ExhibitSection(EmbeddedModel): + """A section within an exhibit, containing multiple artifacts.""" + + section_number = models.IntegerField() + artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True) + + +class MuseumExhibit(models.Model): + """An exhibit in the museum, composed of multiple sections.""" + + exhibit_name = models.CharField(max_length=255) + sections = EmbeddedModelArrayField(ExhibitSection, null=True) + main_section = EmbeddedModelField(ExhibitSection, null=True) + + def __str__(self): + return self.exhibit_name diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index ec9f9dfc4..6e954d747 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -1,5 +1,5 @@ import operator -from datetime import timedelta +from datetime import date, timedelta from django.core.exceptions import FieldDoesNotExist, ValidationError from django.db import models @@ -18,13 +18,24 @@ from django_mongodb_backend.models import EmbeddedModel from .models import ( + A, Address, + ArtifactDetail, Author, + B, Book, + C, + D, Data, + E, + ExhibitSection, Holder, Library, + Movie, + MuseumExhibit, NestedData, + RestorationRecord, + Review, ) from .utils import truncate_ms @@ -96,6 +107,215 @@ def test_pre_save(self): self.assertGreater(obj.data.auto_now, auto_now_two) +class EmbeddedArrayTests(TestCase): + def test_save_load(self): + reviews = [ + Review(title="The best", rating=10), + Review(title="Mediocre", rating=5), + Review(title="Horrible", rating=1), + ] + Movie.objects.create(title="Lion King", reviews=reviews) + movie = Movie.objects.get(title="Lion King") + self.assertEqual(movie.reviews[0].title, "The best") + self.assertEqual(movie.reviews[0].rating, 10) + self.assertEqual(movie.reviews[1].title, "Mediocre") + self.assertEqual(movie.reviews[1].rating, 5) + self.assertEqual(movie.reviews[2].title, "Horrible") + self.assertEqual(movie.reviews[2].rating, 1) + self.assertEqual(len(movie.reviews), 3) + + def test_save_load_null(self): + movie = Movie.objects.create(title="Lion King") + movie = Movie.objects.get(title="Lion King") + self.assertIsNone(movie.reviews) + + +class EmbeddedArrayQueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + reviews = [ + Review(title="The best", rating=10), + Review(title="Mediocre", rating=5), + Review(title="Horrible", rating=1), + ] + cls.clouds = Movie.objects.create(title="Clouds", reviews=reviews) + reviews = [ + Review(title="Super", rating=9), + Review(title="Meh", rating=5), + Review(title="Horrible", rating=2), + ] + cls.frozen = Movie.objects.create(title="Frozen", reviews=reviews) + reviews = [ + Review(title="Excellent", rating=9), + Review(title="Wow", rating=8), + Review(title="Classic", rating=7), + ] + cls.bears = Movie.objects.create(title="Bears", reviews=reviews) + cls.egypt = MuseumExhibit.objects.create( + exhibit_name="Ancient Egypt", + sections=[ + ExhibitSection( + section_number=1, + artifacts=[ + ArtifactDetail( + name="Ptolemaic Crown", + description="Royal headpiece worn by Ptolemy kings.", + metadata={ + "material": "gold", + "origin": "Egypt", + "era": "Ptolemaic Period", + }, + ) + ], + ) + ], + ) + cls.wonders = MuseumExhibit.objects.create( + exhibit_name="Wonders of the Ancient World", + sections=[ + ExhibitSection( + section_number=1, + artifacts=[ + ArtifactDetail( + name="Statue of Zeus", + description="One of the Seven Wonders, created by Phidias.", + metadata={"location": "Olympia", "height_m": 12}, + ), + ArtifactDetail( + name="Hanging Gardens", + description="Legendary gardens of Babylon.", + metadata={"debated_existence": True}, + ), + ], + ), + ExhibitSection( + section_number=2, + artifacts=[ + ArtifactDetail( + name="Lighthouse of Alexandria", + description="Guided sailors safely to port.", + metadata={"height_m": 100, "built": "3rd century BC"}, + ) + ], + ), + ], + ) + cls.new_descoveries = MuseumExhibit.objects.create( + exhibit_name="New Discoveries", + sections=[ExhibitSection(section_number=1, artifacts=[])], + ) + cls.lost_empires = MuseumExhibit.objects.create( + exhibit_name="Lost Empires", + main_section=ExhibitSection( + section_number=3, + artifacts=[ + ArtifactDetail( + name="Bronze Statue", + description="Statue from the Hellenistic period.", + metadata={"origin": "Pergamon", "material": "bronze"}, + restorations=[ + RestorationRecord( + date=date(1998, 4, 15), + description="Removed oxidized layer.", + restored_by="Restoration Lab A", + ), + RestorationRecord( + date=date(2010, 7, 22), + description="Reinforced the base structure.", + restored_by="Dr. Liu Cheng", + ), + ], + last_restoration=RestorationRecord( + date=date(2010, 7, 22), + description="Reinforced the base structure.", + restored_by="Dr. Liu Cheng", + ), + ) + ], + ), + ) + + def test_filter_with_field(self): + self.assertCountEqual( + Movie.objects.filter(reviews__title="Horrible"), [self.clouds, self.frozen] + ) + + def test_filter_with_model(self): + self.assertCountEqual( + Movie.objects.filter(reviews=Review(title="Horrible", rating=2)), + [self.frozen], + ) + + def test_filter_with_embeddedfield_path(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__0__section_number=1), + [self.egypt, self.wonders, self.new_descoveries], + ) + + def test_filter_with_embeddedfield_array_path(self): + self.assertCountEqual( + MuseumExhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Restoration Lab A" + ), + [self.lost_empires], + ) + + def test_len(self): + self.assertCountEqual(MuseumExhibit.objects.filter(sections__len=10), []) + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__len=1), [self.egypt, self.new_descoveries] + ) + # Nested EMF + self.assertCountEqual( + MuseumExhibit.objects.filter(main_section__artifacts__len=1), [self.lost_empires] + ) + self.assertCountEqual(MuseumExhibit.objects.filter(main_section__artifacts__len=2), []) + self.assertCountEqual(MuseumExhibit.objects.filter(main_section__artifacts__len=2), []) + # Nested Indexed Array + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__0__artifacts__len=2), [self.wonders] + ) + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__0__artifacts__len=0), [self.new_descoveries] + ) + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders] + ) + + def test_overlap_simplefield(self): + self.assertSequenceEqual( + MuseumExhibit.objects.filter(sections__section_number__overlap=[10]), [] + ) + self.assertSequenceEqual( + MuseumExhibit.objects.filter(sections__section_number__overlap=[1]), + [self.egypt, self.wonders, self.new_descoveries], + ) + self.assertSequenceEqual( + MuseumExhibit.objects.filter(sections__section_number__overlap=[2]), [self.wonders] + ) + + def test_overlap_emf(self): + self.assertSequenceEqual( + Movie.objects.filter(reviews__overlap=[Review(title="The best", rating=10)]), + [self.clouds], + ) + + def test_overlap_values(self): + qs = Movie.objects.filter(title__in=["Clouds", "Frozen"]) + self.assertCountEqual( + Movie.objects.filter( + reviews__overlap=qs.values_list("reviews"), + ), + [self.clouds, self.frozen], + ) + self.assertCountEqual( + Movie.objects.filter( + reviews__overlap=qs.values("reviews"), + ), + [self.clouds, self.frozen], + ) + + class QueryingTests(TestCase): @classmethod def setUpTestData(cls):