Skip to content

Commit a9e5106

Browse files
committed
Add size parameter
1 parent a345548 commit a9e5106

File tree

5 files changed

+63
-1
lines changed

5 files changed

+63
-1
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..forms import SimpleArrayField
1111
from ..query_utils import process_lhs, process_rhs
1212
from ..utils import prefix_validation_error
13+
from .validators import LengthValidator
1314

1415
__all__ = ["ArrayField"]
1516

@@ -27,14 +28,23 @@ class ArrayField(CheckFieldDefaultMixin, Field):
2728
}
2829
_default_hint = ("list", "[]")
2930

30-
def __init__(self, base_field, max_size=None, **kwargs):
31+
def __init__(self, base_field, max_size=None, size=None, **kwargs):
3132
self.base_field = base_field
3233
self.max_size = max_size
34+
self.size = size
35+
if size and max_size:
36+
raise ValueError("Cannot define both, size and max_size")
3337
if self.max_size:
3438
self.default_validators = [
3539
*self.default_validators,
3640
ArrayMaxLengthValidator(self.max_size),
3741
]
42+
if self.size:
43+
self.default_validators = [
44+
*self.default_validators,
45+
ArrayMaxLengthValidator(self.size),
46+
LengthValidator(self.size),
47+
]
3848
# For performance, only add a from_db_value() method if the base field
3949
# implements it.
4050
if hasattr(self.base_field, "from_db_value"):
@@ -127,6 +137,8 @@ def deconstruct(self):
127137
kwargs["base_field"] = self.base_field.clone()
128138
if self.max_size is not None:
129139
kwargs["max_size"] = self.max_size
140+
if self.size is not None:
141+
kwargs["size"] = self.size
130142
return name, path, args, kwargs
131143

132144
def to_python(self, value):
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from django.core.validators import BaseValidator
2+
from django.utils.deconstruct import deconstructible
3+
from django.utils.translation import ngettext_lazy
4+
5+
6+
@deconstructible
7+
class LengthValidator(BaseValidator):
8+
message = ngettext_lazy(
9+
"List contains %(show_value)d item, it should contain %(limit_value)d.",
10+
"List contains %(show_value)d items, it should contain %(limit_value)d.",
11+
"show_value",
12+
)
13+
code = "length"
14+
15+
def compare(self, a, b):
16+
return a != b
17+
18+
def clean(self, x):
19+
return len(x)

tests/model_fields_/test_arrayfield.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,23 @@ def test_with_max_size_singular(self):
818818
with self.assertRaisesMessage(exceptions.ValidationError, msg):
819819
field.clean([1, 2], None)
820820

821+
def test_with_size(self):
822+
field = ArrayField(models.IntegerField(), size=3)
823+
field.clean([1, 2, 3], None)
824+
with self.assertRaises(exceptions.ValidationError) as cm:
825+
field.clean([1, 2, 3, 4], None)
826+
self.assertEqual(
827+
cm.exception.messages[0],
828+
"List contains 4 items, it should contain 3.",
829+
)
830+
831+
def test_with_size_singular(self):
832+
field = ArrayField(models.IntegerField(), size=2)
833+
field.clean([1, 2], None)
834+
msg = "List contains 1 item, it should contain 2."
835+
with self.assertRaisesMessage(exceptions.ValidationError, msg):
836+
field.clean([1], None)
837+
821838
def test_nested_array_mismatch(self):
822839
field = ArrayField(ArrayField(models.IntegerField()))
823840
field.clean([[1, 2], [3, 4]], None)

tests/validators_/__init__.py

Whitespace-only changes.

tests/validators_/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from django.core.exceptions import ValidationError
2+
from django.test import SimpleTestCase
3+
4+
from django_mongodb_backend.fields.validators import LengthValidator
5+
6+
7+
class TestValidators(SimpleTestCase):
8+
def test_validators(self):
9+
validator = LengthValidator(10)
10+
with self.assertRaises(ValidationError):
11+
validator([])
12+
with self.assertRaises(ValidationError):
13+
validator(list(range(11)))
14+
self.assertEqual(validator(list(range(10))), None)

0 commit comments

Comments
 (0)