diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index baf33fc04..e21afdcfa 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -1,6 +1,5 @@ import json -from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value from django.db.models.fields.mixins import CheckFieldDefaultMixin @@ -10,6 +9,7 @@ from ..forms import SimpleArrayField from ..query_utils import process_lhs, process_rhs from ..utils import prefix_validation_error +from ..validators import ArrayMaxLengthValidator, LengthValidator __all__ = ["ArrayField"] @@ -27,14 +27,20 @@ class ArrayField(CheckFieldDefaultMixin, Field): } _default_hint = ("list", "[]") - def __init__(self, base_field, max_size=None, **kwargs): + def __init__(self, base_field, max_size=None, size=None, **kwargs): self.base_field = base_field self.max_size = max_size + self.size = size if self.max_size: self.default_validators = [ *self.default_validators, ArrayMaxLengthValidator(self.max_size), ] + if self.size: + self.default_validators = [ + *self.default_validators, + LengthValidator(self.size), + ] # For performance, only add a from_db_value() method if the base field # implements it. if hasattr(self.base_field, "from_db_value"): @@ -98,6 +104,14 @@ def check(self, **kwargs): id="django_mongodb_backend.array.W004", ) ) + if self.size and self.max_size: + errors.append( + checks.Error( + "ArrayField cannot specify both size and max_size.", + obj=self, + id="django_mongodb_backend.array.E003", + ) + ) return errors def set_attributes_from_name(self, name): @@ -127,6 +141,8 @@ def deconstruct(self): kwargs["base_field"] = self.base_field.clone() if self.max_size is not None: kwargs["max_size"] = self.max_size + if self.size is not None: + kwargs["size"] = self.size return name, path, args, kwargs def to_python(self, value): @@ -211,6 +227,7 @@ def formfield(self, **kwargs): "form_class": SimpleArrayField, "base_field": self.base_field.formfield(), "max_length": self.max_size, + "length": self.size, **kwargs, } ) diff --git a/django_mongodb_backend/forms/fields/array.py b/django_mongodb_backend/forms/fields/array.py index 0de48dff4..854508cc5 100644 --- a/django_mongodb_backend/forms/fields/array.py +++ b/django_mongodb_backend/forms/fields/array.py @@ -2,11 +2,11 @@ from itertools import chain from django import forms -from django.core.exceptions import ValidationError +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.utils.translation import gettext_lazy as _ from ...utils import prefix_validation_error -from ...validators import ArrayMaxLengthValidator, ArrayMinLengthValidator +from ...validators import ArrayMaxLengthValidator, ArrayMinLengthValidator, LengthValidator class SimpleArrayField(forms.CharField): @@ -14,16 +14,26 @@ class SimpleArrayField(forms.CharField): "item_invalid": _("Item %(nth)s in the array did not validate:"), } - def __init__(self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs): + def __init__( + self, base_field, *, delimiter=",", max_length=None, min_length=None, length=None, **kwargs + ): self.base_field = base_field self.delimiter = delimiter super().__init__(**kwargs) + if (min_length is not None or max_length is not None) and length is not None: + invalid_param = "max_length" if max_length is not None else "min_length" + raise ImproperlyConfigured( + f"The length and {invalid_param} parameters are mutually exclusive." + ) if min_length is not None: self.min_length = min_length self.validators.append(ArrayMinLengthValidator(int(min_length))) if max_length is not None: self.max_length = max_length self.validators.append(ArrayMaxLengthValidator(int(max_length))) + if length is not None: + self.length = length + self.validators.append(LengthValidator(int(length))) def clean(self, value): value = super().clean(value) diff --git a/django_mongodb_backend/validators.py b/django_mongodb_backend/validators.py index 6005152e8..5ca6cbe23 100644 --- a/django_mongodb_backend/validators.py +++ b/django_mongodb_backend/validators.py @@ -1,4 +1,5 @@ -from django.core.validators import MaxLengthValidator, MinLengthValidator +from django.core.validators import BaseValidator, MaxLengthValidator, MinLengthValidator +from django.utils.deconstruct import deconstructible from django.utils.translation import ngettext_lazy @@ -16,3 +17,19 @@ class ArrayMinLengthValidator(MinLengthValidator): "List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.", "show_value", ) + + +@deconstructible +class LengthValidator(BaseValidator): + message = ngettext_lazy( + "List contains %(show_value)d item, it should contain %(limit_value)d.", + "List contains %(show_value)d items, it should contain %(limit_value)d.", + "show_value", + ) + code = "length" + + def compare(self, a, b): + return a != b + + def clean(self, x): + return len(x) diff --git a/docs/source/ref/forms.rst b/docs/source/ref/forms.rst index 64c42755d..934af20e3 100644 --- a/docs/source/ref/forms.rst +++ b/docs/source/ref/forms.rst @@ -33,7 +33,7 @@ Stores an :class:`~bson.objectid.ObjectId`. ``SimpleArrayField`` -------------------- -.. class:: SimpleArrayField(base_field, delimiter=',', max_length=None, min_length=None) +.. class:: SimpleArrayField(base_field, delimiter=',', length=None, max_length=None, min_length=None) A field which maps to an array. It is represented by an HTML ````. @@ -91,6 +91,14 @@ Stores an :class:`~bson.objectid.ObjectId`. in cases where the delimiter is a valid character in the underlying field. The delimiter does not need to be only one character. + .. attribute:: length + + This is an optional argument which validates that the array contains + the stated number of items. + + ``length`` may not be specified along with ``max_length`` or + ``min_length``. + .. attribute:: max_length This is an optional argument which validates that the array does not diff --git a/docs/source/ref/models/fields.rst b/docs/source/ref/models/fields.rst index fa0672dc1..47a3149c6 100644 --- a/docs/source/ref/models/fields.rst +++ b/docs/source/ref/models/fields.rst @@ -8,13 +8,12 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. ``ArrayField`` -------------- -.. class:: ArrayField(base_field, max_size=None, **options) +.. class:: ArrayField(base_field, max_size=None, size=None, **options) A field for storing lists of data. Most field types can be used, and you - pass another field instance as the :attr:`base_field - `. You may also specify a :attr:`max_size - `. ``ArrayField`` can be nested to store - multi-dimensional arrays. + pass another field instance as the :attr:`~ArrayField.base_field`. You may + also specify a :attr:`~ArrayField.size` or :attr:`~ArrayField.max_size`. + ``ArrayField`` can be nested to store multi-dimensional arrays. If you give the field a :attr:`~django.db.models.Field.default`, ensure it's a callable such as ``list`` (for an empty default) or a callable that @@ -50,9 +49,9 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. board = ArrayField( ArrayField( models.CharField(max_length=10, blank=True), - max_size=8, + size=8, ), - max_size=8, + size=8, ) Transformation of values between the database and the model, validation @@ -66,6 +65,15 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. If passed, the array will have a maximum size as specified, validated by forms and model validation, but not enforced by the database. + The ``max_size`` and ``size`` options are mutually exclusive. + + .. attribute:: size + + This is an optional argument. + + If passed, the array will have size as specified, validated by forms + and model validation, but not enforced by the database. + Querying ``ArrayField`` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/releases/5.1.x.rst b/docs/source/releases/5.1.x.rst index 9690943e2..2e36e6d8d 100644 --- a/docs/source/releases/5.1.x.rst +++ b/docs/source/releases/5.1.x.rst @@ -8,7 +8,9 @@ Django MongoDB Backend 5.1.x *Unreleased* - Backward-incompatible: :class:`~django_mongodb_backend.fields.ArrayField`\'s - ``size`` argument is renamed to ``max_size``. + :attr:`~.ArrayField.size` parameter is renamed to + :attr:`~.ArrayField.max_size`. The :attr:`~.ArrayField.size` parameter is now + used to enforce fixed-length arrays. - Added support for :doc:`database caching `. - Fixed ``QuerySet.raw_aggregate()`` field initialization when the document key order doesn't match the order of the model's fields. diff --git a/tests/forms_tests_/test_array.py b/tests/forms_tests_/test_array.py index e0107943a..1a496019e 100644 --- a/tests/forms_tests_/test_array.py +++ b/tests/forms_tests_/test_array.py @@ -115,12 +115,35 @@ def test_min_length_singular(self): with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean([1]) + def test_size_length(self): + field = SimpleArrayField(forms.CharField(max_length=27), length=4) + msg = "List contains 3 items, it should contain 4." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean(["a", "b", "c"]) + msg = "List contains 5 items, it should contain 4." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean(["a", "b", "c", "d", "e"]) + + def test_size_length_singular(self): + field = SimpleArrayField(forms.CharField(max_length=27), length=4) + msg = "List contains 1 item, it should contain 4." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean(["a"]) + def test_required(self): field = SimpleArrayField(forms.CharField(), required=True) with self.assertRaises(exceptions.ValidationError) as cm: field.clean("") self.assertEqual(cm.exception.messages[0], "This field is required.") + def test_length_and_max_min_length(self): + msg = "The length and max_length parameters are mutually exclusive." + with self.assertRaisesMessage(exceptions.ImproperlyConfigured, msg): + SimpleArrayField(forms.CharField(), max_length=3, length=2) + msg = "The length and min_length parameters are mutually exclusive." + with self.assertRaisesMessage(exceptions.ImproperlyConfigured, msg): + SimpleArrayField(forms.CharField(), min_length=3, length=2) + def test_model_field_formfield(self): model_field = ArrayField(models.CharField(max_length=27)) form_field = model_field.formfield() @@ -134,6 +157,12 @@ def test_model_field_formfield_max_size(self): self.assertIsInstance(form_field, SimpleArrayField) self.assertEqual(form_field.max_length, 4) + def test_model_field_formfield_size(self): + model_field = ArrayField(models.CharField(max_length=27), size=4) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertEqual(form_field.length, 4) + def test_model_field_choices(self): model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B")))) form_field = model_field.formfield() diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py index e1537b970..08d1d8eee 100644 --- a/tests/model_fields_/test_arrayfield.py +++ b/tests/model_fields_/test_arrayfield.py @@ -646,6 +646,15 @@ class MyModel(models.Model): self.assertEqual(len(errors), 1) self.assertEqual(errors[0].id, "django_mongodb_backend.array.E002") + def test_both_size_and_max_size(self): + class MyModel(models.Model): + field = ArrayField(models.CharField(max_length=3), size=3, max_size=4) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.array.E003") + def test_invalid_default(self): class MyModel(models.Model): field = ArrayField(models.IntegerField(), default=[]) @@ -818,6 +827,20 @@ def test_with_max_size_singular(self): with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean([1, 2], None) + def test_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + field.clean([1, 2, 3], None) + msg = "List contains 4 items, it should contain 3." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean([1, 2, 3, 4], None) + + def test_with_size_singular(self): + field = ArrayField(models.IntegerField(), size=2) + field.clean([1, 2], None) + msg = "List contains 1 item, it should contain 2." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean([1], None) + def test_nested_array_mismatch(self): field = ArrayField(ArrayField(models.IntegerField())) field.clean([[1, 2], [3, 4]], None) diff --git a/tests/validators_/__init__.py b/tests/validators_/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/validators_/tests.py b/tests/validators_/tests.py new file mode 100644 index 000000000..09a549490 --- /dev/null +++ b/tests/validators_/tests.py @@ -0,0 +1,31 @@ +from django.core.exceptions import ValidationError +from django.test import SimpleTestCase + +from django_mongodb_backend.validators import LengthValidator + + +class TestLengthValidator(SimpleTestCase): + validator = LengthValidator(10) + + def test_empty(self): + msg = "List contains 0 items, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator([]) + + def test_singular(self): + msg = "List contains 1 item, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator([1]) + + def test_too_short(self): + msg = "List contains 9 items, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator([1, 2, 3, 4, 5, 6, 7, 8, 9]) + + def test_too_long(self): + msg = "List contains 11 items, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator(list(range(11))) + + def test_valid(self): + self.assertEqual(self.validator(list(range(10))), None)