diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a0f865d09..4d2e046ff 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,10 @@ Added - Add `no_key` parameter to `queryset.select_for_update`. - `F()` supports referencing JSONField attributes, e.g. `F("json_field__custom_field__nested_id")` (#1960) +Fixed +^^^^^ +- Fix PostgreSQL schema creation for non-default schemas - automatically create schemas if they don't exist (#1671) + 0.25.0 ------ Fixed diff --git a/tests/schema/test_schema_creation.py b/tests/schema/test_schema_creation.py new file mode 100644 index 000000000..ddd2d989f --- /dev/null +++ b/tests/schema/test_schema_creation.py @@ -0,0 +1,28 @@ +"""Tests for automatic PostgreSQL schema creation functionality.""" + +from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator +from tortoise.contrib import test + + +class TestPostgresSchemaCreation(test.TestCase): + """Test automatic PostgreSQL schema creation.""" + + def test_postgres_schema_creation_sql(self): + """Test that BasePostgresSchemaGenerator can create schema SQL.""" + # Mock client for testing + class MockClient: + def __init__(self): + self.capabilities = type('obj', (object,), { + 'inline_comment': False, + 'safe': True + })() + + mock_client = MockClient() + generator = BasePostgresSchemaGenerator(mock_client) + + # Test schema creation SQL generation + schema_sql = generator._get_create_schema_sql("pgdev", safe=True) + self.assertEqual(schema_sql, 'CREATE SCHEMA IF NOT EXISTS "pgdev";') + + schema_sql_unsafe = generator._get_create_schema_sql("pgdev", safe=False) + self.assertEqual(schema_sql_unsafe, 'CREATE SCHEMA "pgdev";') diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 61e9a68e7..227f7c09e 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -24,17 +24,21 @@ class BaseSchemaGenerator: DIALECT = "sql" - TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};' + TABLE_CREATE_TEMPLATE = ( + "CREATE TABLE {exists}{table_name} ({fields}){extra}{comment};" + ) FIELD_TEMPLATE = '"{name}" {type}{nullable}{unique}{primary}{default}{comment}' INDEX_CREATE_TEMPLATE = ( - 'CREATE {index_type}INDEX {exists}"{index_name}" ON "{table_name}" ({fields}){extra};' + 'CREATE {index_type}INDEX {exists}"{index_name}" ON {table_name} ({fields}){extra};' + ) + UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace( + "INDEX", "UNIQUE INDEX" ) - UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX") UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})' GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}' FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}' M2M_TABLE_TEMPLATE = ( - 'CREATE TABLE {exists}"{table_name}" (\n' + "CREATE TABLE {exists}{table_name} (\n" ' "{backward_key}" {backward_type} NOT NULL{backward_fk},\n' ' "{forward_key}" {forward_type} NOT NULL{forward_fk}\n' "){extra}{comment};" @@ -75,7 +79,11 @@ def _create_fk_string( comment: str, ) -> str: return self.FK_TEMPLATE.format( - db_column=db_column, table=table, field=field, on_delete=on_delete, comment=comment + db_column=db_column, + table=table, + field=field, + on_delete=on_delete, + comment=comment, ) def _table_comment_generator(self, table: str, comment: str) -> str: @@ -139,6 +147,21 @@ def _get_inner_statements(self) -> list[str]: def quote(self, val: str) -> str: return f'"{val}"' + def _get_qualified_table_name(self, model: type[Model]) -> str: + """Get the fully qualified table name including schema if present.""" + table_name = model._meta.db_table + if model._meta.schema: + return f'"{model._meta.schema}"."{table_name}"' + return f'"{table_name}"' + + def _get_qualified_m2m_table_name( + self, model: type[Model], through_table_name: str + ) -> str: + """Get the fully qualified M2M table name including schema if present.""" + if model._meta.schema: + return f'"{model._meta.schema}"."{through_table_name}"' + return f'"{through_table_name}"' + @staticmethod def _make_hash(*args: str, length: int) -> str: # Hash a set of string values and get a digest of the given length. @@ -156,8 +179,11 @@ def _get_index_name( hashed = self._make_hash(table_name, *field_names, length=6) return f"{prefix}_{table}_{field}_{hashed}" - def _get_fk_name(self, from_table: str, from_field: str, to_table: str, to_field: str) -> str: - # NOTE: for compatibility, index name should not be longer than 30 characters (Oracle limit). + def _get_fk_name( + self, from_table: str, from_field: str, to_table: str, to_field: str + ) -> str: + # NOTE: for compatibility, index name should not be longer than 30 characters + # (Oracle limit). # That's why we slice some of the strings here. hashed = self._make_hash(from_table, from_field, to_table, to_field, length=8) return f"fk_{from_table[:8]}_{to_table[:8]}_{hashed}" @@ -175,7 +201,7 @@ def _get_index_sql( exists="IF NOT EXISTS " if safe else "", index_name=index_name or self._get_index_name("idx", model, field_names), index_type=f"{index_type} " if index_type else "", - table_name=model._meta.db_table, + table_name=self._get_qualified_table_name(model), fields=", ".join([self.quote(f) for f in field_names]), extra=f"{extra}" if extra else "", ) @@ -184,16 +210,19 @@ def _get_unique_index_sql( self, exists: str, table_name: str, field_names: Sequence[str] ) -> str: index_name = self._get_index_name("uidx", table_name, field_names) + quoted_table_name = self.quote(table_name) return self.UNIQUE_INDEX_CREATE_TEMPLATE.format( exists=exists, index_name=index_name, index_type="", - table_name=table_name, + table_name=quoted_table_name, fields=", ".join([self.quote(f) for f in field_names]), extra="", ) - def _get_unique_constraint_sql(self, model: type[Model], field_names: Sequence[str]) -> str: + def _get_unique_constraint_sql( + self, model: type[Model], field_names: Sequence[str] + ) -> str: return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format( index_name=self._get_index_name("uid", model, field_names), fields=", ".join([self.quote(f) for f in field_names]), @@ -206,7 +235,9 @@ def _get_pk_field_sql_type(self, pk_field: Field) -> str: return sql_type raise ConfigurationError(f"Can't get SQL type of {pk_field} for {self.DIALECT}") - def _get_pk_create_sql(self, field_object: Field, column_name: str, comment: str) -> str: + def _get_pk_create_sql( + self, field_object: Field, column_name: str, comment: str + ) -> str: if field_object.pk and field_object.generated: generated_sql = field_object.get_for_dialect(self.DIALECT, "GENERATED_SQL") if generated_sql: # pragma: nobranch @@ -217,7 +248,9 @@ def _get_pk_create_sql(self, field_object: Field, column_name: str, comment: str ) return "" - def _get_field_comment(self, field_object: Field, table_name: str, column_name: str) -> str: + def _get_field_comment( + self, field_object: Field, table_name: str, column_name: str + ) -> str: if desc := field_object.description: return self._column_comment_generator( table=table_name, column=column_name, comment=desc @@ -225,7 +258,12 @@ def _get_field_comment(self, field_object: Field, table_name: str, column_name: return "" def _get_field_sql_and_related_table( - self, field_object: Field, table_name: str, column_name: str, default: str, comment: str + self, + field_object: Field, + table_name: str, + column_name: str, + default: str, + comment: str, ) -> tuple[str, str]: nullable = " NOT NULL" if not field_object.null else "" unique = " UNIQUE" if field_object.unique else "" @@ -241,6 +279,10 @@ def _get_field_sql_and_related_table( to_field_name = reference.to_field_instance.model_field_name related_table_name = reference.related_model._meta.db_table + # Get qualified table name for FK reference if related model has schema + qualified_related_table_name = self._get_qualified_table_name( + reference.related_model + ).strip('"') if reference.db_constraint: field_creation_string = self._create_string( db_column=column_name, @@ -258,7 +300,7 @@ def _get_field_sql_and_related_table( to_field_name, ), db_column=column_name, - table=related_table_name, + table=qualified_related_table_name, field=to_field_name, on_delete=reference.on_delete, comment=comment, @@ -278,7 +320,9 @@ def _get_field_sql_and_related_table( def _get_field_indexes_sqls( self, model: type[Model], field_names: Sequence[str], safe: bool ) -> list[str]: - indexes = [self._get_index_sql(model, [field], safe=safe) for field in field_names] + indexes = [ + self._get_index_sql(model, [field], safe=safe) for field in field_names + ] if model._meta.indexes: for index in model._meta.indexes: @@ -286,7 +330,8 @@ def _get_field_indexes_sqls( idx_sql = index.get_sql(self, model, safe) else: fields = [ - model._meta.fields_map[field].source_field or field for field in index + model._meta.fields_map[field].source_field or field + for field in index ] idx_sql = self._get_index_sql(model, fields, safe=safe) @@ -300,15 +345,28 @@ def _get_m2m_tables( ) -> list[str]: m2m_tables_for_create = [] for m2m_field in model._meta.m2m_fields: - field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field]) + field_object = cast( + "ManyToManyFieldInstance", model._meta.fields_map[m2m_field] + ) if field_object._generated or field_object.through in models_tables: continue - backward_key, forward_key = field_object.backward_key, field_object.forward_key + backward_key, forward_key = ( + field_object.backward_key, + field_object.forward_key, + ) if field_object.db_constraint: + # Get qualified table names for M2M foreign key references + qualified_backward_table = self._get_qualified_table_name(model).strip( + '"' + ) + qualified_forward_table = self._get_qualified_table_name( + field_object.related_model + ).strip('"') + backward_fk = self._create_fk_string( "", backward_key, - db_table, + qualified_backward_table, model._meta.db_pk_column, field_object.on_delete, "", @@ -316,7 +374,7 @@ def _get_m2m_tables( forward_fk = self._create_fk_string( "", forward_key, - field_object.related_model._meta.db_table, + qualified_forward_table, field_object.related_model._meta.db_pk_column, field_object.on_delete, "", @@ -325,14 +383,21 @@ def _get_m2m_tables( backward_fk = forward_fk = "" exists = "IF NOT EXISTS " if safe else "" through_table_name = field_object.through + qualified_through_table_name = self._get_qualified_m2m_table_name( + model, through_table_name + ) backward_type = self._get_pk_field_sql_type(model._meta.pk) - forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk) + forward_type = self._get_pk_field_sql_type( + field_object.related_model._meta.pk + ) comment = "" if desc := field_object.description: - comment = self._table_comment_generator(table=through_table_name, comment=desc) + comment = self._table_comment_generator( + table=through_table_name, comment=desc + ) m2m_create_string = self.M2M_TABLE_TEMPLATE.format( exists=exists, - table_name=through_table_name, + table_name=qualified_through_table_name, backward_fk=backward_fk, forward_fk=forward_fk, backward_key=backward_key, @@ -354,6 +419,11 @@ def _get_m2m_tables( unique_index_create_sql = self._get_unique_index_sql( exists, through_table_name, [backward_key, forward_key] ) + # Replace unqualified table name with qualified name in the SQL if schema exists + if model._meta.schema: + unique_index_create_sql = unique_index_create_sql.replace( + f'"{through_table_name}"', qualified_through_table_name + ) if unique_index_create_sql.endswith(";"): m2m_create_string += "\n" + unique_index_create_sql else: @@ -394,18 +464,26 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: references = set() models_to_create: list[type[Model]] = self._get_models_to_create() table_name = model._meta.db_table + qualified_table_name = self._get_qualified_table_name(model) models_tables = [model._meta.db_table for model in models_to_create] for field_name, column_name in model._meta.fields_db_projection.items(): field_object = model._meta.fields_map[field_name] comment = self._get_field_comment(field_object, table_name, column_name) - default = self._get_field_default(field_object, table_name, column_name, model) + default = self._get_field_default( + field_object, table_name, column_name, model + ) # TODO: PK generation needs to move out of schema generator. - if create_pk_field := self._get_pk_create_sql(field_object, column_name, comment): + if create_pk_field := self._get_pk_create_sql( + field_object, column_name, comment + ): fields_to_create.append(create_pk_field) continue - field_creation_string, related_table_name = self._get_field_sql_and_related_table( + ( + field_creation_string, + related_table_name, + ) = self._get_field_sql_and_related_table( field_object, table_name, column_name, default, comment ) if related_table_name: @@ -425,20 +503,24 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: self._get_unique_constraint_sql(model, unique_together_to_create) ) - field_indexes_sqls = self._get_field_indexes_sqls(model, fields_with_index, safe) + field_indexes_sqls = self._get_field_indexes_sqls( + model, fields_with_index, safe + ) fields_to_create.extend(self._get_inner_statements()) table_fields_string = "\n {}\n".format(",\n ".join(fields_to_create)) table_comment = ( - self._table_comment_generator(table=table_name, comment=model._meta.table_description) + self._table_comment_generator( + table=table_name, comment=model._meta.table_description + ) if model._meta.table_description else "" ) table_create_string = self.TABLE_CREATE_TEMPLATE.format( exists="IF NOT EXISTS " if safe else "", - table_name=table_name, + table_name=qualified_table_name, fields=table_fields_string, comment=table_comment, extra=self._table_generate_extra(table=table_name), @@ -448,7 +530,9 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: table_create_string += self._post_table_hook() - m2m_tables_for_create = self._get_m2m_tables(model, table_name, safe, models_tables) + m2m_tables_for_create = self._get_m2m_tables( + model, table_name, safe, models_tables + ) return { "table": table_name, @@ -491,13 +575,19 @@ def get_create_schema_sql(self, safe: bool = True) -> str: if t["references"].issubset(created_tables | {t["table"]}) ) except StopIteration: - raise ConfigurationError("Can't create schema due to cyclic fk references") + raise ConfigurationError( + "Can't create schema due to cyclic fk references" + ) tables_to_create.remove(next_table_for_create) created_tables.add(next_table_for_create["table"]) - ordered_tables_for_create.append(next_table_for_create["table_creation_string"]) + ordered_tables_for_create.append( + next_table_for_create["table_creation_string"] + ) m2m_tables_to_create += next_table_for_create["m2m_tables"] - schema_creation_string = "\n".join(ordered_tables_for_create + m2m_tables_to_create) + schema_creation_string = "\n".join( + ordered_tables_for_create + m2m_tables_to_create + ) return schema_creation_string async def generate_from_string(self, creation_string: str) -> None: diff --git a/tortoise/backends/base_postgres/schema_generator.py b/tortoise/backends/base_postgres/schema_generator.py index 72cbeeaf5..ce6adba8a 100644 --- a/tortoise/backends/base_postgres/schema_generator.py +++ b/tortoise/backends/base_postgres/schema_generator.py @@ -14,9 +14,11 @@ class BasePostgresSchemaGenerator(BaseSchemaGenerator): DIALECT = "postgres" INDEX_CREATE_TEMPLATE = ( - 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" {index_type}({fields}){extra};' + 'CREATE INDEX {exists}"{index_name}" ON {table_name} {index_type}({fields}){extra};' + ) + UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace( + "INDEX", "UNIQUE INDEX" ) - UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX") TABLE_COMMENT_TEMPLATE = "COMMENT ON TABLE \"{table}\" IS '{comment}';" COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table}"."{column}" IS \'{comment}\';' GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}' @@ -70,6 +72,20 @@ def _escape_default_value(self, default: Any): return default return encoders.get(type(default))(default) # type: ignore + def _get_create_schema_sql(self, schema: str, safe: bool = True) -> str: + """Generate CREATE SCHEMA SQL for PostgreSQL.""" + if safe: + return f'CREATE SCHEMA IF NOT EXISTS "{schema}";' + return f'CREATE SCHEMA "{schema}";' + + def _get_schemas_to_create(self) -> set[str]: + """Get all unique schemas that need to be created.""" + schemas = set() + for model in self._get_models_to_create(): + if model._meta.schema: + schemas.add(model._meta.schema) + return schemas + def _get_index_sql( self, model: type[Model], @@ -83,5 +99,31 @@ def _get_index_sql( index_type = f"USING {index_type}" return super()._get_index_sql( - model, field_names, safe, index_name=index_name, index_type=index_type, extra=extra + model, + field_names, + safe, + index_name=index_name, + index_type=index_type, + extra=extra, + ) + + def get_create_schema_sql(self, safe: bool = True) -> str: + """Generate complete schema creation SQL including schemas and tables.""" + # Get all schemas that need to be created + schemas_to_create = self._get_schemas_to_create() + + # Generate CREATE SCHEMA statements + schema_creation_sqls = [] + for schema in schemas_to_create: + schema_creation_sqls.append(self._get_create_schema_sql(schema, safe)) + + # Generate table creation SQL (from parent class) + table_creation_sql = super().get_create_schema_sql(safe) + + # Combine schema and table creation + all_sqls = ( + schema_creation_sqls + [table_creation_sql] + if table_creation_sql + else schema_creation_sqls ) + return "\n".join(all_sqls)