Skip to content

Commit 6df80b2

Browse files
authored
feat: support @@validate in type declarations (#1868)
1 parent c7f333d commit 6df80b2

File tree

8 files changed

+149
-134
lines changed

8 files changed

+149
-134
lines changed

packages/schema/src/language-server/validator/function-invocation-validator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import {
1515
} from '@zenstackhq/language/ast';
1616
import {
1717
ExpressionContext,
18-
getDataModelFieldReference,
18+
getFieldReference,
1919
getFunctionExpressionContext,
2020
getLiteral,
2121
isDataModelFieldReference,
@@ -96,7 +96,7 @@ export default class FunctionInvocationValidator implements AstValidator<Express
9696
// first argument must refer to a model field
9797
const firstArg = expr.args?.[0]?.value;
9898
if (firstArg) {
99-
if (!getDataModelFieldReference(firstArg)) {
99+
if (!getFieldReference(firstArg)) {
100100
accept('error', 'first argument must be a field reference', { node: firstArg });
101101
}
102102
}

packages/schema/src/language-server/zmodel-scope.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,14 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
137137
const node = context.container as MemberAccessExpr;
138138

139139
// typedef's fields are only added to the scope if the access starts with `auth().`
140-
const allowTypeDefScope = isAuthOrAuthMemberAccess(node.operand);
140+
// or the member access resides inside a typedef
141+
const allowTypeDefScope = isAuthOrAuthMemberAccess(node.operand) || !!getContainerOfType(node, isTypeDef);
141142

142143
return match(node.operand)
143144
.when(isReferenceExpr, (operand) => {
144145
// operand is a reference, it can only be a model/type-def field
145146
const ref = operand.target.ref;
146-
if (isDataModelField(ref)) {
147+
if (isDataModelField(ref) || isTypeDefField(ref)) {
147148
return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope);
148149
}
149150
return EMPTY_SCOPE;

packages/schema/src/plugins/prisma/schema-generator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ import {
3030
ReferenceExpr,
3131
StringLiteral,
3232
} from '@zenstackhq/language/ast';
33+
import { getIdFields } from '@zenstackhq/sdk';
3334
import { getPrismaVersion } from '@zenstackhq/sdk/prisma';
3435
import { match } from 'ts-pattern';
35-
import { getIdFields } from '../../utils/ast-utils';
3636

3737
import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime';
3838
import {

packages/schema/src/plugins/zod/generator.ts

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import {
2+
ExpressionContext,
23
PluginError,
34
PluginGlobalOptions,
45
PluginOptions,
56
RUNTIME_PACKAGE,
7+
TypeScriptExpressionTransformer,
8+
TypeScriptExpressionTransformerError,
69
ensureEmptyDir,
10+
getAttributeArg,
11+
getAttributeArgLiteral,
712
getDataModels,
13+
getLiteralArray,
814
hasAttribute,
15+
isDataModelFieldReference,
916
isDiscriminatorField,
1017
isEnumFieldReference,
1118
isForeignKeyField,
@@ -15,7 +22,7 @@ import {
1522
resolvePath,
1623
saveSourceFile,
1724
} from '@zenstackhq/sdk';
18-
import { DataModel, EnumField, Model, TypeDef, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
25+
import { DataModel, EnumField, Model, TypeDef, isArrayExpr, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
1926
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
2027
import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
2128
import { streamAllContents } from 'langium';
@@ -26,7 +33,7 @@ import { name } from '.';
2633
import { getDefaultOutputFolder } from '../plugin-utils';
2734
import Transformer from './transformer';
2835
import { ObjectMode } from './types';
29-
import { makeFieldSchema, makeValidationRefinements } from './utils/schema-gen';
36+
import { makeFieldSchema } from './utils/schema-gen';
3037

3138
export class ZodSchemaGenerator {
3239
private readonly sourceFiles: SourceFile[] = [];
@@ -294,7 +301,7 @@ export class ZodSchemaGenerator {
294301
sf.replaceWithText((writer) => {
295302
this.addPreludeAndImports(typeDef, writer, output);
296303

297-
writer.write(`export const ${typeDef.name}Schema = z.object(`);
304+
writer.write(`const baseSchema = z.object(`);
298305
writer.inlineBlock(() => {
299306
typeDef.fields.forEach((field) => {
300307
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
@@ -313,9 +320,24 @@ export class ZodSchemaGenerator {
313320
writer.writeLine(').strict();');
314321
break;
315322
}
316-
});
317323

318-
// TODO: "@@validate" refinements
324+
// compile "@@validate" to a function calling zod's `.refine()`
325+
const refineFuncName = this.createRefineFunction(typeDef, writer);
326+
327+
if (refineFuncName) {
328+
// export a schema without refinement for extensibility: `[Model]WithoutRefineSchema`
329+
const noRefineSchema = `${upperCaseFirst(typeDef.name)}WithoutRefineSchema`;
330+
writer.writeLine(`
331+
/**
332+
* \`${typeDef.name}\` schema prior to calling \`.refine()\` for extensibility.
333+
*/
334+
export const ${noRefineSchema} = baseSchema;
335+
export const ${typeDef.name}Schema = ${refineFuncName}(${noRefineSchema});
336+
`);
337+
} else {
338+
writer.writeLine(`export const ${typeDef.name}Schema = baseSchema;`);
339+
}
340+
});
319341

320342
return schemaName;
321343
}
@@ -436,22 +458,7 @@ export class ZodSchemaGenerator {
436458
}
437459

438460
// compile "@@validate" to ".refine"
439-
const refinements = makeValidationRefinements(model);
440-
let refineFuncName: string | undefined;
441-
if (refinements.length > 0) {
442-
refineFuncName = `refine${upperCaseFirst(model.name)}`;
443-
writer.writeLine(
444-
`
445-
/**
446-
* Schema refinement function for applying \`@@validate\` rules.
447-
*/
448-
export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
449-
'\n'
450-
)};
451-
}
452-
`
453-
);
454-
}
461+
const refineFuncName = this.createRefineFunction(model, writer);
455462

456463
// delegate discriminator fields are to be excluded from mutation schemas
457464
const delegateDiscriminatorFields = model.fields.filter((field) => isDiscriminatorField(field));
@@ -658,6 +665,74 @@ export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};
658665
return schemaName;
659666
}
660667

668+
private createRefineFunction(decl: DataModel | TypeDef, writer: CodeBlockWriter) {
669+
const refinements = this.makeValidationRefinements(decl);
670+
let refineFuncName: string | undefined;
671+
if (refinements.length > 0) {
672+
refineFuncName = `refine${upperCaseFirst(decl.name)}`;
673+
writer.writeLine(
674+
`
675+
/**
676+
* Schema refinement function for applying \`@@validate\` rules.
677+
*/
678+
export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
679+
'\n'
680+
)};
681+
}
682+
`
683+
);
684+
return refineFuncName;
685+
} else {
686+
return undefined;
687+
}
688+
}
689+
690+
private makeValidationRefinements(decl: DataModel | TypeDef) {
691+
const attrs = decl.attributes.filter((attr) => attr.decl.ref?.name === '@@validate');
692+
const refinements = attrs
693+
.map((attr) => {
694+
const valueArg = getAttributeArg(attr, 'value');
695+
if (!valueArg) {
696+
return undefined;
697+
}
698+
699+
const messageArg = getAttributeArgLiteral<string>(attr, 'message');
700+
const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : '';
701+
702+
const pathArg = getAttributeArg(attr, 'path');
703+
const path =
704+
pathArg && isArrayExpr(pathArg)
705+
? `path: ['${getLiteralArray<string>(pathArg)?.join(`', '`)}'],`
706+
: '';
707+
708+
const options = `, { ${message} ${path} }`;
709+
710+
try {
711+
let expr = new TypeScriptExpressionTransformer({
712+
context: ExpressionContext.ValidationRule,
713+
fieldReferenceContext: 'value',
714+
}).transform(valueArg);
715+
716+
if (isDataModelFieldReference(valueArg)) {
717+
// if the expression is a simple field reference, treat undefined
718+
// as true since the all fields are optional in validation context
719+
expr = `${expr} ?? true`;
720+
}
721+
722+
return `.refine((value: any) => ${expr}${options})`;
723+
} catch (err) {
724+
if (err instanceof TypeScriptExpressionTransformerError) {
725+
throw new PluginError(name, err.message);
726+
} else {
727+
throw err;
728+
}
729+
}
730+
})
731+
.filter((r) => !!r);
732+
733+
return refinements;
734+
}
735+
661736
private makePartial(schema: string, fields?: string[]) {
662737
if (fields) {
663738
if (fields.length === 0) {

packages/schema/src/plugins/zod/utils/schema-gen.ts

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,7 @@
1+
import { getLiteral, isFromStdlib } from '@zenstackhq/sdk';
12
import {
2-
ExpressionContext,
3-
getAttributeArg,
4-
getAttributeArgLiteral,
5-
getLiteral,
6-
getLiteralArray,
7-
isDataModelFieldReference,
8-
isFromStdlib,
9-
PluginError,
10-
TypeScriptExpressionTransformer,
11-
TypeScriptExpressionTransformerError,
12-
} from '@zenstackhq/sdk';
13-
import {
14-
DataModel,
153
DataModelField,
164
DataModelFieldAttribute,
17-
isArrayExpr,
185
isBooleanLiteral,
196
isDataModel,
207
isEnum,
@@ -25,7 +12,6 @@ import {
2512
TypeDefField,
2613
} from '@zenstackhq/sdk/ast';
2714
import { upperCaseFirst } from 'upper-case-first';
28-
import { name } from '..';
2915
import { isDefaultWithAuth } from '../../enhancer/enhancer-utils';
3016

3117
export function makeFieldSchema(field: DataModelField | TypeDefField) {
@@ -222,50 +208,6 @@ function makeZodSchema(field: DataModelField | TypeDefField) {
222208
return schema;
223209
}
224210

225-
export function makeValidationRefinements(model: DataModel) {
226-
const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === '@@validate');
227-
const refinements = attrs
228-
.map((attr) => {
229-
const valueArg = getAttributeArg(attr, 'value');
230-
if (!valueArg) {
231-
return undefined;
232-
}
233-
234-
const messageArg = getAttributeArgLiteral<string>(attr, 'message');
235-
const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : '';
236-
237-
const pathArg = getAttributeArg(attr, 'path');
238-
const path =
239-
pathArg && isArrayExpr(pathArg) ? `path: ['${getLiteralArray<string>(pathArg)?.join(`', '`)}'],` : '';
240-
241-
const options = `, { ${message} ${path} }`;
242-
243-
try {
244-
let expr = new TypeScriptExpressionTransformer({
245-
context: ExpressionContext.ValidationRule,
246-
fieldReferenceContext: 'value',
247-
}).transform(valueArg);
248-
249-
if (isDataModelFieldReference(valueArg)) {
250-
// if the expression is a simple field reference, treat undefined
251-
// as true since the all fields are optional in validation context
252-
expr = `${expr} ?? true`;
253-
}
254-
255-
return `.refine((value: any) => ${expr}${options})`;
256-
} catch (err) {
257-
if (err instanceof TypeScriptExpressionTransformerError) {
258-
throw new PluginError(name, err.message);
259-
} else {
260-
throw err;
261-
}
262-
}
263-
})
264-
.filter((r) => !!r);
265-
266-
return refinements;
267-
}
268-
269211
function getAttrLiteralArg<T extends string | number>(attr: DataModelFieldAttribute, paramName: string) {
270212
const arg = attr.args.find((arg) => arg.$resolvedParam?.name === paramName);
271213
return arg && getLiteral<T>(arg.value);

packages/schema/src/utils/ast-utils.ts

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,19 @@ import {
22
BinaryExpr,
33
DataModel,
44
DataModelAttribute,
5-
DataModelField,
65
Expression,
76
InheritableNode,
8-
isArrayExpr,
97
isBinaryExpr,
108
isDataModel,
119
isDataModelField,
1210
isInvocationExpr,
13-
isMemberAccessExpr,
1411
isModel,
15-
isReferenceExpr,
1612
isTypeDef,
1713
Model,
1814
ModelImport,
19-
ReferenceExpr,
2015
TypeDef,
2116
} from '@zenstackhq/language/ast';
22-
import {
23-
getInheritanceChain,
24-
getModelFieldsWithBases,
25-
getRecursiveBases,
26-
isDelegateModel,
27-
isFromStdlib,
28-
} from '@zenstackhq/sdk';
17+
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
2918
import {
3019
AstNode,
3120
copyAstNode,
@@ -151,29 +140,6 @@ function cloneAst<T extends InheritableNode>(
151140
return clone;
152141
}
153142

154-
export function getIdFields(dataModel: DataModel) {
155-
const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) =>
156-
f.attributes.some((attr) => attr.decl.$refText === '@id')
157-
);
158-
if (fieldLevelId) {
159-
return [fieldLevelId];
160-
} else {
161-
// get model level @@id attribute
162-
const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id');
163-
if (modelIdAttr) {
164-
// get fields referenced in the attribute: @@id([field1, field2]])
165-
if (!isArrayExpr(modelIdAttr.args[0]?.value)) {
166-
return [];
167-
}
168-
const argValue = modelIdAttr.args[0].value;
169-
return argValue.items
170-
.filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr))
171-
.map((expr) => expr.target.ref as DataModelField);
172-
}
173-
}
174-
return [];
175-
}
176-
177143
export function isAuthInvocation(node: AstNode) {
178144
return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref);
179145
}
@@ -186,16 +152,6 @@ export function isCheckInvocation(node: AstNode) {
186152
return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref);
187153
}
188154

189-
export function getDataModelFieldReference(expr: Expression): DataModelField | undefined {
190-
if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) {
191-
return expr.target.ref;
192-
} else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) {
193-
return expr.member.ref;
194-
} else {
195-
return undefined;
196-
}
197-
}
198-
199155
export function resolveImportUri(imp: ModelImport): URI | undefined {
200156
if (!imp.path) return undefined; // This will return true if imp.path is undefined, null, or an empty string ("").
201157

0 commit comments

Comments
 (0)