|
5 | 5 | """ |
6 | 6 | from dataclasses import fields |
7 | 7 | from decimal import Decimal |
8 | | -from typing import Annotated |
| 8 | +from typing import Annotated, Type |
9 | 9 |
|
10 | 10 | from sqlalchemy import Dialect, MetaData, String |
11 | 11 | import sqlalchemy.types as types |
@@ -51,37 +51,41 @@ class ScBase(MappedAsDataclass, DeclarativeBase): |
51 | 51 | }) |
52 | 52 |
|
53 | 53 | def __init_subclass__(cls) -> None: |
54 | | - if settings.TABLE_PREFIX: |
55 | | - # Check if the prefix has already been applied |
56 | | - # unless the class is a test mapper class |
57 | | - if ( |
58 | | - cls.__name__.startswith('Test') |
59 | | - or not cls.__tablename__.startswith(settings.TABLE_PREFIX) |
60 | | - ): |
61 | | - # Prefix the table name |
62 | | - cls.__tablename__ = f"{settings.TABLE_PREFIX}_{cls.__tablename__}" |
63 | | - |
64 | | - for dfield in fields(cls): |
65 | | - field = getattr(cls, dfield.name, None) |
66 | | - if field and hasattr(field, 'foreign_keys'): |
67 | | - # Prefix all foreign key column references |
68 | | - for fk in field.foreign_keys: |
69 | | - if not fk._colspec.startswith(settings.TABLE_PREFIX): |
70 | | - fk._colspec = f"{settings.TABLE_PREFIX}_{fk._colspec}" |
71 | | - if fk.name and not fk.name.startswith(settings.TABLE_PREFIX): |
72 | | - fk.name = f"{settings.TABLE_PREFIX}_{fk.name}" |
73 | | - |
| 54 | + prefix_tables_and_foreign_keys(cls) |
74 | 55 | super().__init_subclass__() |
75 | 56 |
|
76 | 57 |
|
77 | | -def create_test_mappers(): |
| 58 | +def prefix_tables_and_foreign_keys(subclass: Type[DeclarativeBase]): |
| 59 | + if settings.TABLE_PREFIX: |
| 60 | + assert issubclass(subclass, MappedAsDataclass) |
| 61 | + # Check if the prefix has already been applied |
| 62 | + # unless the class is a test mapper class |
| 63 | + if ( |
| 64 | + subclass.__name__.startswith('Test') |
| 65 | + or not subclass.__tablename__.startswith(settings.TABLE_PREFIX) |
| 66 | + ): |
| 67 | + # Prefix the table name |
| 68 | + subclass.__tablename__ = f"{settings.TABLE_PREFIX}_{subclass.__tablename__}" |
| 69 | + |
| 70 | + for dfield in fields(subclass): |
| 71 | + field = getattr(subclass, dfield.name, None) |
| 72 | + if field and hasattr(field, 'foreign_keys'): |
| 73 | + # Prefix all foreign key column references |
| 74 | + for fk in field.foreign_keys: |
| 75 | + if not fk._colspec.startswith(settings.TABLE_PREFIX): |
| 76 | + fk._colspec = f"{settings.TABLE_PREFIX}_{fk._colspec}" |
| 77 | + if fk.name and not fk.name.startswith(settings.TABLE_PREFIX): |
| 78 | + fk.name = f"{settings.TABLE_PREFIX}_{fk.name}" |
| 79 | + |
| 80 | + |
| 81 | +def create_test_mappers(base_class: Type[DeclarativeBase]): |
78 | 82 | """ |
79 | 83 | Create testnet versions of all registered models. |
80 | 84 | These mappers are not needed at runtime, but they must exist when generating migrations. |
81 | 85 | """ |
82 | 86 | models = [ |
83 | 87 | mapper.class_ |
84 | | - for mapper in ScBase.registry.mappers |
| 88 | + for mapper in base_class.registry.mappers |
85 | 89 | ] |
86 | 90 | # TODO: generate indexes for test tables |
87 | 91 | settings.TABLE_PREFIX = "test" |
|
0 commit comments