diff --git a/packages/eslint-config/base.js b/packages/config/eslint-config/base.js similarity index 100% rename from packages/eslint-config/base.js rename to packages/config/eslint-config/base.js diff --git a/packages/eslint-config/package.json b/packages/config/eslint-config/package.json similarity index 100% rename from packages/eslint-config/package.json rename to packages/config/eslint-config/package.json diff --git a/packages/typescript-config/base.json b/packages/config/typescript-config/base.json similarity index 100% rename from packages/typescript-config/base.json rename to packages/config/typescript-config/base.json diff --git a/packages/typescript-config/package.json b/packages/config/typescript-config/package.json similarity index 100% rename from packages/typescript-config/package.json rename to packages/config/typescript-config/package.json diff --git a/packages/vitest-config/base.config.js b/packages/config/vitest-config/base.config.js similarity index 100% rename from packages/vitest-config/base.config.js rename to packages/config/vitest-config/base.config.js diff --git a/packages/vitest-config/package.json b/packages/config/vitest-config/package.json similarity index 100% rename from packages/vitest-config/package.json rename to packages/config/vitest-config/package.json diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 52d34ae4..85dc8e91 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -543,22 +543,22 @@ attribute @upper() @@@targetField([StringField]) @@@validation /** * Validates a number field is greater than the given value. */ -attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @gt(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is greater than or equal to the given value. */ -attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @gte(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is less than the given value. */ -attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @lt(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is less than or equal to the given value. */ -attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @lte(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates the entity with a complex condition. diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 002f478c..2374bc6e 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -1,4 +1,4 @@ -import type { Decimal } from 'decimal.js'; +import type Decimal from 'decimal.js'; import { type GetModels, type IsDelegateModel, type ProcedureDef, type SchemaDef } from '../schema'; import type { AuthType } from '../schema/auth'; import type { OrUndefinedIf, Simplify, UnwrapTuplePromises } from '../utils/type-utils'; diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 34924952..65bdbbc2 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -131,15 +131,10 @@ export abstract class BaseOperationHandler { model: GetModels, filter: any, ): Promise { - const idFields = requireIdFields(this.schema, model); - const _filter = flattenCompoundUniqueFilters(this.schema, model, filter); - const query = kysely - .selectFrom(model) - .where((eb) => eb.and(_filter)) - .select(idFields.map((f) => kysely.dynamic.ref(f))) - .limit(1) - .modifyEnd(this.makeContextComment({ model, operation: 'read' })); - return this.executeQueryTakeFirst(kysely, query, 'exists'); + return this.readUnique(kysely, model, { + where: filter, + select: this.makeIdSelect(model), + }); } protected async read( diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator/index.ts similarity index 95% rename from packages/runtime/src/client/crud/validator.ts rename to packages/runtime/src/client/crud/validator/index.ts index beb31faf..90cc67e0 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator/index.ts @@ -4,17 +4,18 @@ import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; import { z, ZodSchema, ZodType } from 'zod'; import { + type AttributeApplication, type BuiltinType, type EnumDef, type FieldDef, type GetModels, type ModelDef, type SchemaDef, -} from '../../schema'; -import { enumerate } from '../../utils/enumerate'; -import { extractFields } from '../../utils/object-utils'; -import { formatError } from '../../utils/zod-utils'; -import { AGGREGATE_OPERATORS, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../constants'; +} from '../../../schema'; +import { enumerate } from '../../../utils/enumerate'; +import { extractFields } from '../../../utils/object-utils'; +import { formatError } from '../../../utils/zod-utils'; +import { AGGREGATE_OPERATORS, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../../constants'; import { type AggregateArgs, type CountArgs, @@ -29,8 +30,8 @@ import { type UpdateManyAndReturnArgs, type UpdateManyArgs, type UpsertArgs, -} from '../crud-types'; -import { InputValidationError, InternalError } from '../errors'; +} from '../../crud-types'; +import { InputValidationError, InternalError } from '../../errors'; import { fieldHasDefaultValue, getDiscriminatorField, @@ -38,7 +39,14 @@ import { getUniqueFields, requireField, requireModel, -} from '../query-utils'; +} from '../../query-utils'; +import { + addBigIntValidation, + addCustomValidation, + addDecimalValidation, + addNumberValidation, + addStringValidation, +} from './utils'; type GetSchemaFunc = (model: GetModels, options: Options) => ZodType; @@ -191,11 +199,14 @@ export class InputValidator { schema = getSchema(model, options); this.schemaCache.set(cacheKey!, schema); } - const { error } = schema.safeParse(args); + const { error, data } = schema.safeParse(args); if (error) { - throw new InputValidationError(`Invalid ${operation} args: ${formatError(error)}`, error); + throw new InputValidationError( + `Invalid ${operation} args for model "${model}": ${formatError(error)}`, + error, + ); } - return args as T; + return data as T; } // #region Find @@ -235,17 +246,28 @@ export class InputValidator { return result; } - private makePrimitiveSchema(type: string) { + private makePrimitiveSchema(type: string, attributes?: AttributeApplication[]) { if (this.schema.typeDefs && type in this.schema.typeDefs) { return this.makeTypeDefSchema(type); } else { return match(type) - .with('String', () => z.string()) - .with('Int', () => z.number().int()) - .with('Float', () => z.number()) + .with('String', () => addStringValidation(z.string(), attributes)) + .with('Int', () => addNumberValidation(z.number().int(), attributes)) + .with('Float', () => addNumberValidation(z.number(), attributes)) .with('Boolean', () => z.boolean()) - .with('BigInt', () => z.union([z.number().int(), z.bigint()])) - .with('Decimal', () => z.union([z.number(), z.instanceof(Decimal), z.string()])) + .with('BigInt', () => + z.union([ + addNumberValidation(z.number().int(), attributes), + addBigIntValidation(z.bigint(), attributes), + ]), + ) + .with('Decimal', () => + z.union([ + addNumberValidation(z.number(), attributes), + addDecimalValidation(z.instanceof(Decimal), attributes), + addDecimalValidation(z.string(), attributes), + ]), + ) .with('DateTime', () => z.union([z.date(), z.string().datetime()])) .with('Bytes', () => z.instanceof(Uint8Array)) .otherwise(() => z.unknown()); @@ -860,7 +882,7 @@ export class InputValidator { uncheckedVariantFields[field] = fieldSchema; } } else { - let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type); + let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes); if (fieldDef.array) { fieldSchema = z @@ -889,14 +911,17 @@ export class InputValidator { } }); + const uncheckedCreateSchema = addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes); + const checkedCreateSchema = addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes); + if (!hasRelation) { - return this.orArray(z.strictObject(uncheckedVariantFields), canBeArray); + return this.orArray(uncheckedCreateSchema, canBeArray); } else { return z.union([ - z.strictObject(uncheckedVariantFields), - z.strictObject(checkedVariantFields), - ...(canBeArray ? [z.array(z.strictObject(uncheckedVariantFields))] : []), - ...(canBeArray ? [z.array(z.strictObject(checkedVariantFields))] : []), + uncheckedCreateSchema, + checkedCreateSchema, + ...(canBeArray ? [z.array(uncheckedCreateSchema)] : []), + ...(canBeArray ? [z.array(checkedCreateSchema)] : []), ]); } } @@ -1112,7 +1137,7 @@ export class InputValidator { uncheckedVariantFields[field] = fieldSchema; } } else { - let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type).optional(); + let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes).optional(); if (this.isNumericField(fieldDef)) { fieldSchema = z.union([ @@ -1161,10 +1186,12 @@ export class InputValidator { } }); + const uncheckedUpdateSchema = addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes); + const checkedUpdateSchema = addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes); if (!hasRelation) { - return z.strictObject(uncheckedVariantFields); + return uncheckedUpdateSchema; } else { - return z.union([z.strictObject(uncheckedVariantFields), z.strictObject(checkedVariantFields)]); + return z.union([uncheckedUpdateSchema, checkedUpdateSchema]); } } diff --git a/packages/runtime/src/client/crud/validator/utils.ts b/packages/runtime/src/client/crud/validator/utils.ts new file mode 100644 index 00000000..6b0a17d5 --- /dev/null +++ b/packages/runtime/src/client/crud/validator/utils.ts @@ -0,0 +1,412 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import type { + AttributeApplication, + BinaryExpression, + CallExpression, + Expression, + FieldExpression, + MemberExpression, + UnaryExpression, +} from '@zenstackhq/sdk/schema'; +import Decimal from 'decimal.js'; +import { match, P } from 'ts-pattern'; +import { z } from 'zod'; +import { ExpressionUtils } from '../../../schema'; +import { QueryError } from '../../errors'; + +function getArgValue(expr: Expression | undefined): T | undefined { + if (!expr || !ExpressionUtils.isLiteral(expr)) { + return undefined; + } + return expr.value as T; +} + +export function addStringValidation(schema: z.ZodString, attributes: AttributeApplication[] | undefined): z.ZodSchema { + if (!attributes || attributes.length === 0) { + return schema; + } + + let result = schema; + for (const attr of attributes) { + match(attr.name) + .with('@length', () => { + const min = getArgValue(attr.args?.[0]?.value); + if (min !== undefined) { + result = result.min(min); + } + const max = getArgValue(attr.args?.[1]?.value); + if (max !== undefined) { + result = result.max(max); + } + }) + .with('@startsWith', () => { + const value = getArgValue(attr.args?.[0]?.value); + if (value !== undefined) { + result = result.startsWith(value); + } + }) + .with('@endsWith', () => { + const value = getArgValue(attr.args?.[0]?.value); + if (value !== undefined) { + result = result.endsWith(value); + } + }) + .with('@contains', () => { + const value = getArgValue(attr.args?.[0]?.value); + if (value !== undefined) { + result = result.includes(value); + } + }) + .with('@regex', () => { + const pattern = getArgValue(attr.args?.[0]?.value); + if (pattern !== undefined) { + result = result.regex(new RegExp(pattern)); + } + }) + .with('@email', () => { + result = result.email(); + }) + .with('@datetime', () => { + result = result.datetime(); + }) + .with('@url', () => { + result = result.url(); + }) + .with('@trim', () => { + result = result.trim(); + }) + .with('@lower', () => { + result = result.toLowerCase(); + }) + .with('@upper', () => { + result = result.toUpperCase(); + }); + } + return result; +} + +export function addNumberValidation(schema: z.ZodNumber, attributes: AttributeApplication[] | undefined): z.ZodSchema { + if (!attributes || attributes.length === 0) { + return schema; + } + + let result = schema; + for (const attr of attributes) { + const val = getArgValue(attr.args?.[0]?.value); + if (val === undefined) { + continue; + } + match(attr.name) + .with('@gt', () => { + result = result.gt(val); + }) + .with('@gte', () => { + result = result.gte(val); + }) + .with('@lt', () => { + result = result.lt(val); + }) + .with('@lte', () => { + result = result.lte(val); + }); + } + return result; +} + +export function addBigIntValidation(schema: z.ZodBigInt, attributes: AttributeApplication[] | undefined): z.ZodSchema { + if (!attributes || attributes.length === 0) { + return schema; + } + + let result = schema; + for (const attr of attributes) { + const val = getArgValue(attr.args?.[0]?.value); + if (val === undefined) { + continue; + } + const bigIntVal = BigInt(val); + match(attr.name) + .with('@gt', () => { + result = result.gt(bigIntVal); + }) + .with('@gte', () => { + result = result.gte(bigIntVal); + }) + .with('@lt', () => { + result = result.lt(bigIntVal); + }) + .with('@lte', () => { + result = result.lte(bigIntVal); + }); + } + return result; +} + +export function addDecimalValidation( + schema: z.ZodType | z.ZodString, + attributes: AttributeApplication[] | undefined, +): z.ZodSchema { + let result: z.ZodSchema = schema; + + // parse string to Decimal + if (schema instanceof z.ZodString) { + result = schema + .superRefine((v, ctx) => { + try { + new Decimal(v); + } catch (err) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: `Invalid decimal: ${err}`, + }); + } + }) + .transform((val) => new Decimal(val)); + } + + // add validations + + function refine(schema: z.ZodSchema, op: 'gt' | 'gte' | 'lt' | 'lte', value: number) { + return schema.superRefine((v, ctx) => { + const base = z.number(); + const { error } = base[op](value).safeParse((v as Decimal).toNumber()); + error?.errors.forEach((e) => { + ctx.addIssue(e); + }); + }); + } + + if (attributes) { + for (const attr of attributes) { + const val = getArgValue(attr.args?.[0]?.value); + if (val === undefined) { + continue; + } + + match(attr.name) + .with('@gt', () => { + result = refine(result, 'gt', val); + }) + .with('@gte', () => { + result = refine(result, 'gte', val); + }) + .with('@lt', () => { + result = refine(result, 'lt', val); + }) + .with('@lte', () => { + result = refine(result, 'lte', val); + }); + } + } + + return result; +} + +export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeApplication[] | undefined): z.ZodSchema { + const attrs = attributes?.filter((a) => a.name === '@@validate'); + if (!attrs || attrs.length === 0) { + return schema; + } + + let result = schema; + for (const attr of attrs) { + const expr = attr.args?.[0]?.value; + if (!expr) { + continue; + } + const message = getArgValue(attr.args?.[1]?.value); + const pathExpr = attr.args?.[2]?.value; + let path: string[] | undefined = undefined; + if (pathExpr && ExpressionUtils.isArray(pathExpr)) { + path = pathExpr.items.map((e) => ExpressionUtils.getLiteralValue(e) as string); + } + result = applyValidation(result, expr, message, path); + } + return result; +} + +function applyValidation( + schema: z.ZodSchema, + expr: Expression, + message: string | undefined, + path: string[] | undefined, +) { + const options: z.CustomErrorParams = {}; + if (message) { + options.message = message; + } + if (path) { + options.path = path; + } + return schema.refine((data) => Boolean(evalExpression(data, expr)), options); +} + +function evalExpression(data: any, expr: Expression): unknown { + return match(expr) + .with({ kind: 'literal' }, (e) => e.value) + .with({ kind: 'array' }, (e) => e.items.map((item) => evalExpression(data, item))) + .with({ kind: 'field' }, (e) => evalField(data, e)) + .with({ kind: 'member' }, (e) => evalMember(data, e)) + .with({ kind: 'unary' }, (e) => evalUnary(data, e)) + .with({ kind: 'binary' }, (e) => evalBinary(data, e)) + .with({ kind: 'call' }, (e) => evalCall(data, e)) + .with({ kind: 'this' }, () => data ?? null) + .with({ kind: 'null' }, () => null) + .exhaustive(); +} + +function evalField(data: any, e: FieldExpression) { + return data?.[e.field] ?? null; +} + +function evalUnary(data: any, expr: UnaryExpression) { + const operand = evalExpression(data, expr.operand); + switch (expr.op) { + case '!': + return !operand; + default: + throw new Error(`Unsupported unary operator: ${expr.op}`); + } +} + +function evalBinary(data: any, expr: BinaryExpression) { + const left = evalExpression(data, expr.left); + const right = evalExpression(data, expr.right); + return match(expr.op) + .with('&&', () => Boolean(left) && Boolean(right)) + .with('||', () => Boolean(left) || Boolean(right)) + .with('==', () => left == right) + .with('!=', () => left != right) + .with('<', () => (left as any) < (right as any)) + .with('<=', () => (left as any) <= (right as any)) + .with('>', () => (left as any) > (right as any)) + .with('>=', () => (left as any) >= (right as any)) + .with('?', () => { + if (!Array.isArray(left)) { + return false; + } + return left.some((item) => item === right); + }) + .with('!', () => { + if (!Array.isArray(left)) { + return false; + } + return left.every((item) => item === right); + }) + .with('^', () => { + if (!Array.isArray(left)) { + return false; + } + return !left.some((item) => item === right); + }) + .with('in', () => { + if (!Array.isArray(right)) { + return false; + } + return right.includes(left); + }) + .exhaustive(); +} + +function evalMember(data: any, expr: MemberExpression) { + let result: any = evalExpression(data, expr.receiver); + for (const member of expr.members) { + if (!result || typeof result !== 'object') { + return undefined; + } + result = result[member]; + } + return result ?? null; +} + +function evalCall(data: any, expr: CallExpression) { + const fieldArg = expr.args?.[0] ? evalExpression(data, expr.args[0]) : undefined; + return ( + match(expr.function) + // string functions + .with('length', (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); + + const min = getArgValue(expr.args?.[1]); + const max = getArgValue(expr.args?.[2]); + if (min !== undefined && fieldArg.length < min) { + return false; + } + if (max !== undefined && fieldArg.length > max) { + return false; + } + return true; + }) + .with(P.union('startsWith', 'endsWith', 'contains'), (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); + invariant(expr.args?.[1], `"${f}" requires a search argument`); + + const search = getArgValue(expr.args?.[1])!; + const caseInsensitive = getArgValue(expr.args?.[2]) ?? false; + + const matcher = (x: string, y: string) => + match(f) + .with('startsWith', () => x.startsWith(y)) + .with('endsWith', () => x.endsWith(y)) + .with('contains', () => x.includes(y)) + .exhaustive(); + return caseInsensitive + ? matcher(fieldArg.toLowerCase(), search.toLowerCase()) + : matcher(fieldArg, search); + }) + .with('regex', (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); + const pattern = getArgValue(expr.args?.[1])!; + invariant(pattern !== undefined, `"${f}" requires a pattern argument`); + return new RegExp(pattern).test(fieldArg); + }) + .with(P.union('email', 'url', 'datetime'), (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + return z.string()[f]().safeParse(fieldArg).success; + }) + // list functions + .with(P.union('has', 'hasEvery', 'hasSome'), (f) => { + invariant(expr.args?.[1], `${f} requires a search argument`); + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(Array.isArray(fieldArg), `"${f}" first argument must be an array field`); + + const search = evalExpression(data, expr.args?.[1])!; + const matcher = (x: any[], y: any) => + match(f) + .with('has', () => x.some((item) => item === y)) + .with('hasEvery', () => { + invariant(Array.isArray(y), 'hasEvery second argument must be an array'); + return y.every((v) => x.some((item) => item === v)); + }) + .with('hasSome', () => { + invariant(Array.isArray(y), 'hasSome second argument must be an array'); + return y.some((v) => x.some((item) => item === v)); + }) + .exhaustive(); + return matcher(fieldArg, search); + }) + .with('isEmpty', (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(Array.isArray(fieldArg), `"${f}" first argument must be an array field`); + return fieldArg.length === 0; + }) + .otherwise(() => { + throw new QueryError(`Unknown function "${expr.function}"`); + }) + ); +} diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 869d3535..b5107cdf 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -304,16 +304,37 @@ export function flattenCompoundUniqueFilters(schema: SchemaDef, model: string, f return filter; } - const result: any = {}; + const flattenedResult: any = {}; + const restFilter: any = {}; + for (const [key, value] of Object.entries(filter)) { if (compoundUniques.some(({ name }) => name === key)) { // flatten the compound field - Object.assign(result, value); + Object.assign(flattenedResult, value); } else { - result[key] = value; + restFilter[key] = value; + } + } + + if (Object.keys(flattenedResult).length === 0) { + // nothing flattened + return filter; + } else if (Object.keys(restFilter).length === 0) { + // all flattened + return flattenedResult; + } else { + const flattenedKeys = Object.keys(flattenedResult); + const restKeys = Object.keys(restFilter); + if (flattenedKeys.some((k) => restKeys.includes(k))) { + // keys overlap, cannot merge directly, build an AND clause + return { + AND: [flattenedResult, restFilter], + }; + } else { + // safe to merge directly + return { ...flattenedResult, ...restFilter }; } } - return result; } export function ensureArray(value: T | T[]): T[] { diff --git a/packages/testtools/src/types.d.ts b/packages/testtools/src/types.d.ts index b547127c..9f58106f 100644 --- a/packages/testtools/src/types.d.ts +++ b/packages/testtools/src/types.d.ts @@ -7,6 +7,7 @@ interface CustomMatchers { toResolveWithLength: (length: number) => Promise; toBeRejectedNotFound: () => Promise; toBeRejectedByPolicy: (expectedMessages?: string[]) => Promise; + toBeRejectedByValidation: (expectedMessages?: string[]) => Promise; } declare module 'vitest' { diff --git a/packages/testtools/src/vitest-ext.ts b/packages/testtools/src/vitest-ext.ts index 70b5a61b..06b1709b 100644 --- a/packages/testtools/src/vitest-ext.ts +++ b/packages/testtools/src/vitest-ext.ts @@ -1,4 +1,4 @@ -import { NotFoundError, RejectedByPolicyError } from '@zenstackhq/runtime'; +import { InputValidationError, NotFoundError, RejectedByPolicyError } from '@zenstackhq/runtime'; import { expect } from 'vitest'; function isPromise(value: any) { @@ -19,6 +19,18 @@ function expectError(err: any, errorType: any) { } } +function expectErrorMessages(expectedMessages: string[], message: string) { + for (const m of expectedMessages) { + if (!message.includes(m)) { + return { + message: () => `expected message not found in error: ${m}, got message: ${message}`, + pass: false, + }; + } + } + return undefined; +} + expect.extend({ async toResolveTruthy(received: Promise) { if (!isPromise(received)) { @@ -84,14 +96,9 @@ expect.extend({ await received; } catch (err) { if (expectedMessages && err instanceof RejectedByPolicyError) { - const message = err.message || ''; - for (const m of expectedMessages) { - if (!message.includes(m)) { - return { - message: () => `expected message not found in error: ${m}, got message: ${message}`, - pass: false, - }; - } + const r = expectErrorMessages(expectedMessages, err.message || ''); + if (r) { + return r; } } return expectError(err, RejectedByPolicyError); @@ -101,4 +108,25 @@ expect.extend({ pass: false, }; }, + + async toBeRejectedByValidation(received: Promise, expectedMessages?: string[]) { + if (!isPromise(received)) { + return { message: () => 'a promise is expected', pass: false }; + } + try { + await received; + } catch (err) { + if (expectedMessages && err instanceof InputValidationError) { + const r = expectErrorMessages(expectedMessages, err.message || ''); + if (r) { + return r; + } + } + return expectError(err, InputValidationError); + } + return { + message: () => `expected InputValidationError, got no error`, + pass: false, + }; + }, }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3c85aa5e..740f983e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -134,7 +134,7 @@ importers: version: 0.2.6 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/runtime': specifier: workspace:* version: link:../runtime @@ -143,10 +143,10 @@ importers: version: link:../testtools '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../vitest-config + version: link:../config/vitest-config better-sqlite3: specifier: 'catalog:' version: 12.2.0 @@ -158,10 +158,16 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config + + packages/config/eslint-config: {} + + packages/config/typescript-config: {} + + packages/config/vitest-config: {} packages/create-zenstack: dependencies: @@ -177,10 +183,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config packages/dialects/sql.js: devDependencies: @@ -189,13 +195,13 @@ importers: version: 1.4.9 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../../eslint-config + version: link:../../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../typescript-config + version: link:../../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../vitest-config + version: link:../../config/vitest-config kysely: specifier: 'catalog:' version: 0.27.6 @@ -203,8 +209,6 @@ importers: specifier: ^1.13.0 version: 1.13.0 - packages/eslint-config: {} - packages/ide/vscode: dependencies: '@zenstackhq/language': @@ -225,10 +229,10 @@ importers: version: 1.101.0 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../../eslint-config + version: link:../../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../typescript-config + version: link:../../config/typescript-config packages/language: dependencies: @@ -253,13 +257,13 @@ importers: version: link:../common-helpers '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../vitest-config + version: link:../config/vitest-config glob: specifier: ^11.0.2 version: 11.0.2 @@ -296,13 +300,13 @@ importers: version: 8.11.11 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../../eslint-config + version: link:../../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../typescript-config + version: link:../../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../vitest-config + version: link:../../config/vitest-config packages/runtime: dependencies: @@ -357,7 +361,7 @@ importers: version: 2.0.7 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/language': specifier: workspace:* version: link:../language @@ -366,10 +370,10 @@ importers: version: link:../sdk '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../vitest-config + version: link:../config/vitest-config tsx: specifier: ^4.19.2 version: 4.19.2 @@ -397,10 +401,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config decimal.js: specifier: ^10.4.3 version: 10.4.3 @@ -419,10 +423,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config packages/testtools: dependencies: @@ -477,10 +481,10 @@ importers: version: 0.2.6 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config copyfiles: specifier: ^2.4.1 version: 2.4.1 @@ -488,10 +492,6 @@ importers: specifier: 'catalog:' version: 5.8.3 - packages/typescript-config: {} - - packages/vitest-config: {} - packages/zod: dependencies: '@zenstackhq/runtime': @@ -503,10 +503,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config zod: specifier: ~3.25.0 version: 3.25.76 @@ -531,7 +531,7 @@ importers: version: link:../../packages/cli '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../packages/typescript-config + version: link:../../packages/config/typescript-config prisma: specifier: 'catalog:' version: 6.14.0(typescript@5.8.3) @@ -580,10 +580,10 @@ importers: version: 11.0.0 '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../packages/typescript-config + version: link:../../packages/config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../packages/vitest-config + version: link:../../packages/config/vitest-config tests/regression: dependencies: @@ -605,10 +605,10 @@ importers: version: link:../../packages/sdk '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../packages/typescript-config + version: link:../../packages/config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../packages/vitest-config + version: link:../../packages/config/vitest-config packages: diff --git a/tests/e2e/orm/client-api/compound-id.test.ts b/tests/e2e/orm/client-api/compound-id.test.ts index b983b045..dc11c253 100644 --- a/tests/e2e/orm/client-api/compound-id.test.ts +++ b/tests/e2e/orm/client-api/compound-id.test.ts @@ -1,5 +1,5 @@ -import { describe, expect, it } from 'vitest'; import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; describe('Compound ID tests', () => { describe('to-one relation', () => { diff --git a/tests/e2e/orm/client-api/type-coverage.test.ts b/tests/e2e/orm/client-api/type-coverage.test.ts index 9ce29fce..a0c24880 100644 --- a/tests/e2e/orm/client-api/type-coverage.test.ts +++ b/tests/e2e/orm/client-api/type-coverage.test.ts @@ -1,6 +1,6 @@ +import { createTestClient, getTestDbProvider } from '@zenstackhq/testtools'; import Decimal from 'decimal.js'; import { describe, expect, it } from 'vitest'; -import { createTestClient, getTestDbProvider } from '@zenstackhq/testtools'; describe('Zmodel type coverage tests', () => { it('supports all types - plain', async () => { diff --git a/tests/e2e/orm/validation/custom-validation.test.ts b/tests/e2e/orm/validation/custom-validation.test.ts new file mode 100644 index 00000000..35667e4c --- /dev/null +++ b/tests/e2e/orm/validation/custom-validation.test.ts @@ -0,0 +1,111 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Custom validation tests', () => { + it('works with custom validation', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + str1 String? + str2 String? + str3 String? + str4 String? + str5 String? + int1 Int? + list1 Int[] + list2 Int[] + + @@validate( + (str1 == null || length(str1, 8, 10)) + && (int1 == null || (int1 > 1 && int1 < 4)), + 'invalid fields') + + @@validate(str1 == null || (startsWith(str1, 'a') && endsWith(str1, 'm') && contains(str1, 'b')), 'invalid fields') + + @@validate(str2 == null || regex(str2, '^x.*z$'), 'invalid str2') + + @@validate(str3 == null || email(str3), 'invalid str3') + + @@validate(str4 == null || url(str4), 'invalid str4') + + @@validate(str5 == null || datetime(str5), 'invalid str5') + + @@validate(list1 == null || (has(list1, 1) && hasSome(list1, [2, 3]) && hasEvery(list1, [4, 5])), 'invalid list1') + + @@validate(list2 == null || isEmpty(list2), 'invalid list2', ['x', 'y']) + } + `, + { provider: 'postgresql' }, + ); + + await db.foo.create({ data: { id: 100 } }); + + for (const action of ['create', 'update']) { + const _t = + action === 'create' + ? (data: any) => db.foo.create({ data }) + : (data: any) => db.foo.update({ where: { id: 100 }, data }); + // violates length + await expect(_t({ str1: 'abd@efg.com' })).toBeRejectedByValidation(['invalid fields']); + await expect(_t({ str1: 'a@b.c' })).toBeRejectedByValidation(['invalid fields']); + + // violates int1 > 1 + await expect(_t({ int1: 1 })).toBeRejectedByValidation(['invalid fields']); + + // violates startsWith + await expect(_t({ str1: 'b@cd.com' })).toBeRejectedByValidation(['invalid fields']); + + // violates endsWith + await expect(_t({ str1: 'a@b.gov' })).toBeRejectedByValidation(['invalid fields']); + + // violates contains + await expect(_t({ str1: 'a@cd.com' })).toBeRejectedByValidation(['invalid fields']); + + // violates regex + await expect(_t({ str2: 'xab' })).toBeRejectedByValidation(['invalid str2']); + + // violates email + await expect(_t({ str3: 'not-an-email' })).toBeRejectedByValidation(['invalid str3']); + + // violates url + await expect(_t({ str4: 'not-an-url' })).toBeRejectedByValidation(['invalid str4']); + + // violates datetime + await expect(_t({ str5: 'not-an-datetime' })).toBeRejectedByValidation(['invalid str5']); + + // violates has + await expect(_t({ list1: [2, 3, 4, 5] })).toBeRejectedByValidation(['invalid list1']); + + // violates hasSome + await expect(_t({ list1: [1, 4, 5] })).toBeRejectedByValidation(['invalid list1']); + + // violates hasEvery + await expect(_t({ list1: [1, 2, 3, 4] })).toBeRejectedByValidation(['invalid list1']); + + // violates isEmpty + let thrown = false; + try { + await _t({ list2: [1] }); + } catch (err) { + thrown = true; + expect((err as any).cause.issues[0].path).toEqual(['data', 'x', 'y']); + } + expect(thrown).toBe(true); + + // satisfies all + await expect( + _t({ + str1: 'ab12345m', + str2: 'x...z', + str3: 'ab@c.com', + str4: 'http://a.b.c', + str5: new Date().toISOString(), + int1: 2, + list1: [1, 2, 4, 5], + list2: [], + }), + ).toResolveTruthy(); + } + }); +}); diff --git a/tests/e2e/orm/validation/nested.test.ts b/tests/e2e/orm/validation/nested.test.ts new file mode 100644 index 00000000..0949a503 --- /dev/null +++ b/tests/e2e/orm/validation/nested.test.ts @@ -0,0 +1,41 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Nested field validation tests', () => { + it('works with nested create/update', async () => { + const db = await createTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + } + + model Profile { + id Int @id @default(autoincrement()) + email String @email + user User @relation(fields: [userId], references: [id]) + userId Int @unique + @@validate(contains(email, 'zenstack'), 'email must be a zenstack email') + } + `, + ); + + await db.user.create({ data: { id: 1 } }); + + for (const action of ['create', 'update']) { + const _t = + action === 'create' + ? (data: any) => db.user.update({ where: { id: 1 }, data: { profile: { create: data } } }) + : (data: any) => db.user.update({ where: { id: 1 }, data: { profile: { update: data } } }); + + // violates email + await expect(_t({ email: 'zenstack' })).toBeRejectedByValidation(['Invalid email']); + + // violates custom validation + await expect(_t({ email: 'a@b.com' })).toBeRejectedByValidation(['email must be a zenstack email']); + + // satisfies all + await expect(_t({ email: 'me@zenstack.dev' })).toResolveTruthy(); + } + }); +}); diff --git a/tests/e2e/orm/validation/toplevel.test.ts b/tests/e2e/orm/validation/toplevel.test.ts new file mode 100644 index 00000000..a7d76475 --- /dev/null +++ b/tests/e2e/orm/validation/toplevel.test.ts @@ -0,0 +1,209 @@ +import { createTestClient, loadSchemaWithError } from '@zenstackhq/testtools'; +import Decimal from 'decimal.js'; +import { describe, expect, it } from 'vitest'; + +describe('Toplevel field validation tests', () => { + it('works with string fields', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + str1 String? @length(2, 4) @startsWith('a') @endsWith('b') @contains('m') @regex('b{2}') + str2 String? @email + str3 String? @datetime + str4 String? @url + str5 String? @trim @lower + str6 String? @upper + } + `, + ); + + await db.foo.create({ data: { id: 100 } }); + + for (const action of ['create', 'update', 'upsert', 'updateMany']) { + console.log(`Testing action: ${action}`); + const _t = + action === 'create' + ? (data: any) => db.foo.create({ data }) + : action === 'update' + ? (data: any) => db.foo.update({ where: { id: 100 }, data }) + : action === 'upsert' + ? (data: any) => + db.foo.upsert({ where: { id: 100 }, create: { id: 101, ...data }, update: data }) + : (data: any) => db.foo.updateMany({ where: { id: 100 }, data }); + + // violates @length min + await expect(_t({ str1: 'a' })).toBeRejectedByValidation(); + + // violates @length max + await expect(_t({ str1: 'abcde' })).toBeRejectedByValidation(); + + // violates @startsWith + await expect(_t({ str1: 'bcd' })).toBeRejectedByValidation(); + + // violates @endsWith + await expect(_t({ str1: 'abc' })).toBeRejectedByValidation(); + + // violates @contains + await expect(_t({ str1: 'abz' })).toBeRejectedByValidation(); + + // violates @regex + await expect(_t({ str1: 'amcb' })).toBeRejectedByValidation(); + + // satisfies all + await expect(_t({ str1: 'ambb' })).toResolveTruthy(); + + // violates @email + await expect(_t({ str2: 'not-an-email' })).toBeRejectedByValidation(['Invalid email']); + + // satisfies @email + await expect(_t({ str2: 'test@example.com' })).toResolveTruthy(); + + // violates @datetime + await expect(_t({ str3: 'not-datetime' })).toBeRejectedByValidation(); + + // satisfies @datetime + await expect(_t({ str3: new Date().toISOString() })).toResolveTruthy(); + + // violates @url + await expect(_t({ str4: 'not-a-url' })).toBeRejectedByValidation(); + + // satisfies @url + await expect(_t({ str4: 'https://example.com' })).toResolveTruthy(); + + // test @trim and @lower + if (action !== 'updateMany') { + await expect(_t({ str5: ' AbC ' })).resolves.toMatchObject({ str5: 'abc' }); + } else { + await expect(_t({ str5: ' AbC ' })).resolves.toMatchObject({ count: 1 }); + } + + // test @upper + if (action !== 'updateMany') { + await expect(_t({ str6: 'aBc' })).resolves.toMatchObject({ str6: 'ABC' }); + } else { + await expect(_t({ str6: 'aBc' })).resolves.toMatchObject({ count: 1 }); + } + } + }); + + it('works with number fields', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + int1 Int? @gt(2) @lt(4) + int2 Int? @gte(2) @lte(4) + } + `, + ); + + // violates @gt + await expect(db.foo.create({ data: { int1: 1 } })).toBeRejectedByValidation(); + + // violates @lt + await expect(db.foo.create({ data: { int1: 4 } })).toBeRejectedByValidation(); + + // violates @gte + await expect(db.foo.create({ data: { int2: 1 } })).toBeRejectedByValidation(); + + // violates @lte + await expect(db.foo.create({ data: { int2: 5 } })).toBeRejectedByValidation(); + + // satisfies all + await expect(db.foo.create({ data: { int1: 3, int2: 4 } })).toResolveTruthy(); + }); + + it('works with bigint fields', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + int1 BigInt? @gt(2) @lt(4) + int2 BigInt? @gte(2) @lte(4) + } + `, + ); + + // violates @gt + await expect(db.foo.create({ data: { int1: 1 } })).toBeRejectedByValidation(); + + // violates @lt + await expect(db.foo.create({ data: { int1: 4 } })).toBeRejectedByValidation(); + + // violates @gte + await expect(db.foo.create({ data: { int2: 1n } })).toBeRejectedByValidation(); + + // violates @lte + await expect(db.foo.create({ data: { int2: 5n } })).toBeRejectedByValidation(); + + // satisfies all + await expect(db.foo.create({ data: { int1: 3, int2: 4 } })).toResolveTruthy(); + }); + + it('works with decimal fields', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + int1 Decimal? @gt(2) @lt(4) + int2 Decimal? @gte(2) @lte(4) + } + `, + ); + + // violates @gt + await expect(db.foo.create({ data: { int1: 1 } })).toBeRejectedByValidation(); + + // violates @lt + await expect(db.foo.create({ data: { int1: new Decimal(4) } })).toBeRejectedByValidation(); + + // invalid decimal string + await expect(db.foo.create({ data: { int2: 'f1.2' } })).toBeRejectedByValidation(); + + // violates @gte + await expect(db.foo.create({ data: { int2: '1.1' } })).toBeRejectedByValidation(); + + // violates @lte + await expect(db.foo.create({ data: { int2: '5.12345678' } })).toBeRejectedByValidation(); + + // satisfies all + await expect(db.foo.create({ data: { int1: '3.3', int2: new Decimal(3.9) } })).toResolveTruthy(); + }); + + it('rejects accessing relation fields', async () => { + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + bars Bar[] + @@validate(bars != null) + } + + model Bar { + id Int @id @default(autoincrement()) + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int + } + `, + 'cannot use relation fields', + ); + + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + bars Bar[] + @@validate(bars.fooId > 0) + } + + model Bar { + id Int @id @default(autoincrement()) + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int + } + `, + 'cannot use relation fields', + ); + }); +});