diff --git a/pydantic_duality/__init__.py b/pydantic_duality/__init__.py index 2d808c0..32fc35f 100644 --- a/pydantic_duality/__init__.py +++ b/pydantic_duality/__init__.py @@ -36,7 +36,9 @@ def _resolve_annotation(annotation, attr: str) -> Any: tuple(_resolve_annotation(a, attr) for a in get_args(annotation)), ) elif isinstance(annotation, UnionType): - return Union.__getitem__(tuple(_resolve_annotation(a, attr) for a in get_args(annotation))) + return Union.__getitem__( + tuple(_resolve_annotation(a, attr) for a in get_args(annotation)) + ) elif get_origin(annotation) is Annotated: return Annotated.__class_getitem__( tuple(_resolve_annotation(a, attr) for a in get_args(annotation)), @@ -58,7 +60,9 @@ def _alter_attrs(attrs: dict[str, object], name: str, attr: str): if attr == PATCH_REQUEST_ATTR: if get_origin(annotations[key]) is Annotated: args = get_args(annotations[key]) - annotations[key] = Annotated.__class_getitem__(tuple([args[0] | None, *args[1:]])) + annotations[key] = Annotated.__class_getitem__( + tuple([args[0] | None, *args[1:]]) + ) elif isinstance(annotations[key], str): annotations[key] += " | None" else: @@ -67,12 +71,18 @@ def _alter_attrs(attrs: dict[str, object], name: str, attr: str): return attrs -def _lazily_initalize_models(request_cls: type, own_attr_name: str, constructor: Callable[[], Any]): +def _lazily_initalize_models( + request_cls: type, own_attr_name: str, constructor: Callable[[], Any] +): def constructor_wrapper(*a, **kw) -> object: obj = constructor() obj.__request__ = request_cls - obj.__response__ = cached_classproperty(lambda cls: request_cls.__response__, RESPONSE_ATTR) - obj.__patch_request__ = cached_classproperty(lambda cls: request_cls.__patch_request__, PATCH_REQUEST_ATTR) + obj.__response__ = cached_classproperty( + lambda cls: request_cls.__response__, RESPONSE_ATTR + ) + obj.__patch_request__ = cached_classproperty( + lambda cls: request_cls.__patch_request__, PATCH_REQUEST_ATTR + ) return obj return cached_classproperty(constructor_wrapper, own_attr_name) @@ -96,7 +106,9 @@ def __new__( **kwargs, ) -> Self: new_class = type.__new__(cls, name, bases, attrs) - if not bases or not any(isinstance(b, (ModelMetaclass, DualBaseModelMeta)) for b in bases): + if not bases or not any( + isinstance(b, (ModelMetaclass, DualBaseModelMeta)) for b in bases + ): raise TypeError( f"ModelDuplicatorMeta's instances must be created with a DualBaseModel base class or a BaseModel base class." ) @@ -108,11 +120,17 @@ def __new__( ) elif not inspect.isclass(kwargs["__config__"]): raise TypeError("The __config__ argument must be a class.") - elif request_suffix is None or response_suffix is None or patch_request_suffix is None: + elif ( + request_suffix is None + or response_suffix is None + or patch_request_suffix is None + ): raise TypeError( "The first instance of DualBaseModel must pass suffixes for the request, response, and patch request models." ) - new_class._generate_base_alternative_classes(request_suffix, response_suffix, kwargs) + new_class._generate_base_alternative_classes( + request_suffix, response_suffix, kwargs + ) else: request_suffix, response_suffix, patch_request_suffix = ( request_suffix or new_class.request_suffix, @@ -120,7 +138,13 @@ def __new__( patch_request_suffix or new_class.patch_request_suffix, ) new_class._generate_alternative_classes( - name, bases, attrs, request_suffix, response_suffix, patch_request_suffix, kwargs + name, + bases, + attrs, + request_suffix, + response_suffix, + patch_request_suffix, + kwargs, ) new_class.__request__.request_suffix = request_suffix # type: ignore @@ -129,16 +153,22 @@ def __new__( return new_class - def _generate_base_alternative_classes(self, request_suffix, response_suffix, kwargs): + def _generate_base_alternative_classes( + self, request_suffix, response_suffix, kwargs + ): class Config(kwargs["__config__"]): # type: ignore extra = Extra.forbid - BaseRequest = ModelMetaclass(f"Base{request_suffix}", (BaseModel,), {"Config": Config}) + BaseRequest = ModelMetaclass( + f"Base{request_suffix}", (BaseModel,), {"Config": Config} + ) class Config(kwargs["__config__"]): extra = Extra.ignore - BaseResponse = ModelMetaclass(f"Base{response_suffix}", (BaseModel,), {"Config": Config}) + BaseResponse = ModelMetaclass( + f"Base{response_suffix}", (BaseModel,), {"Config": Config} + ) type.__setattr__(self, "__request__", BaseRequest) BaseRequest.__request__ = BaseRequest # type: ignore @@ -146,13 +176,22 @@ class Config(kwargs["__config__"]): BaseRequest.__patch_request__ = BaseRequest # type: ignore def _generate_alternative_classes( - self, name, bases, attrs, request_suffix, response_suffix, patch_request_suffix, kwargs + self, + name, + bases, + attrs, + request_suffix, + response_suffix, + patch_request_suffix, + kwargs, ): + anonymized_attrs = attrs.copy() + anonymized_attrs.pop("__classcell__", None) request_bases = tuple(_resolve_annotation(b, REQUEST_ATTR) for b in bases) request_class = ModelMetaclass( name + request_suffix, request_bases, - _alter_attrs(attrs, name + request_suffix, REQUEST_ATTR), + _alter_attrs(anonymized_attrs, name + request_suffix, REQUEST_ATTR), **kwargs, ) request_class.__response__ = _lazily_initalize_models( @@ -161,7 +200,7 @@ def _generate_alternative_classes( lambda: ModelMetaclass( name + response_suffix, tuple(_resolve_annotation(b, RESPONSE_ATTR) for b in bases), - _alter_attrs(attrs, name + response_suffix, RESPONSE_ATTR), + _alter_attrs(anonymized_attrs, name + response_suffix, RESPONSE_ATTR), **kwargs, ), ) @@ -171,7 +210,9 @@ def _generate_alternative_classes( lambda: ModelMetaclass( name + patch_request_suffix, tuple(_resolve_annotation(b, PATCH_REQUEST_ATTR) for b in bases), - _alter_attrs(attrs, name + patch_request_suffix, PATCH_REQUEST_ATTR), + _alter_attrs( + anonymized_attrs, name + patch_request_suffix, PATCH_REQUEST_ATTR + ), **kwargs, ), ) @@ -182,7 +223,12 @@ def _generate_alternative_classes( def __getattribute__(self, attr: str): # Note here that RESPONSE_ATTR and PATCH_REQUEST_ATTR goes into REQUEST_ATTR's __getattribute__ method - if attr in {REQUEST_ATTR, "__new__", "_generate_base_alternative_classes", "_generate_alternative_classes"}: + if attr in { + REQUEST_ATTR, + "__new__", + "_generate_base_alternative_classes", + "_generate_alternative_classes", + }: return type.__getattribute__(self, attr) return getattr(type.__getattribute__(self, REQUEST_ATTR), attr) @@ -202,10 +248,14 @@ def __hash__(self) -> int: return hash(self.__request__) def __instancecheck__(cls, instance) -> bool: - return type.__instancecheck__(cls, instance) or isinstance(instance, cls.__request__) + return type.__instancecheck__(cls, instance) or isinstance( + instance, cls.__request__ + ) def __subclasscheck__(cls, subclass: type): - return type.__subclasscheck__(cls, subclass) or issubclass(subclass, cls.__request__) + return type.__subclasscheck__(cls, subclass) or issubclass( + subclass, cls.__request__ + ) def generate_dual_base_model( diff --git a/tests/test_duality.py b/tests/test_duality.py index b07f43f..6eb356c 100644 --- a/tests/test_duality.py +++ b/tests/test_duality.py @@ -151,15 +151,21 @@ class SubSchema(DualBaseModel): def test_ignore_forbid_attrs(schemas): assert ( - schemas["A"].__request__.__response__.__response__.__request__.__response__.__request__.Config.extra + schemas[ + "A" + ].__request__.__response__.__response__.__request__.__response__.__request__.Config.extra == Extra.forbid ) assert ( - schemas["A"].__request__.__response__.__response__.__request__.__response__.__patch_request__.Config.extra + schemas[ + "A" + ].__request__.__response__.__response__.__request__.__response__.__patch_request__.Config.extra == Extra.forbid ) assert ( - schemas["A"].__request__.__response__.__response__.__request__.__response__.__response__.Config.extra + schemas[ + "A" + ].__request__.__response__.__response__.__request__.__response__.__response__.Config.extra == Extra.ignore ) @@ -218,16 +224,33 @@ class ChildSchema2(DualBaseModel): obj: str class Schema(DualBaseModel): - child: Annotated[ChildSchema1 | ChildSchema2, Field(discriminator="object_type")] + child: Annotated[ + ChildSchema1 | ChildSchema2, Field(discriminator="object_type") + ] for object_type in (1, 2): - child_schema = Schema.parse_obj({"child": {"object_type": object_type, "obj": object_type}}) - child_req_schema = Schema.__request__.parse_obj({"child": {"object_type": object_type, "obj": object_type}}) - child_resp_schema = Schema.__response__.parse_obj({"child": {"object_type": object_type, "obj": object_type}}) + child_schema = Schema.parse_obj( + {"child": {"object_type": object_type, "obj": object_type}} + ) + child_req_schema = Schema.__request__.parse_obj( + {"child": {"object_type": object_type, "obj": object_type}} + ) + child_resp_schema = Schema.__response__.parse_obj( + {"child": {"object_type": object_type, "obj": object_type}} + ) - assert type(child_schema.child) is locals()[f"ChildSchema{object_type}"].__request__ - assert type(child_req_schema.child) is locals()[f"ChildSchema{object_type}"].__request__ - assert type(child_resp_schema.child) is locals()[f"ChildSchema{object_type}"].__response__ + assert ( + type(child_schema.child) + is locals()[f"ChildSchema{object_type}"].__request__ + ) + assert ( + type(child_req_schema.child) + is locals()[f"ChildSchema{object_type}"].__request__ + ) + assert ( + type(child_resp_schema.child) + is locals()[f"ChildSchema{object_type}"].__response__ + ) with pytest.raises(ValidationError): Schema.parse_obj( { @@ -269,7 +292,9 @@ class Schema(DualBaseModel): ) -@pytest.mark.parametrize("field_type", [Annotated[int, "Hello"], Annotated[int, "Hello", "Darkness"]]) +@pytest.mark.parametrize( + "field_type", [Annotated[int, "Hello"], Annotated[int, "Hello", "Darkness"]] +) def test_annotated_model_creation_with_regular_metadata(field_type): class Schema(DualBaseModel): field: field_type @@ -344,3 +369,33 @@ class Schema(DualBaseModel, extra=Extra.ignore): assert Schema.__patch_request__.Config.extra == Extra.forbid Schema(field=1, extra=2) + + +def test_model_can_be_created_with_super_init_in_init(): + class MyModel(DualBaseModel): + one: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +def test_model_can_be_created_with_init_subclass(): + class MyModel(DualBaseModel): + one: str + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + class MyModelChild(MyModel): + pass + + +def test_model_can_be_created_with_classmethod(): + class MyModel(DualBaseModel): + one: str + + @classmethod + def get_stuff(cls): + return super().parse_obj + + MyModel.get_stuff()