Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
- [x] Array update
- [x] Strict typing for checked/unchecked input
- [x] Upsert
- [ ] Implement with "on conflict"
- [x] Delete
- [x] Aggregation
- [x] Count
Expand Down Expand Up @@ -86,7 +85,7 @@
- [ ] Global omit
- [ ] DbNull vs JsonNull
- [ ] Migrate to tsdown
- [ ] @default validation
- [x] @default validation
- [ ] Benchmark
- [x] Plugin
- [x] Post-mutation hooks should be called after transaction is committed
Expand All @@ -96,7 +95,7 @@
- [x] ZModel
- [x] Runtime
- [x] Typing
- [ ] Validation
- [x] Validation
- [ ] Access Policy
- [ ] Short-circuit pre-create check for scalar-field only policies
- [x] Inject "on conflict do update"
Expand Down
15 changes: 8 additions & 7 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ enum AttributeTargetField {
BytesField
ModelField
TypeDefField
ListField
}

/**
Expand Down Expand Up @@ -486,9 +487,9 @@ attribute @db.ByteA() @@@targetField([BytesField]) @@@prisma
//////////////////////////////////////////////

/**
* Validates length of a string field.
* Validates length of a string field or list field.
*/
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField]) @@@validation
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField, ListField]) @@@validation

/**
* Validates a string field value starts with the given text.
Expand Down Expand Up @@ -566,9 +567,9 @@ attribute @lte(_ value: Any, _ message: String?) @@@targetField([IntField, Float
attribute @@validate(_ value: Boolean, _ message: String?, _ path: String[]?) @@@validation

/**
* Validates length of a string field.
* Returns the length of a string field or a list field.
*/
function length(field: String, min: Int, max: Int?): Boolean {
function length(field: Any): Int {
} @@@expressionContext([ValidationRule])


Expand All @@ -581,19 +582,19 @@ function regex(field: String, regex: String): Boolean {
/**
* Validates a string field value is a valid email address.
*/
function email(field: String): Boolean {
function isEmail(field: String): Boolean {
} @@@expressionContext([ValidationRule])

/**
* Validates a string field value is a valid ISO datetime.
*/
function datetime(field: String): Boolean {
function isDateTime(field: String): Boolean {
} @@@expressionContext([ValidationRule])

/**
* Validates a string field value is a valid url.
*/
function url(field: String): Boolean {
function isUrl(field: String): Boolean {
} @@@expressionContext([ValidationRule])

//////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataField) {
case 'TypeDefField':
allowed = allowed || isTypeDef(targetDecl.type.reference?.ref);
break;
case 'ListField':
allowed = allowed || (!isDataModel(targetDecl.type.reference?.ref) && targetDecl.type.array);
break;
default:
break;
}
Expand Down
48 changes: 48 additions & 0 deletions packages/language/src/validators/function-invocation-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
isDataFieldAttribute,
isDataModel,
isDataModelAttribute,
isStringLiteral,
} from '../generated/ast';
import {
getFunctionExpressionContext,
Expand Down Expand Up @@ -183,6 +184,53 @@ export default class FunctionInvocationValidator implements AstValidator<Express
return true;
}

@func('length')
// @ts-expect-error
private _checkLength(expr: InvocationExpr, accept: ValidationAcceptor) {
const msg = 'argument must be a string or list field';
const fieldArg = expr.args[0]!.value;
if (!isDataFieldReference(fieldArg)) {
accept('error', msg, {
node: expr.args[0]!,
});
return;
}

if (isDataModel(fieldArg.$resolvedType?.decl)) {
accept('error', msg, {
node: expr.args[0]!,
});
return;
}

if (!fieldArg.$resolvedType?.array && fieldArg.$resolvedType?.decl !== 'String') {
accept('error', msg, {
node: expr.args[0]!,
});
}
}

@func('regex')
// @ts-expect-error
private _checkRegex(expr: InvocationExpr, accept: ValidationAcceptor) {
const regex = expr.args[1]?.value;
if (!isStringLiteral(regex)) {
accept('error', 'second argument must be a string literal', {
node: expr.args[1]!,
});
return;
}

try {
// try to create a RegExp object to verify the pattern
new RegExp(regex.value);
} catch (e) {
accept('error', 'invalid regular expression: ' + (e as Error).message, {
node: expr.args[1]!,
});
}
}

// TODO: move this to policy plugin
@func('check')
// @ts-expect-error
Expand Down
13 changes: 8 additions & 5 deletions packages/runtime/src/client/crud/validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import {
addBigIntValidation,
addCustomValidation,
addDecimalValidation,
addListValidation,
addNumberValidation,
addStringValidation,
} from './utils';
Expand Down Expand Up @@ -904,11 +905,12 @@ export class InputValidator<Schema extends SchemaDef> {
let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes);

if (fieldDef.array) {
fieldSchema = addListValidation(fieldSchema.array(), fieldDef.attributes);
fieldSchema = z
.union([
z.array(fieldSchema),
fieldSchema,
z.strictObject({
set: z.array(fieldSchema),
set: fieldSchema,
}),
])
.optional();
Expand Down Expand Up @@ -1186,13 +1188,14 @@ export class InputValidator<Schema extends SchemaDef> {
}

if (fieldDef.array) {
const arraySchema = addListValidation(fieldSchema.array(), fieldDef.attributes);
fieldSchema = z
.union([
fieldSchema.array(),
arraySchema,
z
.object({
set: z.array(fieldSchema).optional(),
push: this.orArray(fieldSchema, true).optional(),
set: arraySchema.optional(),
push: z.union([fieldSchema, arraySchema]).optional(),
})
.refine(
(v) => Object.keys(v).length === 1,
Expand Down
52 changes: 39 additions & 13 deletions packages/runtime/src/client/crud/validator/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,32 @@ export function addDecimalValidation(
return result;
}

export function addListValidation(
schema: z.ZodArray<any>,
attributes: AttributeApplication[] | undefined,
): z.ZodSchema {
if (!attributes || attributes.length === 0) {
return schema;
}

let result = schema;
for (const attr of attributes) {
match(attr.name)
.with('@length', () => {
const min = getArgValue<number>(attr.args?.[0]?.value);
if (min !== undefined) {
result = result.min(min);
}
const max = getArgValue<number>(attr.args?.[1]?.value);
if (max !== undefined) {
result = result.max(max);
}
})
.otherwise(() => {});
}
return result;
}

export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeApplication[] | undefined): z.ZodSchema {
const attrs = attributes?.filter((a) => a.name === '@@validate');
if (!attrs || attrs.length === 0) {
Expand Down Expand Up @@ -329,17 +355,11 @@ function evalCall(data: any, expr: CallExpression) {
if (fieldArg === undefined || fieldArg === null) {
return false;
}
invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`);

const min = getArgValue<number>(expr.args?.[1]);
const max = getArgValue<number>(expr.args?.[2]);
if (min !== undefined && fieldArg.length < min) {
return false;
}
if (max !== undefined && fieldArg.length > max) {
return false;
}
return true;
invariant(
typeof fieldArg === 'string' || Array.isArray(fieldArg),
`"${f}" first argument must be a string or an list`,
);
return fieldArg.length;
})
.with(P.union('startsWith', 'endsWith', 'contains'), (f) => {
if (fieldArg === undefined || fieldArg === null) {
Expand Down Expand Up @@ -370,11 +390,17 @@ function evalCall(data: any, expr: CallExpression) {
invariant(pattern !== undefined, `"${f}" requires a pattern argument`);
return new RegExp(pattern).test(fieldArg);
})
.with(P.union('email', 'url', 'datetime'), (f) => {
.with(P.union('isEmail', 'isUrl', 'isDateTime'), (f) => {
if (fieldArg === undefined || fieldArg === null) {
return false;
}
return z.string()[f]().safeParse(fieldArg).success;
invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`);
const fn = match(f)
.with('isEmail', () => 'email' as const)
.with('isUrl', () => 'url' as const)
.with('isDateTime', () => 'datetime' as const)
.exhaustive();
return z.string()[fn]().safeParse(fieldArg).success;
})
// list functions
.with(P.union('has', 'hasEvery', 'hasSome'), (f) => {
Expand Down
76 changes: 70 additions & 6 deletions tests/e2e/orm/validation/custom-validation.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createTestClient } from '@zenstackhq/testtools';
import { createTestClient, loadSchemaWithError } from '@zenstackhq/testtools';
import { describe, expect, it } from 'vitest';

describe('Custom validation tests', () => {
Expand All @@ -15,25 +15,28 @@ describe('Custom validation tests', () => {
int1 Int?
list1 Int[]
list2 Int[]
list3 Int[]

@@validate(
(str1 == null || length(str1, 8, 10))
(str1 == null || (length(str1) >= 8 && length(str1) <= 10))
&& (int1 == null || (int1 > 1 && int1 < 4)),
'invalid fields')

@@validate(str1 == null || (startsWith(str1, 'a') && endsWith(str1, 'm') && contains(str1, 'b')), 'invalid fields')

@@validate(str2 == null || regex(str2, '^x.*z$'), 'invalid str2')

@@validate(str3 == null || email(str3), 'invalid str3')
@@validate(str3 == null || isEmail(str3), 'invalid str3')

@@validate(str4 == null || url(str4), 'invalid str4')
@@validate(str4 == null || isUrl(str4), 'invalid str4')

@@validate(str5 == null || datetime(str5), 'invalid str5')
@@validate(str5 == null || isDateTime(str5), 'invalid str5')

@@validate(list1 == null || (has(list1, 1) && hasSome(list1, [2, 3]) && hasEvery(list1, [4, 5])), 'invalid list1')

@@validate(list2 == null || isEmpty(list2), 'invalid list2', ['x', 'y'])

@@validate(list3 == null || length(list3) <2 , 'invalid list3')
}
`,
{ provider: 'postgresql' },
Expand Down Expand Up @@ -93,6 +96,9 @@ describe('Custom validation tests', () => {
}
expect(thrown).toBe(true);

// validates list length
await expect(_t({ list3: [1, 2] })).toBeRejectedByValidation(['invalid list3']);

// satisfies all
await expect(
_t({
Expand All @@ -104,6 +110,7 @@ describe('Custom validation tests', () => {
int1: 2,
list1: [1, 2, 4, 5],
list2: [],
list3: [1],
}),
).toResolveTruthy();
}
Expand All @@ -115,7 +122,7 @@ describe('Custom validation tests', () => {
model User {
id Int @id @default(autoincrement())
email String @unique @email
@@validate(length(email, 8))
@@validate(length(email) >= 8)
@@allow('all', true)
}
`,
Expand Down Expand Up @@ -170,4 +177,61 @@ describe('Custom validation tests', () => {
}),
).toBeRejectedByValidation();
});

it('checks arg type for validation functions', async () => {
// length() on relation field
await loadSchemaWithError(
`
model Foo {
id Int @id @default(autoincrement())
bars Bar[]
@@validate(length(bars) > 0)
}

model Bar {
id Int @id @default(autoincrement())
foo Foo @relation(fields: [fooId], references: [id])
fooId Int
}
`,
'argument must be a string or list field',
);

// length() on non-string/list field
await loadSchemaWithError(
`
model Foo {
id Int @id @default(autoincrement())
x Int
@@validate(length(x) > 0)
}
`,
'argument must be a string or list field',
);

// invalid regex pattern
await loadSchemaWithError(
`
model Foo {
id Int @id @default(autoincrement())
x String
@@validate(regex(x, '[abc'))
}
`,
'invalid regular expression',
);

// using field as regex pattern
await loadSchemaWithError(
`
model Foo {
id Int @id @default(autoincrement())
x String
y String
@@validate(regex(x, y))
}
`,
'second argument must be a string literal',
);
});
});
Loading