Skip to content

Commit 160e768

Browse files
committed
feat: ✨ Support recursive validation (maybe, untested)
1 parent 82c4942 commit 160e768

File tree

4 files changed

+107
-42
lines changed

4 files changed

+107
-42
lines changed

pydantic_async_validation/metaclasses.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111

1212
if TYPE_CHECKING:
13-
from pydantic_async_validation.validators import Validator
13+
from pydantic_async_validation.validators import ValidationInfo
1414

1515

1616
class AsyncValidationModelMetaclass(ModelMetaclass):
@@ -21,12 +21,12 @@ def __new__(
2121
namespace: dict[str, Any],
2222
**kwargs: Any,
2323
) -> Any:
24-
async_field_validators: List[Tuple[List[str], Validator]] = []
25-
async_model_validators: List[Validator] = []
24+
async_field_validators: List[Tuple[List[str], ValidationInfo]] = []
25+
async_model_validators: List[ValidationInfo] = []
2626

2727
async_field_validator_fields: Optional[List[str]]
28-
async_field_validator_config: "Optional[Validator]"
29-
async_model_validator_config: "Optional[Validator]"
28+
async_field_validator_config: "Optional[ValidationInfo]"
29+
async_model_validator_config: "Optional[ValidationInfo]"
3030

3131
for base in bases:
3232
async_field_validators += getattr(

pydantic_async_validation/mixins.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import ClassVar, List, Tuple
1+
from typing import ClassVar, List, Tuple, Union, cast
22

33
import pydantic
4-
from pydantic_core import ErrorDetails, PydanticCustomError, ValidationError
4+
from pydantic_core import InitErrorDetails, PydanticCustomError, ValidationError
55

66
from pydantic_async_validation.constants import (
77
ASYNC_FIELD_VALIDATOR_CONFIG_KEY,
@@ -10,16 +10,17 @@
1010
ASYNC_MODEL_VALIDATORS_KEY,
1111
)
1212
from pydantic_async_validation.metaclasses import AsyncValidationModelMetaclass
13-
from pydantic_async_validation.validators import Validator
13+
from pydantic_async_validation.utils import prefix_errors
14+
from pydantic_async_validation.validators import ValidationInfo
1415

1516

1617
class AsyncValidationModelMixin(
1718
pydantic.BaseModel,
1819
metaclass=AsyncValidationModelMetaclass,
1920
):
2021
# MUST match names defined in constants.py!
21-
pydantic_model_async_field_validators: ClassVar[List[Tuple[List[str], Validator]]]
22-
pydantic_model_async_model_validators: ClassVar[List[Validator]]
22+
pydantic_model_async_field_validators: ClassVar[List[Tuple[List[str], ValidationInfo]]]
23+
pydantic_model_async_model_validators: ClassVar[List[ValidationInfo]]
2324

2425
async def model_async_validate(self) -> None:
2526
"""
@@ -29,60 +30,99 @@ async def model_async_validate(self) -> None:
2930
collected and raised as a `ValidationError` exception.
3031
"""
3132
field_names: list[str]
32-
validator: Validator
33+
field_validator: ValidationInfo
34+
model_validator: ValidationInfo
3335

3436
validation_errors = []
35-
validators = getattr(self, ASYNC_FIELD_VALIDATORS_KEY, [])
36-
root_validators = getattr(self, ASYNC_MODEL_VALIDATORS_KEY, [])
37+
field_validators = getattr(self, ASYNC_FIELD_VALIDATORS_KEY, [])
38+
model_validators = getattr(self, ASYNC_MODEL_VALIDATORS_KEY, [])
3739

38-
for validator_attr in validators:
39-
field_names, validator = getattr(
40-
validator_attr,
40+
# Call all field validators
41+
for field_validator_attr in field_validators:
42+
field_names, field_validator = getattr(
43+
field_validator_attr,
4144
ASYNC_FIELD_VALIDATOR_CONFIG_KEY,
4245
)
4346
for field_name in field_names:
4447
try:
45-
await validator.func(
48+
await field_validator.func(
4649
self,
4750
getattr(self, field_name, None),
4851
field_name,
49-
validator,
52+
field_validator,
5053
)
5154
except (ValueError, TypeError, AssertionError) as o_O:
5255
validation_errors.append(
53-
ErrorDetails(
54-
type=PydanticCustomError('value_error', str(o_O)),
55-
msg=str(o_O),
56+
InitErrorDetails(
57+
type=PydanticCustomError('value_error', str(o_O)), # type: ignore
5658
loc=(field_name,),
5759
input=getattr(self, field_name, None),
5860
),
5961
)
6062

61-
for validator_attr in root_validators:
62-
validator = getattr(
63-
validator_attr,
63+
# Call all model validators
64+
for model_validator_attr in model_validators:
65+
model_validator = getattr(
66+
model_validator_attr,
6467
ASYNC_MODEL_VALIDATOR_CONFIG_KEY,
6568
)
6669
try:
67-
await validator.func(
70+
await model_validator.func(
6871
self,
69-
validator,
72+
model_validator,
7073
)
7174
except (ValueError, TypeError, AssertionError) as o_O:
7275
validation_errors.append(
73-
ErrorDetails(
74-
type=PydanticCustomError('value_error', str(o_O)),
75-
msg=str(o_O),
76+
InitErrorDetails(
77+
type=PydanticCustomError('value_error', str(o_O)), # type: ignore
7678
loc=('__root__',),
7779
input=self.__dict__,
7880
),
7981
)
8082

81-
# TODO:
82-
# for attribute_name, attribute_value in self.__dict__.items():
83-
# if isinstance(attribute_value, AsyncValidationModelMixin):
84-
# await attribute_value.model_async_validate()
83+
# Also call async validation on attribute values
84+
async def extend_with_validation_errors_by(
85+
prefix: Tuple[Union[int, str], ...],
86+
instance: AsyncValidationModelMixin,
87+
) -> None:
88+
try:
89+
await instance.model_async_validate()
90+
except ValidationError as O_o:
91+
validation_errors.extend(
92+
prefix_errors(
93+
prefix,
94+
cast(
95+
List[InitErrorDetails],
96+
O_o.errors(),
97+
),
98+
),
99+
)
100+
101+
for attribute_name, attribute_value in self.__dict__.items():
102+
# Direct child instance
103+
if isinstance(attribute_value, AsyncValidationModelMixin):
104+
await extend_with_validation_errors_by(
105+
(attribute_name,),
106+
attribute_value,
107+
)
108+
# List of child instances
109+
if isinstance(attribute_value, list):
110+
for index, item in enumerate(attribute_value):
111+
if isinstance(item, AsyncValidationModelMixin):
112+
await extend_with_validation_errors_by(
113+
(attribute_name, index),
114+
item,
115+
)
116+
# Dict of child instances
117+
if isinstance(attribute_value, dict):
118+
for key, item in attribute_value.items():
119+
if isinstance(item, AsyncValidationModelMixin):
120+
await extend_with_validation_errors_by(
121+
(attribute_name, key),
122+
item,
123+
)
85124

125+
# If some errors did occur, raise them as a ValidationError
86126
if len(validation_errors) > 0:
87127
raise ValidationError.from_exception_data(
88128
self.__class__.__name__,

pydantic_async_validation/utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from functools import wraps
22
from inspect import Signature, signature
3-
from typing import Callable
3+
from typing import Callable, List, Tuple, Union, cast
44

55
from pydantic import PydanticUserError
6+
from pydantic_core import InitErrorDetails
67

78

89
def make_generic_field_validator(validator_func: Callable) -> Callable:
@@ -143,3 +144,27 @@ def generic_model_validator_wrapper(
143144
return lambda self, config: validator_func(
144145
self, config=config,
145146
)
147+
148+
149+
def prefix_errors(
150+
prefix: Tuple[Union[int, str], ...],
151+
errors: List[InitErrorDetails],
152+
) -> List[InitErrorDetails]:
153+
"""
154+
Extend all errors passed as list to include an additional prefix.
155+
156+
This is used to prefix errors occuring in child classes to include the parents
157+
field details in the error locations.
158+
"""
159+
160+
return [
161+
cast(
162+
InitErrorDetails,
163+
{
164+
**error,
165+
'loc': (*prefix, *error['loc']),
166+
},
167+
)
168+
for error
169+
in errors
170+
]

pydantic_async_validation/validators.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from types import FunctionType
2-
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
2+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
33

44
from pydantic.errors import PydanticUserError
55

@@ -27,7 +27,7 @@
2727
ValidatorListDict = "dict[str, list[Validator]]"
2828

2929

30-
class Validator:
30+
class ValidationInfo:
3131
"""Helper / data class to store validator information."""
3232

3333
__slots__ = ('func', 'extra')
@@ -48,7 +48,7 @@ def async_field_validator(
4848
/,
4949
*additional_field_names: str,
5050
**extra: Any,
51-
) -> Callable[[Callable], classmethod]:
51+
) -> Callable[[Callable], Callable]:
5252
"""
5353
Decorate methods on a model indicating that they should be used to validate data.
5454
@@ -66,13 +66,13 @@ def async_field_validator(
6666

6767
field_names: Tuple[str, ...] = __field_name, *additional_field_names
6868

69-
def dec(func: Callable) -> classmethod:
69+
def dec(func: Callable) -> Callable:
7070
setattr(
7171
func,
7272
ASYNC_FIELD_VALIDATOR_CONFIG_KEY,
7373
(
7474
field_names,
75-
Validator(
75+
ValidationInfo(
7676
func=make_generic_field_validator(func),
7777
extra=extra,
7878
),
@@ -85,19 +85,19 @@ def dec(func: Callable) -> classmethod:
8585

8686
def async_model_validator(
8787
**extra: Any,
88-
) -> Union[classmethod, Callable[[Callable], classmethod]]:
88+
) -> Callable[[Callable], Callable]:
8989
"""
9090
Decorate methods on a model indicating that they should be used to validate data.
9191
9292
This decorator allows you to assign your validation
9393
function to the whole model (root validator).
9494
"""
9595

96-
def dec(func: Callable) -> classmethod:
96+
def dec(func: Callable) -> Callable:
9797
setattr(
9898
func,
9999
ASYNC_MODEL_VALIDATOR_CONFIG_KEY,
100-
Validator(
100+
ValidationInfo(
101101
func=make_generic_model_validator(func),
102102
extra=extra,
103103
),

0 commit comments

Comments
 (0)