diff --git a/TODO.md b/TODO.md index cd7e8eb8..49ae0537 100644 --- a/TODO.md +++ b/TODO.md @@ -56,7 +56,6 @@ - [x] Array update - [x] Strict typing for checked/unchecked input - [x] Upsert - - [ ] Implement with "on conflict" - [x] Delete - [x] Aggregation - [x] Count @@ -86,7 +85,7 @@ - [ ] Global omit - [ ] DbNull vs JsonNull - [ ] Migrate to tsdown - - [ ] @default validation + - [x] @default validation - [ ] Benchmark - [x] Plugin - [x] Post-mutation hooks should be called after transaction is committed @@ -96,7 +95,7 @@ - [x] ZModel - [x] Runtime - [x] Typing -- [ ] Validation +- [x] Validation - [ ] Access Policy - [ ] Short-circuit pre-create check for scalar-field only policies - [x] Inject "on conflict do update" diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index c49f2606..7ac57ba3 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -48,6 +48,7 @@ enum AttributeTargetField { BytesField ModelField TypeDefField + ListField } /** @@ -486,9 +487,9 @@ attribute @db.ByteA() @@@targetField([BytesField]) @@@prisma ////////////////////////////////////////////// /** - * Validates length of a string field. + * Validates length of a string field or list field. */ -attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField]) @@@validation +attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField, ListField]) @@@validation /** * Validates a string field value starts with the given text. @@ -566,9 +567,9 @@ attribute @lte(_ value: Any, _ message: String?) @@@targetField([IntField, Float attribute @@validate(_ value: Boolean, _ message: String?, _ path: String[]?) @@@validation /** - * Validates length of a string field. + * Returns the length of a string field or a list field. */ -function length(field: String, min: Int, max: Int?): Boolean { +function length(field: Any): Int { } @@@expressionContext([ValidationRule]) @@ -581,19 +582,19 @@ function regex(field: String, regex: String): Boolean { /** * Validates a string field value is a valid email address. */ -function email(field: String): Boolean { +function isEmail(field: String): Boolean { } @@@expressionContext([ValidationRule]) /** * Validates a string field value is a valid ISO datetime. */ -function datetime(field: String): Boolean { +function isDateTime(field: String): Boolean { } @@@expressionContext([ValidationRule]) /** * Validates a string field value is a valid url. */ -function url(field: String): Boolean { +function isUrl(field: String): Boolean { } @@@expressionContext([ValidationRule]) ////////////////////////////////////////////// diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index d1319cf0..981eb814 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -491,6 +491,9 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) { case 'TypeDefField': allowed = allowed || isTypeDef(targetDecl.type.reference?.ref); break; + case 'ListField': + allowed = allowed || (!isDataModel(targetDecl.type.reference?.ref) && targetDecl.type.array); + break; default: break; } diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index ae759904..a740b86e 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -13,6 +13,7 @@ import { isDataFieldAttribute, isDataModel, isDataModelAttribute, + isStringLiteral, } from '../generated/ast'; import { getFunctionExpressionContext, @@ -183,6 +184,53 @@ export default class FunctionInvocationValidator implements AstValidator { let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes); if (fieldDef.array) { + fieldSchema = addListValidation(fieldSchema.array(), fieldDef.attributes); fieldSchema = z .union([ - z.array(fieldSchema), + fieldSchema, z.strictObject({ - set: z.array(fieldSchema), + set: fieldSchema, }), ]) .optional(); @@ -1165,14 +1167,14 @@ export class InputValidator { uncheckedVariantFields[field] = fieldSchema; } } else { - let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes).optional(); + let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes); if (this.isNumericField(fieldDef)) { fieldSchema = z.union([ fieldSchema, z .object({ - set: this.nullableIf(z.number().optional(), !!fieldDef.optional), + set: this.nullableIf(z.number().optional(), !!fieldDef.optional).optional(), increment: z.number().optional(), decrement: z.number().optional(), multiply: z.number().optional(), @@ -1186,26 +1188,25 @@ export class InputValidator { } if (fieldDef.array) { - fieldSchema = z - .union([ - fieldSchema.array(), - z - .object({ - set: z.array(fieldSchema).optional(), - push: this.orArray(fieldSchema, true).optional(), - }) - .refine( - (v) => Object.keys(v).length === 1, - 'Only one of "set", "push" can be provided', - ), - ]) - .optional(); + const arraySchema = addListValidation(fieldSchema.array(), fieldDef.attributes); + fieldSchema = z.union([ + arraySchema, + z + .object({ + set: arraySchema.optional(), + push: z.union([fieldSchema, fieldSchema.array()]).optional(), + }) + .refine((v) => Object.keys(v).length === 1, 'Only one of "set", "push" can be provided'), + ]); } if (fieldDef.optional) { fieldSchema = fieldSchema.nullable(); } + // all fields are optional in update + fieldSchema = fieldSchema.optional(); + uncheckedVariantFields[field] = fieldSchema; if (!fieldDef.foreignKeyFor) { // non-fk field diff --git a/packages/runtime/src/client/crud/validator/utils.ts b/packages/runtime/src/client/crud/validator/utils.ts index 1fdecb25..980fff50 100644 --- a/packages/runtime/src/client/crud/validator/utils.ts +++ b/packages/runtime/src/client/crud/validator/utils.ts @@ -203,6 +203,32 @@ export function addDecimalValidation( return result; } +export function addListValidation( + schema: z.ZodArray, + attributes: AttributeApplication[] | undefined, +): z.ZodSchema { + if (!attributes || attributes.length === 0) { + return schema; + } + + let result = schema; + for (const attr of attributes) { + match(attr.name) + .with('@length', () => { + const min = getArgValue(attr.args?.[0]?.value); + if (min !== undefined) { + result = result.min(min); + } + const max = getArgValue(attr.args?.[1]?.value); + if (max !== undefined) { + result = result.max(max); + } + }) + .otherwise(() => {}); + } + return result; +} + export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeApplication[] | undefined): z.ZodSchema { const attrs = attributes?.filter((a) => a.name === '@@validate'); if (!attrs || attrs.length === 0) { @@ -329,17 +355,11 @@ function evalCall(data: any, expr: CallExpression) { if (fieldArg === undefined || fieldArg === null) { return false; } - invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); - - const min = getArgValue(expr.args?.[1]); - const max = getArgValue(expr.args?.[2]); - if (min !== undefined && fieldArg.length < min) { - return false; - } - if (max !== undefined && fieldArg.length > max) { - return false; - } - return true; + invariant( + typeof fieldArg === 'string' || Array.isArray(fieldArg), + `"${f}" first argument must be a string or a list`, + ); + return fieldArg.length; }) .with(P.union('startsWith', 'endsWith', 'contains'), (f) => { if (fieldArg === undefined || fieldArg === null) { @@ -370,11 +390,17 @@ function evalCall(data: any, expr: CallExpression) { invariant(pattern !== undefined, `"${f}" requires a pattern argument`); return new RegExp(pattern).test(fieldArg); }) - .with(P.union('email', 'url', 'datetime'), (f) => { + .with(P.union('isEmail', 'isUrl', 'isDateTime'), (f) => { if (fieldArg === undefined || fieldArg === null) { return false; } - return z.string()[f]().safeParse(fieldArg).success; + invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); + const fn = match(f) + .with('isEmail', () => 'email' as const) + .with('isUrl', () => 'url' as const) + .with('isDateTime', () => 'datetime' as const) + .exhaustive(); + return z.string()[fn]().safeParse(fieldArg).success; }) // list functions .with(P.union('has', 'hasEvery', 'hasSome'), (f) => { diff --git a/tests/e2e/orm/validation/custom-validation.test.ts b/tests/e2e/orm/validation/custom-validation.test.ts index edd0c00e..955121b0 100644 --- a/tests/e2e/orm/validation/custom-validation.test.ts +++ b/tests/e2e/orm/validation/custom-validation.test.ts @@ -1,4 +1,4 @@ -import { createTestClient } from '@zenstackhq/testtools'; +import { createTestClient, loadSchemaWithError } from '@zenstackhq/testtools'; import { describe, expect, it } from 'vitest'; describe('Custom validation tests', () => { @@ -15,9 +15,10 @@ describe('Custom validation tests', () => { int1 Int? list1 Int[] list2 Int[] + list3 Int[] @@validate( - (str1 == null || length(str1, 8, 10)) + (str1 == null || (length(str1) >= 8 && length(str1) <= 10)) && (int1 == null || (int1 > 1 && int1 < 4)), 'invalid fields') @@ -25,15 +26,17 @@ describe('Custom validation tests', () => { @@validate(str2 == null || regex(str2, '^x.*z$'), 'invalid str2') - @@validate(str3 == null || email(str3), 'invalid str3') + @@validate(str3 == null || isEmail(str3), 'invalid str3') - @@validate(str4 == null || url(str4), 'invalid str4') + @@validate(str4 == null || isUrl(str4), 'invalid str4') - @@validate(str5 == null || datetime(str5), 'invalid str5') + @@validate(str5 == null || isDateTime(str5), 'invalid str5') @@validate(list1 == null || (has(list1, 1) && hasSome(list1, [2, 3]) && hasEvery(list1, [4, 5])), 'invalid list1') @@validate(list2 == null || isEmpty(list2), 'invalid list2', ['x', 'y']) + + @@validate(list3 == null || length(list3) <2 , 'invalid list3') } `, { provider: 'postgresql' }, @@ -93,6 +96,9 @@ describe('Custom validation tests', () => { } expect(thrown).toBe(true); + // validates list length + await expect(_t({ list3: [1, 2] })).toBeRejectedByValidation(['invalid list3']); + // satisfies all await expect( _t({ @@ -104,6 +110,7 @@ describe('Custom validation tests', () => { int1: 2, list1: [1, 2, 4, 5], list2: [], + list3: [1], }), ).toResolveTruthy(); } @@ -115,7 +122,7 @@ describe('Custom validation tests', () => { model User { id Int @id @default(autoincrement()) email String @unique @email - @@validate(length(email, 8)) + @@validate(length(email) >= 8) @@allow('all', true) } `, @@ -170,4 +177,61 @@ describe('Custom validation tests', () => { }), ).toBeRejectedByValidation(); }); + + it('checks arg type for validation functions', async () => { + // length() on relation field + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + bars Bar[] + @@validate(length(bars) > 0) + } + + model Bar { + id Int @id @default(autoincrement()) + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int + } + `, + 'argument must be a string or list field', + ); + + // length() on non-string/list field + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + x Int + @@validate(length(x) > 0) + } + `, + 'argument must be a string or list field', + ); + + // invalid regex pattern + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + x String + @@validate(regex(x, '[abc')) + } + `, + 'invalid regular expression', + ); + + // using field as regex pattern + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + x String + y String + @@validate(regex(x, y)) + } + `, + 'second argument must be a string literal', + ); + }); }); diff --git a/tests/e2e/orm/validation/toplevel.test.ts b/tests/e2e/orm/validation/toplevel.test.ts index a7d76475..f4204b62 100644 --- a/tests/e2e/orm/validation/toplevel.test.ts +++ b/tests/e2e/orm/validation/toplevel.test.ts @@ -171,6 +171,28 @@ describe('Toplevel field validation tests', () => { await expect(db.foo.create({ data: { int1: '3.3', int2: new Decimal(3.9) } })).toResolveTruthy(); }); + it('works with list fields', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + list1 Int[] @length(2, 4) + } + `, + { provider: 'postgresql' }, + ); + + await expect(db.foo.create({ data: { id: 1, list1: [1] } })).toBeRejectedByValidation(); + + await expect(db.foo.create({ data: { id: 1, list1: [1, 2, 3, 4, 5] } })).toBeRejectedByValidation(); + + await expect(db.foo.create({ data: { id: 1, list1: [1, 2, 3] } })).toResolveTruthy(); + + await expect(db.foo.update({ where: { id: 1 }, data: { list1: [1] } })).toBeRejectedByValidation(); + await expect(db.foo.update({ where: { id: 1 }, data: { list1: [1, 2, 3, 4, 5] } })).toBeRejectedByValidation(); + await expect(db.foo.update({ where: { id: 1 }, data: { list1: [2, 3, 4] } })).toResolveTruthy(); + }); + it('rejects accessing relation fields', async () => { await loadSchemaWithError( `