Skip to content
21 changes: 19 additions & 2 deletions django_mongodb_backend/fields/array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json

from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value
from django.db.models.fields.mixins import CheckFieldDefaultMixin
Expand All @@ -10,6 +9,7 @@
from ..forms import SimpleArrayField
from ..query_utils import process_lhs, process_rhs
from ..utils import prefix_validation_error
from ..validators import ArrayMaxLengthValidator, LengthValidator

__all__ = ["ArrayField"]

Expand All @@ -27,14 +27,20 @@ class ArrayField(CheckFieldDefaultMixin, Field):
}
_default_hint = ("list", "[]")

def __init__(self, base_field, max_size=None, **kwargs):
def __init__(self, base_field, max_size=None, size=None, **kwargs):
self.base_field = base_field
self.max_size = max_size
self.size = size
if self.max_size:
self.default_validators = [
*self.default_validators,
ArrayMaxLengthValidator(self.max_size),
]
if self.size:
self.default_validators = [
*self.default_validators,
LengthValidator(self.size),
]
# For performance, only add a from_db_value() method if the base field
# implements it.
if hasattr(self.base_field, "from_db_value"):
Expand Down Expand Up @@ -98,6 +104,14 @@ def check(self, **kwargs):
id="django_mongodb_backend.array.W004",
)
)
if self.size and self.max_size:
errors.append(
checks.Error(
"ArrayField cannot specify both size and max_size.",
obj=self,
id="django_mongodb_backend.array.E003",
)
)
return errors

def set_attributes_from_name(self, name):
Expand Down Expand Up @@ -127,6 +141,8 @@ def deconstruct(self):
kwargs["base_field"] = self.base_field.clone()
if self.max_size is not None:
kwargs["max_size"] = self.max_size
if self.size is not None:
kwargs["size"] = self.size
return name, path, args, kwargs

def to_python(self, value):
Expand Down Expand Up @@ -211,6 +227,7 @@ def formfield(self, **kwargs):
"form_class": SimpleArrayField,
"base_field": self.base_field.formfield(),
"max_length": self.max_size,
"size": self.size,
**kwargs,
}
)
Expand Down
16 changes: 13 additions & 3 deletions django_mongodb_backend/forms/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,38 @@
from itertools import chain

from django import forms
from django.core.exceptions import ValidationError
from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.utils.translation import gettext_lazy as _

from ...utils import prefix_validation_error
from ...validators import ArrayMaxLengthValidator, ArrayMinLengthValidator
from ...validators import ArrayMaxLengthValidator, ArrayMinLengthValidator, LengthValidator


class SimpleArrayField(forms.CharField):
default_error_messages = {
"item_invalid": _("Item %(nth)s in the array did not validate:"),
}

def __init__(self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs):
def __init__(
self, base_field, *, delimiter=",", max_length=None, min_length=None, size=None, **kwargs
):
self.base_field = base_field
self.delimiter = delimiter
super().__init__(**kwargs)
if (min_length is not None or max_length is not None) and size is not None:
raise ImproperlyConfigured(
"SimpleArrayField param 'size' cannot be "
"specified with 'max_length' or 'min_length'."
)
if min_length is not None:
self.min_length = min_length
self.validators.append(ArrayMinLengthValidator(int(min_length)))
if max_length is not None:
self.max_length = max_length
self.validators.append(ArrayMaxLengthValidator(int(max_length)))
if size is not None:
self.size = size
self.validators.append(LengthValidator(int(size)))

def clean(self, value):
value = super().clean(value)
Expand Down
19 changes: 18 additions & 1 deletion django_mongodb_backend/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.core.validators import MaxLengthValidator, MinLengthValidator
from django.core.validators import BaseValidator, MaxLengthValidator, MinLengthValidator
from django.utils.deconstruct import deconstructible
from django.utils.translation import ngettext_lazy


Expand All @@ -16,3 +17,19 @@ class ArrayMinLengthValidator(MinLengthValidator):
"List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.",
"show_value",
)


@deconstructible
class LengthValidator(BaseValidator):
message = ngettext_lazy(
"List contains %(show_value)d item, it should contain %(limit_value)d.",
"List contains %(show_value)d items, it should contain %(limit_value)d.",
"show_value",
)
code = "length"

def compare(self, a, b):
return a != b

def clean(self, x):
return len(x)
9 changes: 9 additions & 0 deletions docs/source/ref/forms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ Stores an :class:`~bson.objectid.ObjectId`.
This is an optional argument which validates that the array reaches at
least the stated length.

.. attribute:: size

This is an optional argument which validates that the array reaches at
exactly the stated length.

.. note::
Defining ``size`` along with ``max_length`` or ``min_length`` will raise an exception.
Use ``size`` for fixed-length arrays and ``max_length`` / ``min_length`` for variable-length arrays with an upper or lower limit.

.. admonition:: User friendly forms

``SimpleArrayField`` is not particularly user friendly in most cases,
Expand Down
24 changes: 20 additions & 4 deletions docs/source/ref/models/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``.
``ArrayField``
--------------

.. class:: ArrayField(base_field, max_size=None, **options)
.. class:: ArrayField(base_field, max_size=None, size=None, **options)

A field for storing lists of data. Most field types can be used, and you
pass another field instance as the :attr:`base_field
<ArrayField.base_field>`. You may also specify a :attr:`max_size
<ArrayField.max_size>`. ``ArrayField`` can be nested to store
<ArrayField.max_size>` or :attr:`size
<ArrayField.size>`. ``ArrayField`` can be nested to store
multi-dimensional arrays.

If you give the field a :attr:`~django.db.models.Field.default`, ensure
Expand Down Expand Up @@ -50,9 +51,13 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``.
board = ArrayField(
ArrayField(
models.CharField(max_length=10, blank=True),
max_size=8,
size=8,
),
max_size=8,
size=8,
)
active_pieces = ArrayField(
models.CharField(max_length=10, blank=True),
max_size=32
)

Transformation of values between the database and the model, validation
Expand All @@ -66,6 +71,17 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``.
If passed, the array will have a maximum size as specified, validated
by forms and model validation, but not enforced by the database.

.. attribute:: size

This is an optional argument.

If passed, the array will have size as specified, validated by forms and model validation, but not enforced by the database.

.. note::

Defining both ``size`` and ``max_size`` will raise an exception.
Use ``size`` for fixed-length arrays and ``max_size`` for variable-length arrays with an upper limit.

Querying ``ArrayField``
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions docs/source/releases/5.1.x.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Django MongoDB Backend 5.1.x

- Backward-incompatible: :class:`~django_mongodb_backend.fields.ArrayField`\'s
``size`` argument is renamed to ``max_size``.
- Added the ``size`` parameter to :class:`~django_mongodb_backend.fields.ArrayField`
for enforcing fixed-length arrays.
- Added support for :doc:`database caching </topics/cache>`.
- Fixed ``QuerySet.raw_aggregate()`` field initialization when the document key
order doesn't match the order of the model's fields.
Expand Down
39 changes: 39 additions & 0 deletions tests/forms_tests_/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,45 @@ def test_min_length_singular(self):
with self.assertRaisesMessage(exceptions.ValidationError, msg):
field.clean([1])

def test_size_length(self):
field = SimpleArrayField(forms.CharField(max_length=27), size=4)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean(["a", "b", "c"])
self.assertEqual(
cm.exception.messages[0],
"List contains 3 items, it should contain 4.",
)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean(["a", "b", "c", "d", "e"])
self.assertEqual(
cm.exception.messages[0],
"List contains 5 items, it should contain 4.",
)

def test_size_length_singular(self):
field = SimpleArrayField(forms.CharField(max_length=27), size=4)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean(["a"])
self.assertEqual(
cm.exception.messages[0],
"List contains 1 item, it should contain 4.",
)

def test_required(self):
field = SimpleArrayField(forms.CharField(), required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean("")
self.assertEqual(cm.exception.messages[0], "This field is required.")

def test_misconfigured(self):
msg = "SimpleArrayField param 'size' cannot be specified with 'max_length' or 'min_length'."
with self.assertRaises(exceptions.ImproperlyConfigured) as cm:
SimpleArrayField(forms.CharField(), max_length=3, size=2)
self.assertEqual(cm.exception.args[0], msg)
with self.assertRaises(exceptions.ImproperlyConfigured) as cm:
SimpleArrayField(forms.CharField(), min_length=3, size=2)
self.assertEqual(cm.exception.args[0], msg)

def test_model_field_formfield(self):
model_field = ArrayField(models.CharField(max_length=27))
form_field = model_field.formfield()
Expand All @@ -134,6 +167,12 @@ def test_model_field_formfield_max_size(self):
self.assertIsInstance(form_field, SimpleArrayField)
self.assertEqual(form_field.max_length, 4)

def test_model_field_formfield_size(self):
model_field = ArrayField(models.CharField(max_length=27), size=4)
form_field = model_field.formfield()
self.assertIsInstance(form_field, SimpleArrayField)
self.assertEqual(form_field.size, 4)

def test_model_field_choices(self):
model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B"))))
form_field = model_field.formfield()
Expand Down
26 changes: 26 additions & 0 deletions tests/model_fields_/test_arrayfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,15 @@ class MyModel(models.Model):
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, "django_mongodb_backend.array.E002")

def test_both_size_and_max_size(self):
class MyModel(models.Model):
field = ArrayField(models.CharField(max_length=3), size=3, max_size=4)

model = MyModel()
errors = model.check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, "django_mongodb_backend.array.E003")

def test_invalid_default(self):
class MyModel(models.Model):
field = ArrayField(models.IntegerField(), default=[])
Expand Down Expand Up @@ -818,6 +827,23 @@ def test_with_max_size_singular(self):
with self.assertRaisesMessage(exceptions.ValidationError, msg):
field.clean([1, 2], None)

def test_with_size(self):
field = ArrayField(models.IntegerField(), size=3)
field.clean([1, 2, 3], None)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean([1, 2, 3, 4], None)
self.assertEqual(
cm.exception.messages[0],
"List contains 4 items, it should contain 3.",
)

def test_with_size_singular(self):
field = ArrayField(models.IntegerField(), size=2)
field.clean([1, 2], None)
msg = "List contains 1 item, it should contain 2."
with self.assertRaisesMessage(exceptions.ValidationError, msg):
field.clean([1], None)

def test_nested_array_mismatch(self):
field = ArrayField(ArrayField(models.IntegerField()))
field.clean([[1, 2], [3, 4]], None)
Expand Down
Empty file added tests/validators_/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions tests/validators_/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from django.core.exceptions import ValidationError
from django.test import SimpleTestCase

from django_mongodb_backend.fields.validators import LengthValidator


class TestValidators(SimpleTestCase):
def test_validators(self):
validator = LengthValidator(10)
with self.assertRaises(ValidationError):
validator([])
with self.assertRaises(ValidationError):
validator(list(range(11)))
self.assertEqual(validator(list(range(10))), None)
Loading