1
1
import re
2
2
from hashlib import sha256
3
- from typing import TYPE_CHECKING , Any , List , Set , Type , Union , cast
3
+ from typing import TYPE_CHECKING , Any , List , Optional , Set , Type , Union , cast
4
+
5
+ from pypika_tortoise .context import DEFAULT_SQL_CONTEXT
4
6
5
7
from tortoise .exceptions import ConfigurationError
6
8
from tortoise .fields import JSONField , TextField , UUIDField
@@ -23,8 +25,10 @@ class BaseSchemaGenerator:
23
25
DIALECT = "sql"
24
26
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};'
25
27
FIELD_TEMPLATE = '"{name}" {type}{nullable}{unique}{primary}{default}{comment}'
26
- INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});'
27
- UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE .replace (" INDEX" , " UNIQUE INDEX" )
28
+ INDEX_CREATE_TEMPLATE = (
29
+ 'CREATE {index_type}INDEX {exists}"{index_name}" ON "{table_name}" ({fields}){extra};'
30
+ )
31
+ UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE .replace ("INDEX" , "UNIQUE INDEX" )
28
32
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})'
29
33
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}'
30
34
FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}'
@@ -167,21 +171,33 @@ def _generate_fk_name(
167
171
)
168
172
return index_name
169
173
170
- def _get_index_sql (self , model : "Type[Model]" , field_names : List [str ], safe : bool ) -> str :
174
+ def _get_index_sql (
175
+ self ,
176
+ model : "Type[Model]" ,
177
+ field_names : List [str ],
178
+ safe : bool ,
179
+ index_name : Optional [str ] = None ,
180
+ index_type : Optional [str ] = None ,
181
+ extra : Optional [str ] = None ,
182
+ ) -> str :
171
183
return self .INDEX_CREATE_TEMPLATE .format (
172
184
exists = "IF NOT EXISTS " if safe else "" ,
173
- index_name = self ._generate_index_name ("idx" , model , field_names ),
185
+ index_name = index_name or self ._generate_index_name ("idx" , model , field_names ),
186
+ index_type = f"{ index_type } " if index_type else "" ,
174
187
table_name = model ._meta .db_table ,
175
188
fields = ", " .join ([self .quote (f ) for f in field_names ]),
189
+ extra = f"{ extra } " if extra else "" ,
176
190
)
177
191
178
192
def _get_unique_index_sql (self , exists : str , table_name : str , field_names : List [str ]) -> str :
179
193
index_name = self ._generate_index_name ("uidx" , table_name , field_names )
180
194
return self .UNIQUE_INDEX_CREATE_TEMPLATE .format (
181
195
exists = exists ,
182
196
index_name = index_name ,
197
+ index_type = "" ,
183
198
table_name = table_name ,
184
199
fields = ", " .join ([self .quote (f ) for f in field_names ]),
200
+ extra = "" ,
185
201
)
186
202
187
203
def _get_unique_constraint_sql (self , model : "Type[Model]" , field_names : List [str ]) -> str :
@@ -324,22 +340,37 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
324
340
self ._get_unique_constraint_sql (model , unique_together_to_create )
325
341
)
326
342
327
- # Indexes.
328
343
_indexes = [
329
344
self ._get_index_sql (model , [field_name ], safe = safe ) for field_name in fields_with_index
330
345
]
331
346
332
347
if model ._meta .indexes :
333
- for indexes_list in model ._meta .indexes :
334
- if not isinstance (indexes_list , Index ):
335
- indexes_to_create = []
336
- for field in indexes_list :
348
+ for index in model ._meta .indexes :
349
+ if not isinstance (index , Index ):
350
+ fields = []
351
+ for field in index :
337
352
field_object = model ._meta .fields_map [field ]
338
- indexes_to_create .append (field_object .source_field or field )
353
+ fields .append (field_object .source_field or field )
339
354
340
- _indexes .append (self ._get_index_sql (model , indexes_to_create , safe = safe ))
355
+ _indexes .append (self ._get_index_sql (model , fields , safe = safe ))
341
356
else :
342
- _indexes .append (indexes_list .get_sql (self , model , safe ))
357
+ if index .fields :
358
+ fields = [f for f in index .fields ]
359
+ elif index .expressions :
360
+ fields = [
361
+ f"({ expression .get_sql (DEFAULT_SQL_CONTEXT )} )"
362
+ for expression in index .expressions
363
+ ]
364
+ else :
365
+ raise ConfigurationError (
366
+ "At least one field or expression is required to define an index."
367
+ )
368
+
369
+ _indexes .append (
370
+ self ._get_index_sql (
371
+ model , fields , safe = safe , index_type = index .INDEX_TYPE , extra = index .extra
372
+ )
373
+ )
343
374
344
375
field_indexes_sqls = [val for val in list (dict .fromkeys (_indexes )) if val ]
345
376
0 commit comments