Skip to content

Commit 7a207bf

Browse files
authored
fix: stricter default schema validation and proper transformation to Prisma schema (#390)
1 parent 267f98d commit 7a207bf

File tree

3 files changed

+136
-18
lines changed

3 files changed

+136
-18
lines changed

packages/language/src/validators/datasource-validator.ts

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,20 @@ export default class DataSourceValidator implements AstValidator<DataSource> {
3939
}
4040

4141
const defaultSchemaField = ds.fields.find((f) => f.name === 'defaultSchema');
42-
if (defaultSchemaField && providerValue !== 'postgresql') {
43-
accept('error', '"defaultSchema" is only supported for "postgresql" provider', {
44-
node: defaultSchemaField,
45-
});
42+
let defaultSchemaValue: string | undefined;
43+
if (defaultSchemaField) {
44+
if (providerValue !== 'postgresql') {
45+
accept('error', '"defaultSchema" is only supported for "postgresql" provider', {
46+
node: defaultSchemaField,
47+
});
48+
}
49+
50+
defaultSchemaValue = getStringLiteral(defaultSchemaField.value);
51+
if (!defaultSchemaValue) {
52+
accept('error', '"defaultSchema" must be a string literal', {
53+
node: defaultSchemaField.value,
54+
});
55+
}
4656
}
4757

4858
const schemasField = ds.fields.find((f) => f.name === 'schemas');
@@ -60,6 +70,14 @@ export default class DataSourceValidator implements AstValidator<DataSource> {
6070
accept('error', '"schemas" must be an array of string literals', {
6171
node: schemasField,
6272
});
73+
} else if (
74+
// validate `defaultSchema` is included in `schemas`
75+
defaultSchemaValue &&
76+
!schemasValue.items.some((e) => getStringLiteral(e) === defaultSchemaValue)
77+
) {
78+
accept('error', `"${defaultSchemaValue}" must be included in the "schemas" array`, {
79+
node: schemasField,
80+
});
6381
}
6482
}
6583
}

packages/sdk/src/prisma/prisma-schema-generator.ts

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {
1919
InvocationExpr,
2020
isArrayExpr,
2121
isDataModel,
22+
isDataSource,
2223
isInvocationExpr,
2324
isLiteralExpr,
2425
isNullExpr,
@@ -29,9 +30,14 @@ import {
2930
Model,
3031
NumberLiteral,
3132
StringLiteral,
32-
type AstNode,
3333
} from '@zenstackhq/language/ast';
34-
import { getAllAttributes, getAllFields, isAuthInvocation, isDelegateModel } from '@zenstackhq/language/utils';
34+
import {
35+
getAllAttributes,
36+
getAllFields,
37+
getStringLiteral,
38+
isAuthInvocation,
39+
isDelegateModel,
40+
} from '@zenstackhq/language/utils';
3541
import { AstUtils } from 'langium';
3642
import { match } from 'ts-pattern';
3743
import { ModelUtils } from '..';
@@ -58,6 +64,9 @@ import {
5864
// Here we use a conservative value that should work for most cases, and truncate names if needed
5965
const IDENTIFIER_NAME_MAX_LENGTH = 50 - DELEGATE_AUX_RELATION_PREFIX.length;
6066

67+
// Datasource fields that only exist in ZModel but not in Prisma schema
68+
const NON_PRISMA_DATASOURCE_FIELDS = ['defaultSchema'];
69+
6170
/**
6271
* Generates Prisma schema file
6372
*/
@@ -101,10 +110,12 @@ export class PrismaSchemaGenerator {
101110
}
102111

103112
private generateDataSource(prisma: PrismaModel, dataSource: DataSource) {
104-
const fields: SimpleField[] = dataSource.fields.map((f) => ({
105-
name: f.name,
106-
text: this.configExprToText(f.value),
107-
}));
113+
const fields: SimpleField[] = dataSource.fields
114+
.filter((f) => !NON_PRISMA_DATASOURCE_FIELDS.includes(f.name))
115+
.map((f) => ({
116+
name: f.name,
117+
text: this.configExprToText(f.value),
118+
}));
108119
prisma.addDataSource(dataSource.name, fields);
109120
}
110121

@@ -171,13 +182,27 @@ export class PrismaSchemaGenerator {
171182
}
172183
}
173184

174-
const allAttributes = getAllAttributes(decl);
175-
for (const attr of allAttributes.filter(
185+
const allAttributes = getAllAttributes(decl).filter(
176186
(attr) => this.isPrismaAttribute(attr) && !this.isInheritedMapAttribute(attr, decl),
177-
)) {
187+
);
188+
189+
for (const attr of allAttributes) {
178190
this.generateContainerAttribute(model, attr);
179191
}
180192

193+
if (
194+
this.datasourceHasSchemasSetting(decl.$container) &&
195+
!allAttributes.some((attr) => attr.decl.ref?.name === '@@schema')
196+
) {
197+
// if the datasource declared `schemas` and no @@schema attribute is defined, add a default one
198+
model.addAttribute('@@schema', [
199+
new PrismaAttributeArg(
200+
undefined,
201+
new PrismaAttributeArgValue('String', this.getDefaultPostgresSchemaName(decl.$container)),
202+
),
203+
]);
204+
}
205+
181206
// user defined comments pass-through
182207
decl.comments.forEach((c) => model.addComment(c));
183208

@@ -188,6 +213,20 @@ export class PrismaSchemaGenerator {
188213
this.generateDelegateRelationForConcrete(model, decl);
189214
}
190215

216+
private getDatasourceField(zmodel: Model, fieldName: string) {
217+
const dataSource = zmodel.declarations.find(isDataSource);
218+
return dataSource?.fields.find((f) => f.name === fieldName);
219+
}
220+
221+
private datasourceHasSchemasSetting(zmodel: Model) {
222+
return !!this.getDatasourceField(zmodel, 'schemas');
223+
}
224+
225+
private getDefaultPostgresSchemaName(zmodel: Model) {
226+
const defaultSchemaField = this.getDatasourceField(zmodel, 'defaultSchema');
227+
return getStringLiteral(defaultSchemaField?.value) ?? 'public';
228+
}
229+
191230
private isInheritedMapAttribute(attr: DataModelAttribute, contextModel: DataModel) {
192231
if (attr.$container === contextModel) {
193232
return false;
@@ -206,7 +245,7 @@ export class PrismaSchemaGenerator {
206245

207246
private getUnsupportedFieldType(fieldType: DataFieldType) {
208247
if (fieldType.unsupported) {
209-
const value = this.getStringLiteral(fieldType.unsupported.value);
248+
const value = getStringLiteral(fieldType.unsupported.value);
210249
if (value) {
211250
return `Unsupported("${value}")`;
212251
} else {
@@ -217,10 +256,6 @@ export class PrismaSchemaGenerator {
217256
}
218257
}
219258

220-
private getStringLiteral(node: AstNode | undefined): string | undefined {
221-
return isStringLiteral(node) ? node.value : undefined;
222-
}
223-
224259
private generateModelField(model: PrismaDataModel, field: DataField, contextModel: DataModel, addToFront = false) {
225260
let fieldType: string | undefined;
226261

tests/e2e/orm/client-api/pg-custom-schema.test.ts

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,69 @@ model Foo {
193193
),
194194
).rejects.toThrow('Schema "mySchema" is not defined in the datasource');
195195
});
196+
197+
it('requires defaultSchema to be included in schemas', async () => {
198+
await expect(
199+
createTestClient(
200+
`
201+
datasource db {
202+
provider = 'postgresql'
203+
defaultSchema = 'mySchema'
204+
schemas = ['public']
205+
}
206+
207+
model Foo {
208+
id Int @id
209+
name String
210+
}
211+
`,
212+
),
213+
).rejects.toThrow('"mySchema" must be included in the "schemas" array');
214+
});
215+
216+
it('allows specifying schema only on a few models', async () => {
217+
let fooQueriesVerified = false;
218+
let barQueriesVerified = false;
219+
220+
const db = await createTestClient(
221+
`
222+
datasource db {
223+
provider = 'postgresql'
224+
defaultSchema = 'somedefault'
225+
schemas = ['mySchema', 'somedefault']
226+
url = '$DB_URL'
227+
}
228+
229+
model Foo {
230+
id Int @id
231+
name String
232+
@@schema('mySchema')
233+
}
234+
235+
model Bar {
236+
id Int @id
237+
name String
238+
}
239+
`,
240+
{
241+
provider: 'postgresql',
242+
usePrismaPush: true,
243+
log: (event) => {
244+
const sql = event.query.sql.toLowerCase();
245+
if (sql.includes('"myschema"."foo"')) {
246+
fooQueriesVerified = true;
247+
}
248+
if (sql.includes('"somedefault"."bar"')) {
249+
barQueriesVerified = true;
250+
}
251+
},
252+
},
253+
);
254+
255+
await expect(db.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy();
256+
await expect(db.bar.create({ data: { id: 1, name: 'test' } })).toResolveTruthy();
257+
258+
expect(fooQueriesVerified).toBe(true);
259+
expect(barQueriesVerified).toBe(true);
260+
});
196261
});

0 commit comments

Comments
 (0)