-
-
Notifications
You must be signed in to change notification settings - Fork 439
Fix schema creation for non-default PostgreSQL schemas #1979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Pritish053
wants to merge
5
commits into
tortoise:develop
Choose a base branch
from
Pritish053:fix-postgresql-schema-creation
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
d2d86da
Fix schema creation for non-default PostgreSQL schemas
Pritish053 5e9d528
Clean up documentation files for cleaner PR
Pritish053 5fd7d42
Fix line length issues for static analysis compliance
Pritish053 fb80641
Clean up test file formatting for static analysis
Pritish053 5c35bc7
Add changelog entry for PostgreSQL schema fix
Pritish053 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Tests for automatic PostgreSQL schema creation functionality.""" | ||
|
||
from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator | ||
from tortoise.contrib import test | ||
|
||
|
||
class TestPostgresSchemaCreation(test.TestCase): | ||
"""Test automatic PostgreSQL schema creation.""" | ||
|
||
def test_postgres_schema_creation_sql(self): | ||
"""Test that BasePostgresSchemaGenerator can create schema SQL.""" | ||
# Mock client for testing | ||
class MockClient: | ||
def __init__(self): | ||
self.capabilities = type('obj', (object,), { | ||
'inline_comment': False, | ||
'safe': True | ||
})() | ||
|
||
mock_client = MockClient() | ||
generator = BasePostgresSchemaGenerator(mock_client) | ||
|
||
# Test schema creation SQL generation | ||
schema_sql = generator._get_create_schema_sql("pgdev", safe=True) | ||
self.assertEqual(schema_sql, 'CREATE SCHEMA IF NOT EXISTS "pgdev";') | ||
|
||
schema_sql_unsafe = generator._get_create_schema_sql("pgdev", safe=False) | ||
self.assertEqual(schema_sql_unsafe, 'CREATE SCHEMA "pgdev";') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,17 +24,21 @@ | |
|
||
class BaseSchemaGenerator: | ||
DIALECT = "sql" | ||
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};' | ||
TABLE_CREATE_TEMPLATE = ( | ||
"CREATE TABLE {exists}{table_name} ({fields}){extra}{comment};" | ||
) | ||
FIELD_TEMPLATE = '"{name}" {type}{nullable}{unique}{primary}{default}{comment}' | ||
INDEX_CREATE_TEMPLATE = ( | ||
'CREATE {index_type}INDEX {exists}"{index_name}" ON "{table_name}" ({fields}){extra};' | ||
'CREATE {index_type}INDEX {exists}"{index_name}" ON {table_name} ({fields}){extra};' | ||
) | ||
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace( | ||
"INDEX", "UNIQUE INDEX" | ||
) | ||
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX") | ||
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})' | ||
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}' | ||
FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}' | ||
M2M_TABLE_TEMPLATE = ( | ||
'CREATE TABLE {exists}"{table_name}" (\n' | ||
"CREATE TABLE {exists}{table_name} (\n" | ||
' "{backward_key}" {backward_type} NOT NULL{backward_fk},\n' | ||
' "{forward_key}" {forward_type} NOT NULL{forward_fk}\n' | ||
"){extra}{comment};" | ||
|
@@ -75,7 +79,11 @@ def _create_fk_string( | |
comment: str, | ||
) -> str: | ||
return self.FK_TEMPLATE.format( | ||
db_column=db_column, table=table, field=field, on_delete=on_delete, comment=comment | ||
db_column=db_column, | ||
table=table, | ||
field=field, | ||
on_delete=on_delete, | ||
comment=comment, | ||
) | ||
|
||
def _table_comment_generator(self, table: str, comment: str) -> str: | ||
|
@@ -139,6 +147,21 @@ def _get_inner_statements(self) -> list[str]: | |
def quote(self, val: str) -> str: | ||
return f'"{val}"' | ||
|
||
def _get_qualified_table_name(self, model: type[Model]) -> str: | ||
"""Get the fully qualified table name including schema if present.""" | ||
table_name = model._meta.db_table | ||
if model._meta.schema: | ||
return f'"{model._meta.schema}"."{table_name}"' | ||
return f'"{table_name}"' | ||
|
||
def _get_qualified_m2m_table_name( | ||
self, model: type[Model], through_table_name: str | ||
) -> str: | ||
"""Get the fully qualified M2M table name including schema if present.""" | ||
if model._meta.schema: | ||
return f'"{model._meta.schema}"."{through_table_name}"' | ||
return f'"{through_table_name}"' | ||
|
||
@staticmethod | ||
def _make_hash(*args: str, length: int) -> str: | ||
# Hash a set of string values and get a digest of the given length. | ||
|
@@ -156,8 +179,11 @@ def _get_index_name( | |
hashed = self._make_hash(table_name, *field_names, length=6) | ||
return f"{prefix}_{table}_{field}_{hashed}" | ||
|
||
def _get_fk_name(self, from_table: str, from_field: str, to_table: str, to_field: str) -> str: | ||
# NOTE: for compatibility, index name should not be longer than 30 characters (Oracle limit). | ||
def _get_fk_name( | ||
self, from_table: str, from_field: str, to_table: str, to_field: str | ||
) -> str: | ||
# NOTE: for compatibility, index name should not be longer than 30 characters | ||
# (Oracle limit). | ||
# That's why we slice some of the strings here. | ||
hashed = self._make_hash(from_table, from_field, to_table, to_field, length=8) | ||
return f"fk_{from_table[:8]}_{to_table[:8]}_{hashed}" | ||
|
@@ -175,7 +201,7 @@ def _get_index_sql( | |
exists="IF NOT EXISTS " if safe else "", | ||
index_name=index_name or self._get_index_name("idx", model, field_names), | ||
index_type=f"{index_type} " if index_type else "", | ||
table_name=model._meta.db_table, | ||
table_name=self._get_qualified_table_name(model), | ||
fields=", ".join([self.quote(f) for f in field_names]), | ||
extra=f"{extra}" if extra else "", | ||
) | ||
|
@@ -184,16 +210,19 @@ def _get_unique_index_sql( | |
self, exists: str, table_name: str, field_names: Sequence[str] | ||
) -> str: | ||
index_name = self._get_index_name("uidx", table_name, field_names) | ||
quoted_table_name = self.quote(table_name) | ||
return self.UNIQUE_INDEX_CREATE_TEMPLATE.format( | ||
exists=exists, | ||
index_name=index_name, | ||
index_type="", | ||
table_name=table_name, | ||
table_name=quoted_table_name, | ||
fields=", ".join([self.quote(f) for f in field_names]), | ||
extra="", | ||
) | ||
|
||
def _get_unique_constraint_sql(self, model: type[Model], field_names: Sequence[str]) -> str: | ||
def _get_unique_constraint_sql( | ||
self, model: type[Model], field_names: Sequence[str] | ||
) -> str: | ||
return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format( | ||
index_name=self._get_index_name("uid", model, field_names), | ||
fields=", ".join([self.quote(f) for f in field_names]), | ||
|
@@ -206,7 +235,9 @@ def _get_pk_field_sql_type(self, pk_field: Field) -> str: | |
return sql_type | ||
raise ConfigurationError(f"Can't get SQL type of {pk_field} for {self.DIALECT}") | ||
|
||
def _get_pk_create_sql(self, field_object: Field, column_name: str, comment: str) -> str: | ||
def _get_pk_create_sql( | ||
self, field_object: Field, column_name: str, comment: str | ||
) -> str: | ||
if field_object.pk and field_object.generated: | ||
generated_sql = field_object.get_for_dialect(self.DIALECT, "GENERATED_SQL") | ||
if generated_sql: # pragma: nobranch | ||
|
@@ -217,15 +248,22 @@ def _get_pk_create_sql(self, field_object: Field, column_name: str, comment: str | |
) | ||
return "" | ||
|
||
def _get_field_comment(self, field_object: Field, table_name: str, column_name: str) -> str: | ||
def _get_field_comment( | ||
self, field_object: Field, table_name: str, column_name: str | ||
) -> str: | ||
if desc := field_object.description: | ||
return self._column_comment_generator( | ||
table=table_name, column=column_name, comment=desc | ||
) | ||
return "" | ||
|
||
def _get_field_sql_and_related_table( | ||
self, field_object: Field, table_name: str, column_name: str, default: str, comment: str | ||
self, | ||
field_object: Field, | ||
table_name: str, | ||
column_name: str, | ||
default: str, | ||
comment: str, | ||
) -> tuple[str, str]: | ||
nullable = " NOT NULL" if not field_object.null else "" | ||
unique = " UNIQUE" if field_object.unique else "" | ||
|
@@ -241,6 +279,10 @@ def _get_field_sql_and_related_table( | |
to_field_name = reference.to_field_instance.model_field_name | ||
|
||
related_table_name = reference.related_model._meta.db_table | ||
# Get qualified table name for FK reference if related model has schema | ||
qualified_related_table_name = self._get_qualified_table_name( | ||
reference.related_model | ||
).strip('"') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would you remove the leading and trailing |
||
if reference.db_constraint: | ||
field_creation_string = self._create_string( | ||
db_column=column_name, | ||
|
@@ -258,7 +300,7 @@ def _get_field_sql_and_related_table( | |
to_field_name, | ||
), | ||
db_column=column_name, | ||
table=related_table_name, | ||
table=qualified_related_table_name, | ||
field=to_field_name, | ||
on_delete=reference.on_delete, | ||
comment=comment, | ||
|
@@ -278,15 +320,18 @@ def _get_field_sql_and_related_table( | |
def _get_field_indexes_sqls( | ||
self, model: type[Model], field_names: Sequence[str], safe: bool | ||
) -> list[str]: | ||
indexes = [self._get_index_sql(model, [field], safe=safe) for field in field_names] | ||
indexes = [ | ||
self._get_index_sql(model, [field], safe=safe) for field in field_names | ||
] | ||
|
||
if model._meta.indexes: | ||
for index in model._meta.indexes: | ||
if isinstance(index, Index): | ||
idx_sql = index.get_sql(self, model, safe) | ||
else: | ||
fields = [ | ||
model._meta.fields_map[field].source_field or field for field in index | ||
model._meta.fields_map[field].source_field or field | ||
for field in index | ||
] | ||
idx_sql = self._get_index_sql(model, fields, safe=safe) | ||
|
||
|
@@ -300,23 +345,36 @@ def _get_m2m_tables( | |
) -> list[str]: | ||
m2m_tables_for_create = [] | ||
for m2m_field in model._meta.m2m_fields: | ||
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field]) | ||
field_object = cast( | ||
"ManyToManyFieldInstance", model._meta.fields_map[m2m_field] | ||
) | ||
if field_object._generated or field_object.through in models_tables: | ||
continue | ||
backward_key, forward_key = field_object.backward_key, field_object.forward_key | ||
backward_key, forward_key = ( | ||
field_object.backward_key, | ||
field_object.forward_key, | ||
) | ||
if field_object.db_constraint: | ||
# Get qualified table names for M2M foreign key references | ||
qualified_backward_table = self._get_qualified_table_name(model).strip( | ||
'"' | ||
) | ||
qualified_forward_table = self._get_qualified_table_name( | ||
field_object.related_model | ||
).strip('"') | ||
|
||
backward_fk = self._create_fk_string( | ||
"", | ||
backward_key, | ||
db_table, | ||
qualified_backward_table, | ||
model._meta.db_pk_column, | ||
field_object.on_delete, | ||
"", | ||
) | ||
forward_fk = self._create_fk_string( | ||
"", | ||
forward_key, | ||
field_object.related_model._meta.db_table, | ||
qualified_forward_table, | ||
field_object.related_model._meta.db_pk_column, | ||
field_object.on_delete, | ||
"", | ||
|
@@ -325,14 +383,21 @@ def _get_m2m_tables( | |
backward_fk = forward_fk = "" | ||
exists = "IF NOT EXISTS " if safe else "" | ||
through_table_name = field_object.through | ||
qualified_through_table_name = self._get_qualified_m2m_table_name( | ||
model, through_table_name | ||
) | ||
backward_type = self._get_pk_field_sql_type(model._meta.pk) | ||
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk) | ||
forward_type = self._get_pk_field_sql_type( | ||
field_object.related_model._meta.pk | ||
) | ||
comment = "" | ||
if desc := field_object.description: | ||
comment = self._table_comment_generator(table=through_table_name, comment=desc) | ||
comment = self._table_comment_generator( | ||
table=through_table_name, comment=desc | ||
) | ||
m2m_create_string = self.M2M_TABLE_TEMPLATE.format( | ||
exists=exists, | ||
table_name=through_table_name, | ||
table_name=qualified_through_table_name, | ||
backward_fk=backward_fk, | ||
forward_fk=forward_fk, | ||
backward_key=backward_key, | ||
|
@@ -354,6 +419,11 @@ def _get_m2m_tables( | |
unique_index_create_sql = self._get_unique_index_sql( | ||
exists, through_table_name, [backward_key, forward_key] | ||
) | ||
# Replace unqualified table name with qualified name in the SQL if schema exists | ||
if model._meta.schema: | ||
unique_index_create_sql = unique_index_create_sql.replace( | ||
f'"{through_table_name}"', qualified_through_table_name | ||
) | ||
if unique_index_create_sql.endswith(";"): | ||
m2m_create_string += "\n" + unique_index_create_sql | ||
else: | ||
|
@@ -394,18 +464,26 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: | |
references = set() | ||
models_to_create: list[type[Model]] = self._get_models_to_create() | ||
table_name = model._meta.db_table | ||
qualified_table_name = self._get_qualified_table_name(model) | ||
models_tables = [model._meta.db_table for model in models_to_create] | ||
for field_name, column_name in model._meta.fields_db_projection.items(): | ||
field_object = model._meta.fields_map[field_name] | ||
comment = self._get_field_comment(field_object, table_name, column_name) | ||
default = self._get_field_default(field_object, table_name, column_name, model) | ||
default = self._get_field_default( | ||
field_object, table_name, column_name, model | ||
) | ||
|
||
# TODO: PK generation needs to move out of schema generator. | ||
if create_pk_field := self._get_pk_create_sql(field_object, column_name, comment): | ||
if create_pk_field := self._get_pk_create_sql( | ||
field_object, column_name, comment | ||
): | ||
fields_to_create.append(create_pk_field) | ||
continue | ||
|
||
field_creation_string, related_table_name = self._get_field_sql_and_related_table( | ||
( | ||
field_creation_string, | ||
related_table_name, | ||
) = self._get_field_sql_and_related_table( | ||
field_object, table_name, column_name, default, comment | ||
) | ||
if related_table_name: | ||
|
@@ -425,20 +503,24 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: | |
self._get_unique_constraint_sql(model, unique_together_to_create) | ||
) | ||
|
||
field_indexes_sqls = self._get_field_indexes_sqls(model, fields_with_index, safe) | ||
field_indexes_sqls = self._get_field_indexes_sqls( | ||
model, fields_with_index, safe | ||
) | ||
|
||
fields_to_create.extend(self._get_inner_statements()) | ||
|
||
table_fields_string = "\n {}\n".format(",\n ".join(fields_to_create)) | ||
table_comment = ( | ||
self._table_comment_generator(table=table_name, comment=model._meta.table_description) | ||
self._table_comment_generator( | ||
table=table_name, comment=model._meta.table_description | ||
) | ||
if model._meta.table_description | ||
else "" | ||
) | ||
|
||
table_create_string = self.TABLE_CREATE_TEMPLATE.format( | ||
exists="IF NOT EXISTS " if safe else "", | ||
table_name=table_name, | ||
table_name=qualified_table_name, | ||
fields=table_fields_string, | ||
comment=table_comment, | ||
extra=self._table_generate_extra(table=table_name), | ||
|
@@ -448,7 +530,9 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: | |
|
||
table_create_string += self._post_table_hook() | ||
|
||
m2m_tables_for_create = self._get_m2m_tables(model, table_name, safe, models_tables) | ||
m2m_tables_for_create = self._get_m2m_tables( | ||
model, table_name, safe, models_tables | ||
) | ||
|
||
return { | ||
"table": table_name, | ||
|
@@ -491,13 +575,19 @@ def get_create_schema_sql(self, safe: bool = True) -> str: | |
if t["references"].issubset(created_tables | {t["table"]}) | ||
) | ||
except StopIteration: | ||
raise ConfigurationError("Can't create schema due to cyclic fk references") | ||
raise ConfigurationError( | ||
"Can't create schema due to cyclic fk references" | ||
) | ||
tables_to_create.remove(next_table_for_create) | ||
created_tables.add(next_table_for_create["table"]) | ||
ordered_tables_for_create.append(next_table_for_create["table_creation_string"]) | ||
ordered_tables_for_create.append( | ||
next_table_for_create["table_creation_string"] | ||
) | ||
m2m_tables_to_create += next_table_for_create["m2m_tables"] | ||
|
||
schema_creation_string = "\n".join(ordered_tables_for_create + m2m_tables_to_create) | ||
schema_creation_string = "\n".join( | ||
ordered_tables_for_create + m2m_tables_to_create | ||
) | ||
return schema_creation_string | ||
|
||
async def generate_from_string(self, creation_string: str) -> None: | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.25.1 has been already released, you need to add
0.25.2 (unreleased)
and put your changes there.