Skip to content

Commit adfe92d

Browse files
committed
refactor(validation): clean up validation functions
1 parent e0040cb commit adfe92d

File tree

8 files changed

+200
-34
lines changed

8 files changed

+200
-34
lines changed

TODO.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
- [x] Array update
5757
- [x] Strict typing for checked/unchecked input
5858
- [x] Upsert
59-
- [ ] Implement with "on conflict"
6059
- [x] Delete
6160
- [x] Aggregation
6261
- [x] Count
@@ -86,7 +85,7 @@
8685
- [ ] Global omit
8786
- [ ] DbNull vs JsonNull
8887
- [ ] Migrate to tsdown
89-
- [ ] @default validation
88+
- [x] @default validation
9089
- [ ] Benchmark
9190
- [x] Plugin
9291
- [x] Post-mutation hooks should be called after transaction is committed
@@ -96,7 +95,7 @@
9695
- [x] ZModel
9796
- [x] Runtime
9897
- [x] Typing
99-
- [ ] Validation
98+
- [x] Validation
10099
- [ ] Access Policy
101100
- [ ] Short-circuit pre-create check for scalar-field only policies
102101
- [x] Inject "on conflict do update"

packages/language/res/stdlib.zmodel

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ enum AttributeTargetField {
4848
BytesField
4949
ModelField
5050
TypeDefField
51+
ListField
5152
}
5253

5354
/**
@@ -486,9 +487,9 @@ attribute @db.ByteA() @@@targetField([BytesField]) @@@prisma
486487
//////////////////////////////////////////////
487488

488489
/**
489-
* Validates length of a string field.
490+
* Validates length of a string field or list field.
490491
*/
491-
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField]) @@@validation
492+
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField, ListField]) @@@validation
492493

493494
/**
494495
* Validates a string field value starts with the given text.
@@ -566,9 +567,9 @@ attribute @lte(_ value: Any, _ message: String?) @@@targetField([IntField, Float
566567
attribute @@validate(_ value: Boolean, _ message: String?, _ path: String[]?) @@@validation
567568

568569
/**
569-
* Validates length of a string field.
570+
* Returns the length of a string field or a list field.
570571
*/
571-
function length(field: String, min: Int, max: Int?): Boolean {
572+
function length(field: Any): Int {
572573
} @@@expressionContext([ValidationRule])
573574

574575

@@ -581,19 +582,19 @@ function regex(field: String, regex: String): Boolean {
581582
/**
582583
* Validates a string field value is a valid email address.
583584
*/
584-
function email(field: String): Boolean {
585+
function isEmail(field: String): Boolean {
585586
} @@@expressionContext([ValidationRule])
586587

587588
/**
588589
* Validates a string field value is a valid ISO datetime.
589590
*/
590-
function datetime(field: String): Boolean {
591+
function isDateTime(field: String): Boolean {
591592
} @@@expressionContext([ValidationRule])
592593

593594
/**
594595
* Validates a string field value is a valid url.
595596
*/
596-
function url(field: String): Boolean {
597+
function isUrl(field: String): Boolean {
597598
} @@@expressionContext([ValidationRule])
598599

599600
//////////////////////////////////////////////

packages/language/src/validators/attribute-application-validator.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,9 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) {
491491
case 'TypeDefField':
492492
allowed = allowed || isTypeDef(targetDecl.type.reference?.ref);
493493
break;
494+
case 'ListField':
495+
allowed = allowed || (!isDataModel(targetDecl.type.reference?.ref) && targetDecl.type.array);
496+
break;
494497
default:
495498
break;
496499
}

packages/language/src/validators/function-invocation-validator.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import {
1313
isDataFieldAttribute,
1414
isDataModel,
1515
isDataModelAttribute,
16+
isStringLiteral,
1617
} from '../generated/ast';
1718
import {
1819
getFunctionExpressionContext,
@@ -183,6 +184,53 @@ export default class FunctionInvocationValidator implements AstValidator<Express
183184
return true;
184185
}
185186

187+
@func('length')
188+
// @ts-expect-error
189+
private _checkLength(expr: InvocationExpr, accept: ValidationAcceptor) {
190+
const msg = 'argument must be a string or list field';
191+
const fieldArg = expr.args[0]!.value;
192+
if (!isDataFieldReference(fieldArg)) {
193+
accept('error', msg, {
194+
node: expr.args[0]!,
195+
});
196+
return;
197+
}
198+
199+
if (isDataModel(fieldArg.$resolvedType?.decl)) {
200+
accept('error', msg, {
201+
node: expr.args[0]!,
202+
});
203+
return;
204+
}
205+
206+
if (!fieldArg.$resolvedType?.array && fieldArg.$resolvedType?.decl !== 'String') {
207+
accept('error', msg, {
208+
node: expr.args[0]!,
209+
});
210+
}
211+
}
212+
213+
@func('regex')
214+
// @ts-expect-error
215+
private _checkRegex(expr: InvocationExpr, accept: ValidationAcceptor) {
216+
const regex = expr.args[1]?.value;
217+
if (!isStringLiteral(regex)) {
218+
accept('error', 'second argument must be a string literal', {
219+
node: expr.args[1]!,
220+
});
221+
return;
222+
}
223+
224+
try {
225+
// try to create a RegExp object to verify the pattern
226+
new RegExp(regex.value);
227+
} catch (e) {
228+
accept('error', 'invalid regular expression: ' + (e as Error).message, {
229+
node: expr.args[1]!,
230+
});
231+
}
232+
}
233+
186234
// TODO: move this to policy plugin
187235
@func('check')
188236
// @ts-expect-error

packages/runtime/src/client/crud/validator/index.ts

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import {
4545
addBigIntValidation,
4646
addCustomValidation,
4747
addDecimalValidation,
48+
addListValidation,
4849
addNumberValidation,
4950
addStringValidation,
5051
} from './utils';
@@ -904,11 +905,12 @@ export class InputValidator<Schema extends SchemaDef> {
904905
let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes);
905906

906907
if (fieldDef.array) {
908+
fieldSchema = addListValidation(fieldSchema.array(), fieldDef.attributes);
907909
fieldSchema = z
908910
.union([
909-
z.array(fieldSchema),
911+
fieldSchema,
910912
z.strictObject({
911-
set: z.array(fieldSchema),
913+
set: fieldSchema,
912914
}),
913915
])
914916
.optional();
@@ -1186,13 +1188,14 @@ export class InputValidator<Schema extends SchemaDef> {
11861188
}
11871189

11881190
if (fieldDef.array) {
1191+
const arraySchema = addListValidation(fieldSchema.array(), fieldDef.attributes);
11891192
fieldSchema = z
11901193
.union([
1191-
fieldSchema.array(),
1194+
arraySchema,
11921195
z
11931196
.object({
1194-
set: z.array(fieldSchema).optional(),
1195-
push: this.orArray(fieldSchema, true).optional(),
1197+
set: arraySchema.optional(),
1198+
push: z.union([fieldSchema, arraySchema]).optional(),
11961199
})
11971200
.refine(
11981201
(v) => Object.keys(v).length === 1,

packages/runtime/src/client/crud/validator/utils.ts

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,32 @@ export function addDecimalValidation(
203203
return result;
204204
}
205205

206+
export function addListValidation(
207+
schema: z.ZodArray<any>,
208+
attributes: AttributeApplication[] | undefined,
209+
): z.ZodSchema {
210+
if (!attributes || attributes.length === 0) {
211+
return schema;
212+
}
213+
214+
let result = schema;
215+
for (const attr of attributes) {
216+
match(attr.name)
217+
.with('@length', () => {
218+
const min = getArgValue<number>(attr.args?.[0]?.value);
219+
if (min !== undefined) {
220+
result = result.min(min);
221+
}
222+
const max = getArgValue<number>(attr.args?.[1]?.value);
223+
if (max !== undefined) {
224+
result = result.max(max);
225+
}
226+
})
227+
.otherwise(() => {});
228+
}
229+
return result;
230+
}
231+
206232
export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeApplication[] | undefined): z.ZodSchema {
207233
const attrs = attributes?.filter((a) => a.name === '@@validate');
208234
if (!attrs || attrs.length === 0) {
@@ -329,17 +355,11 @@ function evalCall(data: any, expr: CallExpression) {
329355
if (fieldArg === undefined || fieldArg === null) {
330356
return false;
331357
}
332-
invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`);
333-
334-
const min = getArgValue<number>(expr.args?.[1]);
335-
const max = getArgValue<number>(expr.args?.[2]);
336-
if (min !== undefined && fieldArg.length < min) {
337-
return false;
338-
}
339-
if (max !== undefined && fieldArg.length > max) {
340-
return false;
341-
}
342-
return true;
358+
invariant(
359+
typeof fieldArg === 'string' || Array.isArray(fieldArg),
360+
`"${f}" first argument must be a string or an list`,
361+
);
362+
return fieldArg.length;
343363
})
344364
.with(P.union('startsWith', 'endsWith', 'contains'), (f) => {
345365
if (fieldArg === undefined || fieldArg === null) {
@@ -370,11 +390,17 @@ function evalCall(data: any, expr: CallExpression) {
370390
invariant(pattern !== undefined, `"${f}" requires a pattern argument`);
371391
return new RegExp(pattern).test(fieldArg);
372392
})
373-
.with(P.union('email', 'url', 'datetime'), (f) => {
393+
.with(P.union('isEmail', 'isUrl', 'isDateTime'), (f) => {
374394
if (fieldArg === undefined || fieldArg === null) {
375395
return false;
376396
}
377-
return z.string()[f]().safeParse(fieldArg).success;
397+
invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`);
398+
const fn = match(f)
399+
.with('isEmail', () => 'email' as const)
400+
.with('isUrl', () => 'url' as const)
401+
.with('isDateTime', () => 'datetime' as const)
402+
.exhaustive();
403+
return z.string()[fn]().safeParse(fieldArg).success;
378404
})
379405
// list functions
380406
.with(P.union('has', 'hasEvery', 'hasSome'), (f) => {

tests/e2e/orm/validation/custom-validation.test.ts

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { createTestClient } from '@zenstackhq/testtools';
1+
import { createTestClient, loadSchemaWithError } from '@zenstackhq/testtools';
22
import { describe, expect, it } from 'vitest';
33

44
describe('Custom validation tests', () => {
@@ -15,25 +15,28 @@ describe('Custom validation tests', () => {
1515
int1 Int?
1616
list1 Int[]
1717
list2 Int[]
18+
list3 Int[]
1819
1920
@@validate(
20-
(str1 == null || length(str1, 8, 10))
21+
(str1 == null || (length(str1) >= 8 && length(str1) <= 10))
2122
&& (int1 == null || (int1 > 1 && int1 < 4)),
2223
'invalid fields')
2324
2425
@@validate(str1 == null || (startsWith(str1, 'a') && endsWith(str1, 'm') && contains(str1, 'b')), 'invalid fields')
2526
2627
@@validate(str2 == null || regex(str2, '^x.*z$'), 'invalid str2')
2728
28-
@@validate(str3 == null || email(str3), 'invalid str3')
29+
@@validate(str3 == null || isEmail(str3), 'invalid str3')
2930
30-
@@validate(str4 == null || url(str4), 'invalid str4')
31+
@@validate(str4 == null || isUrl(str4), 'invalid str4')
3132
32-
@@validate(str5 == null || datetime(str5), 'invalid str5')
33+
@@validate(str5 == null || isDateTime(str5), 'invalid str5')
3334
3435
@@validate(list1 == null || (has(list1, 1) && hasSome(list1, [2, 3]) && hasEvery(list1, [4, 5])), 'invalid list1')
3536
3637
@@validate(list2 == null || isEmpty(list2), 'invalid list2', ['x', 'y'])
38+
39+
@@validate(list3 == null || length(list3) <2 , 'invalid list3')
3740
}
3841
`,
3942
{ provider: 'postgresql' },
@@ -93,6 +96,9 @@ describe('Custom validation tests', () => {
9396
}
9497
expect(thrown).toBe(true);
9598

99+
// validates list length
100+
await expect(_t({ list3: [1, 2] })).toBeRejectedByValidation(['invalid list3']);
101+
96102
// satisfies all
97103
await expect(
98104
_t({
@@ -104,6 +110,7 @@ describe('Custom validation tests', () => {
104110
int1: 2,
105111
list1: [1, 2, 4, 5],
106112
list2: [],
113+
list3: [1],
107114
}),
108115
).toResolveTruthy();
109116
}
@@ -115,7 +122,7 @@ describe('Custom validation tests', () => {
115122
model User {
116123
id Int @id @default(autoincrement())
117124
email String @unique @email
118-
@@validate(length(email, 8))
125+
@@validate(length(email) >= 8)
119126
@@allow('all', true)
120127
}
121128
`,
@@ -170,4 +177,61 @@ describe('Custom validation tests', () => {
170177
}),
171178
).toBeRejectedByValidation();
172179
});
180+
181+
it('checks arg type for validation functions', async () => {
182+
// length() on relation field
183+
await loadSchemaWithError(
184+
`
185+
model Foo {
186+
id Int @id @default(autoincrement())
187+
bars Bar[]
188+
@@validate(length(bars) > 0)
189+
}
190+
191+
model Bar {
192+
id Int @id @default(autoincrement())
193+
foo Foo @relation(fields: [fooId], references: [id])
194+
fooId Int
195+
}
196+
`,
197+
'argument must be a string or list field',
198+
);
199+
200+
// length() on non-string/list field
201+
await loadSchemaWithError(
202+
`
203+
model Foo {
204+
id Int @id @default(autoincrement())
205+
x Int
206+
@@validate(length(x) > 0)
207+
}
208+
`,
209+
'argument must be a string or list field',
210+
);
211+
212+
// invalid regex pattern
213+
await loadSchemaWithError(
214+
`
215+
model Foo {
216+
id Int @id @default(autoincrement())
217+
x String
218+
@@validate(regex(x, '[abc'))
219+
}
220+
`,
221+
'invalid regular expression',
222+
);
223+
224+
// using field as regex pattern
225+
await loadSchemaWithError(
226+
`
227+
model Foo {
228+
id Int @id @default(autoincrement())
229+
x String
230+
y String
231+
@@validate(regex(x, y))
232+
}
233+
`,
234+
'second argument must be a string literal',
235+
);
236+
});
173237
});

0 commit comments

Comments
 (0)