diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d61c32cfd..14bc791d0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,13 @@ Changelog 0.24 ==== +0.24.2 (Unreleased) +------ + +Fixed +^^^^^ +- Fix model with multi m2m fields generates wrong references name (#1897) + 0.24.1 ------ Added diff --git a/tests/schema/models_m2m_2.py b/tests/schema/models_m2m_2.py new file mode 100644 index 000000000..27932bb19 --- /dev/null +++ b/tests/schema/models_m2m_2.py @@ -0,0 +1,20 @@ +""" +This is the testing Models — Multi ManyToMany fields +""" + +from __future__ import annotations + +from tortoise import Model, fields + + +class One(Model): + threes: fields.ManyToManyRelation[Three] + + +class Two(Model): + threes: fields.ManyToManyRelation[Three] + + +class Three(Model): + ones: fields.ManyToManyRelation[One] = fields.ManyToManyField("models.One") + twos: fields.ManyToManyRelation[Two] = fields.ManyToManyField("models.Two") diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index 9bb141f4d..acb88d506 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -217,6 +217,13 @@ async def test_m2m_bad_model_name(self): ): await self.init_for("tests.schema.models_m2m_1") + async def test_multi_m2m_fields_in_a_model(self): + await self.init_for("tests.schema.models_m2m_2") + sql = self.get_sql("CASCADE") + self.assertNotRegex(sql, r'REFERENCES [`"]three_one[`"]') + self.assertNotRegex(sql, r'REFERENCES [`"]three_two[`"]') + self.assertRegex(sql, r'REFERENCES [`"](one|two|three)[`"]') + async def test_table_and_row_comment_generation(self): await self.init_for("tests.testmodels") sql = self.get_sql("comments") diff --git a/tests/test_order_by.py b/tests/test_order_by.py index bfb740408..24015fcef 100644 --- a/tests/test_order_by.py +++ b/tests/test_order_by.py @@ -9,7 +9,7 @@ from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ from tortoise.exceptions import ConfigurationError, FieldError -from tortoise.expressions import Q, Case, When +from tortoise.expressions import Case, Q, When from tortoise.functions import Count, Sum diff --git a/tests/test_values.py b/tests/test_values.py index 5dc901585..b681b318b 100644 --- a/tests/test_values.py +++ b/tests/test_values.py @@ -4,7 +4,7 @@ from tortoise.contrib import test from tortoise.contrib.test.condition import In, NotEQ from tortoise.exceptions import FieldError -from tortoise.expressions import Q, Case, Function, When +from tortoise.expressions import Case, Function, Q, When from tortoise.functions import Length, Trim diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index cfe76ed69..97a05aeeb 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -296,7 +296,7 @@ def _get_field_indexes_sqls( return [val for val in list(dict.fromkeys(indexes)) if val] def _get_m2m_tables( - self, model: type[Model], table_name: str, safe: bool, models_tables: list[str] + self, model: type[Model], db_table: str, safe: bool, models_tables: list[str] ) -> list[str]: m2m_tables_for_create = [] for m2m_field in model._meta.m2m_fields: @@ -308,7 +308,7 @@ def _get_m2m_tables( backward_fk = self._create_fk_string( "", backward_key, - table_name, + db_table, model._meta.db_pk_column, field_object.on_delete, "", @@ -324,12 +324,15 @@ def _get_m2m_tables( else: backward_fk = forward_fk = "" exists = "IF NOT EXISTS " if safe else "" - table_name = field_object.through + through_table_name = field_object.through backward_type = self._get_pk_field_sql_type(model._meta.pk) forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk) + comment = "" + if desc := field_object.description: + comment = self._table_comment_generator(table=through_table_name, comment=desc) m2m_create_string = self.M2M_TABLE_TEMPLATE.format( exists=exists, - table_name=table_name, + table_name=through_table_name, backward_fk=backward_fk, forward_fk=forward_fk, backward_key=backward_key, @@ -337,13 +340,7 @@ def _get_m2m_tables( forward_key=forward_key, forward_type=forward_type, extra=self._table_generate_extra(table=field_object.through), - comment=( - self._table_comment_generator( - table=field_object.through, comment=field_object.description - ) - if field_object.description - else "" - ), + comment=comment, ) if not field_object.db_constraint: m2m_create_string = m2m_create_string.replace( @@ -355,7 +352,7 @@ def _get_m2m_tables( m2m_create_string += self._post_table_hook() if field_object.create_unique_index: unique_index_create_sql = self._get_unique_index_sql( - exists, table_name, [backward_key, forward_key] + exists, through_table_name, [backward_key, forward_key] ) if unique_index_create_sql.endswith(";"): m2m_create_string += "\n" + unique_index_create_sql diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 2e5d2e31f..b289d1dcd 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -9,7 +9,7 @@ from pypika_tortoise.analytics import Count from pypika_tortoise.functions import Cast from pypika_tortoise.queries import QueryBuilder -from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper, PseudoColumn +from pypika_tortoise.terms import Case, Field, PseudoColumn, Star, Term, ValueWrapper from typing_extensions import Literal, Protocol from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities