diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 78e832ee1..1a4a22ab2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Added - ``QuerySet.contains()`` method to check if an object exists in a queryset. - Added comprehensive EXPLAIN support for MySQL and PostgreSQL. - Built-in ``DomainNameValidator``, ``URLValidator``, and ``EmailValidator`` classes for common validation patterns. (#2162) +- Typed ``**kwargs`` on field constructors via PEP 692 (``Unpack[TypedDict]``), so IDEs and type checkers can autocomplete and validate common field arguments (``default``, ``null``, ``unique``, ``db_index``, ``description``, etc.). (#2168) Fixed ^^^^^ diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 81db786c0..bd6562c0f 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -7,7 +7,7 @@ from collections.abc import Callable from enum import Enum from functools import reduce -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, overload from pypika_tortoise.terms import Term @@ -90,6 +90,76 @@ class OnDelete(StrEnum): NO_ACTION = OnDelete.NO_ACTION +class _FieldKwargsCommon(TypedDict, total=False): + """:class:`Field` constructor arguments that are never declared as explicit parameters. + + Used with :data:`typing.Unpack` to give ``**kwargs`` explicit type hints. This is the + smallest set; fields that declare ``unique``/``db_index``/``primary_key`` explicitly + (e.g. ``TextField``) unpack this directly to avoid PEP 692 parameter-name collisions. + """ + + source_field: str | None + generated: bool + default: Any + db_default: Any + description: str | None + model: Model | None + validators: list[Validator | Callable] + pk: bool # deprecated alias for primary_key + index: bool # deprecated alias for db_index + + +class _FieldKwargsNoPk(_FieldKwargsCommon, total=False): + """Common arguments excluding ``primary_key`` and ``null``. + + For constructors that declare ``primary_key`` and ``null`` as explicit parameters + (e.g. ``IntField``). + """ + + unique: bool + db_index: bool | None + + +class FieldKwargs(_FieldKwargsNoPk, total=False): + """Common arguments excluding ``null``. + + For constructors that declare only ``null`` as an explicit parameter (the majority). + """ + + primary_key: bool | None + + +class JSONFieldKwargs(FieldKwargs, total=False): + """Constructor arguments for :class:`JSONField`. + + ``JSONField`` declares neither ``null`` nor ``primary_key`` explicitly, and also accepts + a custom ``field_type`` (e.g. a Pydantic model class). + """ + + null: bool + field_type: Any + + +class RelationalFieldKwargs(FieldKwargs, total=False): + """Constructor arguments for :func:`ForeignKeyField` and :func:`OneToOneField`. + + Extends the common :class:`~tortoise.fields.base.FieldKwargs` with ``to_field``. + ``null`` is declared as an explicit parameter on those constructors, so it is omitted. + """ + + to_field: str | None + + +class ManyToManyFieldKwargs(_FieldKwargsCommon, total=False): + """Constructor arguments for :func:`ManyToManyField`. + + ``unique`` is declared as an explicit parameter, so it is omitted here; the deprecated + ``create_unique_index`` alias is still accepted. + """ + + create_unique_index: bool # deprecated alias for unique + + class _FieldMeta(type): # TODO: Require functions to return field instances instead of this hack def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict) -> type: diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 6109d0022..f36332432 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -4,6 +4,7 @@ import datetime import functools import json +import sys import warnings from collections.abc import Callable from decimal import Decimal @@ -17,7 +18,13 @@ from tortoise import timezone from tortoise.exceptions import ConfigurationError, FieldError -from tortoise.fields.base import Field +from tortoise.fields.base import ( + Field, + FieldKwargs, + JSONFieldKwargs, + _FieldKwargsCommon, + _FieldKwargsNoPk, +) from tortoise.timezone import get_default_timezone, get_timezone, get_use_tz, localtime from tortoise.validators import MaxLengthValidator @@ -30,7 +37,9 @@ try: from pydantic import BaseModel as _PydanticBaseModel - from pydantic._internal._model_construction import ModelMetaclass as _PydanticModelMetaclass + from pydantic._internal._model_construction import ( + ModelMetaclass as _PydanticModelMetaclass, + ) except ImportError: _PydanticBaseModel = None # type: ignore[assignment,misc] _PydanticModelMetaclass = None # type: ignore[assignment,misc] @@ -38,6 +47,12 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model + +if sys.version_info >= (3, 11): + from typing import Unpack +else: # pragma: no cover + from typing_extensions import Unpack + __all__ = ( "BigIntField", "BinaryField", @@ -107,7 +122,7 @@ def __init__( primary_key: bool | None = None, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[_FieldKwargsNoPk], ) -> None: ... @overload @@ -116,7 +131,7 @@ def __init__( primary_key: bool | None = None, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[_FieldKwargsNoPk], ) -> None: ... def __init__(self, primary_key: bool | None = None, **kwargs: Any) -> None: @@ -222,12 +237,20 @@ class CharField(Field[T_STR]): @overload def __init__( - self: CharField[str], max_length: int, *, null: Literal[False] = False, **kwargs: Any + self: CharField[str], + max_length: int, + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: CharField[str | None], max_length: int, *, null: Literal[True], **kwargs: Any + self: CharField[str | None], + max_length: int, + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, max_length: int, **kwargs: Any) -> None: @@ -256,7 +279,7 @@ def SQL_TYPE(self) -> str: return f"NVARCHAR2({self.field.max_length})" -class TextField(Field[str], str): # type: ignore +class TextField(Field[T_STR], str): # type: ignore """ Large Text field. """ @@ -264,6 +287,28 @@ class TextField(Field[str], str): # type: ignore indexable = False SQL_TYPE = "TEXT" + @overload + def __init__( + self: TextField[str], + *, + primary_key: bool | None = None, + unique: bool = False, + db_index: bool = False, + null: Literal[False] = False, + **kwargs: Unpack[_FieldKwargsCommon], + ) -> None: ... + + @overload + def __init__( + self: TextField[str | None], + *, + primary_key: bool | None = None, + unique: bool = False, + db_index: bool = False, + null: Literal[True], + **kwargs: Unpack[_FieldKwargsCommon], + ) -> None: ... + def __init__( self, primary_key: bool | None = None, @@ -315,12 +360,18 @@ class BooleanField(Field[T_BOOL]): @overload def __init__( - self: BooleanField[bool], *, null: Literal[False] = False, **kwargs: Any + self: BooleanField[bool], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: BooleanField[bool | None], *, null: Literal[True], **kwargs: Any + self: BooleanField[bool | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -357,7 +408,7 @@ def __init__( decimal_places: int, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload @@ -367,7 +418,7 @@ def __init__( decimal_places: int, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, max_digits: int, decimal_places: int, **kwargs: Any) -> None: @@ -440,7 +491,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload @@ -450,7 +501,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: @@ -530,12 +581,18 @@ class DateField(Field[T_DATE], datetime.date): @overload def __init__( - self: DateField[datetime.date], *, null: Literal[False] = False, **kwargs: Any + self: DateField[datetime.date], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: DateField[datetime.date | None], *, null: Literal[True], **kwargs: Any + self: DateField[datetime.date | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -574,7 +631,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload @@ -584,7 +641,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: @@ -654,12 +711,18 @@ class TimeDeltaField(Field[T_TIMEDELTA]): @overload def __init__( - self: TimeDeltaField[datetime.timedelta], *, null: Literal[False] = False, **kwargs: Any + self: TimeDeltaField[datetime.timedelta], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: TimeDeltaField[datetime.timedelta | None], *, null: Literal[True], **kwargs: Any + self: TimeDeltaField[datetime.timedelta | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -692,11 +755,19 @@ class FloatField(Field[T_FLOAT], float): @overload def __init__( - self: FloatField[float], *, null: Literal[False] = False, **kwargs: Any + self: FloatField[float], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload - def __init__(self: FloatField[float | None], *, null: Literal[True], **kwargs: Any) -> None: ... + def __init__( + self: FloatField[float | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], + ) -> None: ... def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -748,7 +819,7 @@ def __init__( self, encoder: JsonDumpsFunc = JSON_DUMPS, decoder: JsonLoadsFunc = JSON_LOADS, - **kwargs: Any, + **kwargs: Unpack[JSONFieldKwargs], ) -> None: super().__init__(**kwargs) self.encoder = encoder @@ -820,10 +891,20 @@ class _db_postgres: SQL_TYPE = "UUID" @overload - def __init__(self: UUIDField[UUID], *, null: Literal[False] = False, **kwargs: Any) -> None: ... + def __init__( + self: UUIDField[UUID], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], + ) -> None: ... @overload - def __init__(self: UUIDField[UUID | None], *, null: Literal[True], **kwargs: Any) -> None: ... + def __init__( + self: UUIDField[UUID | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], + ) -> None: ... def __init__(self, **kwargs: Any) -> None: if (kwargs.get("primary_key") or kwargs.get("pk", False)) and "default" not in kwargs: @@ -852,12 +933,18 @@ class BinaryField(Field[T_BINARY], bytes): # type: ignore @overload def __init__( - self: BinaryField[bytes], *, null: Literal[False] = False, **kwargs: Any + self: BinaryField[bytes], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: BinaryField[bytes | None], *, null: Literal[True], **kwargs: Any + self: BinaryField[bytes | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index 442707848..b5e7ab568 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import warnings from collections.abc import AsyncGenerator, Generator, Iterator from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload @@ -7,7 +8,19 @@ from pypika_tortoise.queries import Table from tortoise.exceptions import ConfigurationError, NoValuesFetched, OperationalError -from tortoise.fields.base import CASCADE, SET_NULL, Field, OnDelete +from tortoise.fields.base import ( + CASCADE, + SET_NULL, + Field, + ManyToManyFieldKwargs, + OnDelete, + RelationalFieldKwargs, +) + +if sys.version_info >= (3, 11): + from typing import Unpack +else: # pragma: no cover + from typing_extensions import Unpack if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient @@ -441,7 +454,7 @@ def OneToOneField( db_constraint: bool = True, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> OneToOneNullableRelation[MODEL]: ... @@ -452,7 +465,7 @@ def OneToOneField( on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> OneToOneRelation[MODEL]: ... @@ -516,7 +529,7 @@ def ForeignKeyField( db_constraint: bool = True, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> ForeignKeyNullableRelation[MODEL]: ... @@ -527,7 +540,7 @@ def ForeignKeyField( on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> ForeignKeyRelation[MODEL]: ... @@ -592,7 +605,7 @@ def ManyToManyField( on_delete: OnDelete = CASCADE, db_constraint: bool = True, unique: bool = True, - **kwargs: Any, + **kwargs: Unpack[ManyToManyFieldKwargs], ) -> ManyToManyRelation[MODEL]: """ ManyToMany relation field.