Skip to content

Commit 2f17553

Browse files
authored
feat: postgres custom schema (#385)
* feat: postgres custom schema * update
1 parent fd5db63 commit 2f17553

File tree

10 files changed

+336
-37
lines changed

10 files changed

+336
-37
lines changed

packages/language/res/stdlib.zmodel

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,12 +454,12 @@ attribute @db.JsonB() @@@targetField([JsonField]) @@@prisma
454454

455455
attribute @db.ByteA() @@@targetField([BytesField]) @@@prisma
456456

457-
// /**
458-
// * Specifies the schema to use in a multi-schema database. https://www.prisma.io/docs/guides/database/multi-schema.
459-
// *
460-
// * @param: The name of the database schema.
461-
// */
462-
// attribute @@schema(_ name: String) @@@prisma
457+
/**
458+
* Specifies the schema to use in a multi-schema PostgreSQL database.
459+
*
460+
* @param name: The name of the database schema.
461+
*/
462+
attribute @@schema(_ name: String) @@@prisma
463463

464464
//////////////////////////////////////////////
465465
// Begin validation attributes and functions

packages/language/src/validators/attribute-application-validator.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { invariant } from '@zenstackhq/common-helpers';
12
import { AstUtils, type ValidationAcceptor } from 'langium';
23
import pluralize from 'pluralize';
34
import type { BinaryExpr, DataModel, Expression } from '../ast';
@@ -13,9 +14,13 @@ import {
1314
ReferenceExpr,
1415
isArrayExpr,
1516
isAttribute,
17+
isConfigArrayExpr,
1618
isDataField,
1719
isDataModel,
20+
isDataSource,
1821
isEnum,
22+
isLiteralExpr,
23+
isModel,
1924
isReferenceExpr,
2025
isTypeDef,
2126
} from '../generated/ast';
@@ -332,6 +337,28 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
332337
}
333338
}
334339

340+
@check('@@schema')
341+
private _checkSchema(attr: AttributeApplication, accept: ValidationAcceptor) {
342+
const schemaName = getStringLiteral(attr.args[0]?.value);
343+
invariant(schemaName, `@@schema expects a string literal`);
344+
345+
// verify the schema name is defined in the datasource
346+
const zmodel = AstUtils.getContainerOfType(attr, isModel)!;
347+
const datasource = zmodel.declarations.find(isDataSource);
348+
if (datasource) {
349+
let found = false;
350+
const schemas = datasource.fields.find((f) => f.name === 'schemas');
351+
if (schemas && isConfigArrayExpr(schemas.value)) {
352+
found = schemas.value.items.some((item) => isLiteralExpr(item) && item.value === schemaName);
353+
}
354+
if (!found) {
355+
accept('error', `Schema "${schemaName}" is not defined in the datasource`, {
356+
node: attr,
357+
});
358+
}
359+
}
360+
}
361+
335362
private validatePolicyKinds(
336363
kind: string,
337364
candidates: string[],
Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { ValidationAcceptor } from 'langium';
22
import { SUPPORTED_PROVIDERS } from '../constants';
3-
import { DataSource, isInvocationExpr } from '../generated/ast';
3+
import { DataSource, isConfigArrayExpr, isInvocationExpr, isLiteralExpr } from '../generated/ast';
44
import { getStringLiteral } from '../utils';
55
import { validateDuplicatedDeclarations, type AstValidator } from './common';
66

@@ -12,7 +12,6 @@ export default class DataSourceValidator implements AstValidator<DataSource> {
1212
validateDuplicatedDeclarations(ds, ds.fields, accept);
1313
this.validateProvider(ds, accept);
1414
this.validateUrl(ds, accept);
15-
this.validateRelationMode(ds, accept);
1615
}
1716

1817
private validateProvider(ds: DataSource, accept: ValidationAcceptor) {
@@ -24,20 +23,45 @@ export default class DataSourceValidator implements AstValidator<DataSource> {
2423
return;
2524
}
2625

27-
const value = getStringLiteral(provider.value);
28-
if (!value) {
26+
const providerValue = getStringLiteral(provider.value);
27+
if (!providerValue) {
2928
accept('error', '"provider" must be set to a string literal', {
3029
node: provider.value,
3130
});
32-
} else if (!SUPPORTED_PROVIDERS.includes(value)) {
31+
} else if (!SUPPORTED_PROVIDERS.includes(providerValue)) {
3332
accept(
3433
'error',
35-
`Provider "${value}" is not supported. Choose from ${SUPPORTED_PROVIDERS.map((p) => '"' + p + '"').join(
36-
' | ',
37-
)}.`,
34+
`Provider "${providerValue}" is not supported. Choose from ${SUPPORTED_PROVIDERS.map(
35+
(p) => '"' + p + '"',
36+
).join(' | ')}.`,
3837
{ node: provider.value },
3938
);
4039
}
40+
41+
const defaultSchemaField = ds.fields.find((f) => f.name === 'defaultSchema');
42+
if (defaultSchemaField && providerValue !== 'postgresql') {
43+
accept('error', '"defaultSchema" is only supported for "postgresql" provider', {
44+
node: defaultSchemaField,
45+
});
46+
}
47+
48+
const schemasField = ds.fields.find((f) => f.name === 'schemas');
49+
if (schemasField) {
50+
if (providerValue !== 'postgresql') {
51+
accept('error', '"schemas" is only supported for "postgresql" provider', {
52+
node: schemasField,
53+
});
54+
}
55+
const schemasValue = schemasField.value;
56+
if (
57+
!isConfigArrayExpr(schemasValue) ||
58+
!schemasValue.items.every((e) => isLiteralExpr(e) && typeof getStringLiteral(e) === 'string')
59+
) {
60+
accept('error', '"schemas" must be an array of string literals', {
61+
node: schemasField,
62+
});
63+
}
64+
}
4165
}
4266

4367
private validateUrl(ds: DataSource, accept: ValidationAcceptor) {
@@ -53,14 +77,4 @@ export default class DataSourceValidator implements AstValidator<DataSource> {
5377
});
5478
}
5579
}
56-
57-
private validateRelationMode(ds: DataSource, accept: ValidationAcceptor) {
58-
const field = ds.fields.find((f) => f.name === 'relationMode');
59-
if (field) {
60-
const val = getStringLiteral(field.value);
61-
if (!val || !['foreignKeys', 'prisma'].includes(val)) {
62-
accept('error', '"relationMode" must be set to "foreignKeys" or "prisma"', { node: field.value });
63-
}
64-
}
65-
}
6680
}

packages/orm/src/client/executor/name-mapper.ts

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
129129
mappedTableName = this.mapTableName(scope.model);
130130
}
131131
}
132-
133132
return ReferenceNode.create(
134133
ColumnNode.create(mappedFieldName),
135-
mappedTableName ? TableNode.create(mappedTableName) : undefined,
134+
mappedTableName ? this.createTableNode(mappedTableName, undefined) : undefined,
136135
);
137136
} else {
138137
// no name mapping needed
@@ -316,7 +315,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
316315
if (!TableNode.is(node)) {
317316
return super.transformNode(node);
318317
}
319-
return TableNode.create(this.mapTableName(node.table.identifier.name));
318+
const mappedName = this.mapTableName(node.table.identifier.name);
319+
const tableSchema = this.getTableSchema(node.table.identifier.name);
320+
return this.createTableNode(mappedName, tableSchema);
320321
}
321322

322323
private getMappedName(def: ModelDef | FieldDef) {
@@ -362,8 +363,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
362363
const modelName = innerNode.table.identifier.name;
363364
const mappedName = this.mapTableName(modelName);
364365
const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined);
366+
const tableSchema = this.getTableSchema(modelName);
365367
return {
366-
node: this.wrapAlias(TableNode.create(mappedName), finalAlias),
368+
node: this.wrapAlias(this.createTableNode(mappedName, tableSchema), finalAlias),
367369
scope: {
368370
alias: alias ?? IdentifierNode.create(modelName),
369371
model: modelName,
@@ -384,6 +386,21 @@ export class QueryNameMapper extends OperationNodeTransformer {
384386
}
385387
}
386388

389+
private getTableSchema(model: string) {
390+
if (this.schema.provider.type !== 'postgresql') {
391+
return undefined;
392+
}
393+
let schema = this.schema.provider.defaultSchema ?? 'public';
394+
const schemaAttr = this.schema.models[model]?.attributes?.find((attr) => attr.name === '@@schema');
395+
if (schemaAttr) {
396+
const nameArg = schemaAttr.args?.find((arg) => arg.name === 'name');
397+
if (nameArg && nameArg.value.kind === 'literal') {
398+
schema = nameArg.value.value as string;
399+
}
400+
}
401+
return schema;
402+
}
403+
387404
private createSelectAllFields(model: string, alias: OperationNode | undefined) {
388405
const modelDef = requireModel(this.schema, model);
389406
return this.getModelFields(modelDef).map((fieldDef) => {
@@ -454,5 +471,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
454471
});
455472
}
456473

474+
private createTableNode(tableName: string, schemaName: string | undefined) {
475+
return schemaName ? TableNode.createWithSchema(schemaName, tableName) : TableNode.create(tableName);
476+
}
477+
457478
// #endregion
458479
}

packages/orm/src/client/executor/zenstack-query-executor.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
5555
) {
5656
super(compiler, adapter, connectionProvider, plugins);
5757

58-
if (this.schemaHasMappedNames(client.$schema)) {
58+
if (
59+
client.$schema.provider.type === 'postgresql' || // postgres queries need to be schema-qualified
60+
this.schemaHasMappedNames(client.$schema)
61+
) {
5962
this.nameMapper = new QueryNameMapper(client.$schema);
6063
}
6164
}

packages/schema/src/schema.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export type DataSourceProviderType = 'sqlite' | 'postgresql';
55

66
export type DataSourceProvider = {
77
type: DataSourceProviderType;
8+
defaultSchema?: string;
89
};
910

1011
export type SchemaDef = {

packages/sdk/src/ts-schema-generator.ts

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,20 @@ export class TsSchemaGenerator {
236236

237237
private createProviderObject(model: Model): ts.Expression {
238238
const dsProvider = this.getDataSourceProvider(model);
239+
const defaultSchema = this.getDataSourceDefaultSchema(model);
240+
239241
return ts.factory.createObjectLiteralExpression(
240-
[ts.factory.createPropertyAssignment('type', ts.factory.createStringLiteral(dsProvider.type))],
242+
[
243+
ts.factory.createPropertyAssignment('type', ts.factory.createStringLiteral(dsProvider)),
244+
...(defaultSchema
245+
? [
246+
ts.factory.createPropertyAssignment(
247+
'defaultSchema',
248+
ts.factory.createStringLiteral(defaultSchema),
249+
),
250+
]
251+
: []),
252+
],
241253
true,
242254
);
243255
}
@@ -621,9 +633,26 @@ export class TsSchemaGenerator {
621633
invariant(dataSource, 'No data source found in the model');
622634

623635
const providerExpr = dataSource.fields.find((f) => f.name === 'provider')?.value;
624-
invariant(isLiteralExpr(providerExpr), 'Provider must be a literal');
625-
const type = providerExpr.value as string;
626-
return { type };
636+
invariant(
637+
isLiteralExpr(providerExpr) && typeof providerExpr.value === 'string',
638+
'Provider must be a string literal',
639+
);
640+
return providerExpr.value as string;
641+
}
642+
643+
private getDataSourceDefaultSchema(model: Model) {
644+
const dataSource = model.declarations.find(isDataSource);
645+
invariant(dataSource, 'No data source found in the model');
646+
647+
const defaultSchemaExpr = dataSource.fields.find((f) => f.name === 'defaultSchema')?.value;
648+
if (!defaultSchemaExpr) {
649+
return undefined;
650+
}
651+
invariant(
652+
isLiteralExpr(defaultSchemaExpr) && typeof defaultSchemaExpr.value === 'string',
653+
'Default schema must be a string literal',
654+
);
655+
return defaultSchemaExpr.value as string;
627656
}
628657

629658
private getFieldMappedDefault(

packages/testtools/src/client.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { invariant } from '@zenstackhq/common-helpers';
22
import type { Model } from '@zenstackhq/language/ast';
3-
import { PolicyPlugin } from '@zenstackhq/plugin-policy';
43
import { ZenStackClient, type ClientContract, type ClientOptions } from '@zenstackhq/orm';
54
import type { SchemaDef } from '@zenstackhq/orm/schema';
5+
import { PolicyPlugin } from '@zenstackhq/plugin-policy';
66
import { PrismaSchemaGenerator } from '@zenstackhq/sdk';
77
import SQLite from 'better-sqlite3';
88
import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely';
@@ -59,7 +59,6 @@ export async function createTestClient<Schema extends SchemaDef>(
5959
let _schema: Schema;
6060
const provider = options?.provider ?? getTestDbProvider() ?? 'sqlite';
6161
const dbName = options?.dbName ?? getTestDbName(provider);
62-
6362
const dbUrl =
6463
provider === 'sqlite'
6564
? `file:${dbName}`
@@ -68,13 +67,14 @@ export async function createTestClient<Schema extends SchemaDef>(
6867
let model: Model | undefined;
6968

7069
if (typeof schema === 'string') {
71-
const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles);
70+
const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles, undefined);
7271
workDir = generated.workDir;
7372
model = generated.model;
7473
// replace schema's provider
7574
_schema = {
7675
...generated.schema,
7776
provider: {
77+
...generated.schema.provider,
7878
type: provider,
7979
},
8080
} as Schema;

packages/testtools/src/schema.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ datasource db {
3232
.exhaustive();
3333
}
3434

35+
function replacePlaceholders(schemaText: string, provider: 'sqlite' | 'postgresql', dbUrl: string | undefined) {
36+
const url = dbUrl ?? (provider === 'sqlite' ? 'file:./test.db' : 'postgres://postgres:postgres@localhost:5432/db');
37+
return schemaText.replace(/\$DB_URL/g, url).replace(/\$PROVIDER/g, provider);
38+
}
39+
3540
export async function generateTsSchema(
3641
schemaText: string,
3742
provider: 'sqlite' | 'postgresql' = 'sqlite',
@@ -43,7 +48,10 @@ export async function generateTsSchema(
4348

4449
const zmodelPath = path.join(workDir, 'schema.zmodel');
4550
const noPrelude = schemaText.includes('datasource ');
46-
fs.writeFileSync(zmodelPath, `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${schemaText}`);
51+
fs.writeFileSync(
52+
zmodelPath,
53+
`${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${replacePlaceholders(schemaText, provider, dbUrl)}`,
54+
);
4755

4856
const result = await loadDocumentWithPlugins(zmodelPath);
4957
if (!result.success) {

0 commit comments

Comments
 (0)