From a2ad5e5e3c7bfa11dd85136c66309612f4717c80 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 10 Nov 2025 09:50:17 -0800 Subject: [PATCH] fix: stricter default schema validation and proper transformation to Prisma schema --- .../src/validators/datasource-validator.ts | 26 ++++++-- .../sdk/src/prisma/prisma-schema-generator.ts | 63 ++++++++++++++---- .../orm/client-api/pg-custom-schema.test.ts | 65 +++++++++++++++++++ 3 files changed, 136 insertions(+), 18 deletions(-) diff --git a/packages/language/src/validators/datasource-validator.ts b/packages/language/src/validators/datasource-validator.ts index 9f6abd64..b667d2b2 100644 --- a/packages/language/src/validators/datasource-validator.ts +++ b/packages/language/src/validators/datasource-validator.ts @@ -39,10 +39,20 @@ export default class DataSourceValidator implements AstValidator { } const defaultSchemaField = ds.fields.find((f) => f.name === 'defaultSchema'); - if (defaultSchemaField && providerValue !== 'postgresql') { - accept('error', '"defaultSchema" is only supported for "postgresql" provider', { - node: defaultSchemaField, - }); + let defaultSchemaValue: string | undefined; + if (defaultSchemaField) { + if (providerValue !== 'postgresql') { + accept('error', '"defaultSchema" is only supported for "postgresql" provider', { + node: defaultSchemaField, + }); + } + + defaultSchemaValue = getStringLiteral(defaultSchemaField.value); + if (!defaultSchemaValue) { + accept('error', '"defaultSchema" must be a string literal', { + node: defaultSchemaField.value, + }); + } } const schemasField = ds.fields.find((f) => f.name === 'schemas'); @@ -60,6 +70,14 @@ export default class DataSourceValidator implements AstValidator { accept('error', '"schemas" must be an array of string literals', { node: schemasField, }); + } else if ( + // validate `defaultSchema` is included in `schemas` + defaultSchemaValue && + !schemasValue.items.some((e) => getStringLiteral(e) === defaultSchemaValue) + ) { + accept('error', `"${defaultSchemaValue}" must be included in the "schemas" array`, { + node: schemasField, + }); } } } diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index 3f3ba823..45ffed3c 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -19,6 +19,7 @@ import { InvocationExpr, isArrayExpr, isDataModel, + isDataSource, isInvocationExpr, isLiteralExpr, isNullExpr, @@ -29,9 +30,14 @@ import { Model, NumberLiteral, StringLiteral, - type AstNode, } from '@zenstackhq/language/ast'; -import { getAllAttributes, getAllFields, isAuthInvocation, isDelegateModel } from '@zenstackhq/language/utils'; +import { + getAllAttributes, + getAllFields, + getStringLiteral, + isAuthInvocation, + isDelegateModel, +} from '@zenstackhq/language/utils'; import { AstUtils } from 'langium'; import { match } from 'ts-pattern'; import { ModelUtils } from '..'; @@ -58,6 +64,9 @@ import { // Here we use a conservative value that should work for most cases, and truncate names if needed const IDENTIFIER_NAME_MAX_LENGTH = 50 - DELEGATE_AUX_RELATION_PREFIX.length; +// Datasource fields that only exist in ZModel but not in Prisma schema +const NON_PRISMA_DATASOURCE_FIELDS = ['defaultSchema']; + /** * Generates Prisma schema file */ @@ -101,10 +110,12 @@ export class PrismaSchemaGenerator { } private generateDataSource(prisma: PrismaModel, dataSource: DataSource) { - const fields: SimpleField[] = dataSource.fields.map((f) => ({ - name: f.name, - text: this.configExprToText(f.value), - })); + const fields: SimpleField[] = dataSource.fields + .filter((f) => !NON_PRISMA_DATASOURCE_FIELDS.includes(f.name)) + .map((f) => ({ + name: f.name, + text: this.configExprToText(f.value), + })); prisma.addDataSource(dataSource.name, fields); } @@ -171,13 +182,27 @@ export class PrismaSchemaGenerator { } } - const allAttributes = getAllAttributes(decl); - for (const attr of allAttributes.filter( + const allAttributes = getAllAttributes(decl).filter( (attr) => this.isPrismaAttribute(attr) && !this.isInheritedMapAttribute(attr, decl), - )) { + ); + + for (const attr of allAttributes) { this.generateContainerAttribute(model, attr); } + if ( + this.datasourceHasSchemasSetting(decl.$container) && + !allAttributes.some((attr) => attr.decl.ref?.name === '@@schema') + ) { + // if the datasource declared `schemas` and no @@schema attribute is defined, add a default one + model.addAttribute('@@schema', [ + new PrismaAttributeArg( + undefined, + new PrismaAttributeArgValue('String', this.getDefaultPostgresSchemaName(decl.$container)), + ), + ]); + } + // user defined comments pass-through decl.comments.forEach((c) => model.addComment(c)); @@ -188,6 +213,20 @@ export class PrismaSchemaGenerator { this.generateDelegateRelationForConcrete(model, decl); } + private getDatasourceField(zmodel: Model, fieldName: string) { + const dataSource = zmodel.declarations.find(isDataSource); + return dataSource?.fields.find((f) => f.name === fieldName); + } + + private datasourceHasSchemasSetting(zmodel: Model) { + return !!this.getDatasourceField(zmodel, 'schemas'); + } + + private getDefaultPostgresSchemaName(zmodel: Model) { + const defaultSchemaField = this.getDatasourceField(zmodel, 'defaultSchema'); + return getStringLiteral(defaultSchemaField?.value) ?? 'public'; + } + private isInheritedMapAttribute(attr: DataModelAttribute, contextModel: DataModel) { if (attr.$container === contextModel) { return false; @@ -206,7 +245,7 @@ export class PrismaSchemaGenerator { private getUnsupportedFieldType(fieldType: DataFieldType) { if (fieldType.unsupported) { - const value = this.getStringLiteral(fieldType.unsupported.value); + const value = getStringLiteral(fieldType.unsupported.value); if (value) { return `Unsupported("${value}")`; } else { @@ -217,10 +256,6 @@ export class PrismaSchemaGenerator { } } - private getStringLiteral(node: AstNode | undefined): string | undefined { - return isStringLiteral(node) ? node.value : undefined; - } - private generateModelField(model: PrismaDataModel, field: DataField, contextModel: DataModel, addToFront = false) { let fieldType: string | undefined; diff --git a/tests/e2e/orm/client-api/pg-custom-schema.test.ts b/tests/e2e/orm/client-api/pg-custom-schema.test.ts index 7b9252ce..4308e864 100644 --- a/tests/e2e/orm/client-api/pg-custom-schema.test.ts +++ b/tests/e2e/orm/client-api/pg-custom-schema.test.ts @@ -193,4 +193,69 @@ model Foo { ), ).rejects.toThrow('Schema "mySchema" is not defined in the datasource'); }); + + it('requires defaultSchema to be included in schemas', async () => { + await expect( + createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'mySchema' + schemas = ['public'] +} + +model Foo { + id Int @id + name String +} +`, + ), + ).rejects.toThrow('"mySchema" must be included in the "schemas" array'); + }); + + it('allows specifying schema only on a few models', async () => { + let fooQueriesVerified = false; + let barQueriesVerified = false; + + const db = await createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'somedefault' + schemas = ['mySchema', 'somedefault'] + url = '$DB_URL' +} + +model Foo { + id Int @id + name String + @@schema('mySchema') +} + +model Bar { + id Int @id + name String +} +`, + { + provider: 'postgresql', + usePrismaPush: true, + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"myschema"."foo"')) { + fooQueriesVerified = true; + } + if (sql.includes('"somedefault"."bar"')) { + barQueriesVerified = true; + } + }, + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + await expect(db.bar.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + + expect(fooQueriesVerified).toBe(true); + expect(barQueriesVerified).toBe(true); + }); });