11import re
22from hashlib import sha256
3- from typing import TYPE_CHECKING , Any , List , Set , Type , Union , cast
3+ from typing import TYPE_CHECKING , Any , List , Optional , Set , Type , Union , cast
4+
5+ from pypika_tortoise .context import DEFAULT_SQL_CONTEXT
46
57from tortoise .exceptions import ConfigurationError
68from tortoise .fields import JSONField , TextField , UUIDField
@@ -23,8 +25,10 @@ class BaseSchemaGenerator:
2325 DIALECT = "sql"
2426 TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};'
2527 FIELD_TEMPLATE = '"{name}" {type}{nullable}{unique}{primary}{default}{comment}'
26- INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});'
27- UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE .replace (" INDEX" , " UNIQUE INDEX" )
28+ INDEX_CREATE_TEMPLATE = (
29+ 'CREATE {index_type}INDEX {exists}"{index_name}" ON "{table_name}" ({fields}){extra};'
30+ )
31+ UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE .replace ("INDEX" , "UNIQUE INDEX" )
2832 UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})'
2933 GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}'
3034 FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}'
@@ -167,21 +171,33 @@ def _generate_fk_name(
167171 )
168172 return index_name
169173
170- def _get_index_sql (self , model : "Type[Model]" , field_names : List [str ], safe : bool ) -> str :
174+ def _get_index_sql (
175+ self ,
176+ model : "Type[Model]" ,
177+ field_names : List [str ],
178+ safe : bool ,
179+ index_name : Optional [str ] = None ,
180+ index_type : Optional [str ] = None ,
181+ extra : Optional [str ] = None ,
182+ ) -> str :
171183 return self .INDEX_CREATE_TEMPLATE .format (
172184 exists = "IF NOT EXISTS " if safe else "" ,
173- index_name = self ._generate_index_name ("idx" , model , field_names ),
185+ index_name = index_name or self ._generate_index_name ("idx" , model , field_names ),
186+ index_type = f"{ index_type } " if index_type else "" ,
174187 table_name = model ._meta .db_table ,
175188 fields = ", " .join ([self .quote (f ) for f in field_names ]),
189+ extra = f"{ extra } " if extra else "" ,
176190 )
177191
178192 def _get_unique_index_sql (self , exists : str , table_name : str , field_names : List [str ]) -> str :
179193 index_name = self ._generate_index_name ("uidx" , table_name , field_names )
180194 return self .UNIQUE_INDEX_CREATE_TEMPLATE .format (
181195 exists = exists ,
182196 index_name = index_name ,
197+ index_type = "" ,
183198 table_name = table_name ,
184199 fields = ", " .join ([self .quote (f ) for f in field_names ]),
200+ extra = "" ,
185201 )
186202
187203 def _get_unique_constraint_sql (self , model : "Type[Model]" , field_names : List [str ]) -> str :
@@ -324,22 +340,37 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
324340 self ._get_unique_constraint_sql (model , unique_together_to_create )
325341 )
326342
327- # Indexes.
328343 _indexes = [
329344 self ._get_index_sql (model , [field_name ], safe = safe ) for field_name in fields_with_index
330345 ]
331346
332347 if model ._meta .indexes :
333- for indexes_list in model ._meta .indexes :
334- if not isinstance (indexes_list , Index ):
335- indexes_to_create = []
336- for field in indexes_list :
348+ for index in model ._meta .indexes :
349+ if not isinstance (index , Index ):
350+ fields = []
351+ for field in index :
337352 field_object = model ._meta .fields_map [field ]
338- indexes_to_create .append (field_object .source_field or field )
353+ fields .append (field_object .source_field or field )
339354
340- _indexes .append (self ._get_index_sql (model , indexes_to_create , safe = safe ))
355+ _indexes .append (self ._get_index_sql (model , fields , safe = safe ))
341356 else :
342- _indexes .append (indexes_list .get_sql (self , model , safe ))
357+ if index .fields :
358+ fields = [f for f in index .fields ]
359+ elif index .expressions :
360+ fields = [
361+ f"({ expression .get_sql (DEFAULT_SQL_CONTEXT )} )"
362+ for expression in index .expressions
363+ ]
364+ else :
365+ raise ConfigurationError (
366+ "At least one field or expression is required to define an index."
367+ )
368+
369+ _indexes .append (
370+ self ._get_index_sql (
371+ model , fields , safe = safe , index_type = index .INDEX_TYPE , extra = index .extra
372+ )
373+ )
343374
344375 field_indexes_sqls = [val for val in list (dict .fromkeys (_indexes )) if val ]
345376
0 commit comments