Skip to content

Commit 948ccdb

Browse files
authored
Fix index class argument name not work (#1856)
* Fix index class argument `name` not work * refactor: add `field_names` property to Index class * Check custom index name in generated schema * Add `index_name` and `get_sql` back to Index class for aerich
1 parent de48e77 commit 948ccdb

File tree

6 files changed

+63
-25
lines changed

6 files changed

+63
-25
lines changed

tests/fields/test_db_index.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ class TestIndexAliasChar(TestIndexAlias):
9999

100100
class TestModelWithIndexes(test.TestCase):
101101
def test_meta(self):
102-
self.assertEqual(ModelWithIndexes._meta.indexes, [Index(fields=("f1", "f2"))])
102+
self.assertEqual(
103+
ModelWithIndexes._meta.indexes,
104+
[Index(fields=("f1", "f2")), Index(fields=("f3",), name="model_with_indexes__f3")],
105+
)
103106
self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index)
104107
self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index)
105108
self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique)

tests/schema/test_generate_schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ async def test_create_index(self):
174174
sql = self.get_sql("CREATE INDEX")
175175
self.assertIsNotNone(re.search(r"idx_tournament_created_\w+", sql))
176176

177+
async def test_create_index_with_custom_name(self):
178+
await self.init_for("tests.testmodels")
179+
sql = self.get_sql("f3")
180+
self.assertIn("model_with_indexes__f3", sql)
181+
177182
async def test_fk_bad_model_name(self):
178183
with self.assertRaisesRegex(
179184
ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"'

tests/testmodels.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,11 +1059,13 @@ class ModelWithIndexes(Model):
10591059
unique_indexed = fields.CharField(max_length=16, unique=True)
10601060
f1 = fields.CharField(max_length=16)
10611061
f2 = fields.CharField(max_length=16)
1062+
f3 = fields.CharField(max_length=16)
10621063
u1 = fields.IntField()
10631064
u2 = fields.IntField()
10641065

10651066
class Meta:
10661067
indexes = [
10671068
Index(fields=["f1", "f2"]),
1069+
Index(fields=["f3"], name="model_with_indexes__f3"),
10681070
]
10691071
unique_together = [("u1", "u2")]

tests/utils/test_describe_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,16 @@ def test_describe_indexes_serializable(self):
15681568

15691569
self.assertEqual(
15701570
val["indexes"],
1571-
[{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}],
1571+
[
1572+
{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""},
1573+
{
1574+
"fields": ["f3"],
1575+
"expressions": [],
1576+
"name": "model_with_indexes__f3",
1577+
"type": "",
1578+
"extra": "",
1579+
},
1580+
],
15721581
)
15731582

15741583
def test_describe_indexes_not_serializable(self):

tortoise/backends/base/schema_generator.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from hashlib import sha256
55
from typing import TYPE_CHECKING, Any, Type, cast
66

7-
from pypika_tortoise.context import DEFAULT_SQL_CONTEXT
8-
97
from tortoise.exceptions import ConfigurationError
108
from tortoise.fields import JSONField, TextField, UUIDField
119
from tortoise.fields.relational import OneToOneFieldInstance
@@ -348,31 +346,17 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
348346

349347
if model._meta.indexes:
350348
for index in model._meta.indexes:
351-
if not isinstance(index, Index):
349+
if isinstance(index, Index):
350+
idx_sql = index.get_sql(self, model, safe)
351+
else:
352352
fields = []
353353
for field in index:
354354
field_object = model._meta.fields_map[field]
355355
fields.append(field_object.source_field or field)
356+
idx_sql = self._get_index_sql(model, fields, safe=safe)
356357

357-
_indexes.append(self._get_index_sql(model, fields, safe=safe))
358-
else:
359-
if index.fields:
360-
fields = [f for f in index.fields]
361-
elif index.expressions:
362-
fields = [
363-
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})"
364-
for expression in index.expressions
365-
]
366-
else:
367-
raise ConfigurationError(
368-
"At least one field or expression is required to define an index."
369-
)
370-
371-
_indexes.append(
372-
self._get_index_sql(
373-
model, fields, safe=safe, index_type=index.INDEX_TYPE, extra=index.extra
374-
)
375-
)
358+
if idx_sql:
359+
_indexes.append(idx_sql)
376360

377361
field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val]
378362

tortoise/indexes.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any, Type
44

5+
from pypika_tortoise.context import DEFAULT_SQL_CONTEXT
56
from pypika_tortoise.terms import Term, ValueWrapper
67

78
from tortoise.exceptions import ConfigurationError
89

10+
if TYPE_CHECKING:
11+
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
12+
from tortoise.models import Model
13+
914

1015
class Index:
1116
INDEX_TYPE = ""
@@ -46,6 +51,36 @@ def describe(self) -> dict:
4651
"extra": self.extra,
4752
}
4853

54+
def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str:
55+
# This function is required by aerich
56+
return self.name or schema_generator._generate_index_name("idx", model, self.field_names)
57+
58+
def get_sql(
59+
self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool
60+
) -> str:
61+
# This function is required by aerich
62+
return schema_generator._get_index_sql(
63+
model,
64+
self.field_names,
65+
safe,
66+
index_name=self.name,
67+
index_type=self.INDEX_TYPE,
68+
extra=self.extra,
69+
)
70+
71+
@property
72+
def field_names(self) -> list[str]:
73+
if self.fields:
74+
return list(self.fields)
75+
elif self.expressions:
76+
return [
77+
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" for expression in self.expressions
78+
]
79+
else:
80+
raise ConfigurationError(
81+
"At least one field or expression is required to define an index."
82+
)
83+
4984
def __repr__(self) -> str:
5085
argument = ""
5186
if self.expressions:

0 commit comments

Comments
 (0)