diff --git a/Makefile b/Makefile index 9be4a5610..01c60a225 100644 --- a/Makefile +++ b/Makefile @@ -41,9 +41,9 @@ endif lint: deps build ifneq ($(shell which black),) - black --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) + black $(checkfiles) endif - ruff check $(checkfiles) + ruff check --fix $(checkfiles) mypy $(checkfiles) #pylint $(checkfiles) bandit -c pyproject.toml -r $(checkfiles) diff --git a/examples/blacksheep/server.py b/examples/blacksheep/server.py index e3cba59fc..3e0c84a95 100644 --- a/examples/blacksheep/server.py +++ b/examples/blacksheep/server.py @@ -1,5 +1,6 @@ # pylint: disable=E0401,E0611 -from typing import Union +from __future__ import annotations + from uuid import UUID from blacksheep import Response @@ -25,7 +26,7 @@ @app.router.get("/") -async def users_list() -> Union[UserPydanticOut]: +async def users_list() -> UserPydanticOut: return ok(await UserPydanticOut.from_queryset(Users.all())) diff --git a/examples/fastapi/schemas.py b/examples/fastapi/schemas.py index 7afab345b..a9944d44d 100644 --- a/examples/fastapi/schemas.py +++ b/examples/fastapi/schemas.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING from models import Users diff --git a/examples/signals.py b/examples/signals.py index ea3fc692c..60f7538c5 100644 --- a/examples/signals.py +++ b/examples/signals.py @@ -2,7 +2,7 @@ This example demonstrates model signals usage """ -from typing import Optional +from __future__ import annotations from tortoise import BaseDBAsyncClient, Tortoise, fields, run_async from tortoise.models import Model @@ -21,18 +21,16 @@ def __str__(self): @pre_save(Signal) -async def signal_pre_save( - sender: "type[Signal]", instance: Signal, using_db, update_fields -) -> None: +async def signal_pre_save(sender: type[Signal], instance: Signal, using_db, update_fields) -> None: print(sender, instance, using_db, update_fields) @post_save(Signal) async def signal_post_save( - sender: "type[Signal]", + sender: type[Signal], instance: Signal, created: bool, - using_db: "Optional[BaseDBAsyncClient]", + using_db: BaseDBAsyncClient | None, update_fields: list[str], ) -> None: print(sender, instance, using_db, created, update_fields) @@ -40,14 +38,14 @@ async def signal_post_save( @pre_delete(Signal) async def signal_pre_delete( - sender: "type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]" + sender: type[Signal], instance: Signal, using_db: BaseDBAsyncClient | None ) -> None: print(sender, instance, using_db) @post_delete(Signal) async def signal_post_delete( - sender: "type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]" + sender: type[Signal], instance: Signal, using_db: BaseDBAsyncClient | None ) -> None: print(sender, instance, using_db) diff --git a/pyproject.toml b/pyproject.toml index 03295f399..454e021f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,6 +198,11 @@ show_missing = true line-length = 100 [tool.ruff.lint] ignore = ["E501"] +extend-select = [ + "FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa + "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up + "RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf +] [tool.bandit] exclude_dirs = ["tests", 'examples/*/_tests.py', "conftest.py"] diff --git a/tests/backends/test_capabilities.py b/tests/backends/test_capabilities.py index 1a9c1be21..362213bd3 100644 --- a/tests/backends/test_capabilities.py +++ b/tests/backends/test_capabilities.py @@ -6,7 +6,7 @@ class TestCapabilities(test.TestCase): # pylint: disable=E1101 async def asyncSetUp(self) -> None: - await super(TestCapabilities, self).asyncSetUp() + await super().asyncSetUp() self.db = connections.get("models") self.caps = self.db.capabilities diff --git a/tests/benchmarks/test_bulk_create.py b/tests/benchmarks/test_bulk_create.py index bca894009..f554c8887 100644 --- a/tests/benchmarks/test_bulk_create.py +++ b/tests/benchmarks/test_bulk_create.py @@ -9,7 +9,8 @@ def test_bulk_create_few_fields(benchmark): data = [ BenchmarkFewFields( - level=random.choice([10, 20, 30, 40, 50]), text=f"Insert from C, item {i}" # nosec + level=random.choice([10, 20, 30, 40, 50]), # nosec + text=f"Insert from C, item {i}", ) for i in range(100) ] diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index a4fcd4466..56ebeaf37 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -29,7 +29,7 @@ class TestPydantic(test.TestCase): async def asyncSetUp(self) -> None: - await super(TestPydantic, self).asyncSetUp() + await super().asyncSetUp() self.Event_Pydantic = pydantic_model_creator(Event) self.Event_Pydantic_List = pydantic_queryset_creator(Event) self.Tournament_Pydantic = pydantic_model_creator(Tournament) @@ -1392,7 +1392,7 @@ def test_exclude_readonly(self): class TestPydanticCycle(test.TestCase): async def asyncSetUp(self) -> None: - await super(TestPydanticCycle, self).asyncSetUp() + await super().asyncSetUp() self.Employee_Pydantic = pydantic_model_creator(Employee) self.root = await Employee.create(name="Root") @@ -1599,7 +1599,7 @@ async def test_serialisation(self): class TestPydanticComputed(test.TestCase): async def asyncSetUp(self) -> None: - await super(TestPydanticComputed, self).asyncSetUp() + await super().asyncSetUp() self.Employee_Pydantic = pydantic_model_creator(Employee) self.employee = await Employee.create(name="Some Employee") self.maxDiff = None diff --git a/tests/fields/subclass_fields.py b/tests/fields/subclass_fields.py index e82a9609d..ce7d00e51 100644 --- a/tests/fields/subclass_fields.py +++ b/tests/fields/subclass_fields.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum, IntEnum from typing import Any diff --git a/tests/fields/test_db_index.py b/tests/fields/test_db_index.py index 6b3a15039..82d69510e 100644 --- a/tests/fields/test_db_index.py +++ b/tests/fields/test_db_index.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from pypika_tortoise.terms import Field @@ -45,7 +47,7 @@ def test_index_repr(self): assert repr(Index(fields=("id",))) == "Index(fields=['id'])" assert repr(Index(fields=("id", "name"))) == "Index(fields=['id', 'name'])" assert repr(Index(fields=("id",), name="MyIndex")) == "Index(fields=['id'], name='MyIndex')" - assert repr(Index(Field("id"))) == f'Index({str(Field("id"))})' + assert repr(Index(Field("id"))) == f"Index({str(Field('id'))})" assert repr(Index(Field("a"), name="Id")) == f"Index({str(Field('a'))}, name='Id')" with self.assertRaises(ConfigurationError): Index(Field("id"), fields=("name",)) diff --git a/tests/fields/test_decimal.py b/tests/fields/test_decimal.py index 9144b7834..c0b16b1f2 100644 --- a/tests/fields/test_decimal.py +++ b/tests/fields/test_decimal.py @@ -12,7 +12,7 @@ class TestDecimalFields(test.TestCase): def test_max_digits_empty(self): with self.assertRaisesRegex( TypeError, - "missing 2 required positional arguments: 'max_digits' and" " 'decimal_places'", + "missing 2 required positional arguments: 'max_digits' and 'decimal_places'", ): fields.DecimalField() # pylint: disable=E1120 @@ -169,25 +169,31 @@ async def test_aggregate_sum_no_exist_field_with_f_expression(self): FieldError, "There is no non-virtual field not_exist on Model DecimalFields", ): - await testmodels.DecimalFields.all().annotate(sum_decimal=Sum(F("not_exist"))).values( - "sum_decimal" + await ( + testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("not_exist"))) + .values("sum_decimal") ) async def test_aggregate_sum_different_field_type_at_right_with_f_expression(self): with self.assertRaisesRegex( FieldError, "Cannot use arithmetic expression between different field type" ): - await testmodels.DecimalFields.all().annotate( - sum_decimal=Sum(F("decimal") + F("id")) - ).values("sum_decimal") + await ( + testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("decimal") + F("id"))) + .values("sum_decimal") + ) async def test_aggregate_sum_different_field_type_at_left_with_f_expression(self): with self.assertRaisesRegex( FieldError, "Cannot use arithmetic expression between different field type" ): - await testmodels.DecimalFields.all().annotate( - sum_decimal=Sum(F("id") + F("decimal")) - ).values("sum_decimal") + await ( + testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("id") + F("decimal"))) + .values("sum_decimal") + ) async def test_aggregate_avg(self): await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index ff69cbbd2..c50657229 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -120,7 +120,7 @@ async def test_minimal__unfetched_contains(self): tour = await testmodels.Tournament.create(name="Team1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): "a" in tour.minrelations # pylint: disable=W0104 @@ -128,7 +128,7 @@ async def test_minimal__unfetched_iter(self): tour = await testmodels.Tournament.create(name="Team1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): for _ in tour.minrelations: pass @@ -137,7 +137,7 @@ async def test_minimal__unfetched_len(self): tour = await testmodels.Tournament.create(name="Team1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): len(tour.minrelations) @@ -145,7 +145,7 @@ async def test_minimal__unfetched_bool(self): tour = await testmodels.Tournament.create(name="Team1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): bool(tour.minrelations) @@ -153,7 +153,7 @@ async def test_minimal__unfetched_getitem(self): tour = await testmodels.Tournament.create(name="Team1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): tour.minrelations[0] # pylint: disable=W0104 diff --git a/tests/fields/test_fk_uuid.py b/tests/fields/test_fk_uuid.py index 1f6d8358f..31c88ad87 100644 --- a/tests/fields/test_fk_uuid.py +++ b/tests/fields/test_fk_uuid.py @@ -112,7 +112,7 @@ async def test_unfetched_contains(self): tour = await self.UUIDPkModel.create() with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): "a" in tour.children # pylint: disable=W0104 @@ -120,7 +120,7 @@ async def test_unfetched_iter(self): tour = await self.UUIDPkModel.create() with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): for _ in tour.children: pass @@ -129,7 +129,7 @@ async def test_unfetched_len(self): tour = await self.UUIDPkModel.create() with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): len(tour.children) @@ -137,7 +137,7 @@ async def test_unfetched_bool(self): tour = await self.UUIDPkModel.create() with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): bool(tour.children) @@ -145,7 +145,7 @@ async def test_unfetched_getitem(self): tour = await self.UUIDPkModel.create() with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): tour.children[0] # pylint: disable=W0104 diff --git a/tests/fields/test_fk_with_unique.py b/tests/fields/test_fk_with_unique.py index 6e5ac622c..6dc71c852 100644 --- a/tests/fields/test_fk_with_unique.py +++ b/tests/fields/test_fk_with_unique.py @@ -104,7 +104,7 @@ async def test_student__unfetched_contains(self): school = await testmodels.School.create(id=1024, name="School1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): "a" in school.students # pylint: disable=W0104 @@ -112,7 +112,7 @@ async def test_stduent__unfetched_iter(self): school = await testmodels.School.create(id=1024, name="School1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): for _ in school.students: pass @@ -121,7 +121,7 @@ async def test_student__unfetched_len(self): school = await testmodels.School.create(id=1024, name="School1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): len(school.students) @@ -129,7 +129,7 @@ async def test_student__unfetched_bool(self): school = await testmodels.School.create(id=1024, name="School1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): bool(school.students) @@ -137,7 +137,7 @@ async def test_student__unfetched_getitem(self): school = await testmodels.School.create(id=1024, name="School1") with self.assertRaisesRegex( NoValuesFetched, - "No values were fetched for this relation," " first use .fetch_related()", + "No values were fetched for this relation, first use .fetch_related()", ): school.students[0] # pylint: disable=W0104 diff --git a/tests/model_setup/model_bad_rel2.py b/tests/model_setup/model_bad_rel2.py index 6a6e286a8..9a324f0c7 100644 --- a/tests/model_setup/model_bad_rel2.py +++ b/tests/model_setup/model_bad_rel2.py @@ -3,6 +3,8 @@ The model 'Tour' does not exist """ +from __future__ import annotations + from typing import Any from tortoise import fields @@ -14,6 +16,6 @@ class Tournament(Model): class Event(Model): - tournament: fields.ForeignKeyRelation["Any"] = fields.ForeignKeyField( + tournament: fields.ForeignKeyRelation[Any] = fields.ForeignKeyField( "models.Tour", related_name="events" ) diff --git a/tests/model_setup/test_bad_relation_reference.py b/tests/model_setup/test_bad_relation_reference.py index c2b3d3f20..5c28e406b 100644 --- a/tests/model_setup/test_bad_relation_reference.py +++ b/tests/model_setup/test_bad_relation_reference.py @@ -15,7 +15,7 @@ async def asyncSetUp(self): async def asyncTearDown(self) -> None: await Tortoise._reset_apps() - await super(TestBadRelationReferenceErrors, self).asyncTearDown() + await super().asyncTearDown() async def test_wrong_app_init(self): with self.assertRaisesRegex(ConfigurationError, "No app with name 'app' registered."): diff --git a/tests/schema/models_fk_1.py b/tests/schema/models_fk_1.py index f82fd3df4..e3d0023ee 100644 --- a/tests/schema/models_fk_1.py +++ b/tests/schema/models_fk_1.py @@ -2,6 +2,8 @@ This is the testing Models — FK bad model name """ +from __future__ import annotations + from typing import Any from tortoise import fields diff --git a/tests/schema/models_fk_2.py b/tests/schema/models_fk_2.py index 82878df0a..76500bae9 100644 --- a/tests/schema/models_fk_2.py +++ b/tests/schema/models_fk_2.py @@ -9,5 +9,6 @@ class One(Model): tournament: fields.ForeignKeyRelation[Two] = fields.ForeignKeyField( - "models.Two", on_delete="WABOOM" # type:ignore + "models.Two", + on_delete="WABOOM", # type:ignore ) diff --git a/tests/schema/models_o2o_2.py b/tests/schema/models_o2o_2.py index 829557bdc..95d5485eb 100644 --- a/tests/schema/models_o2o_2.py +++ b/tests/schema/models_o2o_2.py @@ -9,5 +9,6 @@ class One(Model): tournament: fields.OneToOneRelation[Two] = fields.OneToOneField( - "models.Two", on_delete="WABOOM" # type:ignore + "models.Two", + on_delete="WABOOM", # type:ignore ) diff --git a/tests/test_default.py b/tests/test_default.py index c4b870e49..a361778a8 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -16,7 +16,7 @@ class TestDefault(test.TestCase): async def asyncSetUp(self) -> None: - await super(TestDefault, self).asyncSetUp() + await super().asyncSetUp() db = connections.get("models") if isinstance(db, MySQLClient): await db.execute_query( diff --git a/tests/test_filters.py b/tests/test_filters.py index dd80a84e8..3454c20da 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -343,7 +343,7 @@ async def test_isnull(self): async def test_not_isnull(self): self.assertSetEqual( set(await CharPkModel.filter(children__not_isnull=True).values_list("id", flat=True)), - {"17", "17", "12"}, + {"17", "12"}, ) self.assertSetEqual( set(await CharPkModel.filter(children__not_isnull=False).values_list("id", flat=True)), diff --git a/tests/test_group_by.py b/tests/test_group_by.py index d8a35d66a..aab39a4e3 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -5,7 +5,7 @@ class TestGroupBy(test.TestCase): async def asyncSetUp(self) -> None: - await super(TestGroupBy, self).asyncSetUp() + await super().asyncSetUp() self.a1 = await Author.create(name="author1") self.a2 = await Author.create(name="author2") self.books1 = [ diff --git a/tests/test_only.py b/tests/test_only.py index f3245fc99..0d6a8a6a2 100644 --- a/tests/test_only.py +++ b/tests/test_only.py @@ -5,7 +5,7 @@ class TestOnlyStraight(test.TestCase): async def asyncSetUp(self) -> None: - await super(TestOnlyStraight, self).asyncSetUp() + await super().asyncSetUp() self.model = StraightFields self.instance = await self.model.create(chars="Test") diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 422f100fb..836d1cdfa 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -8,7 +8,6 @@ async def asyncSetUp(self) -> None: class TestPosixRegexFilter(test.TestCase): - @test.requireCapability(support_for_posix_regex_queries=True) async def test_regex_filter(self): author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") diff --git a/tests/test_prefetching.py b/tests/test_prefetching.py index c9246a401..12b366656 100644 --- a/tests/test_prefetching.py +++ b/tests/test_prefetching.py @@ -35,9 +35,11 @@ async def test_prefetch_unknown_field(self): tournament = await Tournament.create(name="tournament") await Event.create(name="First", tournament=tournament) await Event.create(name="Second", tournament=tournament) - await Tournament.all().prefetch_related( - Prefetch("events1", queryset=Event.filter(name="First")) - ).first() + await ( + Tournament.all() + .prefetch_related(Prefetch("events1", queryset=Event.filter(name="First"))) + .first() + ) async def test_prefetch_m2m(self): tournament = await Tournament.create(name="tournament") diff --git a/tests/test_primary_key.py b/tests/test_primary_key.py index 3998e511a..dbe224db4 100644 --- a/tests/test_primary_key.py +++ b/tests/test_primary_key.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid from typing import Any diff --git a/tests/test_relations.py b/tests/test_relations.py index a134df539..9502e4f05 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -326,9 +326,7 @@ async def test_select_related_sets_valid_nulls(self) -> None: root = await DoubleFK.create(name="root", left=left_1st_lvl) retrieved_root = ( - await DoubleFK.all() - .select_related("left__left__left", "right") - .get(id=getattr(root, "id")) + await DoubleFK.all().select_related("left__left__left", "right").get(id=root.pk) ) self.assertIsNone(retrieved_root.right) assert retrieved_root.left is not None diff --git a/tests/test_signals.py b/tests/test_signals.py index c915c3f4c..75f209f11 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from tests.testmodels import Signals from tortoise import BaseDBAsyncClient @@ -8,7 +8,7 @@ @pre_save(Signals) async def signal_pre_save( - sender: "type[Signals]", instance: Signals, using_db, update_fields + sender: type[Signals], instance: Signals, using_db, update_fields ) -> None: await Signals.filter(name="test1").update(name="test_pre-save") await Signals.filter(name="test5").update(name="test_pre-save") @@ -16,10 +16,10 @@ async def signal_pre_save( @post_save(Signals) async def signal_post_save( - sender: "type[Signals]", + sender: type[Signals], instance: Signals, created: bool, - using_db: "Optional[BaseDBAsyncClient]", + using_db: BaseDBAsyncClient | None, update_fields: list, ) -> None: await Signals.filter(name="test2").update(name="test_post-save") @@ -28,14 +28,14 @@ async def signal_post_save( @pre_delete(Signals) async def signal_pre_delete( - sender: "type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]" + sender: type[Signals], instance: Signals, using_db: BaseDBAsyncClient | None ) -> None: await Signals.filter(name="test3").update(name="test_pre-delete") @post_delete(Signals) async def signal_post_delete( - sender: "type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]" + sender: type[Signals], instance: Signals, using_db: BaseDBAsyncClient | None ) -> None: await Signals.filter(name="test4").update(name="test_post-delete") diff --git a/tests/test_update.py b/tests/test_update.py index 6cbabdbb8..66dc592ce 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid from datetime import datetime, timedelta from typing import Any diff --git a/tests/testmodels.py b/tests/testmodels.py index dbb8b2ba1..c0b95f352 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -2,6 +2,8 @@ This is the testing Models """ +from __future__ import annotations + import binascii import datetime import os @@ -71,9 +73,9 @@ class Tournament(Model): desc = fields.TextField(null=True) created = fields.DatetimeField(auto_now_add=True, db_index=True) - events: fields.ReverseRelation["Event"] - minrelations: fields.ReverseRelation["MinRelation"] - uniquetogetherfieldswithfks: fields.ReverseRelation["UniqueTogetherFieldsWithFK"] + events: fields.ReverseRelation[Event] + minrelations: fields.ReverseRelation[MinRelation] + uniquetogetherfieldswithfks: fields.ReverseRelation[UniqueTogetherFieldsWithFK] class PydanticMeta: exclude = ("minrelations", "uniquetogetherfieldswithfks") @@ -88,7 +90,7 @@ class Reporter(Model): id = fields.IntField(primary_key=True) name = fields.TextField() - events: fields.ReverseRelation["Event"] + events: fields.ReverseRelation[Event] class Meta: table = "re_port_er" @@ -104,13 +106,13 @@ class Event(Model): #: The name name = fields.TextField() #: What tournaments is a happenin' - tournament: fields.ForeignKeyRelation["Tournament"] = fields.ForeignKeyField( + tournament: fields.ForeignKeyRelation[Tournament] = fields.ForeignKeyField( "models.Tournament", related_name="events" ) reporter: fields.ForeignKeyNullableRelation[Reporter] = fields.ForeignKeyField( "models.Reporter", null=True ) - participants: fields.ManyToManyRelation["Team"] = fields.ManyToManyField( + participants: fields.ManyToManyRelation[Team] = fields.ManyToManyField( "models.Team", related_name="events", through="event_team", @@ -175,7 +177,7 @@ class Address(Model): class M2mWithO2oPk(Model): name = fields.CharField(max_length=64) - address: fields.ManyToManyRelation["Address"] = fields.ManyToManyField("models.Address") + address: fields.ManyToManyRelation[Address] = fields.ManyToManyField("models.Address") class O2oPkModelWithM2m(Model): @@ -184,7 +186,7 @@ class O2oPkModelWithM2m(Model): on_delete=fields.CASCADE, primary_key=True, ) - nodes: fields.ManyToManyRelation["Node"] = fields.ManyToManyField("models.Node") + nodes: fields.ManyToManyRelation[Node] = fields.ManyToManyField("models.Node") class Dest_null(Model): @@ -210,7 +212,7 @@ class Team(Model): name = fields.TextField() events: fields.ManyToManyRelation[Event] - minrelation_through: fields.ManyToManyRelation["MinRelation"] + minrelation_through: fields.ManyToManyRelation[MinRelation] alias = fields.IntField(null=True) class Meta: @@ -228,7 +230,7 @@ class EventTwo(Model): name = fields.TextField() tournament_id = fields.IntField() # Here we make link to events.Team, not models.Team - participants: fields.ManyToManyRelation["TeamTwo"] = fields.ManyToManyField("events.TeamTwo") + participants: fields.ManyToManyRelation[TeamTwo] = fields.ManyToManyField("events.TeamTwo") class Meta: app = "events" @@ -331,7 +333,7 @@ class FloatFields(Model): floatnum_null = fields.FloatField(null=True) -def raise_if_not_dict_or_list(value: Union[dict, list]): +def raise_if_not_dict_or_list(value: dict | list): if not isinstance(value, (dict, list)): raise ValidationError("Value must be a dict or list.") @@ -373,7 +375,7 @@ class MinRelation(Model): class M2MOne(Model): id = fields.IntField(primary_key=True) name = fields.CharField(max_length=255, null=True) - two: fields.ManyToManyRelation["M2MTwo"] = fields.ManyToManyField( + two: fields.ManyToManyRelation[M2MTwo] = fields.ManyToManyField( "models.M2MTwo", related_name="one" ) @@ -422,9 +424,9 @@ class ImplicitPkModel(Model): class UUIDPkModel(Model): id = fields.UUIDField(primary_key=True) - children: fields.ReverseRelation["UUIDFkRelatedModel"] - children_null: fields.ReverseRelation["UUIDFkRelatedNullModel"] - peers: fields.ManyToManyRelation["UUIDM2MRelatedModel"] + children: fields.ReverseRelation[UUIDFkRelatedModel] + children_null: fields.ReverseRelation[UUIDFkRelatedNullModel] + peers: fields.ManyToManyRelation[UUIDM2MRelatedModel] class UUIDFkRelatedModel(Model): @@ -555,15 +557,15 @@ class Meta: class Employee(Model): name = fields.CharField(max_length=50) - manager: fields.ForeignKeyNullableRelation["Employee"] = fields.ForeignKeyField( + manager: fields.ForeignKeyNullableRelation[Employee] = fields.ForeignKeyField( "models.Employee", related_name="team_members", null=True, on_delete=NO_ACTION ) - team_members: fields.ReverseRelation["Employee"] + team_members: fields.ReverseRelation[Employee] - talks_to: fields.ManyToManyRelation["Employee"] = fields.ManyToManyField( + talks_to: fields.ManyToManyRelation[Employee] = fields.ManyToManyField( "models.Employee", related_name="gets_talked_to", on_delete=NO_ACTION ) - gets_talked_to: fields.ManyToManyRelation["Employee"] + gets_talked_to: fields.ManyToManyRelation[Employee] def __str__(self): return self.name @@ -650,16 +652,16 @@ class StraightFields(Model): blip = fields.CharField(max_length=50, default="BLIP") nullable = fields.CharField(max_length=50, null=True) - fk: fields.ForeignKeyNullableRelation["StraightFields"] = fields.ForeignKeyField( + fk: fields.ForeignKeyNullableRelation[StraightFields] = fields.ForeignKeyField( "models.StraightFields", related_name="fkrev", null=True, description="Tree!", on_delete=NO_ACTION, ) - fkrev: fields.ReverseRelation["StraightFields"] + fkrev: fields.ReverseRelation[StraightFields] - o2o: fields.OneToOneNullableRelation["StraightFields"] = fields.OneToOneField( + o2o: fields.OneToOneNullableRelation[StraightFields] = fields.OneToOneField( "models.StraightFields", related_name="o2o_rev", null=True, @@ -668,13 +670,13 @@ class StraightFields(Model): ) o2o_rev: fields.Field - rel_to: fields.ManyToManyRelation["StraightFields"] = fields.ManyToManyField( + rel_to: fields.ManyToManyRelation[StraightFields] = fields.ManyToManyField( "models.StraightFields", related_name="rel_from", description="M2M to myself", on_delete=fields.NO_ACTION, ) - rel_from: fields.ManyToManyRelation["StraightFields"] + rel_from: fields.ManyToManyRelation[StraightFields] class Meta: unique_together = [["chars", "blip"]] @@ -698,7 +700,7 @@ class SourceFields(Model): blip = fields.CharField(max_length=50, default="BLIP", source_field="da_blip") nullable = fields.CharField(max_length=50, null=True, source_field="some_nullable") - fk: fields.ForeignKeyNullableRelation["SourceFields"] = fields.ForeignKeyField( + fk: fields.ForeignKeyNullableRelation[SourceFields] = fields.ForeignKeyField( "models.SourceFields", related_name="fkrev", null=True, @@ -706,9 +708,9 @@ class SourceFields(Model): description="Tree!", on_delete=NO_ACTION, ) - fkrev: fields.ReverseRelation["SourceFields"] + fkrev: fields.ReverseRelation[SourceFields] - o2o: fields.OneToOneNullableRelation["SourceFields"] = fields.OneToOneField( + o2o: fields.OneToOneNullableRelation[SourceFields] = fields.OneToOneField( "models.SourceFields", related_name="o2o_rev", null=True, @@ -718,7 +720,7 @@ class SourceFields(Model): ) o2o_rev: fields.Field - rel_to: fields.ManyToManyRelation["SourceFields"] = fields.ManyToManyField( + rel_to: fields.ManyToManyRelation[SourceFields] = fields.ManyToManyField( "models.SourceFields", related_name="rel_from", through="sometable_self", @@ -727,7 +729,7 @@ class SourceFields(Model): description="M2M to myself", on_delete=fields.NO_ACTION, ) - rel_from: fields.ManyToManyRelation["SourceFields"] + rel_from: fields.ManyToManyRelation[SourceFields] class Meta: table = "sometable" @@ -754,10 +756,10 @@ class EnumFields(Model): class DoubleFK(Model): name = fields.CharField(max_length=50) - left: fields.ForeignKeyNullableRelation["DoubleFK"] = fields.ForeignKeyField( + left: fields.ForeignKeyNullableRelation[DoubleFK] = fields.ForeignKeyField( "models.DoubleFK", null=True, related_name="left_rel", on_delete=NO_ACTION ) - right: fields.ForeignKeyNullableRelation["DoubleFK"] = fields.ForeignKeyField( + right: fields.ForeignKeyNullableRelation[DoubleFK] = fields.ForeignKeyField( "models.DoubleFK", null=True, related_name="right_rel", on_delete=NO_ACTION ) @@ -803,8 +805,8 @@ class School(Model): name = fields.TextField() id = fields.IntField(unique=True) - students: fields.ReverseRelation["Student"] - principal: fields.ReverseRelation["Principal"] + students: fields.ReverseRelation[Student] + principal: fields.ReverseRelation[Principal] class Student(Model): diff --git a/tests/utils/test_describe_model.py b/tests/utils/test_describe_model.py index ca9700ae6..abc80cd64 100644 --- a/tests/utils/test_describe_model.py +++ b/tests/utils/test_describe_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import uuid from typing import Union diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 9cf7e4627..1b3413370 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -34,8 +34,8 @@ class Tortoise: - apps: dict[str, dict[str, type["Model"]]] = {} - table_name_generator: Callable[[type["Model"]], str] | None = None + apps: dict[str, dict[str, type[Model]]] = {} + table_name_generator: Callable[[type[Model]], str] | None = None _inited: bool = False @classmethod @@ -53,7 +53,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: @classmethod def describe_model( - cls, model: type["Model"], serializable: bool = True + cls, model: type[Model], serializable: bool = True ) -> dict[str, Any]: # pragma: nocoverage """ Describes the given list of models or ALL registered models. @@ -79,7 +79,7 @@ def describe_model( @classmethod def describe_models( - cls, models: list[type["Model"]] | None = None, serializable: bool = True + cls, models: list[type[Model]] | None = None, serializable: bool = True ) -> dict[str, dict[str, Any]]: """ Describes the given list of models or ALL registered models. @@ -115,7 +115,7 @@ def describe_models( @classmethod def _init_relations(cls) -> None: - def get_related_model(related_app_name: str, related_model_name: str) -> type["Model"]: + def get_related_model(related_app_name: str, related_model_name: str) -> type[Model]: """ Test, if app and model really exist. Throws a ConfigurationError with a hopefully helpful message. If successful, returns the requested model. @@ -151,7 +151,7 @@ def split_reference(reference: str) -> tuple[str, str]: ) return items[0], items[1] - def init_fk_o2o_field(model: type["Model"], field: str, is_o2o=False) -> None: + def init_fk_o2o_field(model: type[Model], field: str, is_o2o=False) -> None: fk_object = cast( "OneToOneFieldInstance | ForeignKeyFieldInstance", model._meta.fields_map[field] ) @@ -284,7 +284,7 @@ def init_fk_o2o_field(model: type["Model"], field: str, is_o2o=False) -> None: related_model._meta.add_field(backward_relation_name, m2m_relation) @classmethod - def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[type["Model"]]: + def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[type[Model]]: if isinstance(models_path, ModuleType): module = models_path else: @@ -365,10 +365,10 @@ def _get_config_from_config_file(cls, config_file: str) -> dict: if extension in (".yml", ".yaml"): import yaml # pylint: disable=C0415 - with open(config_file, "r") as f: + with open(config_file) as f: config = yaml.safe_load(f) elif extension == ".json": - with open(config_file, "r") as f: + with open(config_file) as f: config = json.load(f) else: raise ConfigurationError( @@ -399,7 +399,7 @@ async def init( use_tz: bool = False, timezone: str = "UTC", routers: list[str | type] | None = None, - table_name_generator: Callable[[type["Model"]], str] | None = None, + table_name_generator: Callable[[type[Model]], str] | None = None, ) -> None: """ Sets up Tortoise-ORM: loads apps and models, configures database connections but does not @@ -530,7 +530,7 @@ def star_password(connections_config) -> str: str_connection_config = str_connection_config.replace( password, # Show one third of the password at beginning (may be better for debugging purposes) - f"{password[0:len(password) // 3]}***", + f"{password[0 : len(password) // 3]}***", ) return str_connection_config diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index 5128a8f3b..1bd31cba2 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import asyncio from collections.abc import Callable -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar import asyncpg from asyncpg.transaction import Transaction @@ -33,8 +35,8 @@ class AsyncpgDBClient(BasePostgresClient): executor_class = AsyncpgExecutor schema_generator = AsyncpgSchemaGenerator connection_class = asyncpg.connection.Connection - _pool: Optional[asyncpg.Pool] - _connection: Optional[asyncpg.connection.Connection] = None + _pool: asyncpg.Pool | None + _connection: asyncpg.connection.Connection | None = None async def create_connection(self, with_db: bool) -> None: if self.schema: @@ -99,11 +101,11 @@ async def db_delete(self) -> None: pass await self.close() - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) @translate_exceptions - async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Record]: + async def execute_insert(self, query: str, values: list) -> asyncpg.Record | None: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) # TODO: Cache prepared statement @@ -125,9 +127,7 @@ async def execute_many(self, query: str, values: list) -> None: await transaction.commit() @translate_exceptions - async def execute_query( - self, query: str, values: Optional[list] = None - ) -> tuple[int, list[dict]]: + async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) if values: @@ -146,7 +146,7 @@ async def execute_query( return len(rows), rows @translate_exceptions - async def execute_query_dict(self, query: str, values: Optional[list] = None) -> list[dict]: + async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) if values: @@ -165,11 +165,11 @@ def __init__(self, connection: AsyncpgDBClient) -> None: self._lock = asyncio.Lock() self.log = connection.log self.connection_name = connection.connection_name - self.transaction: Optional[Transaction] = None + self.transaction: Transaction | None = None self._finalized = False self._parent: AsyncpgDBClient = connection - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: # since we need to store the transaction object for each transaction block, # we need to wrap the connection with its own TransactionWrapper return NestedTransactionContext(TransactionWrapper(self)) diff --git a/tortoise/backends/asyncpg/executor.py b/tortoise/backends/asyncpg/executor.py index 4468f892a..b51a7394b 100644 --- a/tortoise/backends/asyncpg/executor.py +++ b/tortoise/backends/asyncpg/executor.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations import asyncpg @@ -7,7 +7,5 @@ class AsyncpgExecutor(BasePostgresExecutor): - async def _process_insert_result( - self, instance: Model, results: Optional[asyncpg.Record] - ) -> None: + async def _process_insert_result(self, instance: Model, results: asyncpg.Record | None) -> None: return await super()._process_insert_result(instance, results) diff --git a/tortoise/backends/asyncpg/schema_generator.py b/tortoise/backends/asyncpg/schema_generator.py index fdd7c9a14..05afb6d17 100644 --- a/tortoise/backends/asyncpg/schema_generator.py +++ b/tortoise/backends/asyncpg/schema_generator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator @@ -7,5 +9,5 @@ class AsyncpgSchemaGenerator(BasePostgresSchemaGenerator): - def __init__(self, client: "AsyncpgDBClient") -> None: + def __init__(self, client: AsyncpgDBClient) -> None: super().__init__(client) diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 98bfafd5e..cd009b970 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -106,7 +106,7 @@ class BaseDBAsyncClient(abc.ABC): """ _connection: Any - _parent: "BaseDBAsyncClient" + _parent: BaseDBAsyncClient _pool: Any connection_name: str query_class: type[Query] = Query @@ -154,14 +154,14 @@ async def db_delete(self) -> None: """ raise NotImplementedError() # pragma: nocoverage - def acquire_connection(self) -> "ConnectionWrapper" | "PoolConnectionWrapper": + def acquire_connection(self) -> ConnectionWrapper | PoolConnectionWrapper: """ Acquires a connection from the pool. Will return the current context connection if already in a transaction. """ raise NotImplementedError() # pragma: nocoverage - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: raise NotImplementedError() # pragma: nocoverage async def execute_insert(self, query: str, values: list) -> Any: diff --git a/tortoise/backends/base/config_generator.py b/tortoise/backends/base/config_generator.py index 87576a5d4..260f7b422 100644 --- a/tortoise/backends/base/config_generator.py +++ b/tortoise/backends/base/config_generator.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import urllib.parse as urlparse import uuid from collections.abc import Iterable from types import ModuleType -from typing import Any, Optional, Union +from typing import Any from tortoise.exceptions import ConfigurationError @@ -135,7 +137,7 @@ def expand_db_url(db_url: str, testing: bool = False) -> dict: db_backend = url.scheme db = DB_LOOKUP[db_backend] if db.get("skip_first_char", True): - path: Optional[str] = url.path[1:] + path: str | None = url.path[1:] else: path = url.netloc + url.path @@ -183,8 +185,8 @@ def expand_db_url(db_url: str, testing: bool = False) -> dict: def generate_config( db_url: str, - app_modules: dict[str, Iterable[Union[str, ModuleType]]], - connection_label: Optional[str] = None, + app_modules: dict[str, Iterable[str | ModuleType]], + connection_label: str | None = None, testing: bool = False, ) -> dict: _connection_label = connection_label or "default" diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 53082817e..2b7f07bdb 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import datetime import decimal from collections.abc import Callable, Iterable, Sequence from copy import copy -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast from pypika_tortoise import JoinType, Parameter, Table from pypika_tortoise.queries import QueryBuilder @@ -26,7 +28,7 @@ from tortoise.queryset import QuerySet EXECUTOR_CACHE: dict[ - tuple[str, Optional[str], str], + tuple[str, str | None, str], tuple[list, str, list, str, str, dict[str, str]], ] = {} @@ -38,16 +40,16 @@ class BaseExecutor: def __init__( self, - model: "type[Model]", - db: "BaseDBAsyncClient", - prefetch_map: "Optional[dict[str, set[Union[str, Prefetch]]]]" = None, - prefetch_queries: Optional[dict[str, list[tuple[Optional[str], "QuerySet"]]]] = None, - select_related_idx: Optional[ - list[tuple["type[Model]", int, str, "type[Model]", Iterable[Optional[str]]]] - ] = None, + model: type[Model], + db: BaseDBAsyncClient, + prefetch_map: dict[str, set[str | Prefetch]] | None = None, + prefetch_queries: dict[str, list[tuple[str | None, QuerySet]]] | None = None, + select_related_idx: ( + list[tuple[type[Model], int, str, type[Model], Iterable[str | None]]] | None + ) = None, ) -> None: self.model = model - self.db: "BaseDBAsyncClient" = db + self.db: BaseDBAsyncClient = db self.prefetch_map = prefetch_map or {} self._prefetch_queries = prefetch_queries or {} self.select_related_idx = select_related_idx @@ -98,8 +100,8 @@ async def execute_explain(self, sql: str) -> Any: async def execute_select( self, sql: str, - values: Optional[list] = None, - custom_fields: Optional[list] = None, + values: list | None = None, + custom_fields: list | None = None, ) -> list: _, raw_results = await self.db.execute_query(sql, values) instance_list = [] @@ -107,7 +109,7 @@ async def execute_select( if self.select_related_idx: _, current_idx, _, _, path = self.select_related_idx[0] row_items = list(dict(row).items()) - instance: "Model" = self.model._init_from_db(**dict(row_items[:current_idx])) + instance: Model = self.model._init_from_db(**dict(row_items[:current_idx])) instances: dict[Any, Any] = {path: instance} for model, index, *__, full_path in self.select_related_idx[1:]: (*path, attr) = full_path @@ -157,13 +159,13 @@ def _prepare_insert_statement( query = query.on_conflict().do_nothing() return query - async def _process_insert_result(self, instance: "Model", results: Any) -> None: + async def _process_insert_result(self, instance: Model, results: Any) -> None: raise NotImplementedError() # pragma: nocoverage def parameter(self, pos: int) -> Parameter: return Parameter(idx=pos + 1) - async def execute_insert(self, instance: "Model") -> None: + async def execute_insert(self, instance: Model) -> None: if not instance._custom_generated_pk: values = [ self.model._meta.fields_map[field_name].to_db_value( @@ -185,8 +187,8 @@ async def execute_insert(self, instance: "Model") -> None: async def execute_bulk_insert( self, - instances: "Iterable[Model]", - batch_size: Optional[int] = None, + instances: Iterable[Model], + batch_size: int | None = None, ) -> None: for instance_chunk in chunk(instances, batch_size): values_lists_all = [] @@ -218,8 +220,8 @@ async def execute_bulk_insert( def get_update_sql( self, - update_fields: Optional[Iterable[str]], - expressions: Optional[dict[str, Expression]], + update_fields: Iterable[str] | None, + expressions: dict[str, Expression] | None, ) -> str: """ Generates the SQL for updating a model depending on provided update_fields. @@ -262,7 +264,7 @@ def get_update_sql( return sql async def execute_update( - self, instance: "Union[type[Model], Model]", update_fields: Optional[Iterable[str]] + self, instance: type[Model] | Model, update_fields: Iterable[str] | None ) -> int: values = [] expressions = {} @@ -284,7 +286,7 @@ async def execute_update( await self.db.execute_query(self.get_update_sql(update_fields, expressions), values) )[0] - async def execute_delete(self, instance: "Union[type[Model], Model]") -> int: + async def execute_delete(self, instance: type[Model] | Model) -> int: return ( await self.db.execute_query( self.delete_query, [self.model._meta.pk.to_db_value(instance.pk, instance)] @@ -293,10 +295,10 @@ async def execute_delete(self, instance: "Union[type[Model], Model]") -> int: async def _prefetch_reverse_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable[Model], field: str, - related_query: tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + related_query: tuple[str | None, QuerySet], + ) -> Iterable[Model]: to_attr, related_query = related_query related_objects_for_fetch: dict[str, list] = {} related_field: BackwardFKRelation = self.model._meta.fields_map[field] # type: ignore @@ -322,7 +324,7 @@ async def _prefetch_reverse_relation( related_object_map: dict[str, list] = {} for entry in related_object_list: object_id = getattr(entry, relation_field) - if object_id in related_object_map.keys(): + if object_id in related_object_map: related_object_map[object_id].append(entry) else: related_object_map[object_id] = [entry] @@ -336,10 +338,10 @@ async def _prefetch_reverse_relation( async def _prefetch_reverse_o2o_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable[Model], field: str, - related_query: tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + related_query: tuple[str | None, QuerySet], + ) -> Iterable[Model]: to_attr, related_query = related_query related_objects_for_fetch: dict[str, list] = {} related_field: BackwardOneToOneRelation = self.model._meta.fields_map[field] # type: ignore @@ -377,10 +379,10 @@ async def _prefetch_reverse_o2o_relation( async def _prefetch_m2m_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable[Model], field: str, - related_query: tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + related_query: tuple[str | None, QuerySet], + ) -> Iterable[Model]: to_attr, related_query = related_query instance_id_set: set = { instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list @@ -437,7 +439,7 @@ async def _prefetch_m2m_relation( _, raw_results = await self.db.execute_query(*query.get_parameterized_sql()) relations: list[tuple[Any, Any]] = [] - related_object_list: list["Model"] = [] + related_object_list: list[Model] = [] model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk for e in raw_results: pk_values: tuple[Any, Any] = ( @@ -464,14 +466,14 @@ async def _prefetch_m2m_relation( async def _prefetch_direct_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable[Model], field: str, - related_query: tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + related_query: tuple[str | None, QuerySet], + ) -> Iterable[Model]: to_attr, related_queryset = related_query related_objects_for_fetch: dict[str, list] = {} relation_key_field = f"{field}_id" - model_to_field: dict["type[Model]", str] = {} + model_to_field: dict[type[Model], str] = {} for instance in instance_list: if (value := getattr(instance, relation_key_field)) is not None: if (model_cls := instance.__class__) in model_to_field: @@ -515,7 +517,7 @@ def _make_prefetch_queries(self) -> None: to_attr, related_query = self._prefetch_queries[field_name][0] else: relation_field = self.model._meta.fields_map[field_name] - related_model: "type[Model]" = relation_field.related_model # type: ignore + related_model: type[Model] = relation_field.related_model # type: ignore related_query = related_model.all().using_db(self.db) related_query.query = copy( related_query.model._meta.basequery @@ -526,10 +528,10 @@ def _make_prefetch_queries(self) -> None: async def _do_prefetch( self, - instance_id_list: "Iterable[Model]", + instance_id_list: Iterable[Model], field: str, - related_query: tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + related_query: tuple[str | None, QuerySet], + ) -> Iterable[Model]: if field in self.model._meta.backward_fk_fields: return await self._prefetch_reverse_relation(instance_id_list, field, related_query) @@ -540,9 +542,7 @@ async def _do_prefetch( return await self._prefetch_m2m_relation(instance_id_list, field, related_query) return await self._prefetch_direct_relation(instance_id_list, field, related_query) - async def _execute_prefetch_queries( - self, instance_list: "Iterable[Model]" - ) -> "Iterable[Model]": + async def _execute_prefetch_queries(self, instance_list: Iterable[Model]) -> Iterable[Model]: if instance_list and (self.prefetch_map or self._prefetch_queries): self._make_prefetch_queries() prefetch_tasks = [] @@ -553,9 +553,7 @@ async def _execute_prefetch_queries( return instance_list - async def fetch_for_list( - self, instance_list: "Iterable[Model]", *args: str - ) -> "Iterable[Model]": + async def fetch_for_list(self, instance_list: Iterable[Model], *args: str) -> Iterable[Model]: self.prefetch_map = {} for relation in args: first_level_field, __, forwarded_prefetch = relation.partition("__") @@ -574,5 +572,5 @@ async def fetch_for_list( return instance_list @classmethod - def get_overridden_filter_func(cls, filter_func: Callable) -> Optional[Callable]: + def get_overridden_filter_func(cls, filter_func: Callable) -> Callable | None: return cls.FILTER_FUNC_OVERRIDE.get(filter_func) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 47d77163c..cfe76ed69 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from collections.abc import Sequence from hashlib import sha256 from typing import TYPE_CHECKING, Any, cast @@ -39,7 +40,7 @@ class BaseSchemaGenerator: "){extra}{comment};" ) - def __init__(self, client: "BaseDBAsyncClient") -> None: + def __init__(self, client: BaseDBAsyncClient) -> None: self.client = client def _create_string( @@ -143,38 +144,28 @@ def _make_hash(*args: str, length: int) -> str: # Hash a set of string values and get a digest of the given length. return sha256(";".join(args).encode("utf-8")).hexdigest()[:length] - def _generate_index_name( - self, prefix: str, model: "type[Model] | str", field_names: list[str] + def _get_index_name( + self, prefix: str, model: type[Model] | str, field_names: Sequence[str] ) -> str: # NOTE: for compatibility, index name should not be longer than 30 # characters (Oracle limit). # That's why we slice some of the strings here. table_name = model if isinstance(model, str) else model._meta.db_table - index_name = "{}_{}_{}_{}".format( - prefix, - table_name[:11], - field_names[0][:7], - self._make_hash(table_name, *field_names, length=6), - ) - return index_name + table = table_name[:11] + field = field_names[0][:7] + hashed = self._make_hash(table_name, *field_names, length=6) + return f"{prefix}_{table}_{field}_{hashed}" - def _generate_fk_name( - self, from_table: str, from_field: str, to_table: str, to_field: str - ) -> str: - # NOTE: for compatibility, index name should not be longer than 30 - # characters (Oracle limit). + def _get_fk_name(self, from_table: str, from_field: str, to_table: str, to_field: str) -> str: + # NOTE: for compatibility, index name should not be longer than 30 characters (Oracle limit). # That's why we slice some of the strings here. - index_name = "fk_{f}_{t}_{h}".format( - f=from_table[:8], - t=to_table[:8], - h=self._make_hash(from_table, from_field, to_table, to_field, length=8), - ) - return index_name + hashed = self._make_hash(from_table, from_field, to_table, to_field, length=8) + return f"fk_{from_table[:8]}_{to_table[:8]}_{hashed}" def _get_index_sql( self, - model: "type[Model]", - field_names: list[str], + model: type[Model], + field_names: Sequence[str], safe: bool, index_name: str | None = None, index_type: str | None = None, @@ -182,15 +173,17 @@ def _get_index_sql( ) -> str: return self.INDEX_CREATE_TEMPLATE.format( exists="IF NOT EXISTS " if safe else "", - index_name=index_name or self._generate_index_name("idx", model, field_names), + index_name=index_name or self._get_index_name("idx", model, field_names), index_type=f"{index_type} " if index_type else "", table_name=model._meta.db_table, fields=", ".join([self.quote(f) for f in field_names]), extra=f"{extra}" if extra else "", ) - def _get_unique_index_sql(self, exists: str, table_name: str, field_names: list[str]) -> str: - index_name = self._generate_index_name("uidx", table_name, field_names) + def _get_unique_index_sql( + self, exists: str, table_name: str, field_names: Sequence[str] + ) -> str: + index_name = self._get_index_name("uidx", table_name, field_names) return self.UNIQUE_INDEX_CREATE_TEMPLATE.format( exists=exists, index_name=index_name, @@ -200,200 +193,122 @@ def _get_unique_index_sql(self, exists: str, table_name: str, field_names: list[ extra="", ) - def _get_unique_constraint_sql(self, model: "type[Model]", field_names: list[str]) -> str: + def _get_unique_constraint_sql(self, model: type[Model], field_names: Sequence[str]) -> str: return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format( - index_name=self._generate_index_name("uid", model, field_names), + index_name=self._get_index_name("uid", model, field_names), fields=", ".join([self.quote(f) for f in field_names]), ) - def _get_pk_field_sql_type(self, pk_field: "Field") -> str: + def _get_pk_field_sql_type(self, pk_field: Field) -> str: if isinstance(pk_field, OneToOneFieldInstance): return self._get_pk_field_sql_type(pk_field.related_model._meta.pk) if sql_type := pk_field.get_for_dialect(self.DIALECT, "SQL_TYPE"): return sql_type raise ConfigurationError(f"Can't get SQL type of {pk_field} for {self.DIALECT}") - def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: - fields_to_create = [] - fields_with_index = [] - m2m_tables_for_create = [] - references = set() - models_to_create: "list[type[Model]]" = [] - - self._get_models_to_create(models_to_create) - models_tables = [model._meta.db_table for model in models_to_create] - for field_name, column_name in model._meta.fields_db_projection.items(): - field_object = model._meta.fields_map[field_name] - comment = ( - self._column_comment_generator( - table=model._meta.db_table, column=column_name, comment=field_object.description + def _get_pk_create_sql(self, field_object: Field, column_name: str, comment: str) -> str: + if field_object.pk and field_object.generated: + generated_sql = field_object.get_for_dialect(self.DIALECT, "GENERATED_SQL") + if generated_sql: # pragma: nobranch + return self.GENERATED_PK_TEMPLATE.format( + field_name=column_name, + generated_sql=generated_sql, + comment=comment, ) - if field_object.description - else "" + return "" + + def _get_field_comment(self, field_object: Field, table_name: str, column_name: str) -> str: + if desc := field_object.description: + return self._column_comment_generator( + table=table_name, column=column_name, comment=desc ) + return "" - default = field_object.default - auto_now_add = getattr(field_object, "auto_now_add", False) - auto_now = getattr(field_object, "auto_now", False) - if default is not None or auto_now or auto_now_add: - if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)): - default = "" - else: - default = field_object.to_db_value(default, model) - try: - default = self._column_default_generator( - model._meta.db_table, - column_name, - self._escape_default_value(default), - auto_now_add, - auto_now, - ) - except NotImplementedError: - default = "" - else: - default = "" + def _get_field_sql_and_related_table( + self, field_object: Field, table_name: str, column_name: str, default: str, comment: str + ) -> tuple[str, str]: + nullable = " NOT NULL" if not field_object.null else "" + unique = " UNIQUE" if field_object.unique else "" + field_type = field_object.get_for_dialect(self.DIALECT, "SQL_TYPE") - # TODO: PK generation needs to move out of schema generator. - if field_object.pk: - if field_object.generated: - generated_sql = field_object.get_for_dialect(self.DIALECT, "GENERATED_SQL") - if generated_sql: # pragma: nobranch - fields_to_create.append( - self.GENERATED_PK_TEMPLATE.format( - field_name=column_name, - generated_sql=generated_sql, - comment=comment, - ) - ) - continue - - nullable = " NOT NULL" if not field_object.null else "" - unique = " UNIQUE" if field_object.unique else "" - - if getattr(field_object, "reference", None): - reference = cast("ForeignKeyFieldInstance", field_object.reference) - comment = ( - self._column_comment_generator( - table=model._meta.db_table, - column=column_name, - comment=reference.description, - ) - if reference.description - else "" - ) + field_creation_string, related_table_name = "", "" + if getattr(field_object, "reference", None): + reference = cast("ForeignKeyFieldInstance", field_object.reference) + comment = self._get_field_comment(reference, table_name, column_name) - to_field_name = reference.to_field_instance.source_field - if not to_field_name: - to_field_name = reference.to_field_instance.model_field_name + to_field_name = reference.to_field_instance.source_field + if not to_field_name: + to_field_name = reference.to_field_instance.model_field_name + related_table_name = reference.related_model._meta.db_table + if reference.db_constraint: field_creation_string = self._create_string( db_column=column_name, - field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), + field_type=field_type, nullable=nullable, unique=unique, is_primary_key=field_object.pk, - comment=comment if not reference.db_constraint else "", + comment="", default=default, - ) + ( - self._create_fk_string( - constraint_name=self._generate_fk_name( - model._meta.db_table, - column_name, - reference.related_model._meta.db_table, - to_field_name, - ), - db_column=column_name, - table=reference.related_model._meta.db_table, - field=to_field_name, - on_delete=reference.on_delete, - comment=comment, - ) - if reference.db_constraint - else "" - ) - references.add(reference.related_model._meta.db_table) - else: - field_creation_string = self._create_string( + ) + self._create_fk_string( + constraint_name=self._get_fk_name( + table_name, + column_name, + related_table_name, + to_field_name, + ), db_column=column_name, - field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), - nullable=nullable, - unique=unique, - is_primary_key=field_object.pk, + table=related_table_name, + field=to_field_name, + on_delete=reference.on_delete, comment=comment, - default=default, - ) - - fields_to_create.append(field_creation_string) - - if field_object.index and not field_object.pk: - fields_with_index.append(column_name) - - if model._meta.unique_together: - for unique_together_list in model._meta.unique_together: - unique_together_to_create = [] - - for field in unique_together_list: - field_object = model._meta.fields_map[field] - unique_together_to_create.append(field_object.source_field or field) - - fields_to_create.append( - self._get_unique_constraint_sql(model, unique_together_to_create) ) + if not field_creation_string: + field_creation_string = self._create_string( + db_column=column_name, + field_type=field_type, + nullable=nullable, + unique=unique, + is_primary_key=field_object.pk, + comment=comment, + default=default, + ) + return field_creation_string, related_table_name - _indexes = [ - self._get_index_sql(model, [field_name], safe=safe) for field_name in fields_with_index - ] + def _get_field_indexes_sqls( + self, model: type[Model], field_names: Sequence[str], safe: bool + ) -> list[str]: + indexes = [self._get_index_sql(model, [field], safe=safe) for field in field_names] if model._meta.indexes: for index in model._meta.indexes: if isinstance(index, Index): idx_sql = index.get_sql(self, model, safe) else: - fields = [] - for field in index: - field_object = model._meta.fields_map[field] - fields.append(field_object.source_field or field) + fields = [ + model._meta.fields_map[field].source_field or field for field in index + ] idx_sql = self._get_index_sql(model, fields, safe=safe) if idx_sql: - _indexes.append(idx_sql) + indexes.append(idx_sql) - field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val] - - fields_to_create.extend(self._get_inner_statements()) - - table_fields_string = "\n {}\n".format(",\n ".join(fields_to_create)) - table_comment = ( - self._table_comment_generator( - table=model._meta.db_table, comment=model._meta.table_description - ) - if model._meta.table_description - else "" - ) - - table_create_string = self.TABLE_CREATE_TEMPLATE.format( - exists="IF NOT EXISTS " if safe else "", - table_name=model._meta.db_table, - fields=table_fields_string, - comment=table_comment, - extra=self._table_generate_extra(table=model._meta.db_table), - ) - - table_create_string = "\n".join([table_create_string, *field_indexes_sqls]) - - table_create_string += self._post_table_hook() + return [val for val in list(dict.fromkeys(indexes)) if val] + def _get_m2m_tables( + self, model: type[Model], table_name: str, safe: bool, models_tables: list[str] + ) -> list[str]: + m2m_tables_for_create = [] for m2m_field in model._meta.m2m_fields: field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field]) if field_object._generated or field_object.through in models_tables: continue backward_key, forward_key = field_object.backward_key, field_object.forward_key - backward_fk = forward_fk = "" if field_object.db_constraint: backward_fk = self._create_fk_string( "", backward_key, - model._meta.db_table, + table_name, model._meta.db_pk_column, field_object.on_delete, "", @@ -406,6 +321,8 @@ def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: field_object.on_delete, "", ) + else: + backward_fk = forward_fk = "" exists = "IF NOT EXISTS " if safe else "" table_name = field_object.through backward_type = self._get_pk_field_sql_type(model._meta.pk) @@ -449,28 +366,114 @@ def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: lines.insert(-1, indent + unique_index_create_sql) m2m_create_string = "\n".join(lines) m2m_tables_for_create.append(m2m_create_string) + return m2m_tables_for_create + + def _get_field_default( + self, field_object: Field, table_name: str, column_name: str, model: type[Model] + ) -> str: + auto_now_add = getattr(field_object, "auto_now_add", False) + auto_now = getattr(field_object, "auto_now", False) + default = field_object.default + if default is not None or auto_now or auto_now_add: + if not callable(default) and not isinstance( + field_object, (UUIDField, TextField, JSONField) + ): + default = field_object.to_db_value(default, model) + try: + return self._column_default_generator( + table_name, + column_name, + self._escape_default_value(default), + auto_now_add, + auto_now, + ) + except NotImplementedError: + pass + return "" + + def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: + fields_to_create = [] + fields_with_index = [] + references = set() + models_to_create: list[type[Model]] = self._get_models_to_create() + table_name = model._meta.db_table + models_tables = [model._meta.db_table for model in models_to_create] + for field_name, column_name in model._meta.fields_db_projection.items(): + field_object = model._meta.fields_map[field_name] + comment = self._get_field_comment(field_object, table_name, column_name) + default = self._get_field_default(field_object, table_name, column_name, model) + + # TODO: PK generation needs to move out of schema generator. + if create_pk_field := self._get_pk_create_sql(field_object, column_name, comment): + fields_to_create.append(create_pk_field) + continue + + field_creation_string, related_table_name = self._get_field_sql_and_related_table( + field_object, table_name, column_name, default, comment + ) + if related_table_name: + references.add(related_table_name) + fields_to_create.append(field_creation_string) + + if field_object.index and not field_object.pk: + fields_with_index.append(column_name) + + if model._meta.unique_together: + for unique_together_list in model._meta.unique_together: + unique_together_to_create = [ + model._meta.fields_map[field].source_field or field + for field in unique_together_list + ] + fields_to_create.append( + self._get_unique_constraint_sql(model, unique_together_to_create) + ) + + field_indexes_sqls = self._get_field_indexes_sqls(model, fields_with_index, safe) + + fields_to_create.extend(self._get_inner_statements()) + + table_fields_string = "\n {}\n".format(",\n ".join(fields_to_create)) + table_comment = ( + self._table_comment_generator(table=table_name, comment=model._meta.table_description) + if model._meta.table_description + else "" + ) + + table_create_string = self.TABLE_CREATE_TEMPLATE.format( + exists="IF NOT EXISTS " if safe else "", + table_name=table_name, + fields=table_fields_string, + comment=table_comment, + extra=self._table_generate_extra(table=table_name), + ) + + table_create_string = "\n".join([table_create_string, *field_indexes_sqls]) + + table_create_string += self._post_table_hook() + + m2m_tables_for_create = self._get_m2m_tables(model, table_name, safe, models_tables) return { - "table": model._meta.db_table, + "table": table_name, "model": model, "table_creation_string": table_create_string, "references": references, "m2m_tables": m2m_tables_for_create, } - def _get_models_to_create(self, models_to_create: "list[type[Model]]") -> None: + def _get_models_to_create(self) -> list[type[Model]]: from tortoise import Tortoise + models_to_create: list[type[Model]] = [] for app in Tortoise.apps.values(): for model in app.values(): if model._meta.db == self.client: model._check() models_to_create.append(model) + return models_to_create def get_create_schema_sql(self, safe: bool = True) -> str: - models_to_create: "list[type[Model]]" = [] - - self._get_models_to_create(models_to_create) + models_to_create = self._get_models_to_create() tables_to_create = [] for model in models_to_create: diff --git a/tortoise/backends/base_postgres/client.py b/tortoise/backends/base_postgres/client.py index c7ffdbf8f..10ac8435e 100644 --- a/tortoise/backends/base_postgres/client.py +++ b/tortoise/backends/base_postgres/client.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import abc import asyncio from asyncio.events import AbstractEventLoop from collections.abc import Callable, Coroutine from functools import wraps -from typing import TYPE_CHECKING, Any, Optional, SupportsInt, TypeVar, Union +from typing import TYPE_CHECKING, Any, SupportsInt, TypeVar from pypika_tortoise import PostgreSQLQuery @@ -45,17 +47,17 @@ class BasePostgresClient(BaseDBAsyncClient, abc.ABC): capabilities = Capabilities( "postgres", support_update_limit_order_by=False, support_for_posix_regex_queries=True ) - connection_class: "Optional[Union[AsyncConnection, Connection]]" = None - loop: Optional[AbstractEventLoop] = None - _pool: Optional[Any] = None - _connection: Optional[Any] = None + connection_class: AsyncConnection | Connection | None = None + loop: AbstractEventLoop | None = None + _pool: Any | None = None + _connection: Any | None = None def __init__( self, - user: Optional[str] = None, - password: Optional[str] = None, - database: Optional[str] = None, - host: Optional[str] = None, + user: str | None = None, + password: str | None = None, + database: str | None = None, + host: str | None = None, port: SupportsInt = 5432, **kwargs: Any, ) -> None: @@ -120,15 +122,15 @@ async def db_delete(self) -> None: finally: await self.close() - def acquire_connection(self) -> Union[ConnectionWrapper, PoolConnectionWrapper]: + def acquire_connection(self) -> ConnectionWrapper | PoolConnectionWrapper: return PoolConnectionWrapper(self, self._pool_init_lock) @abc.abstractmethod - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: raise NotImplementedError("_in_transaction is not implemented") @abc.abstractmethod - async def execute_insert(self, query: str, values: list) -> Optional[Any]: + async def execute_insert(self, query: str, values: list) -> Any | None: raise NotImplementedError("execute_insert is not implemented") @abc.abstractmethod @@ -136,13 +138,11 @@ async def execute_many(self, query: str, values: list) -> None: raise NotImplementedError("execute_many is not implemented") @abc.abstractmethod - async def execute_query( - self, query: str, values: Optional[list] = None - ) -> tuple[int, list[dict]]: + async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]: raise NotImplementedError("execute_query is not implemented") @abc.abstractmethod - async def execute_query_dict(self, query: str, values: Optional[list] = None) -> list[dict]: + async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]: raise NotImplementedError("execute_query_dict is not implemented") @translate_exceptions diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index e51632e30..b6e958b0e 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import uuid from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from pypika_tortoise.dialects import PostgreSQLQueryBuilder from pypika_tortoise.terms import Term @@ -68,7 +70,7 @@ def _prepare_insert_statement( query = query.on_conflict().do_nothing() return query - async def _process_insert_result(self, instance: Model, results: Optional[dict]) -> None: + async def _process_insert_result(self, instance: Model, results: dict | None) -> None: if results: generated_fields = self.model._meta.generated_db_fields db_projection = instance._meta.fields_db_projection_reverse diff --git a/tortoise/backends/base_postgres/schema_generator.py b/tortoise/backends/base_postgres/schema_generator.py index bf22c80d8..72cbeeaf5 100644 --- a/tortoise/backends/base_postgres/schema_generator.py +++ b/tortoise/backends/base_postgres/schema_generator.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator @@ -20,7 +21,7 @@ class BasePostgresSchemaGenerator(BaseSchemaGenerator): COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table}"."{column}" IS \'{comment}\';' GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}' - def __init__(self, client: "BasePostgresClient") -> None: + def __init__(self, client: BasePostgresClient) -> None: super().__init__(client) self.comments_array: list[str] = [] @@ -71,8 +72,8 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "type[Model]", - field_names: list[str], + model: type[Model], + field_names: Sequence[str], safe: bool, index_name: str | None = None, index_type: str | None = None, diff --git a/tortoise/backends/mssql/client.py b/tortoise/backends/mssql/client.py index 11fd5763c..8cd60ec43 100644 --- a/tortoise/backends/mssql/client.py +++ b/tortoise/backends/mssql/client.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from itertools import count -from typing import Any, Optional, SupportsInt +from typing import Any, SupportsInt from pypika_tortoise.dialects import MSSQLQuery @@ -40,7 +42,7 @@ def __init__( super().__init__(**kwargs) self.dsn = f"DRIVER={driver};SERVER={host},{port};UID={user};PWD={password};" - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) @translate_exceptions @@ -60,9 +62,9 @@ def _gen_savepoint_name(_c=count()) -> str: class TransactionWrapper(ODBCTransactionWrapper, MSSQLClient): def __init__(self, connection: ODBCClient) -> None: super().__init__(connection) - self._savepoint: Optional[str] = None + self._savepoint: str | None = None - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return NestedTransactionContext(TransactionWrapper(self)) async def begin(self) -> None: diff --git a/tortoise/backends/mssql/executor.py b/tortoise/backends/mssql/executor.py index e2b9fbcb3..db17684c6 100644 --- a/tortoise/backends/mssql/executor.py +++ b/tortoise/backends/mssql/executor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from tortoise.backends.odbc.executor import ODBCExecutor diff --git a/tortoise/backends/mssql/schema_generator.py b/tortoise/backends/mssql/schema_generator.py index 1b4c8edac..16d09ce16 100644 --- a/tortoise/backends/mssql/schema_generator.py +++ b/tortoise/backends/mssql/schema_generator.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator @@ -29,7 +30,7 @@ class MSSQLSchemaGenerator(BaseSchemaGenerator): "){extra};" ) - def __init__(self, client: "MSSQLClient") -> None: + def __init__(self, client: MSSQLClient) -> None: super().__init__(client) self._field_indexes = [] # type: list[str] self._foreign_keys = [] # type: list[str] @@ -63,8 +64,8 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "type[Model]", - field_names: list[str], + model: type[Model], + field_names: Sequence[str], safe: bool, index_name: str | None = None, index_type: str | None = None, @@ -74,7 +75,7 @@ def _get_index_sql( model, field_names, False, index_name=index_name, index_type=index_type, extra=extra ) - def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: return super()._get_table_sql(model, False) def _create_fk_string( diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index c6bb07118..b3731a7c5 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio from collections.abc import Callable, Coroutine from functools import wraps from itertools import count -from typing import Any, Optional, SupportsInt, TypeVar, Union +from typing import Any, SupportsInt, TypeVar try: import asyncmy as mysql @@ -11,9 +13,9 @@ from asyncmy.constants import COMMAND except ImportError: import aiomysql as mysql + from pymysql import err as errors from pymysql.charset import charset_by_name from pymysql.constants import COMMAND - from pymysql import err as errors from pypika_tortoise import MySQLQuery @@ -101,7 +103,7 @@ def __init__( self.pool_maxsize = int(self.extra.pop("maxsize", 5)) self._template: dict = {} - self._pool: Optional[mysql.Pool] = None + self._pool: mysql.Pool | None = None self._connection = None self._pool_init_lock = asyncio.Lock() @@ -132,7 +134,7 @@ async def create_connection(self, with_db: bool) -> None: if self.storage_engine.lower() != "innodb": # pragma: nobranch self.capabilities.__dict__["supports_transactions"] = False hours = timezone.now().utcoffset().seconds / 3600 # type: ignore - tz = "{:+d}:{:02d}".format(int(hours), int((hours % 1) * 60)) + tz = f"{int(hours):+d}:{int((hours % 1) * 60):02d}" await cursor.execute(f"SET time_zone='{tz}';") self.log.debug("Created connection %s pool with params: %s", self._pool, self._template) except errors.OperationalError: @@ -167,10 +169,10 @@ async def db_delete(self) -> None: pass await self.close() - def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: + def acquire_connection(self) -> ConnectionWrapper | PoolConnectionWrapper: return PoolConnectionWrapper(self, self._pool_init_lock) - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) @translate_exceptions @@ -199,9 +201,7 @@ async def execute_many(self, query: str, values: list) -> None: await cursor.executemany(query, values) @translate_exceptions - async def execute_query( - self, query: str, values: Optional[list] = None - ) -> tuple[int, list[dict]]: + async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: @@ -212,7 +212,7 @@ async def execute_query( return cursor.rowcount, [dict(zip(fields, row)) for row in rows] return cursor.rowcount, [] - async def execute_query_dict(self, query: str, values: Optional[list] = None) -> list[dict]: + async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]: return (await self.execute_query(query, values))[1] @translate_exceptions @@ -228,13 +228,13 @@ def __init__(self, connection: MySQLClient) -> None: self.connection_name = connection.connection_name self._connection: mysql.Connection = connection._connection self._lock = asyncio.Lock() - self._savepoint: Optional[str] = None + self._savepoint: str | None = None self.log = connection.log self._finalized: bool = False self.fetch_inserted = connection.fetch_inserted self._parent = connection - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return NestedTransactionContext(TransactionWrapper(self)) def acquire_connection(self) -> ConnectionWrapper[mysql.Connection]: diff --git a/tortoise/backends/mysql/schema_generator.py b/tortoise/backends/mysql/schema_generator.py index a07ba9978..48e2af2f1 100644 --- a/tortoise/backends/mysql/schema_generator.py +++ b/tortoise/backends/mysql/schema_generator.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator @@ -31,7 +32,7 @@ class MySQLSchemaGenerator(BaseSchemaGenerator): "){extra}{comment};" ) - def __init__(self, client: "MySQLClient") -> None: + def __init__(self, client: MySQLClient) -> None: super().__init__(client) self._field_indexes = [] # type: list[str] self._foreign_keys = [] # type: list[str] @@ -72,8 +73,8 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "type[Model]", - field_names: list[str], + model: type[Model], + field_names: Sequence[str], safe: bool, index_name: str | None = None, index_type: str | None = None, diff --git a/tortoise/backends/odbc/client.py b/tortoise/backends/odbc/client.py index 5fae1db1b..91a436948 100644 --- a/tortoise/backends/odbc/client.py +++ b/tortoise/backends/odbc/client.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio from abc import ABC from collections.abc import Callable, Coroutine from functools import wraps -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar, Union import asyncodbc import pyodbc @@ -66,10 +68,10 @@ def __init__( self.maxsize = self._kwargs.pop("maxsize", 10) self.pool_recycle = self._kwargs.pop("pool_recycle", -1) self.echo = self._kwargs.pop("echo", False) - self.dsn: Optional[str] = None + self.dsn: str | None = None self._template: dict = {} - self._pool: Optional[asyncodbc.Pool] = None + self._pool: asyncodbc.Pool | None = None self._connection = None self._pool_init_lock = asyncio.Lock() @@ -132,9 +134,7 @@ async def execute_many(self, query: str, values: list) -> None: await cursor.commit() @translate_exceptions - async def execute_query( - self, query: str, values: Optional[list] = None - ) -> tuple[int, list[dict]]: + async def execute_query(self, query: str, values: list | None = None) -> tuple[int, list[dict]]: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: @@ -153,7 +153,7 @@ async def execute_query( return cursor.rowcount, [dict(zip(fields, row)) for row in rows] return cursor.rowcount, [] - async def execute_query_dict(self, query: str, values: Optional[list] = None) -> list[dict]: + async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]: return (await self.execute_query(query, values))[1] @translate_exceptions @@ -175,7 +175,7 @@ def __init__(self, connection: ODBCClient) -> None: self.fetch_inserted = connection.fetch_inserted self._parent = connection - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return NestedTransactionContext(self) def acquire_connection(self) -> ConnWrapperType: diff --git a/tortoise/backends/oracle/client.py b/tortoise/backends/oracle/client.py index 3b112fa66..cd3a762d4 100644 --- a/tortoise/backends/oracle/client.py +++ b/tortoise/backends/oracle/client.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import datetime import functools -from typing import TYPE_CHECKING, Any, SupportsInt, Union, cast +from typing import TYPE_CHECKING, Any, SupportsInt, cast import pyodbc import pytz @@ -56,10 +58,10 @@ def __init__( dbq += f"/{self.database}" self.dsn = f"DRIVER={driver};DBQ={dbq};UID={user};PWD={password};" - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) - def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: + def acquire_connection(self) -> ConnectionWrapper | PoolConnectionWrapper: return OraclePoolConnectionWrapper(self, self._pool_init_lock) async def db_create(self) -> None: @@ -102,7 +104,7 @@ def _timestamp_convert(self, value: bytes) -> datetime.date: except ValueError: return parse_datetime(value.decode()[:-32]).astimezone(tz=pytz.utc) - async def __aenter__(self) -> "asyncodbc.Connection": + async def __aenter__(self) -> asyncodbc.Connection: connection = await super().__aenter__() if getattr(self.client, "database", False) and not hasattr(connection, "current_schema"): client = cast(OracleClient, self.client) @@ -114,7 +116,7 @@ async def __aenter__(self) -> "asyncodbc.Connection": await connection.add_output_converter( pyodbc.SQL_TYPE_TIMESTAMP, self._timestamp_convert ) - setattr(connection, "current_schema", client.user) + connection.current_schema = client.user return connection diff --git a/tortoise/backends/oracle/executor.py b/tortoise/backends/oracle/executor.py index 07a4b9de1..accc5756a 100644 --- a/tortoise/backends/oracle/executor.py +++ b/tortoise/backends/oracle/executor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, cast from tortoise import Model diff --git a/tortoise/backends/oracle/schema_generator.py b/tortoise/backends/oracle/schema_generator.py index d3cb54300..fed43ff9d 100644 --- a/tortoise/backends/oracle/schema_generator.py +++ b/tortoise/backends/oracle/schema_generator.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator @@ -32,7 +33,7 @@ class OracleSchemaGenerator(BaseSchemaGenerator): "){extra};" ) - def __init__(self, client: "OracleClient") -> None: + def __init__(self, client: OracleClient) -> None: super().__init__(client) self._field_indexes: list[str] = [] self._foreign_keys: list[str] = [] @@ -89,8 +90,8 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "type[Model]", - field_names: list[str], + model: type[Model], + field_names: Sequence[str], safe: bool, index_name: str | None = None, index_type: str | None = None, @@ -100,7 +101,7 @@ def _get_index_sql( model, field_names, False, index_name=index_name, index_type=index_type, extra=extra ) - def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: return super()._get_table_sql(model, False) def _create_fk_string( diff --git a/tortoise/backends/psycopg/client.py b/tortoise/backends/psycopg/client.py index 2a2f59bd1..7acbf956f 100644 --- a/tortoise/backends/psycopg/client.py +++ b/tortoise/backends/psycopg/client.py @@ -36,7 +36,7 @@ async def release(self, connection: psycopg.AsyncConnection): class PsycopgSQLQuery(PostgreSQLQuery): @classmethod - def _builder(cls, **kwargs) -> "PostgreSQLQueryBuilder": + def _builder(cls, **kwargs) -> PostgreSQLQueryBuilder: return PsycopgSQLQueryBuilder(**kwargs) diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 418143198..dc8ac09b5 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import asyncio import os import sqlite3 from collections.abc import Callable, Coroutine, Sequence from functools import wraps from itertools import count -from typing import Any, Optional, TypeVar, cast +from typing import Any, TypeVar, cast import aiosqlite from pypika_tortoise import SQLLiteQuery @@ -71,7 +73,7 @@ def __init__(self, file_path: str, **kwargs: Any) -> None: self.pragmas.setdefault("journal_size_limit", 16384) self.pragmas.setdefault("foreign_keys", "ON") - self._connection: Optional[aiosqlite.Connection] = None + self._connection: aiosqlite.Connection | None = None self._lock = asyncio.Lock() async def create_connection(self, with_db: bool) -> None: @@ -119,7 +121,7 @@ async def db_delete(self) -> None: def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._lock, self) - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return SqliteTransactionContext(SqliteTransactionWrapper(self), self._lock) @translate_exceptions @@ -144,7 +146,7 @@ async def execute_many(self, query: str, values: list[list]) -> None: @translate_exceptions async def execute_query( - self, query: str, values: Optional[list] = None + self, query: str, values: list | None = None ) -> tuple[int, Sequence[dict]]: query = query.replace("\x00", "'||CHAR(0)||'") async with self.acquire_connection() as connection: @@ -154,7 +156,7 @@ async def execute_query( return (connection.total_changes - start) or len(rows), rows @translate_exceptions - async def execute_query_dict(self, query: str, values: Optional[list] = None) -> list[dict]: + async def execute_query_dict(self, query: str, values: list | None = None) -> list[dict]: query = query.replace("\x00", "'||CHAR(0)||'") async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) @@ -215,13 +217,13 @@ def __init__(self, connection: SqliteClient) -> None: self.connection_name = connection.connection_name self._connection: aiosqlite.Connection = cast(aiosqlite.Connection, connection._connection) self._lock = asyncio.Lock() - self._savepoint: Optional[str] = None + self._savepoint: str | None = None self.log = connection.log self._finalized = False self.fetch_inserted = connection.fetch_inserted self._parent = connection - def _in_transaction(self) -> "TransactionContext": + def _in_transaction(self) -> TransactionContext: return NestedTransactionContext(SqliteTransactionWrapper(self)) @translate_exceptions diff --git a/tortoise/backends/sqlite/schema_generator.py b/tortoise/backends/sqlite/schema_generator.py index 94507e68c..14ded1a9e 100644 --- a/tortoise/backends/sqlite/schema_generator.py +++ b/tortoise/backends/sqlite/schema_generator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator diff --git a/tortoise/connection.py b/tortoise/connection.py index 44e166fb7..f7329fa68 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import contextvars import importlib from contextvars import ContextVar from copy import copy -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from tortoise.backends.base.config_generator import expand_db_url from tortoise.exceptions import ConfigurationError @@ -15,16 +17,16 @@ class ConnectionHandler: - _conn_storage: ContextVar[dict[str, "BaseDBAsyncClient"]] = contextvars.ContextVar( + _conn_storage: ContextVar[dict[str, BaseDBAsyncClient]] = contextvars.ContextVar( "_conn_storage", default={} ) def __init__(self) -> None: """Unified connection management interface.""" - self._db_config: Optional["DBConfigType"] = None + self._db_config: DBConfigType | None = None self._create_db: bool = False - async def _init(self, db_config: "DBConfigType", create_db: bool) -> None: + async def _init(self, db_config: DBConfigType, create_db: bool) -> None: if self._db_config is None: self._db_config = db_config else: @@ -33,7 +35,7 @@ async def _init(self, db_config: "DBConfigType", create_db: bool) -> None: await self._init_connections() @property - def db_config(self) -> "DBConfigType": + def db_config(self) -> DBConfigType: """ Return the DB config. @@ -52,20 +54,20 @@ def db_config(self) -> "DBConfigType": ) return self._db_config - def _get_storage(self) -> dict[str, "BaseDBAsyncClient"]: + def _get_storage(self) -> dict[str, BaseDBAsyncClient]: return self._conn_storage.get() - def _set_storage(self, new_storage: dict[str, "BaseDBAsyncClient"]) -> contextvars.Token: + def _set_storage(self, new_storage: dict[str, BaseDBAsyncClient]) -> contextvars.Token: # Should be used only for testing purposes. return self._conn_storage.set(new_storage) - def _copy_storage(self) -> dict[str, "BaseDBAsyncClient"]: + def _copy_storage(self) -> dict[str, BaseDBAsyncClient]: return copy(self._get_storage()) def _clear_storage(self) -> None: self._get_storage().clear() - def _discover_client_class(self, db_info: dict) -> type["BaseDBAsyncClient"]: + def _discover_client_class(self, db_info: dict) -> type[BaseDBAsyncClient]: # Let exception bubble up for transparency engine_str = db_info.get("engine", "") engine_module = importlib.import_module(engine_str) @@ -80,7 +82,7 @@ def _discover_client_class(self, db_info: dict) -> type["BaseDBAsyncClient"]: ) return client_class - def _get_db_info(self, conn_alias: str) -> Union[str, dict]: + def _get_db_info(self, conn_alias: str) -> str | dict: try: return self.db_config[conn_alias] except KeyError: @@ -91,21 +93,21 @@ def _get_db_info(self, conn_alias: str) -> Union[str, dict]: async def _init_connections(self) -> None: for alias in self.db_config: - connection: "BaseDBAsyncClient" = self.get(alias) + connection: BaseDBAsyncClient = self.get(alias) if self._create_db: await connection.db_create() - def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": + def _create_connection(self, conn_alias: str) -> BaseDBAsyncClient: db_info = self._get_db_info(conn_alias) if isinstance(db_info, str): db_info = expand_db_url(db_info) client_class = self._discover_client_class(db_info) db_params = db_info["credentials"].copy() db_params.update({"connection_name": conn_alias}) - connection: "BaseDBAsyncClient" = client_class(**db_params) + connection: BaseDBAsyncClient = client_class(**db_params) return connection - def get(self, conn_alias: str) -> "BaseDBAsyncClient": + def get(self, conn_alias: str) -> BaseDBAsyncClient: """ Return the connection object for the given alias, creating it if needed. @@ -117,7 +119,7 @@ def get(self, conn_alias: str) -> "BaseDBAsyncClient": :raises ConfigurationError: If the connection alias does not exist. """ - storage: dict[str, "BaseDBAsyncClient"] = self._get_storage() + storage: dict[str, BaseDBAsyncClient] = self._get_storage() try: return storage[conn_alias] except KeyError: @@ -125,7 +127,7 @@ def get(self, conn_alias: str) -> "BaseDBAsyncClient": storage[conn_alias] = connection return connection - def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Token: + def set(self, conn_alias: str, conn_obj: BaseDBAsyncClient) -> contextvars.Token: """ Sets the given alias to the provided connection object. @@ -142,7 +144,7 @@ def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Tok storage_copy[conn_alias] = conn_obj return self._conn_storage.set(storage_copy) - def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: + def discard(self, conn_alias: str) -> BaseDBAsyncClient | None: """ Discards the given alias from the storage in the `current context`. @@ -174,7 +176,7 @@ def reset(self, token: contextvars.Token) -> None: if alias not in prev_storage: prev_storage[alias] = conn - def all(self) -> list["BaseDBAsyncClient"]: + def all(self) -> list[BaseDBAsyncClient]: """Returns a list of connection objects from the storage in the `current context`.""" # The reason this method iterates over db_config and not over `storage` directly is # because: assume that someone calls `discard` with a certain alias, and calls this diff --git a/tortoise/contrib/aiohttp/__init__.py b/tortoise/contrib/aiohttp/__init__.py index 2496f1276..0bd7424df 100644 --- a/tortoise/contrib/aiohttp/__init__.py +++ b/tortoise/contrib/aiohttp/__init__.py @@ -1,6 +1,7 @@ +from __future__ import annotations + from collections.abc import Iterable from types import ModuleType -from typing import Optional, Union from aiohttp import web # pylint: disable=E0401 @@ -10,10 +11,10 @@ def register_tortoise( app: web.Application, - config: Optional[dict] = None, - config_file: Optional[str] = None, - db_url: Optional[str] = None, - modules: Optional[dict[str, Iterable[Union[str, ModuleType]]]] = None, + config: dict | None = None, + config_file: str | None = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, generate_schemas: bool = False, ) -> None: """ diff --git a/tortoise/contrib/blacksheep/__init__.py b/tortoise/contrib/blacksheep/__init__.py index 52d6e43bc..f92c9b740 100644 --- a/tortoise/contrib/blacksheep/__init__.py +++ b/tortoise/contrib/blacksheep/__init__.py @@ -1,6 +1,7 @@ +from __future__ import annotations + from collections.abc import Iterable from types import ModuleType -from typing import Optional, Union from blacksheep import Request from blacksheep.server import Application @@ -13,10 +14,10 @@ def register_tortoise( app: Application, - config: Optional[dict] = None, - config_file: Optional[str] = None, - db_url: Optional[str] = None, - modules: Optional[dict[str, Iterable[Union[str, ModuleType]]]] = None, + config: dict | None = None, + config_file: str | None = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, generate_schemas: bool = False, add_exception_handlers: bool = False, ) -> None: diff --git a/tortoise/contrib/fastapi/__init__.py b/tortoise/contrib/fastapi/__init__.py index e9890a8b3..391fc5649 100644 --- a/tortoise/contrib/fastapi/__init__.py +++ b/tortoise/contrib/fastapi/__init__.py @@ -24,10 +24,10 @@ def tortoise_exception_handlers() -> dict: from fastapi.responses import JSONResponse - async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist): + async def doesnotexist_exception_handler(request: Request, exc: DoesNotExist): return JSONResponse(status_code=404, content={"detail": str(exc)}) - async def integrityerror_exception_handler(request: "Request", exc: IntegrityError): + async def integrityerror_exception_handler(request: Request, exc: IntegrityError): return JSONResponse( status_code=422, content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]}, @@ -188,7 +188,7 @@ async def _self() -> Self: def register_tortoise( - app: "FastAPI", + app: FastAPI, config: dict | None = None, config_file: str | None = None, db_url: str | None = None, @@ -267,7 +267,7 @@ def register_tortoise( # So people can upgrade tortoise-orm in running project without changing any code @asynccontextmanager - async def orm_lifespan(app_instance: "FastAPI"): + async def orm_lifespan(app_instance: FastAPI): async with RegisterTortoise( app_instance, config, diff --git a/tortoise/contrib/mysql/fields.py b/tortoise/contrib/mysql/fields.py index 30717bd10..92b9ae1e4 100644 --- a/tortoise/contrib/mysql/fields.py +++ b/tortoise/contrib/mysql/fields.py @@ -1,11 +1,13 @@ -from typing import TYPE_CHECKING, Any, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 from tortoise.fields import Field from tortoise.fields import UUIDField as UUIDFieldBase if TYPE_CHECKING: # pragma: nocoverage - from tortoise.models import Model # noqa pylint: disable=unused-import + from tortoise.models import Model class GeometryField(Field): @@ -38,7 +40,7 @@ def __init__(self, binary_compression: bool = True, **kwargs: Any) -> None: self.SQL_TYPE = "BINARY(16)" self._binary_compression = binary_compression - def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Optional[Union[str, bytes]]: # type: ignore + def to_db_value(self, value: Any, instance: type[Model] | Model) -> str | bytes | None: # type: ignore # Make sure that value is a UUIDv4 # If not, raise an error # This is to prevent UUIDv1 or any other version from being stored in the database @@ -48,7 +50,7 @@ def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Opti return value.bytes return value and str(value) - def to_python_value(self, value: Any) -> Optional[UUID]: + def to_python_value(self, value: Any) -> UUID | None: if value is None or isinstance(value, UUID): return value elif self._binary_compression and isinstance(value, bytes): diff --git a/tortoise/contrib/mysql/indexes.py b/tortoise/contrib/mysql/indexes.py index 16627cceb..536a1d7a6 100644 --- a/tortoise/contrib/mysql/indexes.py +++ b/tortoise/contrib/mysql/indexes.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from pypika_tortoise.terms import Term @@ -11,9 +11,9 @@ class FullTextIndex(Index): def __init__( self, *expressions: Term, - fields: Optional[tuple[str, ...]] = None, - name: Optional[str] = None, - parser_name: Optional[str] = None, + fields: tuple[str, ...] | None = None, + name: str | None = None, + parser_name: str | None = None, ) -> None: super().__init__(*expressions, fields=fields, name=name) if parser_name: diff --git a/tortoise/contrib/mysql/search.py b/tortoise/contrib/mysql/search.py index 888aa3b60..6d680d9a1 100644 --- a/tortoise/contrib/mysql/search.py +++ b/tortoise/contrib/mysql/search.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from enum import Enum -from typing import Any, Optional +from typing import Any from pypika_tortoise import SqlContext from pypika_tortoise.enums import Comparator @@ -25,7 +27,7 @@ def __init__(self, *columns: Term) -> None: class Against(PypikaFunction): - def __init__(self, expr: Term, mode: Optional[Mode] = None) -> None: + def __init__(self, expr: Term, mode: Mode | None = None) -> None: super().__init__("AGAINST", expr) self.mode = mode @@ -40,5 +42,5 @@ class SearchCriterion(BasicCriterion): Only support for CharField, TextField with full search indexes. """ - def __init__(self, *columns: Term, expr: Term, mode: Optional[Mode] = None) -> None: + def __init__(self, *columns: Term, expr: Term, mode: Mode | None = None) -> None: super().__init__(Comp.search, Match(*columns), Against(expr, mode)) diff --git a/tortoise/contrib/postgres/fields.py b/tortoise/contrib/postgres/fields.py index 463b3c7b8..4ff8549a2 100644 --- a/tortoise/contrib/postgres/fields.py +++ b/tortoise/contrib/postgres/fields.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from tortoise.fields import Field diff --git a/tortoise/contrib/postgres/regex.py b/tortoise/contrib/postgres/regex.py index 18fd9870f..28dd1bf80 100644 --- a/tortoise/contrib/postgres/regex.py +++ b/tortoise/contrib/postgres/regex.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from typing import cast diff --git a/tortoise/contrib/postgres/search.py b/tortoise/contrib/postgres/search.py index ac41d25a6..35f5b7ed3 100644 --- a/tortoise/contrib/postgres/search.py +++ b/tortoise/contrib/postgres/search.py @@ -1,4 +1,4 @@ -from typing import Union +from __future__ import annotations from pypika_tortoise.enums import Comparator from pypika_tortoise.terms import BasicCriterion, Function, Term @@ -11,7 +11,7 @@ class Comp(Comparator): class SearchCriterion(BasicCriterion): - def __init__(self, field: Term, expr: Union[Term, Function]) -> None: + def __init__(self, field: Term, expr: Term | Function) -> None: if isinstance(expr, Function): _expr = expr else: diff --git a/tortoise/contrib/pydantic/base.py b/tortoise/contrib/pydantic/base.py index cecf8255a..80c985731 100644 --- a/tortoise/contrib/pydantic/base.py +++ b/tortoise/contrib/pydantic/base.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import sys -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, Union import pydantic from pydantic import BaseModel, ConfigDict, RootModel @@ -16,9 +18,7 @@ from tortoise.queryset import QuerySet, QuerySetSingle -def _get_fetch_fields( - pydantic_class: "type[PydanticModel]", model_class: "type[Model]" -) -> list[str]: +def _get_fetch_fields(pydantic_class: type[PydanticModel], model_class: type[Model]) -> list[str]: """ Recursively collect fields needed to fetch :param pydantic_class: The pydantic model class @@ -28,7 +28,7 @@ def _get_fetch_fields( fetch_fields = [] for field_name, field_type in pydantic_class.__annotations__.items(): origin = getattr(field_type, "__origin__", None) - if origin in (list, List, Union): + if origin in (list, list, Union): field_type = field_type.__args__[0] # noinspection PyProtectedMember @@ -65,7 +65,7 @@ def _tortoise_convert(cls, value): # pylint: disable=E0213 return value @classmethod - async def from_tortoise_orm(cls, obj: "Model") -> Self: + async def from_tortoise_orm(cls, obj: Model) -> Self: """ Returns a serializable pydantic model instance built from the provided model instance. @@ -92,7 +92,7 @@ async def from_tortoise_orm(cls, obj: "Model") -> Self: return cls.model_validate(obj) @classmethod - async def from_queryset_single(cls, queryset: "QuerySetSingle") -> Self: + async def from_queryset_single(cls, queryset: QuerySetSingle) -> Self: """ Returns a serializable pydantic model instance for a single model from the provided queryset. @@ -105,7 +105,7 @@ async def from_queryset_single(cls, queryset: "QuerySetSingle") -> Self: return cls.model_validate(await queryset.prefetch_related(*fetch_fields)) @classmethod - async def from_queryset(cls, queryset: "QuerySet") -> list[Self]: + async def from_queryset(cls, queryset: QuerySet) -> list[Self]: """ Returns a serializable pydantic model instance that contains a list of models, from the provided queryset. @@ -127,7 +127,7 @@ class PydanticListModel(RootModel): """ @classmethod - async def from_queryset(cls, queryset: "QuerySet") -> Self: + async def from_queryset(cls, queryset: QuerySet) -> Self: """ Returns a serializable pydantic model instance that contains a list of models, from the provided queryset. diff --git a/tortoise/contrib/pydantic/creator.py b/tortoise/contrib/pydantic/creator.py index e4a98cba6..bef5ebd67 100644 --- a/tortoise/contrib/pydantic/creator.py +++ b/tortoise/contrib/pydantic/creator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect from base64 import b32encode from collections.abc import MutableMapping @@ -52,7 +54,7 @@ def _cleandoc(obj: Any) -> str: def _pydantic_recursion_protector( - cls: "type[Model]", + cls: type[Model], *, stack: tuple, exclude: tuple[str, ...] = (), @@ -60,8 +62,8 @@ def _pydantic_recursion_protector( computed: tuple[str, ...] = (), name=None, allow_cycles: bool = False, - sort_alphabetically: Optional[bool] = None, -) -> Optional[type[PydanticModel]]: + sort_alphabetically: bool | None = None, +) -> type[PydanticModel] | None: """ It is an inner function to protect pydantic model creator against cyclic recursion """ @@ -99,8 +101,8 @@ def _pydantic_recursion_protector( class FieldMap(MutableMapping[str, Union[Field, ComputedFieldDescription]]): - def __init__(self, meta: PydanticMetaData, pk_field: Optional[Field] = None): - self._field_map: dict[str, Union[Field, ComputedFieldDescription]] = {} + def __init__(self, meta: PydanticMetaData, pk_field: Field | None = None): + self._field_map: dict[str, Field | ComputedFieldDescription] = {} self.pk_raw_field = pk_field.model_field_name if pk_field is not None else "" if pk_field: self.pk_raw_field = pk_field.model_field_name @@ -125,7 +127,7 @@ def __setitem__(self, __key, __value): def sort_alphabetically(self) -> None: self._field_map = {k: self._field_map[k] for k in sorted(self._field_map)} - def sort_definition_order(self, cls: "type[Model]", computed: tuple[str, ...]) -> None: + def sort_definition_order(self, cls: type[Model], computed: tuple[str, ...]) -> None: self._field_map = { k: self._field_map[k] for k in tuple(cls._meta.fields_map.keys()) + computed @@ -149,7 +151,7 @@ def field_map_update(self, fields: list[Field], meta: PydanticMetaData) -> None: self.pop(raw_field, None) self[name] = field - def computed_field_map_update(self, computed: tuple[str, ...], cls: "type[Model]"): + def computed_field_map_update(self, computed: tuple[str, ...], cls: type[Model]): self._field_map.update( { k: ComputedFieldDescription( @@ -163,14 +165,14 @@ def computed_field_map_update(self, computed: tuple[str, ...], cls: "type[Model] def pydantic_queryset_creator( - cls: "type[Model]", + cls: type[Model], *, name=None, exclude: tuple[str, ...] = (), include: tuple[str, ...] = (), computed: tuple[str, ...] = (), - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, + allow_cycles: bool | None = None, + sort_alphabetically: bool | None = None, ) -> type[PydanticListModel]: """ Function to build a `Pydantic Model `__ list off Tortoise Model. @@ -224,24 +226,24 @@ def pydantic_queryset_creator( class PydanticModelCreator: def __init__( self, - cls: "type[Model]", - name: Optional[str] = None, - exclude: Optional[tuple[str, ...]] = None, - include: Optional[tuple[str, ...]] = None, - computed: Optional[tuple[str, ...]] = None, - optional: Optional[tuple[str, ...]] = None, - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, + cls: type[Model], + name: str | None = None, + exclude: tuple[str, ...] | None = None, + include: tuple[str, ...] | None = None, + computed: tuple[str, ...] | None = None, + optional: tuple[str, ...] | None = None, + allow_cycles: bool | None = None, + sort_alphabetically: bool | None = None, exclude_readonly: bool = False, - meta_override: Optional[type] = None, - model_config: Optional[ConfigDict] = None, - validators: Optional[dict[str, Any]] = None, + meta_override: type | None = None, + model_config: ConfigDict | None = None, + validators: dict[str, Any] | None = None, module: str = __name__, _stack: tuple = (), _as_submodel: bool = False, ) -> None: - self._cls: "type[Model]" = cls - self._stack: tuple[tuple["type[Model]", str, int], ...] = ( + self._cls: type[Model] = cls + self._stack: tuple[tuple[type[Model], str, int], ...] = ( _stack # ((type[Model], field_name, max_recursion),) ) self._is_default: bool = ( @@ -401,13 +403,13 @@ def create_pydantic_model(self) -> type[PydanticModel]: def _process_field( self, field_name: str, - field: Union[Field, ComputedFieldDescription], + field: Field | ComputedFieldDescription, ) -> None: json_schema_extra: dict[str, Any] = {} fconfig: dict[str, Any] = { "json_schema_extra": json_schema_extra, } - field_property: Optional[Any] = None + field_property: Any | None = None is_to_one_relation: bool = False if isinstance(field, Field): field_property, is_to_one_relation = self._process_normal_field( @@ -448,7 +450,7 @@ def _process_normal_field( field: Field, json_schema_extra: dict[str, Any], fconfig: dict[str, Any], - ) -> tuple[Optional[Any], bool]: + ) -> tuple[Any | None, bool]: if isinstance( field, (ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation) ): @@ -462,11 +464,11 @@ def _process_normal_field( def _process_single_field_relation( self, field_name: str, - field: Union[ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation], + field: ForeignKeyFieldInstance | OneToOneFieldInstance | BackwardOneToOneRelation, json_schema_extra: dict[str, Any], - ) -> Optional[type[PydanticModel]]: + ) -> type[PydanticModel] | None: python_type = getattr(field, "related_model", field.field_type) - model: Optional[type[PydanticModel]] = self._get_submodel(python_type, field_name) + model: type[PydanticModel] | None = self._get_submodel(python_type, field_name) if model: self._relational_fields_index.append((field_name, model.__name__)) if field.null: @@ -480,8 +482,8 @@ def _process_single_field_relation( def _process_many_field_relation( self, field_name: str, - field: Union[BackwardFKRelation, ManyToManyFieldInstance], - ) -> Optional[type[list[type[PydanticModel]]]]: + field: BackwardFKRelation | ManyToManyFieldInstance, + ) -> type[list[type[PydanticModel]]] | None: python_type = field.related_model model = self._get_submodel(python_type, field_name) if model: @@ -495,14 +497,14 @@ def _process_data_field( field: Field, json_schema_extra: dict[str, Any], fconfig: dict[str, Any], - ) -> Optional[Any]: + ) -> Any | None: annotation = self._annotations.get(field_name, None) constraints = copy(field.constraints) if "readOnly" in constraints: json_schema_extra["readOnly"] = constraints["readOnly"] del constraints["readOnly"] fconfig.update(constraints) - python_type: Union[type[Enum], type[IntEnum], type] + python_type: type[Enum] | type[IntEnum] | type if isinstance(field, (IntEnumFieldInstance, CharEnumFieldInstance)): python_type = field.enum_type else: @@ -521,7 +523,7 @@ def _process_data_field( def _process_computed_field( self, field: ComputedFieldDescription, - ) -> Optional[Any]: + ) -> Any | None: func = field.function annotation = get_annotations(self._cls, func).get("return", None) comment = _cleandoc(func) @@ -532,8 +534,8 @@ def _process_computed_field( return None def _get_submodel( - self, _model: Optional["type[Model]"], field_name: str - ) -> Optional[type[PydanticModel]]: + self, _model: type[Model] | None, field_name: str + ) -> type[PydanticModel] | None: """Get Pydantic model for the submodel""" if _model: @@ -567,19 +569,19 @@ def get_fields_to_carry_on(field_tuple: tuple[str, ...]) -> tuple[str, ...]: def pydantic_model_creator( - cls: "type[Model]", + cls: type[Model], *, name=None, - exclude: Optional[tuple[str, ...]] = None, - include: Optional[tuple[str, ...]] = None, - computed: Optional[tuple[str, ...]] = None, - optional: Optional[tuple[str, ...]] = None, - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, + exclude: tuple[str, ...] | None = None, + include: tuple[str, ...] | None = None, + computed: tuple[str, ...] | None = None, + optional: tuple[str, ...] | None = None, + allow_cycles: bool | None = None, + sort_alphabetically: bool | None = None, exclude_readonly: bool = False, - meta_override: Optional[type] = None, - model_config: Optional[ConfigDict] = None, - validators: Optional[dict[str, Any]] = None, + meta_override: type | None = None, + model_config: ConfigDict | None = None, + validators: dict[str, Any] | None = None, module: str = __name__, ) -> type[PydanticModel]: """ diff --git a/tortoise/contrib/pydantic/descriptions.py b/tortoise/contrib/pydantic/descriptions.py index a0c7b6c66..befae3878 100644 --- a/tortoise/contrib/pydantic/descriptions.py +++ b/tortoise/contrib/pydantic/descriptions.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import dataclasses import sys from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if sys.version_info >= (3, 11): from typing import Self @@ -27,7 +29,7 @@ class ModelDescription: m2m_fields: list[Field] = dataclasses.field(default_factory=list) @classmethod - def from_model(cls, model: type["Model"]) -> Self: + def from_model(cls, model: type[Model]) -> Self: return cls( pk_field=model._meta.fields_map[model._meta.pk_attr], data_fields=[ @@ -68,7 +70,7 @@ def from_model(cls, model: type["Model"]) -> Self: class ComputedFieldDescription: field_type: Any function: Callable[[], Any] - description: Optional[str] + description: str | None @dataclasses.dataclass @@ -101,7 +103,7 @@ class PydanticMetaData: sort_alphabetically: bool = False #: Allows user to specify custom config for generated model - model_config: Optional[ConfigDict] = None + model_config: ConfigDict | None = None @classmethod def from_pydantic_meta(cls, old_pydantic_meta: Any) -> Self: @@ -140,14 +142,14 @@ def get_param_from_pydantic_meta(attr: str, default: Any) -> Any: ) return pmd - def construct_pydantic_meta(self, meta_override: type) -> "PydanticMetaData": + def construct_pydantic_meta(self, meta_override: type) -> PydanticMetaData: def get_param_from_meta_override(attr: str) -> Any: return getattr(meta_override, attr, getattr(self, attr)) default_include: tuple[str, ...] = tuple(get_param_from_meta_override("include")) default_exclude: tuple[str, ...] = tuple(get_param_from_meta_override("exclude")) default_computed: tuple[str, ...] = tuple(get_param_from_meta_override("computed")) - default_config: Optional[ConfigDict] = self.model_config + default_config: ConfigDict | None = self.model_config backward_relations: bool = bool(get_param_from_meta_override("backward_relations")) @@ -174,10 +176,10 @@ def finalize_meta( exclude: tuple[str, ...] = (), include: tuple[str, ...] = (), computed: tuple[str, ...] = (), - allow_cycles: Optional[bool] = None, - sort_alphabetically: Optional[bool] = None, - model_config: Optional[ConfigDict] = None, - ) -> "PydanticMetaData": + allow_cycles: bool | None = None, + sort_alphabetically: bool | None = None, + model_config: ConfigDict | None = None, + ) -> PydanticMetaData: _sort_fields: bool = ( self.sort_alphabetically if sort_alphabetically is None else sort_alphabetically ) diff --git a/tortoise/contrib/pydantic/utils.py b/tortoise/contrib/pydantic/utils.py index e2d577326..fec3f8548 100644 --- a/tortoise/contrib/pydantic/utils.py +++ b/tortoise/contrib/pydantic/utils.py @@ -1,11 +1,13 @@ +from __future__ import annotations + from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Optional, get_type_hints +from typing import TYPE_CHECKING, Any, get_type_hints if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model -def get_annotations(cls: "type[Model]", method: Optional[Callable] = None) -> dict[str, Any]: +def get_annotations(cls: type[Model], method: Callable | None = None) -> dict[str, Any]: """ Get all annotations including base classes :param cls: The model class we need annotations from diff --git a/tortoise/contrib/pylint/__init__.py b/tortoise/contrib/pylint/__init__.py index 8ab49535f..4a85ffa9f 100644 --- a/tortoise/contrib/pylint/__init__.py +++ b/tortoise/contrib/pylint/__init__.py @@ -2,13 +2,17 @@ Tortoise PyLint plugin """ +from __future__ import annotations + from collections.abc import Iterator -from typing import Any +from typing import TYPE_CHECKING, Any from astroid import MANAGER, inference_tip, nodes from astroid.exceptions import AstroidError from astroid.nodes import AnnAssign, Assign, ClassDef -from pylint.lint import PyLinter + +if TYPE_CHECKING: + from pylint.lint import PyLinter MODELS: dict[str, ClassDef] = {} FUTURE_RELATIONS: dict[str, list] = {} diff --git a/tortoise/contrib/quart/__init__.py b/tortoise/contrib/quart/__init__.py index fd79c5ee4..3a2fbf643 100644 --- a/tortoise/contrib/quart/__init__.py +++ b/tortoise/contrib/quart/__init__.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import asyncio import logging from collections.abc import Iterable from types import ModuleType -from typing import Optional, Union from quart import Quart # pylint: disable=E0401 @@ -12,10 +13,10 @@ def register_tortoise( app: Quart, - config: Optional[dict] = None, - config_file: Optional[str] = None, - db_url: Optional[str] = None, - modules: Optional[dict[str, Iterable[Union[str, ModuleType]]]] = None, + config: dict | None = None, + config_file: str | None = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, generate_schemas: bool = False, ) -> None: """ diff --git a/tortoise/contrib/sanic/__init__.py b/tortoise/contrib/sanic/__init__.py index af5798827..02fb5a585 100644 --- a/tortoise/contrib/sanic/__init__.py +++ b/tortoise/contrib/sanic/__init__.py @@ -1,6 +1,7 @@ +from __future__ import annotations + from collections.abc import Iterable from types import ModuleType -from typing import Optional, Union from sanic import Sanic # pylint: disable=E0401 @@ -10,10 +11,10 @@ def register_tortoise( app: Sanic, - config: Optional[dict] = None, - config_file: Optional[str] = None, - db_url: Optional[str] = None, - modules: Optional[dict[str, Iterable[Union[str, ModuleType]]]] = None, + config: dict | None = None, + config_file: str | None = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, generate_schemas: bool = False, ) -> None: """ diff --git a/tortoise/contrib/sqlite/regex.py b/tortoise/contrib/sqlite/regex.py index e70e87a5f..9000bea92 100644 --- a/tortoise/contrib/sqlite/regex.py +++ b/tortoise/contrib/sqlite/regex.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum import re from typing import cast diff --git a/tortoise/contrib/starlette/__init__.py b/tortoise/contrib/starlette/__init__.py index 82890dac9..ac7114b84 100644 --- a/tortoise/contrib/starlette/__init__.py +++ b/tortoise/contrib/starlette/__init__.py @@ -1,6 +1,7 @@ +from __future__ import annotations + from collections.abc import Iterable from types import ModuleType -from typing import Optional, Union from starlette.applications import Starlette # pylint: disable=E0401 @@ -10,10 +11,10 @@ def register_tortoise( app: Starlette, - config: Optional[dict] = None, - config_file: Optional[str] = None, - db_url: Optional[str] = None, - modules: Optional[dict[str, Iterable[Union[str, ModuleType]]]] = None, + config: dict | None = None, + config_file: str | None = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, generate_schemas: bool = False, ) -> None: """ diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 774293f81..14438ab1b 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import os as _os @@ -8,7 +10,7 @@ from collections.abc import Callable, Coroutine, Iterable from functools import partial, wraps from types import ModuleType -from typing import Any, Optional, TypeVar, Union, cast +from typing import Any, TypeVar, Union, cast from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless from tortoise import Model, Tortoise, connections @@ -51,11 +53,11 @@ _CONFIG: dict = {} _CONNECTIONS: dict = {} _LOOP: AbstractEventLoop = None # type: ignore -_MODULES: Iterable[Union[str, ModuleType]] = [] +_MODULES: Iterable[str | ModuleType] = [] _CONN_CONFIG: dict = {} -def getDBConfig(app_label: str, modules: Iterable[Union[str, ModuleType]]) -> dict: +def getDBConfig(app_label: str, modules: Iterable[str | ModuleType]) -> dict: """ DB Config factory, for use in testing. @@ -102,10 +104,10 @@ async def truncate_all_models() -> None: def initializer( - modules: Iterable[Union[str, ModuleType]], - db_url: Optional[str] = None, + modules: Iterable[str | ModuleType], + db_url: str | None = None, app_label: str = "models", - loop: Optional[AbstractEventLoop] = None, + loop: AbstractEventLoop | None = None, ) -> None: """ Sets up the DB for testing. Must be called as part of test environment setup. @@ -233,7 +235,7 @@ async def asyncTearDown(self) -> None: Tortoise._inited = False def assertListSortEqual( - self, list1: list[Any], list2: list[Any], msg: Any = ..., sorted_key: Optional[str] = None + self, list1: list[Any], list2: list[Any], msg: Any = ..., sorted_key: str | None = None ) -> None: if isinstance(list1[0], Model): super().assertListEqual( @@ -265,7 +267,7 @@ class IsolatedTestCase(SimpleTestCase): If you define a ``tortoise_test_modules`` list, it overrides the DB setup module for the tests. """ - tortoise_test_modules: Iterable[Union[str, ModuleType]] = [] + tortoise_test_modules: Iterable[str | ModuleType] = [] async def _setUpDB(self) -> None: await super()._setUpDB() @@ -407,7 +409,7 @@ def skip_wrapper(*args, **kwargs): @typing.overload -def init_memory_sqlite(models: Union[ModulesConfigType, None] = None) -> AsyncFuncDeco: ... +def init_memory_sqlite(models: ModulesConfigType | None = None) -> AsyncFuncDeco: ... @typing.overload @@ -415,8 +417,8 @@ def init_memory_sqlite(models: AsyncFunc) -> AsyncFunc: ... def init_memory_sqlite( - models: Union[ModulesConfigType, AsyncFunc, None] = None, -) -> Union[AsyncFunc, AsyncFuncDeco]: + models: ModulesConfigType | AsyncFunc | None = None, +) -> AsyncFunc | AsyncFuncDeco: """ For single file style to run code with memory sqlite diff --git a/tortoise/contrib/test/condition.py b/tortoise/contrib/test/condition.py index cc6dd24c5..3bc54795e 100644 --- a/tortoise/contrib/test/condition.py +++ b/tortoise/contrib/test/condition.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any diff --git a/tortoise/converters.py b/tortoise/converters.py index d98a59939..d2638166f 100644 --- a/tortoise/converters.py +++ b/tortoise/converters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import time from collections.abc import Sequence @@ -79,15 +81,15 @@ def escape_int(value: int, mapping=None) -> str: def escape_float(value: float, mapping=None) -> str: - return "%.15g" % value + return f"{value:.15g}" def escape_unicode(value: str, mapping=None) -> str: - return "'%s'" % _escape_unicode(value) + return f"'{_escape_unicode(value)}'" def escape_str(value: str, mapping=None) -> str: - return "'%s'" % escape_string(str(value), mapping) + return f"'{escape_string(str(value), mapping)}'" def escape_None(value: None, mapping=None) -> str: diff --git a/tortoise/exceptions.py b/tortoise/exceptions.py index 0662e29cc..5c75e728b 100644 --- a/tortoise/exceptions.py +++ b/tortoise/exceptions.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from tortoise import Model @@ -55,8 +57,8 @@ class NoValuesFetched(OperationalError): class NotExistOrMultiple(OperationalError): TEMPLATE = "" - def __init__(self, model: "Union[type[Model], str]", *args) -> None: - self.model: "Optional[type[Model]]" = None + def __init__(self, model: type[Model] | str, *args) -> None: + self.model: type[Model] | None = None if isinstance(model, str): args = (model,) + args else: @@ -83,8 +85,8 @@ class ObjectDoesNotExistError(OperationalError, KeyError): The DoesNotExist exception is raised when an item with the passed primary key does not exist """ - def __init__(self, model: "type[Model]", pk_name: str, pk_val: Any) -> None: - self.model: "type[Model]" = model + def __init__(self, model: type[Model], pk_name: str, pk_val: Any) -> None: + self.model: type[Model] = model self.pk_name: str = pk_name self.pk_val: Any = pk_val diff --git a/tortoise/expressions.py b/tortoise/expressions.py index f91d1c124..0641e206f 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -11,9 +11,15 @@ from pypika_tortoise import Field as PypikaField from pypika_tortoise import SqlContext, Table from pypika_tortoise.functions import AggregateFunction, DistinctOptionFunction -from pypika_tortoise.terms import ArithmeticExpression, Criterion +from pypika_tortoise.terms import ( + ArithmeticExpression, + Criterion, +) from pypika_tortoise.terms import Function as PypikaFunction -from pypika_tortoise.terms import Term, ValueWrapper +from pypika_tortoise.terms import ( + Term, + ValueWrapper, +) from pypika_tortoise.utils import format_alias_sql from tortoise.exceptions import FieldError, OperationalError @@ -36,7 +42,7 @@ @dataclass(frozen=True) class ResolveContext: - model: type["Model"] + model: type[Model] table: Table annotations: dict[str, Any] custom_filters: dict[str, FilterInfoDict] @@ -200,7 +206,7 @@ def __rpow__(self, other) -> CombinedExpression: class Subquery(Term): - def __init__(self, query: "AwaitableQuery") -> None: + def __init__(self, query: AwaitableQuery) -> None: super().__init__() self.query = query @@ -209,7 +215,7 @@ def get_sql(self, ctx: SqlContext) -> str: self.query._make_query() return self.query.query.get_parameterized_sql(ctx)[0] - def as_(self, alias: str) -> "Selectable": # type: ignore + def as_(self, alias: str) -> Selectable: # type: ignore self.query._choose_db_if_not_chosen() self.query._make_query() return self.query.query.as_(alias) @@ -246,7 +252,7 @@ class Q: AND = "AND" OR = "OR" - def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None: + def __init__(self, *args: Q, join_type: str = AND, **kwargs: Any) -> None: if args and kwargs: newarg = Q(join_type=join_type, **kwargs) args = (newarg,) + args @@ -263,7 +269,7 @@ def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None: self.join_type = join_type self._is_negated = False - def __and__(self, other: "Q") -> "Q": + def __and__(self, other: Q) -> Q: """ Returns a binary AND of Q objects, use ``AND`` operator. @@ -273,7 +279,7 @@ def __and__(self, other: "Q") -> "Q": raise OperationalError("AND operation requires a Q node") return Q(self, other, join_type=self.AND) - def __or__(self, other: "Q") -> "Q": + def __or__(self, other: Q) -> Q: """ Returns a binary OR of Q objects, use ``OR`` operator. @@ -283,7 +289,7 @@ def __or__(self, other: "Q") -> "Q": raise OperationalError("OR operation requires a Q node") return Q(self, other, join_type=self.OR) - def __invert__(self) -> "Q": + def __invert__(self) -> Q: """ Returns a negated instance of the Q object, use ``~`` operator. """ @@ -350,7 +356,7 @@ def _resolve_custom_kwarg( return modifier def _process_filter_kwarg( - self, model: "type[Model]", key: str, value: Any, table: Table + self, model: type[Model], key: str, value: Any, table: Table ) -> tuple[Criterion, tuple[Table, Criterion] | None]: join = None @@ -502,10 +508,10 @@ class Function(Expression): populate_field_object = False def __init__( - self, field: str | F | CombinedExpression | "Function", *default_values: Any + self, field: str | F | CombinedExpression | Function, *default_values: Any ) -> None: self.field = field - self.field_object: "Field | None" = None + self.field_object: Field | None = None self.default_values = default_values def _get_function_field(self, field: Term | str, *default_values) -> PypikaFunction: diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index c7626e055..61473d2ba 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import sys import warnings from collections.abc import Callable from enum import Enum -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, overload from pypika_tortoise.terms import Term @@ -142,40 +144,38 @@ def function_cast(self, term: Term) -> Term: has_db_field: bool = True skip_to_python_if_native: bool = False allows_generated: bool = False - function_cast: Optional[Callable[[Term], Term]] = None + function_cast: Callable[[Term], Term] | None = None SQL_TYPE: str = None # type: ignore GENERATED_SQL: str = None # type: ignore # These methods are just to make IDE/Linters happy: if TYPE_CHECKING: - def __new__(cls, *args: Any, **kwargs: Any) -> "Field[VALUE]": + def __new__(cls, *args: Any, **kwargs: Any) -> Field[VALUE]: return super().__new__(cls) @overload - def __get__(self, instance: None, owner: type["Model"]) -> "Field[VALUE]": ... + def __get__(self, instance: None, owner: type[Model]) -> Field[VALUE]: ... @overload - def __get__(self, instance: "Model", owner: type["Model"]) -> VALUE: ... + def __get__(self, instance: Model, owner: type[Model]) -> VALUE: ... - def __get__( - self, instance: Optional["Model"], owner: type["Model"] - ) -> "Field[VALUE] | VALUE": ... + def __get__(self, instance: Model | None, owner: type[Model]) -> Field[VALUE] | VALUE: ... - def __set__(self, instance: "Model", value: VALUE) -> None: ... + def __set__(self, instance: Model, value: VALUE) -> None: ... def __init__( self, - source_field: Optional[str] = None, + source_field: str | None = None, generated: bool = False, - primary_key: Optional[bool] = None, + primary_key: bool | None = None, null: bool = False, default: Any = None, unique: bool = False, - db_index: Optional[bool] = None, - description: Optional[str] = None, - model: "Optional[Model]" = None, - validators: Optional[list[Union[Validator, Callable]]] = None, + db_index: bool | None = None, + description: str | None = None, + model: Model | None = None, + validators: list[Validator | Callable] | None = None, **kwargs: Any, ) -> None: if (index := kwargs.pop("index", None)) is not None: @@ -225,13 +225,13 @@ def __init__( self.index = bool(db_index) self.model_field_name = "" self.description = description - self.docstring: Optional[str] = None - self.validators: list[Union[Validator, Callable]] = validators or [] + self.docstring: str | None = None + self.validators: list[Validator | Callable] = validators or [] # TODO: consider making this not be set from constructor - self.model: type["Model"] = model # type: ignore - self.reference: "Optional[Field]" = None + self.model: type[Model] = model # type: ignore + self.reference: Field | None = None - def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Any: + def to_db_value(self, value: Any, instance: type[Model] | Model) -> Any: """ Converts from the Python type to the DB type. @@ -324,7 +324,7 @@ def get_db_field_type(self) -> str: dialect = self.model._meta.db.capabilities.dialect return self.get_for_dialect(dialect, "SQL_TYPE") - def get_db_field_types(self) -> Optional[dict[str, str]]: + def get_db_field_types(self) -> dict[str, str] | None: """ Returns the DB types for this field. @@ -333,7 +333,7 @@ def get_db_field_types(self) -> Optional[dict[str, str]]: """ if not self.has_db_field: # pragma: nocoverage return None - default = getattr(self, "SQL_TYPE") + default = self.SQL_TYPE return { "": default, **{ @@ -420,7 +420,7 @@ def _type_name(typ: type) -> str: return str(typ).replace("typing.", "") return f"{typ.__module__}.{typ.__name__}" - def type_name(typ: Any) -> Union[str, list[str]]: + def type_name(typ: Any) -> str | list[str]: try: return typ._meta.full_name except (AttributeError, TypeError): @@ -433,7 +433,7 @@ def type_name(typ: Any) -> Union[str, list[str]]: except TypeError: return str(typ) - def default_name(default: Any) -> Optional[Union[int, float, str, bool]]: + def default_name(default: Any) -> int | float | str | bool | None: if isinstance(default, (int, float, str, bool, type(None))): return default if callable(default): diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 307fb6df0..c0b791c57 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import functools import json @@ -5,7 +7,7 @@ from collections.abc import Callable from decimal import Decimal from enum import Enum, IntEnum -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, TypeVar, Union from uuid import UUID, uuid4 from pypika_tortoise import functions @@ -76,7 +78,7 @@ class IntField(Field[int], int): SQL_TYPE = "INT" allows_generated = True - def __init__(self, primary_key: Optional[bool] = None, **kwargs: Any) -> None: + def __init__(self, primary_key: bool | None = None, **kwargs: Any) -> None: if primary_key or kwargs.get("pk"): kwargs["generated"] = bool(kwargs.get("generated", True)) super().__init__(primary_key=primary_key, **kwargs) @@ -195,7 +197,7 @@ def SQL_TYPE(self) -> str: # type: ignore return f"VARCHAR({self.max_length})" class _db_oracle: - def __init__(self, field: "CharField") -> None: + def __init__(self, field: CharField) -> None: self.field = field @property @@ -213,7 +215,7 @@ class TextField(Field[str], str): # type: ignore def __init__( self, - primary_key: Optional[bool] = None, + primary_key: bool | None = None, unique: bool = False, db_index: bool = False, **kwargs: Any, @@ -286,7 +288,7 @@ def __init__(self, max_digits: int, decimal_places: int, **kwargs: Any) -> None: self.decimal_places = decimal_places self.quant = Decimal("1" if decimal_places == 0 else f"1.{('0' * decimal_places)}") - def to_python_value(self, value: Any) -> Optional[Decimal]: + def to_python_value(self, value: Any) -> Decimal | None: if value is not None: value = Decimal(value).quantize(self.quant).normalize() return value @@ -345,7 +347,7 @@ def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: self.auto_now = auto_now self.auto_now_add = auto_now | auto_now_add - def to_python_value(self, value: Any) -> Optional[datetime.datetime]: + def to_python_value(self, value: Any) -> datetime.datetime | None: if value is not None: if isinstance(value, datetime.datetime): value = value @@ -360,8 +362,8 @@ def to_python_value(self, value: Any) -> Optional[datetime.datetime]: return value def to_db_value( - self, value: Optional[DatetimeFieldQueryValueType], instance: "Union[type[Model], Model]" - ) -> Optional[DatetimeFieldQueryValueType]: + self, value: DatetimeFieldQueryValueType | None, instance: type[Model] | Model + ) -> DatetimeFieldQueryValueType | None: # Only do this if it is a Model instance, not class. Test for guaranteed instance var if hasattr(instance, "_saved_in_db") and ( self.auto_now @@ -374,8 +376,8 @@ def to_db_value( if isinstance(value, datetime.datetime) and get_use_tz(): if timezone.is_naive(value): warnings.warn( - "DateTimeField %s received a naive datetime (%s)" - " while time zone support is active." % (self.model_field_name, value), + f"DateTimeField {self.model_field_name} received a naive datetime ({value})" + " while time zone support is active.", RuntimeWarning, ) value = timezone.make_aware(value, "UTC") @@ -404,14 +406,14 @@ class DateField(Field[datetime.date], datetime.date): skip_to_python_if_native = True SQL_TYPE = "DATE" - def to_python_value(self, value: Any) -> Optional[datetime.date]: + def to_python_value(self, value: Any) -> datetime.date | None: if value is not None and not isinstance(value, datetime.date): value = parse_datetime(value).date() return value def to_db_value( - self, value: Optional[Union[datetime.date, str]], instance: "Union[type[Model], Model]" - ) -> Optional[datetime.date]: + self, value: datetime.date | str | None, instance: type[Model] | Model + ) -> datetime.date | None: if value is not None and not isinstance(value, datetime.date): value = parse_datetime(value).date() self.validate(value) @@ -436,7 +438,7 @@ def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: self.auto_now = auto_now self.auto_now_add = auto_now | auto_now_add - def to_python_value(self, value: Any) -> Optional[Union[datetime.time, datetime.timedelta]]: + def to_python_value(self, value: Any) -> datetime.time | datetime.timedelta | None: if value is not None: if isinstance(value, str): value = datetime.time.fromisoformat(value) @@ -448,9 +450,9 @@ def to_python_value(self, value: Any) -> Optional[Union[datetime.time, datetime. def to_db_value( self, - value: Optional[Union[datetime.time, datetime.timedelta]], - instance: "Union[type[Model], Model]", - ) -> Optional[Union[datetime.time, datetime.timedelta]]: + value: datetime.time | datetime.timedelta | None, + instance: type[Model] | Model, + ) -> datetime.time | datetime.timedelta | None: # Only do this if it is a Model instance, not class. Test for guaranteed instance var if hasattr(instance, "_saved_in_db") and ( self.auto_now @@ -465,8 +467,8 @@ def to_db_value( if get_use_tz(): if timezone.is_naive(value): warnings.warn( - "TimeField %s received a naive time (%s)" - " while time zone support is active." % (self.model_field_name, value), + f"TimeField {self.model_field_name} received a naive time ({value})" + " while time zone support is active.", RuntimeWarning, ) value = value.replace(tzinfo=get_default_timezone()) @@ -491,14 +493,14 @@ class TimeDeltaField(Field[datetime.timedelta]): class _db_oracle: SQL_TYPE = "NUMBER(19)" - def to_python_value(self, value: Any) -> Optional[datetime.timedelta]: + def to_python_value(self, value: Any) -> datetime.timedelta | None: if value is None or isinstance(value, datetime.timedelta): return value return datetime.timedelta(microseconds=value) def to_db_value( - self, value: Optional[datetime.timedelta], instance: "Union[type[Model], Model]" - ) -> Optional[int]: + self, value: datetime.timedelta | None, instance: type[Model] | Model + ) -> int | None: self.validate(value) if value is None: @@ -565,14 +567,14 @@ def __init__( super().__init__(**kwargs) self.encoder = encoder self.decoder = decoder - if field_type := kwargs.get("field_type", None): + if field_type := kwargs.get("field_type"): self.field_type = field_type def to_db_value( self, - value: Optional[Union[T, dict, list, str, bytes]], - instance: "Union[type[Model], Model]", - ) -> Optional[str]: + value: T | dict | list | str | bytes | None, + instance: type[Model] | Model, + ) -> str | None: self.validate(value) if value is None: return None @@ -597,8 +599,8 @@ def to_db_value( return self.encoder(value) def to_python_value( - self, value: Optional[Union[T, str, bytes, dict, list]] - ) -> Optional[Union[T, dict, list]]: + self, value: T | str | bytes | dict | list | None + ) -> T | dict | list | None: if isinstance(value, (str, bytes)): try: data = self.decoder(value) @@ -639,10 +641,10 @@ def __init__(self, **kwargs: Any) -> None: kwargs["default"] = uuid4 super().__init__(**kwargs) - def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Optional[str]: + def to_db_value(self, value: Any, instance: type[Model] | Model) -> str | None: return value and str(value) - def to_python_value(self, value: Any) -> Optional[UUID]: + def to_python_value(self, value: Any) -> UUID | None: if value is None or isinstance(value, UUID): return value return UUID(value) @@ -673,7 +675,7 @@ class IntEnumFieldInstance(SmallIntField): def __init__( self, enum_type: type[IntEnum], - description: Optional[str] = None, + description: str | None = None, generated: bool = False, **kwargs: Any, ) -> None: @@ -686,7 +688,7 @@ def __init__( raise ConfigurationError("IntEnumField only supports integer enums!") if not minimum <= value < 32768: raise ConfigurationError( - "The valid range of IntEnumField's values is {}..32767!".format(minimum) + f"The valid range of IntEnumField's values is {minimum}..32767!" ) # Automatic description for the field if not specified by the user @@ -696,13 +698,11 @@ def __init__( super().__init__(description=description, **kwargs) self.enum_type = enum_type - def to_python_value(self, value: Union[int, None]) -> Union[IntEnum, None]: + def to_python_value(self, value: int | None) -> IntEnum | None: value = self.enum_type(value) if value is not None else None return value - def to_db_value( - self, value: Union[IntEnum, None, int], instance: "Union[type[Model], Model]" - ) -> Union[int, None]: + def to_db_value(self, value: IntEnum | None | int, instance: type[Model] | Model) -> int | None: if isinstance(value, IntEnum): value = int(value.value) if isinstance(value, int): @@ -716,7 +716,7 @@ def to_db_value( def IntEnumField( enum_type: type[IntEnumType], - description: Optional[str] = None, + description: str | None = None, **kwargs: Any, ) -> IntEnumType: """ @@ -743,7 +743,7 @@ class CharEnumFieldInstance(CharField): def __init__( self, enum_type: type[Enum], - description: Optional[str] = None, + description: str | None = None, max_length: int = 0, **kwargs: Any, ) -> None: @@ -761,12 +761,10 @@ def __init__( super().__init__(description=description, max_length=max_length, **kwargs) self.enum_type = enum_type - def to_python_value(self, value: Union[str, None]) -> Union[Enum, None]: + def to_python_value(self, value: str | None) -> Enum | None: return self.enum_type(value) if value is not None else None - def to_db_value( - self, value: Union[Enum, None, str], instance: "Union[type[Model], Model]" - ) -> Union[str, None]: + def to_db_value(self, value: Enum | None | str, instance: type[Model] | Model) -> str | None: self.validate(value) if isinstance(value, Enum): return str(value.value) @@ -780,7 +778,7 @@ def to_db_value( def CharEnumField( enum_type: type[CharEnumType], - description: Optional[str] = None, + description: str | None = None, max_length: int = 0, **kwargs: Any, ) -> CharEnumType: diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index 016824203..3055d5932 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import AsyncGenerator, Generator, Iterator from typing import ( TYPE_CHECKING, @@ -6,7 +8,6 @@ Literal, Optional, TypeVar, - Union, overload, ) @@ -45,7 +46,7 @@ def __init__( self, remote_model: type[MODEL], relation_field: str, - instance: "Model", + instance: Model, from_field: str, ) -> None: self.remote_model = remote_model @@ -57,7 +58,7 @@ def __init__( self.related_objects: list[MODEL] = [] @property - def _query(self) -> "QuerySet[MODEL]": + def _query(self) -> QuerySet[MODEL]: if not self.instance._saved_in_db: raise OperationalError( "This objects hasn't been instanced, call .save() before calling related queries" @@ -70,7 +71,7 @@ def __contains__(self, item: Any) -> bool: self._raise_if_not_fetched() return item in self.related_objects - def __iter__(self) -> "Iterator[MODEL]": + def __iter__(self) -> Iterator[MODEL]: self._raise_if_not_fetched() return self.related_objects.__iter__() @@ -95,37 +96,37 @@ async def __aiter__(self) -> AsyncGenerator[Any, MODEL]: for val in self.related_objects: yield val - def filter(self, *args: "Q", **kwargs: Any) -> "QuerySet[MODEL]": + def filter(self, *args: Q, **kwargs: Any) -> QuerySet[MODEL]: """ Returns a QuerySet with related elements filtered by args/kwargs. """ return self._query.filter(*args, **kwargs) - def all(self) -> "QuerySet[MODEL]": + def all(self) -> QuerySet[MODEL]: """ Returns a QuerySet with all related elements. """ return self._query - def order_by(self, *orderings: str) -> "QuerySet[MODEL]": + def order_by(self, *orderings: str) -> QuerySet[MODEL]: """ Returns a QuerySet related elements in order. """ return self._query.order_by(*orderings) - def limit(self, limit: int) -> "QuerySet[MODEL]": + def limit(self, limit: int) -> QuerySet[MODEL]: """ Returns a QuerySet with at most «limit» related elements. """ return self._query.limit(limit) - def offset(self, offset: int) -> "QuerySet[MODEL]": + def offset(self, offset: int) -> QuerySet[MODEL]: """ Returns a QuerySet with all related elements offset by «offset». """ return self._query.offset(offset) - def _set_result_for_query(self, sequence: list[MODEL], attr: Optional[str] = None) -> None: + def _set_result_for_query(self, sequence: list[MODEL], attr: str | None = None) -> None: self._fetched = True self.related_objects = sequence if attr: @@ -143,12 +144,12 @@ class ManyToManyRelation(ReverseRelation[MODEL]): Many-to-many relation container for :func:`.ManyToManyField`. """ - def __init__(self, instance: "Model", m2m_field: "ManyToManyFieldInstance[MODEL]") -> None: + def __init__(self, instance: Model, m2m_field: ManyToManyFieldInstance[MODEL]) -> None: super().__init__(m2m_field.related_model, m2m_field.related_name, instance, "pk") self.field = m2m_field self.instance = instance - async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: + async def add(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None) -> None: """ Adds one or more of ``instances`` to the relation. @@ -193,15 +194,13 @@ async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = query = query.insert(pk_f, pk_b) await db.execute_query(*query.get_parameterized_sql()) - async def clear(self, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: + async def clear(self, using_db: BaseDBAsyncClient | None = None) -> None: """ Clears ALL relations. """ await self._remove_or_clear(using_db=using_db) - async def remove( - self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = None - ) -> None: + async def remove(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None) -> None: """ Removes one or more of ``instances`` from the relation. @@ -213,8 +212,8 @@ async def remove( async def _remove_or_clear( self, - instances: Optional[tuple[MODEL, ...]] = None, - using_db: "Optional[BaseDBAsyncClient]" = None, + instances: tuple[MODEL, ...] | None = None, + using_db: BaseDBAsyncClient | None = None, ) -> None: db = using_db or self.remote_model._meta.db through_table = Table(self.field.through) @@ -242,13 +241,13 @@ class RelationalField(Field[MODEL]): def __init__( self, - related_model: "type[MODEL]", - to_field: Optional[str] = None, + related_model: type[MODEL], + to_field: str | None = None, db_constraint: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.related_model: "type[MODEL]" = related_model + self.related_model: type[MODEL] = related_model self.to_field: str = to_field # type: ignore self.to_field_instance: Field = None # type: ignore self.db_constraint = db_constraint @@ -256,16 +255,16 @@ def __init__( if TYPE_CHECKING: @overload - def __get__(self, instance: None, owner: type["Model"]) -> "RelationalField[MODEL]": ... + def __get__(self, instance: None, owner: type[Model]) -> RelationalField[MODEL]: ... @overload - def __get__(self, instance: "Model", owner: type["Model"]) -> MODEL: ... + def __get__(self, instance: Model, owner: type[Model]) -> MODEL: ... def __get__( - self, instance: Optional["Model"], owner: type["Model"] - ) -> "RelationalField[MODEL] | MODEL": ... + self, instance: Model | None, owner: type[Model] + ) -> RelationalField[MODEL] | MODEL: ... - def __set__(self, instance: "Model", value: MODEL) -> None: ... + def __set__(self, instance: Model, value: MODEL) -> None: ... def describe(self, serializable: bool) -> dict: desc = super().describe(serializable) @@ -284,7 +283,7 @@ class ForeignKeyFieldInstance(RelationalField[MODEL]): def __init__( self, model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, **kwargs: Any, ) -> None: @@ -310,24 +309,24 @@ def describe(self, serializable: bool) -> dict: class BackwardFKRelation(RelationalField[MODEL]): def __init__( self, - field_type: "type[MODEL]", + field_type: type[MODEL], relation_field: str, relation_source_field: str, null: bool, - description: Optional[str], + description: str | None, **kwargs: Any, ) -> None: super().__init__(field_type, null=null, **kwargs) self.relation_field: str = relation_field self.relation_source_field: str = relation_source_field - self.description: Optional[str] = description + self.description: str | None = description class OneToOneFieldInstance(ForeignKeyFieldInstance[MODEL]): def __init__( self, model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, **kwargs: Any, ) -> None: @@ -345,12 +344,12 @@ class ManyToManyFieldInstance(RelationalField[MODEL]): def __init__( self, model_name: str, - through: Optional[str] = None, - forward_key: Optional[str] = None, + through: str | None = None, + forward_key: str | None = None, backward_key: str = "", related_name: str = "", on_delete: OnDelete = CASCADE, - field_type: "type[MODEL]" = None, # type: ignore + field_type: type[MODEL] = None, # type: ignore create_unique_index: bool = True, **kwargs: Any, ) -> None: @@ -382,34 +381,34 @@ def describe(self, serializable: bool) -> dict: @overload def OneToOneField( model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, *, null: Literal[True], **kwargs: Any, -) -> "OneToOneNullableRelation[MODEL]": ... +) -> OneToOneNullableRelation[MODEL]: ... @overload def OneToOneField( model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, **kwargs: Any, -) -> "OneToOneRelation[MODEL]": ... +) -> OneToOneRelation[MODEL]: ... def OneToOneField( model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: bool = False, **kwargs: Any, -) -> "OneToOneRelation[MODEL] | OneToOneNullableRelation[MODEL]": +) -> OneToOneRelation[MODEL] | OneToOneNullableRelation[MODEL]: """ OneToOne relation field. @@ -457,34 +456,34 @@ def OneToOneField( @overload def ForeignKeyField( model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, *, null: Literal[True], **kwargs: Any, -) -> "ForeignKeyNullableRelation[MODEL]": ... +) -> ForeignKeyNullableRelation[MODEL]: ... @overload def ForeignKeyField( model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, **kwargs: Any, -) -> "ForeignKeyRelation[MODEL]": ... +) -> ForeignKeyRelation[MODEL]: ... def ForeignKeyField( model_name: str, - related_name: Union[Optional[str], Literal[False]] = None, + related_name: str | None | Literal[False] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: bool = False, **kwargs: Any, -) -> "ForeignKeyRelation[MODEL] | ForeignKeyNullableRelation[MODEL]": +) -> ForeignKeyRelation[MODEL] | ForeignKeyNullableRelation[MODEL]: """ ForeignKey relation field. @@ -531,15 +530,15 @@ def ForeignKeyField( def ManyToManyField( model_name: str, - through: Optional[str] = None, - forward_key: Optional[str] = None, + through: str | None = None, + forward_key: str | None = None, backward_key: str = "", related_name: str = "", on_delete: OnDelete = CASCADE, db_constraint: bool = True, create_unique_index: bool = True, **kwargs: Any, -) -> "ManyToManyRelation[Any]": +) -> ManyToManyRelation[Any]: """ ManyToMany relation field. diff --git a/tortoise/filters.py b/tortoise/filters.py index b97a81212..dcb4c7f78 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -1,9 +1,9 @@ from __future__ import annotations import operator -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from functools import partial -from typing import TYPE_CHECKING, Any, Optional, Sequence, TypedDict, Union +from typing import TYPE_CHECKING, Any, TypedDict from pypika_tortoise import SqlContext, Table from pypika_tortoise.enums import DatePart, Matching, SqlTypes @@ -40,7 +40,7 @@ def __init__(self, left, right, alias=None, escape=" ESCAPE '\\'") -> None: def get_sql(self, ctx: SqlContext): sql = super().get_sql(ctx.copy(with_alias=False)) + str(self.escape) if ctx.with_alias and self.alias: # pragma: nocoverage - return '{sql} "{alias}"'.format(sql=sql, alias=self.alias) + return f'{sql} "{self.alias}"' return sql @@ -54,35 +54,35 @@ def escape_like(val: str) -> str: ############################################################################## -def list_encoder(values: Iterable[Any], instance: "Model", field: Field) -> list: +def list_encoder(values: Iterable[Any], instance: Model, field: Field) -> list: """Encodes an iterable of a given field into a database-compatible format.""" return [field.to_db_value(element, instance) for element in values] -def related_list_encoder(values: Iterable[Any], instance: "Model", field: Field) -> list: +def related_list_encoder(values: Iterable[Any], instance: Model, field: Field) -> list: return [ field.to_db_value(element.pk if hasattr(element, "pk") else element, instance) for element in values ] -def bool_encoder(value: Any, instance: "Model", field: Field) -> bool: +def bool_encoder(value: Any, instance: Model, field: Field) -> bool: return bool(value) -def string_encoder(value: Any, instance: "Model", field: Field) -> str: +def string_encoder(value: Any, instance: Model, field: Field) -> str: return str(value) -def int_encoder(value: Any, instance: "Model", field: Field) -> int: +def int_encoder(value: Any, instance: Model, field: Field) -> int: return int(value) -def json_encoder(value: Any, instance: "Model", field: Field) -> dict: +def json_encoder(value: Any, instance: Model, field: Field) -> dict: return value -def array_encoder(value: Union[Any, Sequence[Any]], instance: "Model", field: Field) -> Any: +def array_encoder(value: Any | Sequence[Any], instance: Model, field: Field) -> Any: # Casting to the exact type of the field to avoid issues with psycopg that tries # to use the smallest possible type which can lead to errors, # e.g. {1,2} will be casted to smallint[] instead of integer[]. @@ -238,15 +238,15 @@ def json_filter(field: Term, value: dict) -> Criterion: raise NotImplementedError("must be overridden in each xecutor") -def array_contains(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: +def array_contains(field: Term, value: Any | Sequence[Any]) -> Criterion: raise NotImplementedError("must be overridden in each executor") -def array_contained_by(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: +def array_contained_by(field: Term, value: Any | Sequence[Any]) -> Criterion: raise NotImplementedError("must be overridden in each executor") -def array_overlap(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: +def array_overlap(field: Term, value: Any | Sequence[Any]) -> Criterion: raise NotImplementedError("must be overridden in each executor") @@ -462,7 +462,7 @@ def get_array_filter( def get_filters_for_field( - field_name: str, field: Optional[Field], source_field: str + field_name: str, field: Field | None, source_field: str ) -> dict[str, FilterInfoDict]: if field is not None: if isinstance(field, ManyToManyFieldInstance): diff --git a/tortoise/indexes.py b/tortoise/indexes.py index 4b1eb43fe..8042c047f 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -32,7 +32,7 @@ def __init__( self.fields = list(fields or []) if not expressions and not fields: raise ConfigurationError( - "At least one field or expression is required to define an " "index." + "At least one field or expression is required to define an index." ) if expressions and fields: raise ConfigurationError( @@ -51,13 +51,11 @@ def describe(self) -> dict: "extra": self.extra, } - def index_name(self, schema_generator: "BaseSchemaGenerator", model: "type[Model]") -> str: + def index_name(self, schema_generator: BaseSchemaGenerator, model: type[Model]) -> str: # This function is required by aerich - return self.name or schema_generator._generate_index_name("idx", model, self.field_names) + return self.name or schema_generator._get_index_name("idx", model, self.field_names) - def get_sql( - self, schema_generator: "BaseSchemaGenerator", model: "type[Model]", safe: bool - ) -> str: + def get_sql(self, schema_generator: BaseSchemaGenerator, model: type[Model], safe: bool) -> str: # This function is required by aerich return schema_generator._get_index_sql( model, diff --git a/tortoise/manager.py b/tortoise/manager.py index bb2c948be..0e5bdc850 100644 --- a/tortoise/manager.py +++ b/tortoise/manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from tortoise.queryset import QuerySet diff --git a/tortoise/models.py b/tortoise/models.py index 3276831d2..5888d0a46 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import asyncio import inspect import re from collections.abc import Awaitable, Callable, Generator, Iterable from copy import copy, deepcopy from functools import partial -from typing import Any, Optional, TypedDict, TypeVar, Union, cast +from typing import Any, TypedDict, TypeVar, cast from pypika_tortoise import Order, Query, Table from pypika_tortoise.terms import Term @@ -56,7 +58,7 @@ EMPTY = object() -def get_together(meta: "Model.Meta", together: str) -> tuple[tuple[str, ...], ...]: +def get_together(meta: Model.Meta, together: str) -> tuple[tuple[str, ...], ...]: _together = getattr(meta, together, ()) if _together and isinstance(_together, (list, tuple)) and isinstance(_together[0], str): @@ -66,7 +68,7 @@ def get_together(meta: "Model.Meta", together: str) -> tuple[tuple[str, ...], .. return _together -def prepare_default_ordering(meta: "Model.Meta") -> tuple[tuple[str, Order], ...]: +def prepare_default_ordering(meta: Model.Meta) -> tuple[tuple[str, Order], ...]: ordering_list = getattr(meta, "ordering", ()) parsed_ordering = tuple( @@ -83,8 +85,8 @@ class FkSetterKwargs(TypedDict): def _fk_setter( - self: "Model", - value: "Optional[Model]", + self: Model, + value: Model | None, _key: str, relation_field: str, to_field: str, @@ -94,7 +96,7 @@ def _fk_setter( def _fk_getter( - self: "Model", _key: str, ftype: "type[Model]", relation_field: str, to_field: str + self: Model, _key: str, ftype: type[Model], relation_field: str, to_field: str ) -> Awaitable: try: return getattr(self, _key) @@ -106,7 +108,7 @@ def _fk_getter( def _rfk_getter( - self: "Model", _key: str, ftype: "type[Model]", frelfield: str, from_field: str + self: Model, _key: str, ftype: type[Model], frelfield: str, from_field: str ) -> ReverseRelation: val = getattr(self, _key, None) if val is None: @@ -116,8 +118,8 @@ def _rfk_getter( def _ro2o_getter( - self: "Model", _key: str, ftype: "type[Model]", frelfield: str, from_field: str -) -> "QuerySetSingle[Optional[Model]]": + self: Model, _key: str, ftype: type[Model], frelfield: str, from_field: str +) -> QuerySetSingle[Model | None]: if hasattr(self, _key): return getattr(self, _key) @@ -127,7 +129,7 @@ def _ro2o_getter( def _m2m_getter( - self: "Model", _key: str, field_object: ManyToManyFieldInstance + self: Model, _key: str, field_object: ManyToManyFieldInstance ) -> ManyToManyRelation: val = getattr(self, _key, None) if val is None: @@ -136,7 +138,7 @@ def _m2m_getter( return val -def _get_comments(cls: "type[Model]") -> dict[str, str]: +def _get_comments(cls: type[Model]) -> dict[str, str]: """ Get comments exactly before attributes @@ -207,14 +209,14 @@ class MetaInfo: "_ordering_validated", ) - def __init__(self, meta: "Model.Meta") -> None: + def __init__(self, meta: Model.Meta) -> None: self.abstract: bool = getattr(meta, "abstract", False) self.manager: Manager = getattr(meta, "manager", Manager()) self.db_table: str = getattr(meta, "table", "") - self.schema: Optional[str] = getattr(meta, "schema", None) - self.app: Optional[str] = getattr(meta, "app", None) + self.schema: str | None = getattr(meta, "schema", None) + self.app: str | None = getattr(meta, "app", None) self.unique_together: tuple[tuple[str, ...], ...] = get_together(meta, "unique_together") - self.indexes: tuple[Union[tuple[str, ...], Index], ...] = get_together(meta, "indexes") + self.indexes: tuple[tuple[str, ...] | Index, ...] = get_together(meta, "indexes") self._default_ordering: tuple[tuple[str, Order], ...] = prepare_default_ordering(meta) self._ordering_validated: bool = False self.fields: set[str] = set() @@ -231,13 +233,13 @@ def __init__(self, meta: "Model.Meta") -> None: self.filters: dict[str, FilterInfoDict] = {} self.fields_map: dict[str, Field] = {} self._inited: bool = False - self.default_connection: Optional[str] = None + self.default_connection: str | None = None self.basequery: Query = Query() self.basequery_all_fields: Query = Query() self.basetable: Table = Table("") self.pk_attr: str = getattr(meta, "pk_attr", "") self.generated_db_fields: tuple[str, ...] = None # type: ignore - self._model: type["Model"] = None # type: ignore + self._model: type[Model] = None # type: ignore self.table_description: str = getattr(meta, "table_description", "") self.pk: Field = None # type: ignore self.db_pk_column: str = "" @@ -474,9 +476,9 @@ def _generate_filters(self) -> None: class ModelMeta(type): __slots__ = () - def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> "ModelMeta": + def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> ModelMeta: fields_db_projection: dict[str, str] = {} - meta_class: "Model.Meta" = attrs.get("Meta", type("Meta", (), {})) + meta_class: Model.Meta = attrs.get("Meta", type("Meta", (), {})) pk_attr: str = "id" # Start searching for fields in the base classes. @@ -636,7 +638,7 @@ def _dispatch_fields(attrs: dict, fields_db_projection: dict, is_abstract) -> tu @staticmethod def build_meta( - meta_class: "Model.Meta", + meta_class: Model.Meta, fields_map: dict[str, Field], fields_db_projection: dict[str, str], filters: dict[str, FilterInfoDict], @@ -831,7 +833,7 @@ def _set_pk_val(self, value: Any) -> None: """ @classmethod - def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> None: + def _validate_relation_type(cls, field_key: str, value: Model | None) -> None: if value is None: return @@ -927,31 +929,31 @@ async def _wait_for_listeners(self, signal: Signals, *listener_args) -> None: listeners = [listener(self.__class__, self, *listener_args) for listener in cls_listeners] await asyncio.gather(*listeners) - async def _pre_delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None: + async def _pre_delete(self, using_db: BaseDBAsyncClient | None = None) -> None: await self._wait_for_listeners(Signals.pre_delete, using_db) - async def _post_delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None: + async def _post_delete(self, using_db: BaseDBAsyncClient | None = None) -> None: await self._wait_for_listeners(Signals.post_delete, using_db) async def _pre_save( self, - using_db: Optional[BaseDBAsyncClient] = None, - update_fields: Optional[Iterable[str]] = None, + using_db: BaseDBAsyncClient | None = None, + update_fields: Iterable[str] | None = None, ) -> None: await self._wait_for_listeners(Signals.pre_save, using_db, update_fields) async def _post_save( self, - using_db: Optional[BaseDBAsyncClient] = None, + using_db: BaseDBAsyncClient | None = None, created: bool = False, - update_fields: Optional[Iterable[str]] = None, + update_fields: Iterable[str] | None = None, ) -> None: await self._wait_for_listeners(Signals.post_save, created, using_db, update_fields) async def save( self, - using_db: Optional[BaseDBAsyncClient] = None, - update_fields: Optional[Iterable[str]] = None, + using_db: BaseDBAsyncClient | None = None, + update_fields: Iterable[str] | None = None, force_create: bool = False, force_update: bool = False, ) -> None: @@ -1014,7 +1016,7 @@ async def save( self._saved_in_db = True await self._post_save(db, created, update_fields) - async def delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None: + async def delete(self, using_db: BaseDBAsyncClient | None = None) -> None: """ Deletes the current model object. @@ -1029,7 +1031,7 @@ async def delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None: await db.executor_class(model=self.__class__, db=db).execute_delete(self) await self._post_delete(db) - async def fetch_related(self, *args: Any, using_db: Optional[BaseDBAsyncClient] = None) -> None: + async def fetch_related(self, *args: Any, using_db: BaseDBAsyncClient | None = None) -> None: """ Fetch related fields. @@ -1045,8 +1047,8 @@ async def fetch_related(self, *args: Any, using_db: Optional[BaseDBAsyncClient] async def refresh_from_db( self, - fields: Optional[Iterable[str]] = None, - using_db: Optional[BaseDBAsyncClient] = None, + fields: Iterable[str] | None = None, + using_db: BaseDBAsyncClient | None = None, ) -> None: """ Refresh latest data from db. When this method is called without arguments @@ -1087,8 +1089,8 @@ def _choose_db(cls, for_write: bool = False) -> BaseDBAsyncClient: @classmethod async def get_or_create( cls, - defaults: Optional[dict] = None, - using_db: Optional[BaseDBAsyncClient] = None, + defaults: dict | None = None, + using_db: BaseDBAsyncClient | None = None, **kwargs: Any, ) -> tuple[Self, bool]: """ @@ -1131,7 +1133,7 @@ async def _create_or_get( @classmethod def _db_queryset( - cls, using_db: Optional[BaseDBAsyncClient] = None, for_write: bool = False + cls, using_db: BaseDBAsyncClient | None = None, for_write: bool = False ) -> QuerySet[Self]: db = using_db or cls._choose_db(for_write) return cls._meta.manager.get_queryset().using_db(db) @@ -1142,7 +1144,7 @@ def select_for_update( nowait: bool = False, skip_locked: bool = False, of: tuple[str, ...] = (), - using_db: Optional[BaseDBAsyncClient] = None, + using_db: BaseDBAsyncClient | None = None, ) -> QuerySet[Self]: """ Make QuerySet select for update. @@ -1155,8 +1157,8 @@ def select_for_update( @classmethod async def update_or_create( cls: type[MODEL], - defaults: Optional[dict] = None, - using_db: Optional[BaseDBAsyncClient] = None, + defaults: dict | None = None, + using_db: BaseDBAsyncClient | None = None, **kwargs: Any, ) -> tuple[MODEL, bool]: """ @@ -1178,7 +1180,7 @@ async def update_or_create( @classmethod async def create( - cls: type[MODEL], using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + cls: type[MODEL], using_db: BaseDBAsyncClient | None = None, **kwargs: Any ) -> MODEL: """ Create a record in the DB and returns the object. @@ -1208,9 +1210,9 @@ def bulk_update( cls: type[MODEL], objects: Iterable[MODEL], fields: Iterable[str], - batch_size: Optional[int] = None, - using_db: Optional[BaseDBAsyncClient] = None, - ) -> "BulkUpdateQuery[MODEL]": + batch_size: int | None = None, + using_db: BaseDBAsyncClient | None = None, + ) -> BulkUpdateQuery[MODEL]: """ Update the given fields in each of the given objects in the database. This method efficiently updates the given fields on the provided model instances, generally with one query. @@ -1236,9 +1238,9 @@ def bulk_update( @classmethod async def in_bulk( cls: type[MODEL], - id_list: Iterable[Union[str, int]], + id_list: Iterable[str | int], field_name: str = "pk", - using_db: Optional[BaseDBAsyncClient] = None, + using_db: BaseDBAsyncClient | None = None, ) -> dict[str, MODEL]: """ Return a dictionary mapping each of the given IDs to the object with @@ -1254,12 +1256,12 @@ async def in_bulk( def bulk_create( cls: type[MODEL], objects: Iterable[MODEL], - batch_size: Optional[int] = None, + batch_size: int | None = None, ignore_conflicts: bool = False, - update_fields: Optional[Iterable[str]] = None, - on_conflict: Optional[Iterable[str]] = None, - using_db: Optional[BaseDBAsyncClient] = None, - ) -> "BulkCreateQuery[MODEL]": + update_fields: Iterable[str] | None = None, + on_conflict: Iterable[str] | None = None, + using_db: BaseDBAsyncClient | None = None, + ) -> BulkCreateQuery[MODEL]: """ Bulk insert operation: @@ -1292,14 +1294,14 @@ def bulk_create( ) @classmethod - def first(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySetSingle[Optional[Self]]: + def first(cls, using_db: BaseDBAsyncClient | None = None) -> QuerySetSingle[Self | None]: """ Generates a QuerySet that returns the first record. """ return cls._db_queryset(using_db).first() @classmethod - def last(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySetSingle[Optional[Self]]: + def last(cls, using_db: BaseDBAsyncClient | None = None) -> QuerySetSingle[Self | None]: """ Generates a QuerySet that returns the last record. """ @@ -1316,7 +1318,7 @@ def filter(cls, *args: Q, **kwargs: Any) -> QuerySet[Self]: return cls._meta.manager.get_queryset().filter(*args, **kwargs) @classmethod - def latest(cls, *orderings: str) -> QuerySetSingle[Optional[Self]]: + def latest(cls, *orderings: str) -> QuerySetSingle[Self | None]: """ Generates a QuerySet with the filter applied that returns the last record. @@ -1325,7 +1327,7 @@ def latest(cls, *orderings: str) -> QuerySetSingle[Optional[Self]]: return cls._meta.manager.get_queryset().latest(*orderings) @classmethod - def earliest(cls, *orderings: str) -> QuerySetSingle[Optional[Self]]: + def earliest(cls, *orderings: str) -> QuerySetSingle[Self | None]: """ Generates a QuerySet with the filter applied that returns the first record. @@ -1344,7 +1346,7 @@ def exclude(cls, *args: Q, **kwargs: Any) -> QuerySet[Self]: return cls._meta.manager.get_queryset().exclude(*args, **kwargs) @classmethod - def annotate(cls, **kwargs: Union[Expression, Term]) -> QuerySet[Self]: + def annotate(cls, **kwargs: Expression | Term) -> QuerySet[Self]: """ Annotates the result set with extra Functions/Aggregations/Expressions. @@ -1353,7 +1355,7 @@ def annotate(cls, **kwargs: Union[Expression, Term]) -> QuerySet[Self]: return cls._meta.manager.get_queryset().annotate(**kwargs) @classmethod - def all(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySet[Self]: + def all(cls, using_db: BaseDBAsyncClient | None = None) -> QuerySet[Self]: """ Returns the complete QuerySet. """ @@ -1361,7 +1363,7 @@ def all(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySet[Self]: @classmethod def get( - cls, *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + cls, *args: Q, using_db: BaseDBAsyncClient | None = None, **kwargs: Any ) -> QuerySetSingle[Self]: """ Fetches a single record for a Model type using the provided filter parameters. @@ -1380,7 +1382,7 @@ def get( return cls._db_queryset(using_db).get(*args, **kwargs) @classmethod - def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQuery": + def raw(cls, sql: str, using_db: BaseDBAsyncClient | None = None) -> RawSQLQuery: """ Executes a RAW SQL and returns the result @@ -1395,7 +1397,7 @@ def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQ @classmethod def exists( - cls: type[MODEL], *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + cls: type[MODEL], *args: Q, using_db: BaseDBAsyncClient | None = None, **kwargs: Any ) -> ExistsQuery: """ Return True/False whether record exists with the provided filter parameters. @@ -1412,8 +1414,8 @@ def exists( @classmethod def get_or_none( - cls, *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any - ) -> QuerySetSingle[Optional[Self]]: + cls, *args: Q, using_db: BaseDBAsyncClient | None = None, **kwargs: Any + ) -> QuerySetSingle[Self | None]: """ Fetches a single record for a Model type using the provided filter parameters or None. @@ -1430,9 +1432,9 @@ def get_or_none( @classmethod async def fetch_for_list( cls, - instance_list: "Iterable[Model]", + instance_list: Iterable[Model], *args: Any, - using_db: Optional[BaseDBAsyncClient] = None, + using_db: BaseDBAsyncClient | None = None, ) -> None: """ Fetches related models for provided list of Model objects. @@ -1489,8 +1491,8 @@ def _check_together(cls, together: str) -> None: @classmethod def _describe_index( - cls, index: Union[Index, tuple[str, ...]], serializable: bool - ) -> Union[Index, tuple[str, ...], dict]: + cls, index: Index | tuple[str, ...], serializable: bool + ) -> Index | tuple[str, ...] | dict: if isinstance(index, Index): return index.describe() if serializable else index diff --git a/tortoise/query_utils.py b/tortoise/query_utils.py index 6cfe38640..5a7235cd9 100644 --- a/tortoise/query_utils.py +++ b/tortoise/query_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import copy -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from pypika_tortoise import Table from pypika_tortoise.terms import Criterion, Term @@ -79,8 +79,8 @@ def get_joins_for_related_field( def resolve_nested_field( - model: type["Model"], table: Table, field: str -) -> tuple[Term, list[TableCriterionTuple], Optional[Field]]: + model: type[Model], table: Table, field: str +) -> tuple[Term, list[TableCriterionTuple], Field | None]: """ Resolves a nested field string like events__participants__name and returns the pypika term, required joins and the Field that can be used for @@ -164,9 +164,9 @@ class QueryModifier: def __init__( self, - where_criterion: Optional[Criterion] = None, - joins: Optional[list[TableCriterionTuple]] = None, - having_criterion: Optional[Criterion] = None, + where_criterion: Criterion | None = None, + joins: list[TableCriterionTuple] | None = None, + having_criterion: Criterion | None = None, ) -> None: self.where_criterion: Criterion = where_criterion or EmptyCriterion() self.joins = joins or [] @@ -215,13 +215,13 @@ class Prefetch: __slots__ = ("relation", "queryset", "to_attr") - def __init__(self, relation: str, queryset: "QuerySet", to_attr: Optional[str] = None) -> None: + def __init__(self, relation: str, queryset: QuerySet, to_attr: str | None = None) -> None: self.to_attr = to_attr self.relation = relation self.queryset = queryset self.queryset.query = copy(self.queryset.model._meta.basequery) - def resolve_for_queryset(self, queryset: "QuerySet") -> None: + def resolve_for_queryset(self, queryset: QuerySet) -> None: """ Called internally to generate prefetching query. diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 2d1fc4d42..27da2e10b 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import types from collections.abc import AsyncIterator, Callable, Generator, Iterable from copy import copy -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast, overload +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast, overload from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count @@ -55,24 +57,24 @@ class QuerySetSingle(Protocol[T_co]): def __await__(self) -> Generator[Any, None, T_co]: ... # pragma: nocoverage def prefetch_related( - self, *args: Union[str, Prefetch] - ) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage + self, *args: str | Prefetch + ) -> QuerySetSingle[T_co]: ... # pragma: nocoverage - def select_related(self, *args: str) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage + def select_related(self, *args: str) -> QuerySetSingle[T_co]: ... # pragma: nocoverage def annotate( - self, **kwargs: Union[Expression, Term] - ) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage + self, **kwargs: Expression | Term + ) -> QuerySetSingle[T_co]: ... # pragma: nocoverage - def only(self, *fields_for_select: str) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage + def only(self, *fields_for_select: str) -> QuerySetSingle[T_co]: ... # pragma: nocoverage def values_list( self, *fields_: str, flat: bool = False - ) -> "ValuesListQuery[Literal[True]]": ... # pragma: nocoverage + ) -> ValuesListQuery[Literal[True]]: ... # pragma: nocoverage def values( self, *args: str, **kwargs: str - ) -> "ValuesQuery[Literal[True]]": ... # pragma: nocoverage + ) -> ValuesQuery[Literal[True]]: ... # pragma: nocoverage class AwaitableQuery(Generic[MODEL]): @@ -89,11 +91,11 @@ class AwaitableQuery(Generic[MODEL]): def __init__(self, model: type[MODEL]) -> None: self._joined_tables: list[Table] = [] - self.model: "type[MODEL]" = model + self.model: type[MODEL] = model self.query: QueryBuilder = QUERY self._db: BaseDBAsyncClient = None # type: ignore self.capabilities: Capabilities = model._meta.db.capabilities - self._annotations: dict[str, Union[Expression, Term]] = {} + self._annotations: dict[str, Expression | Term] = {} self._custom_filters: dict[str, FilterInfoDict] = {} self._q_objects: list[Q] = [] @@ -174,9 +176,9 @@ def _resolve_ordering_string(ordering: str, reverse: bool = False) -> tuple[str, def resolve_ordering( self, - model: "type[Model]", + model: type[Model], table: Table, - orderings: Iterable[tuple[str, Union[str, Order]]], + orderings: Iterable[tuple[str, str | Order]], annotations: dict[str, Any], ) -> None: """ @@ -319,12 +321,12 @@ class QuerySet(AwaitableQuery[MODEL]): def __init__(self, model: type[MODEL]) -> None: super().__init__(model) self.fields: set[str] = model._meta.db_fields - self._prefetch_map: dict[str, set[Union[str, Prefetch]]] = {} - self._prefetch_queries: dict[str, list[tuple[Optional[str], QuerySet]]] = {} + self._prefetch_map: dict[str, set[str | Prefetch]] = {} + self._prefetch_queries: dict[str, list[tuple[str | None, QuerySet]]] = {} self._single: bool = False self._raise_does_not_exist: bool = False - self._limit: Optional[int] = None - self._offset: Optional[int] = None + self._limit: int | None = None + self._offset: int | None = None self._filter_kwargs: dict[str, Any] = {} self._orderings: list[tuple[str, Any]] = [] self._distinct: bool = False @@ -337,12 +339,12 @@ def __init__(self, model: type[MODEL]) -> None: self._select_for_update_of: set[str] = set() self._select_related: set[str] = set() self._select_related_idx: list[ - tuple["type[Model]", int, Union[Table, str], "type[Model]", Iterable[Optional[str]]] + tuple[type[Model], int, Table | str, type[Model], Iterable[str | None]] ] = [] # format with: model,idx,model_name,parent_model self._force_indexes: set[str] = set() self._use_indexes: set[str] = set() - def _clone(self) -> "QuerySet[MODEL]": + def _clone(self) -> QuerySet[MODEL]: queryset = self.__class__.__new__(self.__class__) queryset.fields = self.fields queryset.model = self.model @@ -375,7 +377,7 @@ def _clone(self) -> "QuerySet[MODEL]": queryset._use_indexes = self._use_indexes return queryset - def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> "QuerySet[MODEL]": + def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> QuerySet[MODEL]: queryset = self._clone() for arg in args: if not isinstance(arg, Q): @@ -393,7 +395,7 @@ def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> "QuerySet return queryset - def filter(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]": + def filter(self, *args: Q, **kwargs: Any) -> QuerySet[MODEL]: """ Filters QuerySet by given kwargs. You can filter by related objects like this: @@ -405,7 +407,7 @@ def filter(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]": """ return self._filter_or_exclude(negate=False, *args, **kwargs) - def exclude(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]": + def exclude(self, *args: Q, **kwargs: Any) -> QuerySet[MODEL]: """ Same as .filter(), but with appends all args with NOT """ @@ -433,7 +435,7 @@ def _parse_orderings( new_ordering.append((field_name, order_type)) return new_ordering - def order_by(self, *orderings: str) -> "QuerySet[MODEL]": + def order_by(self, *orderings: str) -> QuerySet[MODEL]: """ Accept args to filter by in format like this: @@ -450,12 +452,12 @@ def order_by(self, *orderings: str) -> "QuerySet[MODEL]": queryset._orderings = self._parse_orderings(orderings) return queryset - def _as_single(self) -> QuerySetSingle[Optional[MODEL]]: + def _as_single(self) -> QuerySetSingle[MODEL | None]: self._single = True self._limit = 1 return cast(QuerySetSingle[Optional[MODEL]], self) - def latest(self, *orderings: str) -> QuerySetSingle[Optional[MODEL]]: + def latest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: """ Returns the most recent object by ordering descending on the providers fields. @@ -469,7 +471,7 @@ def latest(self, *orderings: str) -> QuerySetSingle[Optional[MODEL]]: queryset._orderings = self._parse_orderings(orderings, reverse=True) return queryset._as_single() - def earliest(self, *orderings: str) -> QuerySetSingle[Optional[MODEL]]: + def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: """ Returns the earliest object by ordering ascending on the specified field. @@ -483,7 +485,7 @@ def earliest(self, *orderings: str) -> QuerySetSingle[Optional[MODEL]]: queryset._orderings = self._parse_orderings(orderings) return queryset._as_single() - def limit(self, limit: int) -> "QuerySet[MODEL]": + def limit(self, limit: int) -> QuerySet[MODEL]: """ Limits QuerySet to given length. @@ -496,7 +498,7 @@ def limit(self, limit: int) -> "QuerySet[MODEL]": queryset._limit = limit return queryset - def offset(self, offset: int) -> "QuerySet[MODEL]": + def offset(self, offset: int) -> QuerySet[MODEL]: """ Query offset for QuerySet. @@ -511,7 +513,7 @@ def offset(self, offset: int) -> "QuerySet[MODEL]": queryset._limit = 1000000 return queryset - def __getitem__(self, key: slice) -> "QuerySet[MODEL]": + def __getitem__(self, key: slice) -> QuerySet[MODEL]: """ Query offset and limit for Queryset. @@ -544,7 +546,7 @@ def __getitem__(self, key: slice) -> "QuerySet[MODEL]": queryset = queryset.limit(key.stop - start) return queryset - def distinct(self) -> "QuerySet[MODEL]": + def distinct(self) -> QuerySet[MODEL]: """ Make QuerySet distinct. @@ -557,7 +559,7 @@ def distinct(self) -> "QuerySet[MODEL]": def select_for_update( self, nowait: bool = False, skip_locked: bool = False, of: tuple[str, ...] = () - ) -> "QuerySet[MODEL]": + ) -> QuerySet[MODEL]: """ Make QuerySet select for update. @@ -573,7 +575,7 @@ def select_for_update( return queryset return self - def annotate(self, **kwargs: Union[Expression, Term]) -> "QuerySet[MODEL]": + def annotate(self, **kwargs: Expression | Term) -> QuerySet[MODEL]: """ Annotate result with aggregation or function result. @@ -589,7 +591,7 @@ def annotate(self, **kwargs: Union[Expression, Term]) -> "QuerySet[MODEL]": queryset._custom_filters.update(get_filters_for_field(key, None, key)) return queryset - def group_by(self, *fields: str) -> "QuerySet[MODEL]": + def group_by(self, *fields: str) -> QuerySet[MODEL]: """ Make QuerySet returns list of dict or tuple with group by. @@ -599,7 +601,7 @@ def group_by(self, *fields: str) -> "QuerySet[MODEL]": queryset._group_bys = fields return queryset - def values_list(self, *fields_: str, flat: bool = False) -> "ValuesListQuery[Literal[False]]": + def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Literal[False]]: """ Make QuerySet returns list of tuples for given args instead of objects. @@ -632,7 +634,7 @@ def values_list(self, *fields_: str, flat: bool = False) -> "ValuesListQuery[Lit use_indexes=self._use_indexes, ) - def values(self, *args: str, **kwargs: str) -> "ValuesQuery[Literal[False]]": + def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: """ Make QuerySet return dicts instead of objects. @@ -682,7 +684,7 @@ def values(self, *args: str, **kwargs: str) -> "ValuesQuery[Literal[False]]": use_indexes=self._use_indexes, ) - def delete(self) -> "DeleteQuery": + def delete(self) -> DeleteQuery: """ Delete all objects in QuerySet. """ @@ -696,7 +698,7 @@ def delete(self) -> "DeleteQuery": orderings=self._orderings, ) - def update(self, **kwargs: Any) -> "UpdateQuery": + def update(self, **kwargs: Any) -> UpdateQuery: """ Update all objects in QuerySet with given kwargs. @@ -719,7 +721,7 @@ def update(self, **kwargs: Any) -> "UpdateQuery": orderings=self._orderings, ) - def count(self) -> "CountQuery": + def count(self) -> CountQuery: """ Return count of objects in queryset instead of objects. """ @@ -735,7 +737,7 @@ def count(self) -> "CountQuery": use_indexes=self._use_indexes, ) - def exists(self) -> "ExistsQuery": + def exists(self) -> ExistsQuery: """ Return True/False whether queryset exists. """ @@ -749,27 +751,27 @@ def exists(self) -> "ExistsQuery": use_indexes=self._use_indexes, ) - def all(self) -> "QuerySet[MODEL]": + def all(self) -> QuerySet[MODEL]: """ Return the whole QuerySet. Essentially a no-op except as the only operation. """ return self._clone() - def raw(self, sql: str) -> "RawSQLQuery": + def raw(self, sql: str) -> RawSQLQuery: """ Return the QuerySet from raw SQL """ return RawSQLQuery(model=self.model, db=self._db, sql=sql) - def first(self) -> QuerySetSingle[Optional[MODEL]]: + def first(self) -> QuerySetSingle[MODEL | None]: """ Limit queryset to one object and return one object instead of list. """ queryset = self._clone() return queryset._as_single() - def last(self) -> QuerySetSingle[Optional[MODEL]]: + def last(self) -> QuerySetSingle[MODEL | None]: """ Limit queryset to one object and return the last object instead of list. """ @@ -799,9 +801,7 @@ def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]: queryset._raise_does_not_exist = True return queryset # type: ignore - async def in_bulk( - self, id_list: Iterable[Union[str, int]], field_name: str - ) -> dict[str, MODEL]: + async def in_bulk(self, id_list: Iterable[str | int], field_name: str) -> dict[str, MODEL]: """ Return a dictionary mapping each of the given IDs to the object with that ID. If `id_list` isn't provided, evaluate the entire QuerySet. @@ -815,11 +815,11 @@ async def in_bulk( def bulk_create( self, objects: Iterable[MODEL], - batch_size: Optional[int] = None, + batch_size: int | None = None, ignore_conflicts: bool = False, - update_fields: Optional[Iterable[str]] = None, - on_conflict: Optional[Iterable[str]] = None, - ) -> "BulkCreateQuery[MODEL]": + update_fields: Iterable[str] | None = None, + on_conflict: Iterable[str] | None = None, + ) -> BulkCreateQuery[MODEL]: """ This method inserts the provided list of objects into the database in an efficient manner (generally only 1 query, no matter how many objects there are). @@ -853,8 +853,8 @@ def bulk_update( self, objects: Iterable[MODEL], fields: Iterable[str], - batch_size: Optional[int] = None, - ) -> "BulkUpdateQuery[MODEL]": + batch_size: int | None = None, + ) -> BulkUpdateQuery[MODEL]: """ Update the given fields in each of the given objects in the database. @@ -879,7 +879,7 @@ def bulk_update( batch_size=batch_size, ) - def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[Optional[MODEL]]: + def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL | None]: """ Fetch exactly one object matching the parameters. """ @@ -888,7 +888,7 @@ def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[Optional[MODEL] queryset._single = True return queryset # type: ignore - def only(self, *fields_for_select: str) -> "QuerySet[MODEL]": + def only(self, *fields_for_select: str) -> QuerySet[MODEL]: """ Fetch ONLY the specified fields to create a partial model. @@ -910,7 +910,7 @@ def only(self, *fields_for_select: str) -> "QuerySet[MODEL]": queryset._fields_for_select = fields_for_select return queryset - def select_related(self, *fields: str) -> "QuerySet[MODEL]": + def select_related(self, *fields: str) -> QuerySet[MODEL]: """ Return a new QuerySet instance that will select related objects. @@ -923,7 +923,7 @@ def select_related(self, *fields: str) -> "QuerySet[MODEL]": queryset._select_related.add(field) return queryset - def force_index(self, *index_names: str) -> "QuerySet[MODEL]": + def force_index(self, *index_names: str) -> QuerySet[MODEL]: """ The FORCE INDEX hint acts like USE INDEX (index_list), with the addition that a table scan is assumed to be very expensive. @@ -935,7 +935,7 @@ def force_index(self, *index_names: str) -> "QuerySet[MODEL]": return queryset return self - def use_index(self, *index_names: str) -> "QuerySet[MODEL]": + def use_index(self, *index_names: str) -> QuerySet[MODEL]: """ The USE INDEX (index_list) hint tells MySQL to use only one of the named indexes to find rows in the table. """ @@ -946,7 +946,7 @@ def use_index(self, *index_names: str) -> "QuerySet[MODEL]": return queryset return self - def prefetch_related(self, *args: Union[str, Prefetch]) -> "QuerySet[MODEL]": + def prefetch_related(self, *args: str | Prefetch) -> QuerySet[MODEL]: """ Like ``.fetch_related()`` on instance, but works on all objects in QuerySet. @@ -996,7 +996,7 @@ async def explain(self) -> Any: self.query.get_sql() ) - def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": + def using_db(self, _db: BaseDBAsyncClient | None) -> QuerySet[MODEL]: """ Executes query in provided db client. Useful for transactions workaround. @@ -1007,12 +1007,12 @@ def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": def _join_table_with_select_related( self, - model: "type[Model]", + model: type[Model], table: Table, field: str, forwarded_fields: str, - path: Iterable[Optional[str]], - ) -> "QueryBuilder": + path: Iterable[str | None], + ) -> QueryBuilder: if field in model._meta.fields_db_projection and forwarded_fields: raise FieldError(f'Field "{field}" for model "{model.__name__}" is not relation') @@ -1158,7 +1158,7 @@ def __init__( q_objects: list[Q], annotations: dict[str, Any], custom_filters: dict[str, FilterInfoDict], - limit: Optional[int], + limit: int | None, orderings: list[tuple[str, str]], ) -> None: super().__init__(model) @@ -1236,7 +1236,7 @@ def __init__( q_objects: list[Q], annotations: dict[str, Any], custom_filters: dict[str, FilterInfoDict], - limit: Optional[int], + limit: int | None, orderings: list[tuple[str, str]], ) -> None: super().__init__(model) @@ -1334,8 +1334,8 @@ def __init__( q_objects: list[Q], annotations: dict[str, Any], custom_filters: dict[str, FilterInfoDict], - limit: Optional[int], - offset: Optional[int], + limit: int | None, + offset: int | None, force_indexes: set[str], use_indexes: set[str], ) -> None: @@ -1400,8 +1400,8 @@ def _join_table_with_forwarded_fields( if field in self.model._meta.fetch_fields and not forwarded_fields: raise ValueError( - 'Selecting relation "{}" is not possible, select concrete ' - "field on related model".format(field) + f'Selecting relation "{field}" is not possible, select concrete ' + "field on related model" ) field_object = cast(RelationalField, model._meta.fields_map.get(field)) @@ -1432,8 +1432,8 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: if field in self.model._meta.fetch_fields: raise ValueError( - 'Selecting relation "{}" is not possible, select ' - "concrete field on related model".format(field) + f'Selecting relation "{field}" is not possible, select ' + "concrete field on related model" ) field_, __, forwarded_fields = field.partition("__") @@ -1517,9 +1517,9 @@ def __init__( q_objects: list[Q], single: bool, raise_does_not_exist: bool, - fields_for_select_list: Union[tuple[str, ...], list[str]], - limit: Optional[int], - offset: Optional[int], + fields_for_select_list: tuple[str, ...] | list[str], + limit: int | None, + offset: int | None, distinct: bool, orderings: list[tuple[str, str]], flat: bool, @@ -1582,24 +1582,24 @@ def _make_query(self) -> None: @overload def __await__( - self: "ValuesListQuery[Literal[False]]", + self: ValuesListQuery[Literal[False]], ) -> Generator[Any, None, list[tuple[Any, ...]]]: ... @overload def __await__( - self: "ValuesListQuery[Literal[True]]", + self: ValuesListQuery[Literal[True]], ) -> Generator[Any, None, tuple[Any, ...]]: ... - def __await__(self) -> Generator[Any, None, Union[list[Any], tuple[Any, ...]]]: + def __await__(self) -> Generator[Any, None, list[Any] | tuple[Any, ...]]: self._choose_db_if_not_chosen() self._make_query() return self._execute().__await__() # pylint: disable=E1101 - async def __aiter__(self: "ValuesListQuery[Any]") -> AsyncIterator[Any]: + async def __aiter__(self: ValuesListQuery[Any]) -> AsyncIterator[Any]: for val in await self: yield val - async def _execute(self) -> Union[list[Any], tuple]: + async def _execute(self) -> list[Any] | tuple: _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) columns = [ (key, self.resolve_to_python_value(self.model, name)) @@ -1646,8 +1646,8 @@ def __init__( single: bool, raise_does_not_exist: bool, fields_for_select: dict[str, str], - limit: Optional[int], - offset: Optional[int], + limit: int | None, + offset: int | None, distinct: bool, orderings: list[tuple[str, str]], annotations: dict[str, Any], @@ -1709,26 +1709,26 @@ def _make_query(self) -> None: @overload def __await__( - self: "ValuesQuery[Literal[False]]", + self: ValuesQuery[Literal[False]], ) -> Generator[Any, None, list[dict[str, Any]]]: ... @overload def __await__( - self: "ValuesQuery[Literal[True]]", + self: ValuesQuery[Literal[True]], ) -> Generator[Any, None, dict[str, Any]]: ... def __await__( self, - ) -> Generator[Any, None, Union[list[dict[str, Any]], dict[str, Any]]]: + ) -> Generator[Any, None, list[dict[str, Any]] | dict[str, Any]]: self._choose_db_if_not_chosen() self._make_query() return self._execute().__await__() # pylint: disable=E1101 - async def __aiter__(self: "ValuesQuery[Any]") -> AsyncIterator[dict[str, Any]]: + async def __aiter__(self: ValuesQuery[Any]) -> AsyncIterator[dict[str, Any]]: for val in await self: yield val - async def _execute(self) -> Union[list[dict], dict]: + async def _execute(self) -> list[dict] | dict: result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) columns = [ val @@ -1785,11 +1785,11 @@ def __init__( q_objects: list[Q], annotations: dict[str, Any], custom_filters: dict[str, FilterInfoDict], - limit: Optional[int], + limit: int | None, orderings: list[tuple[str, str]], objects: Iterable[MODEL], fields: Iterable[str], - batch_size: Optional[int] = None, + batch_size: int | None = None, ): super().__init__( model, @@ -1883,10 +1883,10 @@ def __init__( model: type[MODEL], db: BaseDBAsyncClient, objects: Iterable[MODEL], - batch_size: Optional[int] = None, + batch_size: int | None = None, ignore_conflicts: bool = False, - update_fields: Optional[Iterable[str]] = None, - on_conflict: Optional[Iterable[str]] = None, + update_fields: Iterable[str] | None = None, + on_conflict: Iterable[str] | None = None, ): super().__init__(model) self._objects = objects diff --git a/tortoise/router.py b/tortoise/router.py index e428dc052..34946ab0e 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -17,7 +17,7 @@ def __init__(self) -> None: def init_routers(self, routers: list[Callable]) -> None: self._routers = [r() for r in routers] - def _router_func(self, model: type["Model"], action: str) -> Any: + def _router_func(self, model: type[Model], action: str) -> Any: for r in self._routers: try: method = getattr(r, action) @@ -29,16 +29,16 @@ def _router_func(self, model: type["Model"], action: str) -> Any: if chosen_db: return chosen_db - def _db_route(self, model: type["Model"], action: str) -> "BaseDBAsyncClient" | None: + def _db_route(self, model: type[Model], action: str) -> BaseDBAsyncClient | None: try: return connections.get(self._router_func(model, action)) except ConfigurationError: return None - def db_for_read(self, model: type["Model"]) -> "BaseDBAsyncClient" | None: + def db_for_read(self, model: type[Model]) -> BaseDBAsyncClient | None: return self._db_route(model, "db_for_read") - def db_for_write(self, model: type["Model"]) -> "BaseDBAsyncClient" | None: + def db_for_write(self, model: type[Model]) -> BaseDBAsyncClient | None: return self._db_route(model, "db_for_write") diff --git a/tortoise/signals.py b/tortoise/signals.py index a0aeb60c7..5a00b7d90 100644 --- a/tortoise/signals.py +++ b/tortoise/signals.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Callable from enum import Enum from typing import TypeVar diff --git a/tortoise/timezone.py b/tortoise/timezone.py index edee30b1e..1255d8fd7 100644 --- a/tortoise/timezone.py +++ b/tortoise/timezone.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import functools import os from datetime import datetime, time, tzinfo -from typing import Optional, Union import pytz @@ -49,7 +50,7 @@ def _reset_timezone_cache() -> None: get_timezone.cache_clear() -def localtime(value: Optional[datetime] = None, timezone: Optional[str] = None) -> datetime: +def localtime(value: datetime | None = None, timezone: str | None = None) -> datetime: """ Convert an aware datetime.datetime to local time. @@ -69,7 +70,7 @@ def localtime(value: Optional[datetime] = None, timezone: Optional[str] = None) return value.astimezone(tz) -def is_aware(value: Union[datetime, time]) -> bool: +def is_aware(value: datetime | time) -> bool: """ Determine if a given datetime.datetime or datetime.time is aware. @@ -82,7 +83,7 @@ def is_aware(value: Union[datetime, time]) -> bool: return value.utcoffset() is not None -def is_naive(value: Union[datetime, time]) -> bool: +def is_naive(value: datetime | time) -> bool: """ Determine if a given datetime.datetime or datetime.time is naive. @@ -96,7 +97,7 @@ def is_naive(value: Union[datetime, time]) -> bool: def make_aware( - value: datetime, timezone: Optional[str] = None, is_dst: Optional[bool] = None + value: datetime, timezone: str | None = None, is_dst: bool | None = None ) -> datetime: """ Make a naive datetime.datetime in a given time zone aware. @@ -107,12 +108,12 @@ def make_aware( if hasattr(tz, "localize"): return tz.localize(value, is_dst=is_dst) if is_aware(value): - raise ValueError("make_aware expects a naive datetime, got %s" % value) + raise ValueError(f"make_aware expects a naive datetime, got {value}") # This may be wrong around DST changes! return value.replace(tzinfo=tz) -def make_naive(value: datetime, timezone: Optional[str] = None) -> datetime: +def make_naive(value: datetime, timezone: str | None = None) -> datetime: """ Make an aware datetime.datetime naive in a given time zone. diff --git a/tortoise/transactions.py b/tortoise/transactions.py index 12eb4d154..4b95ffe32 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Optional, TypeVar, cast +from typing import TYPE_CHECKING, TypeVar, cast from tortoise import connections from tortoise.exceptions import ParamsError @@ -13,7 +15,7 @@ F = TypeVar("F", bound=FuncType) -def _get_connection(connection_name: Optional[str]) -> "BaseDBAsyncClient": +def _get_connection(connection_name: str | None) -> BaseDBAsyncClient: if connection_name: connection = connections.get(connection_name) elif len(connections.db_config) == 1: @@ -27,7 +29,7 @@ def _get_connection(connection_name: Optional[str]) -> "BaseDBAsyncClient": return connection -def in_transaction(connection_name: Optional[str] = None) -> "TransactionContext": +def in_transaction(connection_name: str | None = None) -> TransactionContext: """ Transaction context manager. @@ -41,7 +43,7 @@ def in_transaction(connection_name: Optional[str] = None) -> "TransactionContext return connection._in_transaction() -def atomic(connection_name: Optional[str] = None) -> Callable[[F], F]: +def atomic(connection_name: str | None = None) -> Callable[[F], F]: """ Transaction decorator. diff --git a/tortoise/utils.py b/tortoise/utils.py index 1ca7e5e9f..fe982576a 100644 --- a/tortoise/utils.py +++ b/tortoise/utils.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import sys from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from tortoise.log import logger @@ -19,7 +21,7 @@ def batched(iterable: Iterable[Any], n: int) -> Iterable[tuple[Any]]: from tortoise.backends.base.client import BaseDBAsyncClient -def get_schema_sql(client: "BaseDBAsyncClient", safe: bool) -> str: +def get_schema_sql(client: BaseDBAsyncClient, safe: bool) -> str: """ Generates the SQL schema for the given client. @@ -30,7 +32,7 @@ def get_schema_sql(client: "BaseDBAsyncClient", safe: bool) -> str: return generator.get_create_schema_sql(safe) -async def generate_schema_for_client(client: "BaseDBAsyncClient", safe: bool) -> None: +async def generate_schema_for_client(client: BaseDBAsyncClient, safe: bool) -> None: """ Generates and applies the SQL schema directly to the given client. @@ -44,7 +46,7 @@ async def generate_schema_for_client(client: "BaseDBAsyncClient", safe: bool) -> await generator.generate_from_string(schema) -def chunk(instances: Iterable[Any], batch_size: Optional[int] = None) -> Iterable[Iterable[Any]]: +def chunk(instances: Iterable[Any], batch_size: int | None = None) -> Iterable[Iterable[Any]]: """ Generate iterable chunk by batch_size # noqa: DAR301 diff --git a/tortoise/validators.py b/tortoise/validators.py index 2fa83350a..36785bb74 100644 --- a/tortoise/validators.py +++ b/tortoise/validators.py @@ -107,10 +107,10 @@ class CommaSeparatedIntegerListValidator(Validator): """ def __init__(self, allow_negative: bool = False) -> None: - pattern = r"^%(neg)s\d+(?:%(sep)s%(neg)s\d+)*\Z" % { - "neg": "(-)?" if allow_negative else "", - "sep": re.escape(","), - } + pattern = r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format( + neg="(-)?" if allow_negative else "", + sep=re.escape(","), + ) self.regex = RegexValidator(pattern, re.I) def __call__(self, value: str) -> None: