Skip to content

Commit a6c26f5

Browse files
committed
Add index_name and get_sql back to Index class for aerich
1 parent 81892fe commit a6c26f5

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

tortoise/backends/base/schema_generator.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,14 +345,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
345345
if model._meta.indexes:
346346
for index in model._meta.indexes:
347347
if isinstance(index, Index):
348-
idx_sql = self._get_index_sql(
349-
model,
350-
index.field_names,
351-
safe=safe,
352-
index_name=index.name,
353-
index_type=index.INDEX_TYPE,
354-
extra=index.extra,
355-
)
348+
idx_sql = index.get_sql(self, model, safe)
356349
else:
357350
fields = []
358351
for field in index:

tortoise/indexes.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any, Type
44

55
from pypika_tortoise.context import DEFAULT_SQL_CONTEXT
66
from pypika_tortoise.terms import Term, ValueWrapper
77

88
from tortoise.exceptions import ConfigurationError
99

10+
if TYPE_CHECKING:
11+
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
12+
from tortoise.models import Model
13+
1014

1115
class Index:
1216
INDEX_TYPE = ""
@@ -47,6 +51,23 @@ def describe(self) -> dict:
4751
"extra": self.extra,
4852
}
4953

54+
def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str:
55+
# This function is required by aerich
56+
return self.name or schema_generator._generate_index_name("idx", model, self.field_names)
57+
58+
def get_sql(
59+
self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool
60+
) -> str:
61+
# This function is required by aerich
62+
return schema_generator._get_index_sql(
63+
model,
64+
self.field_names,
65+
safe,
66+
index_name=self.name,
67+
index_type=self.INDEX_TYPE,
68+
extra=self.extra,
69+
)
70+
5071
@property
5172
def field_names(self) -> list[str]:
5273
if self.fields:

0 commit comments

Comments
 (0)