Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/fields/test_db_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from pypika_tortoise.terms import Field

from tests.testmodels import ModelWithIndexes
from tortoise import fields
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.indexes import Index
from tests.testmodels import ModelWithIndexes


class CustomIndex(Index):
Expand Down Expand Up @@ -99,7 +99,10 @@ class TestIndexAliasChar(TestIndexAlias):

class TestModelWithIndexes(test.TestCase):
def test_meta(self):
self.assertEqual(ModelWithIndexes._meta.indexes, [Index(fields=("f1", "f2"))])
self.assertEqual(
ModelWithIndexes._meta.indexes,
[Index(fields=("f1", "f2")), Index(fields=("f3",), name="model_with_indexes__f3")],
)
self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique)
5 changes: 5 additions & 0 deletions tests/schema/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ async def test_create_index(self):
sql = self.get_sql("CREATE INDEX")
self.assertIsNotNone(re.search(r"idx_tournament_created_\w+", sql))

async def test_create_index_with_custom_name(self):
await self.init_for("tests.testmodels")
sql = self.get_sql("f3")
self.assertIn("model_with_indexes__f3", sql)

async def test_fk_bad_model_name(self):
with self.assertRaisesRegex(
ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"'
Expand Down
2 changes: 2 additions & 0 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,11 +1059,13 @@ class ModelWithIndexes(Model):
unique_indexed = fields.CharField(max_length=16, unique=True)
f1 = fields.CharField(max_length=16)
f2 = fields.CharField(max_length=16)
f3 = fields.CharField(max_length=16)
u1 = fields.IntField()
u2 = fields.IntField()

class Meta:
indexes = [
Index(fields=["f1", "f2"]),
Index(fields=["f3"], name="model_with_indexes__f3"),
]
unique_together = [("u1", "u2")]
11 changes: 10 additions & 1 deletion tests/utils/test_describe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,16 @@ def test_describe_indexes_serializable(self):

self.assertEqual(
val["indexes"],
[{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}],
[
{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""},
{
"fields": ["f3"],
"expressions": [],
"name": "model_with_indexes__f3",
"type": "",
"extra": "",
},
],
)

def test_describe_indexes_not_serializable(self):
Expand Down
35 changes: 13 additions & 22 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from hashlib import sha256
from typing import TYPE_CHECKING, Any, List, Optional, Set, Type, Union, cast

from pypika_tortoise.context import DEFAULT_SQL_CONTEXT

from tortoise.exceptions import ConfigurationError
from tortoise.fields import JSONField, TextField, UUIDField
from tortoise.fields.relational import OneToOneFieldInstance
Expand Down Expand Up @@ -346,31 +344,24 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:

if model._meta.indexes:
for index in model._meta.indexes:
if not isinstance(index, Index):
if isinstance(index, Index):
idx_sql = self._get_index_sql(
model,
index.field_names,
safe=safe,
index_name=index.name,
index_type=index.INDEX_TYPE,
extra=index.extra,
)
else:
fields = []
for field in index:
field_object = model._meta.fields_map[field]
fields.append(field_object.source_field or field)
idx_sql = self._get_index_sql(model, fields, safe=safe)

_indexes.append(self._get_index_sql(model, fields, safe=safe))
else:
if index.fields:
fields = [f for f in index.fields]
elif index.expressions:
fields = [
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})"
for expression in index.expressions
]
else:
raise ConfigurationError(
"At least one field or expression is required to define an index."
)

_indexes.append(
self._get_index_sql(
model, fields, safe=safe, index_type=index.INDEX_TYPE, extra=index.extra
)
)
if idx_sql:
_indexes.append(idx_sql)

field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val]

Expand Down
14 changes: 14 additions & 0 deletions tortoise/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any

from pypika_tortoise.context import DEFAULT_SQL_CONTEXT
from pypika_tortoise.terms import Term, ValueWrapper

from tortoise.exceptions import ConfigurationError
Expand Down Expand Up @@ -46,6 +47,19 @@ def describe(self) -> dict:
"extra": self.extra,
}

@property
def field_names(self) -> list[str]:
if self.fields:
return list(self.fields)
elif self.expressions:
return [
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" for expression in self.expressions
]
else:
raise ConfigurationError(
"At least one field or expression is required to define an index."
)

def __repr__(self) -> str:
argument = ""
if self.expressions:
Expand Down
Loading