diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 55749ede4..0317f72f7 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -53,7 +53,7 @@ def _get_column_from_expression(self, expr, alias): Create a column named `alias` from the given expression to hold the aggregate value. """ - column_target = expr.output_field.__class__() + column_target = expr.output_field.clone() column_target.db_column = alias column_target.set_attributes_from_name(alias) return Col(self.collection_name, column_target) @@ -81,7 +81,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group alias = ( f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target ) - column_target = sub_expr.output_field.__class__() + column_target = sub_expr.output_field.clone() column_target.db_column = alias column_target.set_attributes_from_name(alias) inner_column = Col(self.collection_name, column_target) @@ -743,7 +743,7 @@ def execute_sql(self, result_type): elif hasattr(value, "prepare_database_save"): if field.remote_field: value = value.prepare_database_save(field) - else: + elif not hasattr(field, "embedded_model"): raise TypeError( f"Tried to update field {field} with a model " f"instance, {value!r}. Use a value compatible with " diff --git a/django_mongodb/fields/__init__.py b/django_mongodb/fields/__init__.py index 9eb2518d6..3133c20eb 100644 --- a/django_mongodb/fields/__init__.py +++ b/django_mongodb/fields/__init__.py @@ -1,11 +1,13 @@ from .auto import ObjectIdAutoField from .duration import register_duration_field +from .embedded_model import EmbeddedModelField, register_embedded_model_field from .json import register_json_field from .objectid import ObjectIdField -__all__ = ["register_fields", "ObjectIdAutoField", "ObjectIdField"] +__all__ = ["register_fields", "EmbeddedModelField", "ObjectIdAutoField", "ObjectIdField"] def register_fields(): register_duration_field() + register_embedded_model_field() register_json_field() diff --git a/django_mongodb/fields/embedded_model.py b/django_mongodb/fields/embedded_model.py new file mode 100644 index 000000000..23d0443c8 --- /dev/null +++ b/django_mongodb/fields/embedded_model.py @@ -0,0 +1,220 @@ +from django.core.exceptions import FieldDoesNotExist +from django.db import models +from django.db.models.fields.related import lazy_related_operation +from django.db.models.lookups import Transform + +from .. import forms + + +class EmbeddedModelField(models.Field): + """Field that stores a model instance.""" + + def __init__(self, embedded_model, *args, **kwargs): + """ + `embedded_model` is the model class of the instance that will be + stored. Like other relational fields, it may also be passed as a + string. + """ + if not isinstance(embedded_model, str): + self._validate_embedded_field(self, embedded_model) + + self.embedded_model = embedded_model + super().__init__(*args, **kwargs) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if path.startswith("django_mongodb.fields.embedded_model"): + path = path.replace("django_mongodb.fields.embedded_model", "django_mongodb.fields") + kwargs["embedded_model"] = self.embedded_model + return name, path, args, kwargs + + def get_internal_type(self): + return "EmbeddedModelField" + + @staticmethod + def _validate_embedded_field(_, model): + for field in model._meta.local_fields: + if isinstance(field, models.ForeignKey | models.OneToOneField): + raise TypeError( + f"Field of type {type(field)!r} is not supported within an EmbeddedModelField." + ) + + def _set_model(self, model): + """ + Resolve embedded model class once the field knows the model it belongs + to. + + If the model argument passed to __init__() was a string, resolve that + string to the corresponding model class, similar to relation fields. + However, we need to know our own model to generate a valid key + for the embedded model class lookup and EmbeddedModelFields are + not contributed_to_class if used in iterable fields. Thus the + collection field sets this field's "model" attribute in its + contribute_to_class(). + """ + self._model = model + if model is not None and isinstance(self.embedded_model, str): + + def _resolve_lookup(_, resolved_model): + self.embedded_model = resolved_model + + lazy_related_operation(_resolve_lookup, model, self.embedded_model) + lazy_related_operation(self._validate_embedded_field, model, self.embedded_model) + + model = property(lambda self: self._model, _set_model) + + def from_db_value(self, value, expression, connection): + return self.to_python(value) + + def to_python(self, value): + """ + Passes embedded model fields' values through embedded fields + to_python() and reinstiatates the embedded instance. + """ + if value is None: + return None + if not isinstance(value, dict): + return value + # Create the model instance. + instance = self.embedded_model( + **{ + # Pass values through respective fields' to_python(), leaving + # fields for which no value is specified uninitialized. + field.attname: field.to_python(value[field.attname]) + for field in self.embedded_model._meta.fields + if field.attname in value + } + ) + instance._state.adding = False + return instance + + def get_db_prep_save(self, embedded_instance, connection): + """ + Apply pre_save() and get_db_prep_save() of embedded instance + fields and passes a field => value mapping down to database + type conversions. + + The embedded instance will be saved as a column => value dict, but + because we need to apply database type conversions on embedded instance + fields' values and for these we need to know fields those values come + from, we need to entrust the database layer with creating the dict. + """ + if embedded_instance is None: + return None + if not isinstance(embedded_instance, self.embedded_model): + raise TypeError( + f"Expected instance of type {self.embedded_model!r}, not " + f"{type(embedded_instance)!r}." + ) + # Apply pre_save() and get_db_prep_save() of embedded instance + # fields, create the field => value mapping to be passed to + # storage preprocessing. + field_values = {} + add = embedded_instance._state.adding + for field in embedded_instance._meta.fields: + value = field.get_db_prep_save( + field.pre_save(embedded_instance, add), connection=connection + ) + # Exclude unset primary keys (e.g. {'id': None}). + if field.primary_key and value is None: + continue + field_values[field.attname] = value + # This instance will exist in the database soon. + # TODO.XXX: Ensure that this doesn't cause race conditions. + embedded_instance._state.adding = False + return field_values + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + field = self.embedded_model._meta.get_field(name) + return KeyTransformFactory(name, field) + + def validate(self, value, model_instance): + super().validate(value, model_instance) + if self.embedded_model is None: + return + for field in self.embedded_model._meta.fields: + attname = field.attname + field.validate(getattr(value, attname), model_instance) + + def formfield(self, **kwargs): + return super().formfield( + **{ + "form_class": forms.EmbeddedModelFormField, + "model": self.embedded_model, + "name": self.name, + **kwargs, + } + ) + + +class KeyTransform(Transform): + def __init__(self, key_name, ref_field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.key_name = str(key_name) + self.ref_field = ref_field + + def get_transform(self, name): + result = None + if isinstance(self.ref_field, EmbeddedModelField): + opts = self.ref_field.embedded_model._meta + new_field = opts.get_field(name) + result = KeyTransformFactory(name, new_field) + else: + if self.ref_field.get_transform(name) is None: + raise FieldDoesNotExist( + f"{self.ref_field.model._meta.object_name}.{self.ref_field.name}" + f" has no field named '{name}'" + ) + result = KeyTransformFactory(name, self.ref_field) + return result + + def preprocess_lhs(self, compiler, connection): + previous = self + embedded_key_transforms = [] + json_key_transforms = [] + while isinstance(previous, KeyTransform): + if isinstance(previous.ref_field, EmbeddedModelField): + embedded_key_transforms.insert(0, previous.key_name) + else: + json_key_transforms.insert(0, previous.key_name) + previous = previous.lhs + mql = previous.as_mql(compiler, connection) + embedded_key_transforms.append(json_key_transforms.pop(0)) + return mql, embedded_key_transforms, json_key_transforms + + +def key_transform(self, compiler, connection): + mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection) + transforms = ".".join(key_transforms) + result = f"{mql}.{transforms}" + for key in json_key_transforms: + get_field = {"$getField": {"input": result, "field": key}} + # Handle array indexing if the key is a digit. If key is something + # like '001', it's not an array index despite isdigit() returning True. + if key.isdigit() and str(int(key)) == key: + result = { + "$cond": { + "if": {"$isArray": result}, + "then": {"$arrayElemAt": [result, int(key)]}, + "else": get_field, + } + } + else: + result = get_field + return result + + +class KeyTransformFactory: + def __init__(self, key_name, ref_field): + self.key_name = key_name + self.ref_field = ref_field + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, self.ref_field, *args, **kwargs) + + +def register_embedded_model_field(): + KeyTransform.as_mql = key_transform diff --git a/django_mongodb/forms.py b/django_mongodb/forms.py new file mode 100644 index 000000000..0f78f8683 --- /dev/null +++ b/django_mongodb/forms.py @@ -0,0 +1,61 @@ +from django import forms +from django.forms.models import modelform_factory +from django.utils.safestring import mark_safe +from django.utils.translation import gettext_lazy as _ + + +class EmbeddedModelWidget(forms.MultiWidget): + def __init__(self, field_names, *args, **kwargs): + self.field_names = field_names + super().__init__(*args, **kwargs) + # The default widget names are "_0", "_1", etc. Use the field names + # instead since that's how they'll be rendered by the model form. + self.widgets_names = ["-" + name for name in field_names] + + def decompress(self, value): + if value is None: + return [] + # Get the data from `value` (a model) for each field. + return [getattr(value, name) for name in self.field_names] + + +class EmbeddedModelBoundField(forms.BoundField): + def __str__(self): + """Render the model form as the representation for this field.""" + form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs) + return mark_safe(f"{form.as_div()}") # noqa: S308 + + +class EmbeddedModelFormField(forms.MultiValueField): + default_error_messages = { + "invalid": _("Enter a list of values."), + "incomplete": _("Enter all required values."), + } + + def __init__(self, model, name, *args, **kwargs): + form_kwargs = {} + # The field must be prefixed with the name of the field. + form_kwargs["prefix"] = name + self.form_kwargs = form_kwargs + self.model_form_cls = modelform_factory(model, fields="__all__") + self.model_form = self.model_form_cls(**form_kwargs) + self.field_names = list(self.model_form.fields.keys()) + fields = self.model_form.fields.values() + widgets = [field.widget for field in fields] + widget = EmbeddedModelWidget(self.field_names, widgets) + super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs) + + def compress(self, data_dict): + if not data_dict: + return None + values = dict(zip(self.field_names, data_dict, strict=False)) + return self.model_form._meta.model(**values) + + def get_bound_field(self, form, field_name): + return EmbeddedModelBoundField(form, self, field_name) + + def bound_data(self, data, initial): + if self.disabled: + return initial + # The bound data must be transformed into a model instance. + return self.compress(data) diff --git a/django_mongodb/schema.py b/django_mongodb/schema.py index 8fc18feae..e8338aa89 100644 --- a/django_mongodb/schema.py +++ b/django_mongodb/schema.py @@ -5,6 +5,7 @@ from pymongo import ASCENDING, DESCENDING from pymongo.operations import IndexModel +from .fields import EmbeddedModelField from .query import wrap_database_errors from .utils import OperationCollector @@ -29,31 +30,50 @@ def create_model(self, model): if field.remote_field.through._meta.auto_created: self.create_model(field.remote_field.through) - def _create_model_indexes(self, model): + def _create_model_indexes(self, model, column_prefix="", parent_model=None): """ Create all indexes (field indexes & uniques, Meta.index_together, Meta.unique_together, Meta.constraints, Meta.indexes) for the model. + + If this is a recursive call to due to an embedded model, `column_prefix` + tracks the path that must be prepended to the index's column, and + `parent_model` tracks the collection to add the index/constraint to. """ if not model._meta.managed or model._meta.proxy or model._meta.swapped: return # Field indexes and uniques for field in model._meta.local_fields: + if isinstance(field, EmbeddedModelField): + new_path = f"{column_prefix}{field.column}." + self._create_model_indexes( + field.embedded_model, parent_model=parent_model or model, column_prefix=new_path + ) if self._field_should_be_indexed(model, field): - self._add_field_index(model, field) + self._add_field_index(parent_model or model, field, column_prefix=column_prefix) elif self._field_should_have_unique(field): - self._add_field_unique(model, field) + self._add_field_unique(parent_model or model, field, column_prefix=column_prefix) # Meta.index_together (RemovedInDjango51Warning) for field_names in model._meta.index_together: - self._add_composed_index(model, field_names) + self._add_composed_index( + model, field_names, column_prefix=column_prefix, parent_model=parent_model + ) # Meta.unique_together if model._meta.unique_together: - self.alter_unique_together(model, [], model._meta.unique_together) + self.alter_unique_together( + model, + [], + model._meta.unique_together, + column_prefix=column_prefix, + parent_model=parent_model, + ) # Meta.constraints for constraint in model._meta.constraints: - self.add_constraint(model, constraint) + self.add_constraint( + model, constraint, column_prefix=column_prefix, parent_model=parent_model + ) # Meta.indexes for index in model._meta.indexes: - self.add_index(model, index) + self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model) def delete_model(self, model): # Delete implicit M2m tables. @@ -72,6 +92,11 @@ def add_field(self, model, field): self.get_collection(model._meta.db_table).update_many( {}, [{"$set": {column: self.effective_default(field)}}] ) + if isinstance(field, EmbeddedModelField): + new_path = f"{field.column}." + self._create_model_indexes( + field.embedded_model, parent_model=model, column_prefix=new_path + ) # Add an index or unique, if required. if self._field_should_be_indexed(model, field): self._add_field_index(model, field) @@ -136,18 +161,70 @@ def remove_field(self, model, field): self._remove_field_index(model, field) elif self._field_should_have_unique(field): self._remove_field_unique(model, field) + if isinstance(field, EmbeddedModelField): + new_path = f"{field.column}." + self._remove_model_indexes( + field.embedded_model, parent_model=model, column_prefix=new_path + ) - def alter_index_together(self, model, old_index_together, new_index_together): + def _remove_model_indexes(self, model, column_prefix="", parent_model=None): + """ + When removing an EmbeddedModelField, the indexes need to be removed + recursively. + """ + if not model._meta.managed or model._meta.proxy or model._meta.swapped: + return + # Field indexes and uniques + for field in model._meta.local_fields: + if isinstance(field, EmbeddedModelField): + new_path = f"{column_prefix}{field.column}." + self._remove_model_indexes( + field.embedded_model, parent_model=parent_model or model, column_prefix=new_path + ) + if self._field_should_be_indexed(model, field): + self._remove_field_index(parent_model or model, field, column_prefix=column_prefix) + elif self._field_should_have_unique(field): + self._remove_field_unique(parent_model or model, field, column_prefix=column_prefix) + # Meta.index_together (RemovedInDjango51Warning) + for field_names in model._meta.index_together: + self._remove_composed_index( + model, + field_names, + {"index": True, "unique": False}, + column_prefix=column_prefix, + parent_model=parent_model, + ) + # Meta.unique_together + if model._meta.unique_together: + self.alter_unique_together( + model, + model._meta.unique_together, + [], + column_prefix=column_prefix, + parent_model=parent_model, + ) + # Meta.constraints + for constraint in model._meta.constraints: + self.remove_constraint(parent_model or model, constraint) + # Meta.indexes + for index in model._meta.indexes: + self.remove_index(parent_model or model, index) + + def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""): olds = {tuple(fields) for fields in old_index_together} news = {tuple(fields) for fields in new_index_together} # Deleted indexes for field_names in olds.difference(news): - self._remove_composed_index(model, field_names, {"index": True, "unique": False}) + self._remove_composed_index( + model, field_names, {"index": True, "unique": False}, column_prefix="" + ) # Created indexes for field_names in news.difference(olds): - self._add_composed_index(model, field_names) + self._add_composed_index(model, field_names, column_prefix=column_prefix) - def alter_unique_together(self, model, old_unique_together, new_unique_together): + def alter_unique_together( + self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None + ): olds = {tuple(fields) for fields in old_unique_together} news = {tuple(fields) for fields in new_unique_together} # Deleted uniques @@ -156,15 +233,25 @@ def alter_unique_together(self, model, old_unique_together, new_unique_together) model, field_names, {"unique": True, "primary_key": False}, + column_prefix=column_prefix, + parent_model=parent_model, ) # Created uniques for field_names in news.difference(olds): columns = [model._meta.get_field(field).column for field in field_names] - name = str(self._unique_constraint_name(model._meta.db_table, columns)) + name = str( + self._unique_constraint_name( + model._meta.db_table, [column_prefix + col for col in columns] + ) + ) constraint = UniqueConstraint(fields=field_names, name=name) - self.add_constraint(model, constraint) + self.add_constraint( + model, constraint, parent_model=parent_model, column_prefix=column_prefix + ) - def add_index(self, model, index, field=None, unique=False): + def add_index( + self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None + ): if index.contains_expressions: return kwargs = {} @@ -176,7 +263,8 @@ def add_index(self, model, index, field=None, unique=False): # Indexing on $type matches the value of most SQL databases by # allowing multiple null values for the unique constraint. if field: - filter_expression[field.column].update({"$type": field.db_type(self.connection)}) + column = column_prefix + field.column + filter_expression[column].update({"$type": field.db_type(self.connection)}) else: for field_name, _ in index.fields_orders: field_ = model._meta.get_field(field_name) @@ -186,45 +274,51 @@ def add_index(self, model, index, field=None, unique=False): if filter_expression: kwargs["partialFilterExpression"] = filter_expression index_orders = ( - [(field.column, ASCENDING)] + [(column_prefix + field.column, ASCENDING)] if field else [ # order is "" if ASCENDING or "DESC" if DESCENDING (see # django.db.models.indexes.Index.fields_orders). - (model._meta.get_field(field_name).column, ASCENDING if order == "" else DESCENDING) + ( + column_prefix + model._meta.get_field(field_name).column, + ASCENDING if order == "" else DESCENDING, + ) for field_name, order in index.fields_orders ] ) idx = IndexModel(index_orders, name=index.name, **kwargs) + model = parent_model or model self.get_collection(model._meta.db_table).create_indexes([idx]) - def _add_composed_index(self, model, field_names): + def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None): """Add an index on the given list of field_names.""" idx = Index(fields=field_names) idx.set_name_with_model(model) - self.add_index(model, idx) + self.add_index(model, idx, column_prefix=column_prefix, parent_model=parent_model) - def _add_field_index(self, model, field): + def _add_field_index(self, model, field, *, column_prefix=""): """Add an index on a field with db_index=True.""" - index = Index(fields=[field.name]) - index.name = self._create_index_name(model._meta.db_table, [field.column]) - self.add_index(model, index, field=field) + index = Index(fields=[column_prefix + field.name]) + index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column]) + self.add_index(model, index, field=field, column_prefix=column_prefix) def remove_index(self, model, index): if index.contains_expressions: return self.get_collection(model._meta.db_table).drop_index(index.name) - def _remove_composed_index(self, model, field_names, constraint_kwargs): + def _remove_composed_index( + self, model, field_names, constraint_kwargs, column_prefix="", parent_model=None + ): """ Remove the index on the given list of field_names created by index/unique_together, depending on constraint_kwargs. """ meta_constraint_names = {constraint.name for constraint in model._meta.constraints} meta_index_names = {constraint.name for constraint in model._meta.indexes} - columns = [model._meta.get_field(field).column for field in field_names] + columns = [column_prefix + model._meta.get_field(field).column for field in field_names] constraint_names = self._constraint_names( - model, + parent_model or model, columns, exclude=meta_constraint_names | meta_index_names, **constraint_kwargs, @@ -236,16 +330,17 @@ def _remove_composed_index(self, model, field_names, constraint_kwargs): f"Found wrong number ({num_found}) of constraints for " f"{model._meta.db_table}({columns_str})." ) + model = parent_model or model collection = self.get_collection(model._meta.db_table) collection.drop_index(constraint_names[0]) - def _remove_field_index(self, model, field): + def _remove_field_index(self, model, field, column_prefix=""): """Remove a field's db_index=True index.""" collection = self.get_collection(model._meta.db_table) meta_index_names = {index.name for index in model._meta.indexes} index_names = self._constraint_names( model, - [field.column], + [column_prefix + field.column], index=True, # Retrieve only BTREE indexes since this is what's created with # db_index=True. @@ -260,7 +355,7 @@ def _remove_field_index(self, model, field): ) collection.drop_index(index_names[0]) - def add_constraint(self, model, constraint, field=None): + def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None): if isinstance(constraint, UniqueConstraint) and self._unique_supported( condition=constraint.condition, deferrable=constraint.deferrable, @@ -273,12 +368,21 @@ def add_constraint(self, model, constraint, field=None): name=constraint.name, condition=constraint.condition, ) - self.add_index(model, idx, field=field, unique=True) + self.add_index( + model, + idx, + field=field, + unique=True, + column_prefix=column_prefix, + parent_model=parent_model, + ) - def _add_field_unique(self, model, field): - name = str(self._unique_constraint_name(model._meta.db_table, [field.column])) + def _add_field_unique(self, model, field, column_prefix=""): + name = str( + self._unique_constraint_name(model._meta.db_table, [column_prefix + field.column]) + ) constraint = UniqueConstraint(fields=[field.name], name=name) - self.add_constraint(model, constraint, field=field) + self.add_constraint(model, constraint, field=field, column_prefix=column_prefix) def remove_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and self._unique_supported( @@ -295,12 +399,12 @@ def remove_constraint(self, model, constraint): ) self.remove_index(model, idx) - def _remove_field_unique(self, model, field): + def _remove_field_unique(self, model, field, column_prefix=""): # Find the unique constraint for this field meta_constraint_names = {constraint.name for constraint in model._meta.constraints} constraint_names = self._constraint_names( model, - [field.column], + [column_prefix + field.column], unique=True, primary_key=False, exclude=meta_constraint_names, diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 10f212587..41451b14c 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -1,8 +1,9 @@ from django.db import models -from django_mongodb.fields import ObjectIdField +from django_mongodb.fields import EmbeddedModelField, ObjectIdField +# ObjectIdField class ObjectIdModel(models.Model): field = ObjectIdField() @@ -13,3 +14,57 @@ class NullableObjectIdModel(models.Model): class PrimaryKeyObjectIdModel(models.Model): field = ObjectIdField(primary_key=True) + + +# EmbeddedModelField +class Target(models.Model): + index = models.IntegerField() + + +class DecimalModel(models.Model): + decimal = models.DecimalField(max_digits=9, decimal_places=2) + + +class DecimalKey(models.Model): + decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True) + + +class DecimalParent(models.Model): + child = models.ForeignKey(DecimalKey, models.CASCADE) + + +class EmbeddedModelFieldModel(models.Model): + simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True) + decimal_parent = EmbeddedModelField(DecimalKey, null=True, blank=True) + + +class EmbeddedModel(models.Model): + json_value = models.JSONField() + decimal = EmbeddedModelField(DecimalModel, null=True, blank=True) + someint = models.IntegerField(db_column="custom_column") + auto_now = models.DateTimeField(auto_now=True) + auto_now_add = models.DateTimeField(auto_now_add=True) + + +class Address(models.Model): + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField(db_index=True) + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField() + address = EmbeddedModelField(Address) + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) + + +class Library(models.Model): + name = models.CharField(max_length=100) + books = models.ManyToManyField("Book", related_name="libraries") + location = models.CharField(max_length=100, null=True, blank=True) + best_seller = models.CharField(max_length=100, null=True, blank=True) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py new file mode 100644 index 000000000..c6a936a72 --- /dev/null +++ b/tests/model_fields_/test_embedded_model.py @@ -0,0 +1,303 @@ +import operator +from decimal import Decimal + +from django.core.exceptions import FieldDoesNotExist, ValidationError +from django.db.models import ( + Exists, + ExpressionWrapper, + F, + IntegerField, + Max, + Model, + OuterRef, + Subquery, + Sum, +) +from django.test import SimpleTestCase, TestCase + +from django_mongodb.fields import EmbeddedModelField + +from .models import ( + Address, + Author, + Book, + DecimalKey, + DecimalParent, + EmbeddedModel, + EmbeddedModelFieldModel, + Library, +) + + +class MethodTests(SimpleTestCase): + def test_deconstruct(self): + field = EmbeddedModelField("EmbeddedModel", null=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"embedded_model": "EmbeddedModel", "null": True}) + + def test_get_db_prep_save_invalid(self): + msg = ( + "Expected instance of type , " + "not ." + ) + with self.assertRaisesMessage(TypeError, msg): + EmbeddedModelFieldModel(simple=42).save() + + def test_validate(self): + obj = EmbeddedModelFieldModel(simple=EmbeddedModel(someint=None)) + # This isn't quite right because "someint" is the field that's non-null. + msg = "{'simple': ['This field cannot be null.']}" + with self.assertRaisesMessage(ValidationError, msg): + obj.full_clean() + + +class ModelTests(TestCase): + def truncate_ms(self, value): + """Truncate microsends to millisecond precision as supported by MongoDB.""" + return value.replace(microsecond=(value.microsecond // 1000) * 1000) + + def test_save_load(self): + EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5")) + obj = EmbeddedModelFieldModel.objects.get() + self.assertIsInstance(obj.simple, EmbeddedModel) + # Make sure get_prep_value is called. + self.assertEqual(obj.simple.someint, 5) + # Primary keys should not be populated... + self.assertEqual(obj.simple.id, None) + # ... unless set explicitly. + obj.simple.id = obj.id + obj.save() + obj = EmbeddedModelFieldModel.objects.get() + self.assertEqual(obj.simple.id, obj.id) + + def test_save_load_null(self): + EmbeddedModelFieldModel.objects.create(simple=None) + obj = EmbeddedModelFieldModel.objects.get() + self.assertIsNone(obj.simple) + + def test_pre_save(self): + """Field.pre_save() is called on embedded model fields.""" + obj = EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel()) + auto_now = self.truncate_ms(obj.simple.auto_now) + auto_now_add = self.truncate_ms(obj.simple.auto_now_add) + self.assertEqual(auto_now, auto_now_add) + # save() updates auto_now but not auto_now_add. + obj.save() + self.assertEqual(self.truncate_ms(obj.simple.auto_now_add), auto_now_add) + auto_now_two = obj.simple.auto_now + self.assertGreater(auto_now_two, obj.simple.auto_now_add) + # And again, save() updates auto_now but not auto_now_add. + obj = EmbeddedModelFieldModel.objects.get() + obj.save() + self.assertEqual(obj.simple.auto_now_add, auto_now_add) + self.assertGreater(obj.simple.auto_now, auto_now_two) + + def test_foreign_key_in_embedded_object(self): + msg = ( + "Field of type " + "is not supported within an EmbeddedModelField." + ) + with self.assertRaisesMessage(TypeError, msg): + + class EmbeddedModelTest(Model): + decimal = EmbeddedModelField(DecimalParent, null=True, blank=True) + + def test_embedded_field_with_foreign_conversion(self): + decimal = DecimalKey.objects.create(decimal=Decimal("1.5")) + EmbeddedModelFieldModel.objects.create(decimal_parent=decimal) + + +class QueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.objs = [ + EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint=x)) + for x in range(6) + ] + + def test_exact(self): + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__someint=3), [self.objs[3]] + ) + + def test_lt(self): + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__someint__lt=3), self.objs[:3] + ) + + def test_lte(self): + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__someint__lte=3), self.objs[:4] + ) + + def test_gt(self): + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__someint__gt=3), self.objs[4:] + ) + + def test_gte(self): + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__someint__gte=3), self.objs[3:] + ) + + def test_nested(self): + obj = Book.objects.create( + author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) + ) + self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) + + def test_nested_not_exists(self): + msg = "Address.city has no field named 'president'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Book.objects.filter(author__address__city__president="NYC") + + def test_not_exists_in_embedded(self): + msg = "Address has no field named 'floor'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Book.objects.filter(author__address__floor="NYC") + + def test_embedded_with_json_field(self): + models = [] + for i in range(4): + m = EmbeddedModelFieldModel.objects.create( + simple=EmbeddedModel( + json_value={"field1": i * 5, "field2": {"0": {"value": list(range(i))}}} + ) + ) + models.append(m) + + all_models = EmbeddedModelFieldModel.objects.all() + + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__json_value__field2__0__value__0=0), + models[1:], + ) + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__json_value__field2__0__value__1=1), + models[2:], + ) + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__json_value__field2__0__value__1=5), [] + ) + + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__json_value__field1__lt=100), all_models + ) + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter(simple__json_value__field1__gt=100), [] + ) + self.assertCountEqual( + EmbeddedModelFieldModel.objects.filter( + simple__json_value__field1__gte=5, simple__json_value__field1__lte=10 + ), + models[1:3], + ) + + def truncate_ms(self, value): + """Truncate microsends to millisecond precision as supported by MongoDB.""" + return value.replace(microsecond=(value.microsecond // 1000) * 1000) + + ################ + def test_ordering_by_embedded_field(self): + query = ( + EmbeddedModelFieldModel.objects.filter(simple__someint__gt=3) + .order_by("-simple__someint") + .values("pk") + ) + expected = [{"pk": e.pk} for e in list(reversed(self.objs[4:]))] + self.assertSequenceEqual(query, expected) + + def test_ordering_grouping_by_embedded_field(self): + expected = sorted( + ( + EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint=x)) + for x in range(6) + ), + key=lambda x: x.simple.someint, + ) + query = ( + EmbeddedModelFieldModel.objects.annotate( + group=ExpressionWrapper(F("simple__someint") + 5, output_field=IntegerField()) + ) + .values("group") + .annotate(max_auto_now=Max("simple__auto_now")) + .order_by("simple__someint") + ) + query_response = [{**e, "max_auto_now": self.truncate_ms(e["max_auto_now"])} for e in query] + self.assertSequenceEqual( + query_response, + [ + {"group": e.simple.someint + 5, "max_auto_now": self.truncate_ms(e.simple.auto_now)} + for e in expected + ], + ) + + def test_ordering_grouping_by_sum(self): + [EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint=x)) for x in range(6)] + qs = ( + EmbeddedModelFieldModel.objects.values("simple__someint") + .annotate(sum=Sum("simple__someint")) + .order_by("sum") + ) + self.assertQuerySetEqual(qs, [0, 2, 4, 6, 8, 10], operator.itemgetter("sum")) + + +class SubqueryExistsTest(TestCase): + def setUp(self): + # Create test data + address1 = Address.objects.create(city="New York", state="NY", zip_code=10001) + address2 = Address.objects.create(city="Boston", state="MA", zip_code=20002) + author1 = Author.objects.create(name="Alice", age=30, address=address1) + author2 = Author.objects.create(name="Bob", age=40, address=address2) + book1 = Book.objects.create(name="Book A", author=author1) + book2 = Book.objects.create(name="Book B", author=author2) + Book.objects.create(name="Book C", author=author2) + Book.objects.create(name="Book D", author=author2) + Book.objects.create(name="Book E", author=author1) + + library1 = Library.objects.create( + name="Central Library", location="Downtown", best_seller="Book A" + ) + library2 = Library.objects.create( + name="Community Library", location="Suburbs", best_seller="Book A" + ) + + # Add books to libraries + library1.books.add(book1, book2) + library2.books.add(book2) + + def test_exists_subquery(self): + subquery = Book.objects.filter( + author__name=OuterRef("name"), author__address__city="Boston" + ) + queryset = Author.objects.filter(Exists(subquery)) + + self.assertEqual(queryset.count(), 1) + + def test_in_subquery(self): + subquery = Author.objects.filter(age__gt=35).values("name") + queryset = Book.objects.filter(author__name__in=Subquery(subquery)).order_by("name") + + self.assertEqual(queryset.count(), 3) + self.assertQuerySetEqual(queryset, ["Book B", "Book C", "Book D"], lambda book: book.name) + + def test_range_query(self): + queryset = Author.objects.filter(age__range=(25, 45)).order_by("name") + + self.assertEqual(queryset.count(), 2) + self.assertQuerySetEqual(queryset, ["Alice", "Bob"], lambda author: author.name) + + def test_exists_with_foreign_object(self): + subquery = Library.objects.filter(best_seller=OuterRef("name")) + queryset = Book.objects.filter(Exists(subquery)) + + self.assertEqual(queryset.count(), 1) + self.assertEqual(queryset.first().name, "Book A") + + def test_foreign_field_with_ranges(self): + queryset = Library.objects.filter(books__author__age__range=(25, 35)) + + self.assertEqual(queryset.count(), 1) + self.assertEqual(queryset.first().name, "Central Library") diff --git a/tests/model_forms_/__init__.py b/tests/model_forms_/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/model_forms_/forms.py b/tests/model_forms_/forms.py new file mode 100644 index 000000000..7bfed3fbb --- /dev/null +++ b/tests/model_forms_/forms.py @@ -0,0 +1,9 @@ +from django import forms + +from .models import Author + + +class AuthorForm(forms.ModelForm): + class Meta: + fields = "__all__" + model = Author diff --git a/tests/model_forms_/models.py b/tests/model_forms_/models.py new file mode 100644 index 000000000..ef1697562 --- /dev/null +++ b/tests/model_forms_/models.py @@ -0,0 +1,22 @@ +from django.db import models + +from django_mongodb.fields import EmbeddedModelField + + +class Address(models.Model): + po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box") + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField() + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField() + address = EmbeddedModelField(Address) + billing_address = EmbeddedModelField(Address, blank=True, null=True) + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) diff --git a/tests/model_forms_/test_embedded_model.py b/tests/model_forms_/test_embedded_model.py new file mode 100644 index 000000000..240f8c6d8 --- /dev/null +++ b/tests/model_forms_/test_embedded_model.py @@ -0,0 +1,130 @@ +from django.test import TestCase + +from .forms import AuthorForm +from .models import Address, Author + + +class ModelFormTests(TestCase): + def test_update(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "10001", + } + form = AuthorForm(data, instance=author) + self.assertTrue(form.is_valid()) + form.save() + author.refresh_from_db() + self.assertEqual(author.age, 51) + self.assertEqual(author.address.city, "New York City") + + def test_some_missing_data(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["address"], ["Enter all required values."]) + + def test_invalid_field_data(self): + """A field's data (state) is too long.""" + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "TOO LONG", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors["address"], + [ + "Ensure this value has at most 2 characters (it has 8).", + "Enter all required values.", + ], + ) + + def test_all_missing_data(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "", + "address-state": "", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["address"], ["This field is required."]) + + def test_nullable_field(self): + """A nullable EmbeddedModelField is removed if all fields are empty.""" + author = Author.objects.create( + name="Bob", + age=50, + address=Address(city="NYC", state="NY", zip_code="10001"), + billing_address=Address(city="NYC", state="NY", zip_code="10001"), + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "10001", + "billing_address-po_box": "", + "billing_address-city": "", + "billing_address-state": "", + "billing_address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertTrue(form.is_valid()) + form.save() + author.refresh_from_db() + self.assertIsNone(author.billing_address) + + def test_rendering(self): + form = AuthorForm() + self.assertHTMLEqual( + str(form.fields["address"].get_bound_field(form, "address")), + """ +
+ + +
+
+ + +
+
+ + +
+
+ + +
""", + ) diff --git a/tests/schema_/__init__.py b/tests/schema_/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/schema_/models.py b/tests/schema_/models.py new file mode 100644 index 000000000..73fbfcb76 --- /dev/null +++ b/tests/schema_/models.py @@ -0,0 +1,38 @@ +from django.apps.registry import Apps +from django.db import models + +from django_mongodb.fields import EmbeddedModelField + +# Because we want to test creation and deletion of these as separate things, +# these models are all inserted into a separate Apps so the main test +# runner doesn't migrate them. + +new_apps = Apps() + + +class Address(models.Model): + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField(db_index=True) + uid = models.IntegerField(unique=True) + + class Meta: + apps = new_apps + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField(db_index=True) + address = EmbeddedModelField(Address) + employee_id = models.IntegerField(unique=True) + + class Meta: + apps = new_apps + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) + + class Meta: + apps = new_apps diff --git a/tests/schema_/test_embedded_model.py b/tests/schema_/test_embedded_model.py new file mode 100644 index 000000000..591d369af --- /dev/null +++ b/tests/schema_/test_embedded_model.py @@ -0,0 +1,704 @@ +import itertools + +from django.db import connection, models +from django.test import TransactionTestCase, ignore_warnings +from django.test.utils import isolate_apps +from django.utils.deprecation import RemovedInDjango51Warning + +from django_mongodb.fields import EmbeddedModelField + +from .models import Address, Author, Book, new_apps + + +class SchemaTests(TransactionTestCase): + available_apps = [] + models = [Address, Author, Book] + + # Utility functions + + def setUp(self): + # local_models should contain test dependent model classes that will be + # automatically removed from the app cache on test tear down. + self.local_models = [] + # isolated_local_models contains models that are in test methods + # decorated with @isolate_apps. + self.isolated_local_models = [] + + def tearDown(self): + # Delete any tables made for our models + self.delete_tables() + new_apps.clear_cache() + for model in new_apps.get_models(): + model._meta._expire_cache() + if "schema" in new_apps.all_models: + for model in self.local_models: + for many_to_many in model._meta.many_to_many: + through = many_to_many.remote_field.through + if through and through._meta.auto_created: + del new_apps.all_models["schema"][through._meta.model_name] + del new_apps.all_models["schema"][model._meta.model_name] + if self.isolated_local_models: + with connection.schema_editor() as editor: + for model in self.isolated_local_models: + editor.delete_model(model) + + def delete_tables(self): + "Deletes all model tables for our models for a clean test environment" + converter = connection.introspection.identifier_converter + with connection.schema_editor() as editor: + connection.disable_constraint_checking() + table_names = connection.introspection.table_names() + if connection.features.ignores_table_name_case: + table_names = [table_name.lower() for table_name in table_names] + for model in itertools.chain(SchemaTests.models, self.local_models): + tbl = converter(model._meta.db_table) + if connection.features.ignores_table_name_case: + tbl = tbl.lower() + if tbl in table_names: + editor.delete_model(model) + table_names.remove(tbl) + connection.enable_constraint_checking() + + def get_indexes(self, table): + """ + Get the indexes on the table using a new cursor. + """ + with connection.cursor() as cursor: + return [ + c["columns"][0] + for c in connection.introspection.get_constraints(cursor, table).values() + if c["index"] and len(c["columns"]) == 1 + ] + + def get_uniques(self, table): + with connection.cursor() as cursor: + return [ + c["columns"][0] + for c in connection.introspection.get_constraints(cursor, table).values() + if c["unique"] and len(c["columns"]) == 1 + ] + + def get_constraints(self, table): + """ + Get the constraints on a table using a new cursor. + """ + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + def get_constraints_for_columns(self, model, columns): + constraints = self.get_constraints(model._meta.db_table) + constraints_for_column = [] + for name, details in constraints.items(): + if details["columns"] == columns: + constraints_for_column.append(name) + return sorted(constraints_for_column) + + def check_added_field_default( + self, + schema_editor, + model, + field, + field_name, + expected_default, + cast_function=None, + ): + schema_editor.add_field(model, field) + database_default = connection.database[model._meta.db_table].find_one().get(field_name) + if cast_function and type(database_default) is not type(expected_default): + database_default = cast_function(database_default) + self.assertEqual(database_default, expected_default) + + def get_constraints_count(self, table, column, fk_to): + """ + Return a dict with keys 'fks', 'uniques, and 'indexes' indicating the + number of foreign keys, unique constraints, and indexes on + `table`.`column`. The `fk_to` argument is a 2-tuple specifying the + expected foreign key relationship's (table, column). + """ + with connection.cursor() as cursor: + constraints = connection.introspection.get_constraints(cursor, table) + counts = {"fks": 0, "uniques": 0, "indexes": 0} + for c in constraints.values(): + if c["columns"] == [column]: + if c["foreign_key"] == fk_to: + counts["fks"] += 1 + if c["unique"]: + counts["uniques"] += 1 + elif c["index"]: + counts["indexes"] += 1 + return counts + + def assertIndexOrder(self, table, index, order): + constraints = self.get_constraints(table) + self.assertIn(index, constraints) + index_orders = constraints[index]["orders"] + self.assertTrue( + all(val == expected for val, expected in zip(index_orders, order, strict=True)) + ) + + def assertForeignKeyExists(self, model, column, expected_fk_table, field="id"): + """ + Fail if the FK constraint on `model.Meta.db_table`.`column` to + `expected_fk_table`.id doesn't exist. + """ + if not connection.features.can_introspect_foreign_keys: + return + constraints = self.get_constraints(model._meta.db_table) + constraint_fk = None + for details in constraints.values(): + if details["columns"] == [column] and details["foreign_key"]: + constraint_fk = details["foreign_key"] + break + self.assertEqual(constraint_fk, (expected_fk_table, field)) + + def assertForeignKeyNotExists(self, model, column, expected_fk_table): + if not connection.features.can_introspect_foreign_keys: + return + with self.assertRaises(AssertionError): + self.assertForeignKeyExists(model, column, expected_fk_table) + + def assertTableExists(self, model): + self.assertIn(model._meta.db_table, connection.introspection.table_names()) + + def assertTableNotExists(self, model): + self.assertNotIn(model._meta.db_table, connection.introspection.table_names()) + + # SchemaEditor.create_model() tests + def test_db_index(self): + """Field(db_index=True) on an embedded model.""" + with connection.schema_editor() as editor: + # Create the table + editor.create_model(Book) + # The table is there + self.assertTableExists(Book) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.age"]), + ["schema__book_author.age_dc08100b"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.zip_code"]), + ["schema__book_author.address.zip_code_7b9a9307"], + ) + # Clean up that table + editor.delete_model(Book) + # The table is gone + self.assertTableNotExists(Author) + + def test_unique(self): + """Field(unique=True) on an embedded model.""" + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.employee_id"]), + ["schema__book_author.employee_id_7d4d3eff_uniq"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.uid"]), + ["schema__book_author.address.uid_8124a01f_uniq"], + ) + # Clean up that table + editor.delete_model(Book) + self.assertTableNotExists(Author) + + @ignore_warnings(category=RemovedInDjango51Warning) + @isolate_apps("schema_") + def test_index_together(self): + """Meta.index_together on an embedded model.""" + + class Address(models.Model): + index_together_one = models.CharField(max_length=10) + index_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_one", "index_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + index_together_three = models.CharField(max_length=10) + index_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_three", "index_together_four")] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.address.index_together_one", "author.address.index_together_two"] + ), + ["schema__add_index_t_efa93e_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.index_together_three", "author.index_together_four"], + ), + ["schema__aut_index_t_df32aa_idx"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_unique_together(self): + """Meta.unique_together on an embedded model.""" + + class Address(models.Model): + unique_together_one = models.CharField(max_length=10) + unique_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_one", "unique_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_together_three = models.CharField(max_length=10) + unique_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_three", "unique_together_four")] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + [ + "schema__author_author.unique_together_three_author.unique_together_four_39e1cb43_uniq" + ], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + [ + "schema__address_author.address.unique_together_one_author.address.unique_together_two_de682e30_uniq" + ], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_indexes(self): + """Meta.indexes on an embedded model.""" + + class Address(models.Model): + indexed_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_one"])] + + class Author(models.Model): + address = EmbeddedModelField(Address) + indexed_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_two"])] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + ["schema__aut_indexed_b19137_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + ["schema__add_indexed_b64972_idx"], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author) + + @isolate_apps("schema_") + def test_constraints(self): + """Meta.constraints on an embedded model.""" + + class Address(models.Model): + unique_constraint_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_one"], name="unique_one") + ] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_constraint_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_two"], name="unique_two") + ] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + ["unique_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + ["unique_one"], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author) + + # SchemaEditor.add_field() / remove_field() tests + @isolate_apps("schema_") + def test_add_remove_field_db_index_and_unique(self): + """AddField/RemoveField + EmbeddedModelField + Field(db_index=True) & Field(unique=True).""" + + class Book(models.Model): + name = models.CharField(max_length=100) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table amd add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.age"]), + ["schema__book_author.age_dc08100b"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.zip_code"]), + ["schema__book_author.address.zip_code_7b9a9307"], + ) + # Embedded uniques + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.employee_id"]), + ["schema__book_author.employee_id_7d4d3eff_uniq"], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.uid"]), + ["schema__book_author.address.uid_8124a01f_uniq"], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.age"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.zip_code"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.employee_id"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.address.uid"]), + [], + ) + editor.delete_model(Book) + self.assertTableNotExists(Author) + + @ignore_warnings(category=RemovedInDjango51Warning) + @isolate_apps("schema_") + def test_add_remove_field_index_together(self): + """AddField/RemoveField + EmbeddedModelField + Meta.index_together.""" + + class Address(models.Model): + index_together_one = models.CharField(max_length=10) + index_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_one", "index_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + index_together_three = models.CharField(max_length=10) + index_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + index_together = [("index_together_three", "index_together_four")] + + class Book(models.Model): + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table amd add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded index_togethers are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.address.index_together_one", "author.address.index_together_two"] + ), + ["schema__add_index_t_efa93e_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.index_together_three", "author.index_together_four"], + ), + ["schema__aut_index_t_df32aa_idx"], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.address.index_together_one", "author.address.index_together_two"] + ), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.index_together_three", "author.index_together_four"], + ), + [], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_unique_together(self): + """AddField/RemoveField + EmbeddedModelField + Meta.unique_together.""" + + class Address(models.Model): + unique_together_one = models.CharField(max_length=10) + unique_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_one", "unique_together_two")] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_together_three = models.CharField(max_length=10) + unique_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + unique_together = [("unique_together_three", "unique_together_four")] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded uniques are created. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + [ + "schema__author_author.unique_together_three_author.unique_together_four_39e1cb43_uniq" + ], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + [ + "schema__address_author.address.unique_together_one_author.address.unique_together_two_de682e30_uniq" + ], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + [], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_indexes(self): + """AddField/RemoveField + EmbeddedModelField + Meta.indexes.""" + + class Address(models.Model): + indexed_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_one"])] + + class Author(models.Model): + address = EmbeddedModelField(Address) + indexed_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + indexes = [models.Index(fields=["indexed_two"])] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + ["schema__aut_indexed_b19137_idx"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + ["schema__add_indexed_b64972_idx"], + ) + editor.remove_field(Book, new_field) + # Embedded indexes are removed. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + [], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author) + + @isolate_apps("schema_") + def test_add_remove_field_constraints(self): + """AddField/RemoveField + EmbeddedModelField + Meta.constraints.""" + + class Address(models.Model): + unique_constraint_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_one"], name="unique_one") + ] + + class Author(models.Model): + address = EmbeddedModelField(Address) + unique_constraint_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(fields=["unique_constraint_two"], name="unique_two") + ] + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded constraints are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + ["unique_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + ["unique_one"], + ) + editor.remove_field(Book, new_field) + # Embedded constraints are removed. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + [], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + [], + ) + editor.delete_model(Author) + self.assertTableNotExists(Author)