diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index 56ebeaf37..f75ce2a44 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -5,6 +5,7 @@ from tests.testmodels import ( Address, + Author, CamelCaseAliasPerson, Employee, EnumFields, @@ -13,6 +14,7 @@ JSONFields, ModelTestPydanticMetaBackwardRelations1, ModelTestPydanticMetaBackwardRelations2, + Node, Reporter, Team, Tournament, @@ -48,6 +50,10 @@ class PydanticMetaOverride: self.Event_Pydantic_non_backward_from_override = pydantic_model_creator( Event, meta_override=PydanticMetaOverride, name="Event_non_backward" ) + self.Author_Pydantic = pydantic_model_creator(Author, meta_override=PydanticMetaOverride) + self.Node_Pydantic = pydantic_model_creator( + Node, meta_override=PydanticMetaOverride, exclude=("o2opkmodelwithm2ms",) + ) self.tournament = await Tournament.create(name="New Tournament") self.reporter = await Reporter.create(name="The Reporter") @@ -62,6 +68,24 @@ class PydanticMetaOverride: await self.event2.participants.add(self.team1, self.team2) self.maxDiff = None + async def test_with_default_but_not_null(self): + author_data = self.Author_Pydantic.model_validate({"id": 1}).model_dump() + assert author_data == {"id": 1, "name": ""} + node_data = self.Node_Pydantic.model_validate({"id": 1}).model_dump() + assert node_data["id"] == 1 and node_data["name"] and isinstance(node_data["name"], str) + info = { + "input": None, + "type": "string_type", + "loc": ("name",), + "msg": "Input should be a valid string", + } + with self.assertRaises(ValidationError) as cm: + self.Author_Pydantic.model_validate({"id": 1, "name": None}) + self.assertEqual([info], cm.exception.errors(include_url=False)) + with self.assertRaises(ValidationError) as cm: + self.Node_Pydantic.model_validate({"id": 1, "name": None}) + self.assertEqual([info], cm.exception.errors(include_url=False)) + async def test_backward_relations_with_meta_override(self): event_schema = copy.deepcopy(dict(self.Event_Pydantic.model_json_schema())) event_non_backward_schema_by_override = copy.deepcopy( diff --git a/tests/testmodels.py b/tests/testmodels.py index c0b95f352..d02a66e86 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -14,6 +14,7 @@ from typing import Union import pytz +from anyio.lowlevel import checkpoint from pydantic import BaseModel, ConfigDict from tortoise import fields @@ -33,10 +34,15 @@ ) -def generate_token(): +def generate_token() -> str: return binascii.hexlify(os.urandom(16)).decode("ascii") +async def generate_unique_string() -> str: + await checkpoint() + return uuid.uuid4().hex[:10] + + class TestSchemaForJSONField(BaseModel): foo: int bar: str @@ -47,7 +53,7 @@ class TestSchemaForJSONField(BaseModel): class Author(Model): - name = fields.CharField(max_length=255) + name = fields.CharField(max_length=255, default="", null=False) class Book(Model): @@ -151,7 +157,7 @@ class ModelTestPydanticMetaBackwardRelations3(Model): class Node(Model): - name = fields.CharField(max_length=10) + name = fields.CharField(max_length=10, default=generate_unique_string, null=False) class Tree(Model): @@ -333,7 +339,7 @@ class FloatFields(Model): floatnum_null = fields.FloatField(null=True) -def raise_if_not_dict_or_list(value: dict | list): +def raise_if_not_dict_or_list(value: dict | list) -> None: if not isinstance(value, (dict, list)): raise ValidationError("Value must be a dict or list.") @@ -570,7 +576,7 @@ class Employee(Model): def __str__(self): return self.name - async def full_hierarchy__async_for(self, level=0): + async def full_hierarchy__async_for(self, level=0) -> str: """ Demonstrates ``async for` to fetch relations @@ -588,7 +594,7 @@ async def full_hierarchy__async_for(self, level=0): text.append(await member.full_hierarchy__async_for(level + 1)) return "\n".join(text) - async def full_hierarchy__fetch_related(self, level=0): + async def full_hierarchy__fetch_related(self, level=0) -> str: """ Demonstrates ``await .fetch_related`` to fetch relations @@ -883,16 +889,18 @@ class NumberSourceField(Model): class StatusQuerySet(QuerySet): - def active(self): + def active(self) -> QuerySet: return self.filter(status=1) class StatusManager(Manager): - def __init__(self, model=None, queryset_cls=None) -> None: + def __init__( + self, model: type[Model] | None = None, queryset_cls: type[QuerySet] | None = None + ) -> None: super().__init__(model=model) self.queryset_cls = queryset_cls or QuerySet - def get_queryset(self): + def get_queryset(self) -> QuerySet: return self.queryset_cls(self._model) @@ -961,7 +969,7 @@ class OldStyleModel(Model): external_id = fields.IntField(index=True) -def camelize_var(var_name: str): +def camelize_var(var_name: str) -> str: var_parts: list[str] = var_name.split("_") return var_parts[0] + "".join([part.title() for part in var_parts[1:]]) diff --git a/tortoise/contrib/pydantic/creator.py b/tortoise/contrib/pydantic/creator.py index d67f49d38..9d6859e07 100644 --- a/tortoise/contrib/pydantic/creator.py +++ b/tortoise/contrib/pydantic/creator.py @@ -1,13 +1,15 @@ from __future__ import annotations +import functools import inspect from base64 import b32encode -from collections.abc import MutableMapping +from collections.abc import Awaitable, MutableMapping from copy import copy from enum import Enum, IntEnum from hashlib import sha3_224 -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from anyio import from_thread from pydantic import ConfigDict, computed_field, create_model from pydantic import Field as PydanticField from pydantic.fields import ComputedFieldInfo @@ -100,6 +102,27 @@ def _pydantic_recursion_protector( return pmc.create_pydantic_model() +T_Retval = TypeVar("T_Retval") + + +async def async_to_sync(func: Callable[[], Awaitable[T_Retval]]) -> Callable[[], T_Retval]: + """Wrap the async function to be sync(will be run in worker thread)""" + + @functools.wraps(func) + def wrapped() -> T_Retval: + result: list[T_Retval] = [] + + async def runner() -> None: + res = await func() + result.append(res) + + with from_thread.start_blocking_portal() as portal: + portal.call(runner) + return result[0] + + return wrapped + + class FieldMap(MutableMapping[str, Union[Field, ComputedFieldDescription]]): def __init__(self, meta: PydanticMetaData, pk_field: Field | None = None): self._field_map: dict[str, Field | ComputedFieldDescription] = {} @@ -432,20 +455,24 @@ def _process_field( description = _br_it(field.docstring or field.description or "") if description: fconfig["description"] = description - if field_name in self._optional or ( - field.default is not None and not callable(field.default) - ): - self._properties[field_name] = ( - field_property, - PydanticField(default=field.default, **fconfig), - ) + field_default = field.default + if field_name in self._optional: + fconfig["default"] = field_default + elif field_default is not None: + if callable(field_default): + if inspect.iscoroutinefunction(field_default): + fconfig["default_factory"] = async_to_sync(field_default) + else: + fconfig["default_factory"] = field_default + else: + fconfig["default"] = field_default else: if (json_schema_extra.get("nullable") and not is_to_one_relation) or ( self._exclude_read_only and json_schema_extra.get("readOnly") ): # see: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields fconfig["default"] = None - self._properties[field_name] = (field_property, PydanticField(**fconfig)) + self._properties[field_name] = (field_property, PydanticField(**fconfig)) elif isinstance(field, ComputedFieldDescription): field_property, is_to_one_relation = self._process_computed_field(field), False if field_property: @@ -525,7 +552,8 @@ def _process_data_field( if field.null: json_schema_extra["nullable"] = True if not field.pk and ( - field_name in self._optional or field.default is not None or field.null + # field_name in self._optional or field.default is not None or field.null + field_name in self._optional or field.null ): ptype = Optional[ptype] if not (self._exclude_read_only and json_schema_extra.get("readOnly") is True):