Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ Added
- Add `no_key` parameter to `queryset.select_for_update`.
- `F()` supports referencing JSONField attributes, e.g. `F("json_field__custom_field__nested_id")` (#1960)

Fixed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.25.1 has been already released, you need to add 0.25.2 (unreleased) and put your changes there.

^^^^^
- Fix PostgreSQL schema creation for non-default schemas - automatically create schemas if they don't exist (#1671)

0.25.0
------
Fixed
Expand Down
28 changes: 28 additions & 0 deletions tests/schema/test_schema_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Tests for automatic PostgreSQL schema creation functionality."""

from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator
from tortoise.contrib import test


class TestPostgresSchemaCreation(test.TestCase):
"""Test automatic PostgreSQL schema creation."""

def test_postgres_schema_creation_sql(self):
"""Test that BasePostgresSchemaGenerator can create schema SQL."""
# Mock client for testing
class MockClient:
def __init__(self):
self.capabilities = type('obj', (object,), {
'inline_comment': False,
'safe': True
})()

mock_client = MockClient()
generator = BasePostgresSchemaGenerator(mock_client)

# Test schema creation SQL generation
schema_sql = generator._get_create_schema_sql("pgdev", safe=True)
self.assertEqual(schema_sql, 'CREATE SCHEMA IF NOT EXISTS "pgdev";')

schema_sql_unsafe = generator._get_create_schema_sql("pgdev", safe=False)
self.assertEqual(schema_sql_unsafe, 'CREATE SCHEMA "pgdev";')
156 changes: 123 additions & 33 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@

class BaseSchemaGenerator:
DIALECT = "sql"
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};'
TABLE_CREATE_TEMPLATE = (
"CREATE TABLE {exists}{table_name} ({fields}){extra}{comment};"
)
FIELD_TEMPLATE = '"{name}" {type}{nullable}{unique}{primary}{default}{comment}'
INDEX_CREATE_TEMPLATE = (
'CREATE {index_type}INDEX {exists}"{index_name}" ON "{table_name}" ({fields}){extra};'
'CREATE {index_type}INDEX {exists}"{index_name}" ON {table_name} ({fields}){extra};'
)
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace(
"INDEX", "UNIQUE INDEX"
)
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX")
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}'
FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}'
M2M_TABLE_TEMPLATE = (
'CREATE TABLE {exists}"{table_name}" (\n'
"CREATE TABLE {exists}{table_name} (\n"
' "{backward_key}" {backward_type} NOT NULL{backward_fk},\n'
' "{forward_key}" {forward_type} NOT NULL{forward_fk}\n'
"){extra}{comment};"
Expand Down Expand Up @@ -75,7 +79,11 @@ def _create_fk_string(
comment: str,
) -> str:
return self.FK_TEMPLATE.format(
db_column=db_column, table=table, field=field, on_delete=on_delete, comment=comment
db_column=db_column,
table=table,
field=field,
on_delete=on_delete,
comment=comment,
)

def _table_comment_generator(self, table: str, comment: str) -> str:
Expand Down Expand Up @@ -139,6 +147,21 @@ def _get_inner_statements(self) -> list[str]:
def quote(self, val: str) -> str:
return f'"{val}"'

def _get_qualified_table_name(self, model: type[Model]) -> str:
"""Get the fully qualified table name including schema if present."""
table_name = model._meta.db_table
if model._meta.schema:
return f'"{model._meta.schema}"."{table_name}"'
return f'"{table_name}"'

def _get_qualified_m2m_table_name(
self, model: type[Model], through_table_name: str
) -> str:
"""Get the fully qualified M2M table name including schema if present."""
if model._meta.schema:
return f'"{model._meta.schema}"."{through_table_name}"'
return f'"{through_table_name}"'

@staticmethod
def _make_hash(*args: str, length: int) -> str:
# Hash a set of string values and get a digest of the given length.
Expand All @@ -156,8 +179,11 @@ def _get_index_name(
hashed = self._make_hash(table_name, *field_names, length=6)
return f"{prefix}_{table}_{field}_{hashed}"

def _get_fk_name(self, from_table: str, from_field: str, to_table: str, to_field: str) -> str:
# NOTE: for compatibility, index name should not be longer than 30 characters (Oracle limit).
def _get_fk_name(
self, from_table: str, from_field: str, to_table: str, to_field: str
) -> str:
# NOTE: for compatibility, index name should not be longer than 30 characters
# (Oracle limit).
# That's why we slice some of the strings here.
hashed = self._make_hash(from_table, from_field, to_table, to_field, length=8)
return f"fk_{from_table[:8]}_{to_table[:8]}_{hashed}"
Expand All @@ -175,7 +201,7 @@ def _get_index_sql(
exists="IF NOT EXISTS " if safe else "",
index_name=index_name or self._get_index_name("idx", model, field_names),
index_type=f"{index_type} " if index_type else "",
table_name=model._meta.db_table,
table_name=self._get_qualified_table_name(model),
fields=", ".join([self.quote(f) for f in field_names]),
extra=f"{extra}" if extra else "",
)
Expand All @@ -184,16 +210,19 @@ def _get_unique_index_sql(
self, exists: str, table_name: str, field_names: Sequence[str]
) -> str:
index_name = self._get_index_name("uidx", table_name, field_names)
quoted_table_name = self.quote(table_name)
return self.UNIQUE_INDEX_CREATE_TEMPLATE.format(
exists=exists,
index_name=index_name,
index_type="",
table_name=table_name,
table_name=quoted_table_name,
fields=", ".join([self.quote(f) for f in field_names]),
extra="",
)

def _get_unique_constraint_sql(self, model: type[Model], field_names: Sequence[str]) -> str:
def _get_unique_constraint_sql(
self, model: type[Model], field_names: Sequence[str]
) -> str:
return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format(
index_name=self._get_index_name("uid", model, field_names),
fields=", ".join([self.quote(f) for f in field_names]),
Expand All @@ -206,7 +235,9 @@ def _get_pk_field_sql_type(self, pk_field: Field) -> str:
return sql_type
raise ConfigurationError(f"Can't get SQL type of {pk_field} for {self.DIALECT}")

def _get_pk_create_sql(self, field_object: Field, column_name: str, comment: str) -> str:
def _get_pk_create_sql(
self, field_object: Field, column_name: str, comment: str
) -> str:
if field_object.pk and field_object.generated:
generated_sql = field_object.get_for_dialect(self.DIALECT, "GENERATED_SQL")
if generated_sql: # pragma: nobranch
Expand All @@ -217,15 +248,22 @@ def _get_pk_create_sql(self, field_object: Field, column_name: str, comment: str
)
return ""

def _get_field_comment(self, field_object: Field, table_name: str, column_name: str) -> str:
def _get_field_comment(
self, field_object: Field, table_name: str, column_name: str
) -> str:
if desc := field_object.description:
return self._column_comment_generator(
table=table_name, column=column_name, comment=desc
)
return ""

def _get_field_sql_and_related_table(
self, field_object: Field, table_name: str, column_name: str, default: str, comment: str
self,
field_object: Field,
table_name: str,
column_name: str,
default: str,
comment: str,
) -> tuple[str, str]:
nullable = " NOT NULL" if not field_object.null else ""
unique = " UNIQUE" if field_object.unique else ""
Expand All @@ -241,6 +279,10 @@ def _get_field_sql_and_related_table(
to_field_name = reference.to_field_instance.model_field_name

related_table_name = reference.related_model._meta.db_table
# Get qualified table name for FK reference if related model has schema
qualified_related_table_name = self._get_qualified_table_name(
reference.related_model
).strip('"')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you remove the leading and trailing " but leave them in the middle around .?

if reference.db_constraint:
field_creation_string = self._create_string(
db_column=column_name,
Expand All @@ -258,7 +300,7 @@ def _get_field_sql_and_related_table(
to_field_name,
),
db_column=column_name,
table=related_table_name,
table=qualified_related_table_name,
field=to_field_name,
on_delete=reference.on_delete,
comment=comment,
Expand All @@ -278,15 +320,18 @@ def _get_field_sql_and_related_table(
def _get_field_indexes_sqls(
self, model: type[Model], field_names: Sequence[str], safe: bool
) -> list[str]:
indexes = [self._get_index_sql(model, [field], safe=safe) for field in field_names]
indexes = [
self._get_index_sql(model, [field], safe=safe) for field in field_names
]

if model._meta.indexes:
for index in model._meta.indexes:
if isinstance(index, Index):
idx_sql = index.get_sql(self, model, safe)
else:
fields = [
model._meta.fields_map[field].source_field or field for field in index
model._meta.fields_map[field].source_field or field
for field in index
]
idx_sql = self._get_index_sql(model, fields, safe=safe)

Expand All @@ -300,23 +345,36 @@ def _get_m2m_tables(
) -> list[str]:
m2m_tables_for_create = []
for m2m_field in model._meta.m2m_fields:
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
field_object = cast(
"ManyToManyFieldInstance", model._meta.fields_map[m2m_field]
)
if field_object._generated or field_object.through in models_tables:
continue
backward_key, forward_key = field_object.backward_key, field_object.forward_key
backward_key, forward_key = (
field_object.backward_key,
field_object.forward_key,
)
if field_object.db_constraint:
# Get qualified table names for M2M foreign key references
qualified_backward_table = self._get_qualified_table_name(model).strip(
'"'
)
qualified_forward_table = self._get_qualified_table_name(
field_object.related_model
).strip('"')

backward_fk = self._create_fk_string(
"",
backward_key,
db_table,
qualified_backward_table,
model._meta.db_pk_column,
field_object.on_delete,
"",
)
forward_fk = self._create_fk_string(
"",
forward_key,
field_object.related_model._meta.db_table,
qualified_forward_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
"",
Expand All @@ -325,14 +383,21 @@ def _get_m2m_tables(
backward_fk = forward_fk = ""
exists = "IF NOT EXISTS " if safe else ""
through_table_name = field_object.through
qualified_through_table_name = self._get_qualified_m2m_table_name(
model, through_table_name
)
backward_type = self._get_pk_field_sql_type(model._meta.pk)
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk)
forward_type = self._get_pk_field_sql_type(
field_object.related_model._meta.pk
)
comment = ""
if desc := field_object.description:
comment = self._table_comment_generator(table=through_table_name, comment=desc)
comment = self._table_comment_generator(
table=through_table_name, comment=desc
)
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists=exists,
table_name=through_table_name,
table_name=qualified_through_table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
backward_key=backward_key,
Expand All @@ -354,6 +419,11 @@ def _get_m2m_tables(
unique_index_create_sql = self._get_unique_index_sql(
exists, through_table_name, [backward_key, forward_key]
)
# Replace unqualified table name with qualified name in the SQL if schema exists
if model._meta.schema:
unique_index_create_sql = unique_index_create_sql.replace(
f'"{through_table_name}"', qualified_through_table_name
)
if unique_index_create_sql.endswith(";"):
m2m_create_string += "\n" + unique_index_create_sql
else:
Expand Down Expand Up @@ -394,18 +464,26 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict:
references = set()
models_to_create: list[type[Model]] = self._get_models_to_create()
table_name = model._meta.db_table
qualified_table_name = self._get_qualified_table_name(model)
models_tables = [model._meta.db_table for model in models_to_create]
for field_name, column_name in model._meta.fields_db_projection.items():
field_object = model._meta.fields_map[field_name]
comment = self._get_field_comment(field_object, table_name, column_name)
default = self._get_field_default(field_object, table_name, column_name, model)
default = self._get_field_default(
field_object, table_name, column_name, model
)

# TODO: PK generation needs to move out of schema generator.
if create_pk_field := self._get_pk_create_sql(field_object, column_name, comment):
if create_pk_field := self._get_pk_create_sql(
field_object, column_name, comment
):
fields_to_create.append(create_pk_field)
continue

field_creation_string, related_table_name = self._get_field_sql_and_related_table(
(
field_creation_string,
related_table_name,
) = self._get_field_sql_and_related_table(
field_object, table_name, column_name, default, comment
)
if related_table_name:
Expand All @@ -425,20 +503,24 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict:
self._get_unique_constraint_sql(model, unique_together_to_create)
)

field_indexes_sqls = self._get_field_indexes_sqls(model, fields_with_index, safe)
field_indexes_sqls = self._get_field_indexes_sqls(
model, fields_with_index, safe
)

fields_to_create.extend(self._get_inner_statements())

table_fields_string = "\n {}\n".format(",\n ".join(fields_to_create))
table_comment = (
self._table_comment_generator(table=table_name, comment=model._meta.table_description)
self._table_comment_generator(
table=table_name, comment=model._meta.table_description
)
if model._meta.table_description
else ""
)

table_create_string = self.TABLE_CREATE_TEMPLATE.format(
exists="IF NOT EXISTS " if safe else "",
table_name=table_name,
table_name=qualified_table_name,
fields=table_fields_string,
comment=table_comment,
extra=self._table_generate_extra(table=table_name),
Expand All @@ -448,7 +530,9 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict:

table_create_string += self._post_table_hook()

m2m_tables_for_create = self._get_m2m_tables(model, table_name, safe, models_tables)
m2m_tables_for_create = self._get_m2m_tables(
model, table_name, safe, models_tables
)

return {
"table": table_name,
Expand Down Expand Up @@ -491,13 +575,19 @@ def get_create_schema_sql(self, safe: bool = True) -> str:
if t["references"].issubset(created_tables | {t["table"]})
)
except StopIteration:
raise ConfigurationError("Can't create schema due to cyclic fk references")
raise ConfigurationError(
"Can't create schema due to cyclic fk references"
)
tables_to_create.remove(next_table_for_create)
created_tables.add(next_table_for_create["table"])
ordered_tables_for_create.append(next_table_for_create["table_creation_string"])
ordered_tables_for_create.append(
next_table_for_create["table_creation_string"]
)
m2m_tables_to_create += next_table_for_create["m2m_tables"]

schema_creation_string = "\n".join(ordered_tables_for_create + m2m_tables_to_create)
schema_creation_string = "\n".join(
ordered_tables_for_create + m2m_tables_to_create
)
return schema_creation_string

async def generate_from_string(self, creation_string: str) -> None:
Expand Down
Loading
Loading