From 502da9412962292326df5b082294cdd6ff2c950f Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 6 Aug 2025 16:54:08 +0800 Subject: [PATCH 1/2] fix: tighten up query input validation, fixed case-sensitivity compatibility with Prisma --- packages/runtime/src/client/crud-types.ts | 39 ++- .../runtime/src/client/crud/dialects/base.ts | 61 ++-- packages/runtime/src/client/crud/validator.ts | 312 ++++++++++-------- .../runtime/test/client-api/filter.test.ts | 130 +++++++- 4 files changed, 342 insertions(+), 200 deletions(-) diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index 20fd4be9..4c4986b7 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -223,7 +223,7 @@ export type WhereInput< : FieldIsArray extends true ? ArrayFilter> : // primitive - PrimitiveFilter, ModelFieldIsOptional>; + PrimitiveFilter, ModelFieldIsOptional>; } & { $expr?: (eb: ExpressionBuilder, Model>) => OperandExpression; } & { @@ -249,21 +249,21 @@ type ArrayFilter = { isEmpty?: boolean; }; -type PrimitiveFilter = T extends 'String' - ? StringFilter +type PrimitiveFilter = T extends 'String' + ? StringFilter : T extends 'Int' | 'Float' | 'Decimal' | 'BigInt' - ? NumberFilter + ? NumberFilter : T extends 'Boolean' ? BooleanFilter : T extends 'DateTime' - ? DateTimeFilter + ? DateTimeFilter : T extends 'Bytes' ? BytesFilter : T extends 'Json' ? 'Not implemented yet' // TODO: Json filter : never; -type CommonPrimitiveFilter = { +type CommonPrimitiveFilter = { equals?: NullableIf; in?: DataType[]; notIn?: DataType[]; @@ -271,25 +271,30 @@ type CommonPrimitiveFilter; + not?: PrimitiveFilter; }; -export type StringFilter = +export type StringFilter = | NullableIf - | (CommonPrimitiveFilter & { + | (CommonPrimitiveFilter & { contains?: string; startsWith?: string; endsWith?: string; - mode?: 'default' | 'insensitive'; - }); + } & (ProviderSupportsCaseSensitivity extends true + ? { + mode?: 'default' | 'insensitive'; + } + : {})); -export type NumberFilter = - | NullableIf - | CommonPrimitiveFilter; +export type NumberFilter< + Schema extends SchemaDef, + T extends 'Int' | 'Float' | 'Decimal' | 'BigInt', + Nullable extends boolean, +> = NullableIf | CommonPrimitiveFilter; -export type DateTimeFilter = +export type DateTimeFilter = | NullableIf - | CommonPrimitiveFilter; + | CommonPrimitiveFilter; export type BytesFilter = | NullableIf @@ -1192,4 +1197,6 @@ type HasToManyRelations = Schema['provider'] extends 'postgresql' ? true : false; + // #endregion diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index c1bc7660..ee316346 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -457,6 +457,7 @@ export abstract class BaseCrudDialect { recurse: (value: unknown) => Expression, throwIfInvalid = false, onlyForKeys: string[] | undefined = undefined, + excludeKeys: string[] = [], ) { if (payload === null || !isPlainObject(payload)) { return { @@ -472,6 +473,9 @@ export abstract class BaseCrudDialect { if (onlyForKeys && !onlyForKeys.includes(op)) { continue; } + if (excludeKeys.includes(op)) { + continue; + } const rhs = Array.isArray(value) ? value.map(getRhs) : getRhs(value); const condition = match(op) .with('equals', () => (rhs === null ? eb(lhs, 'is', null) : eb(lhs, '=', rhs))) @@ -513,20 +517,23 @@ export abstract class BaseCrudDialect { return { conditions, consumedKeys }; } - private buildStringFilter(eb: ExpressionBuilder, fieldRef: Expression, payload: StringFilter) { - let insensitive = false; - if (payload && typeof payload === 'object' && 'mode' in payload && payload.mode === 'insensitive') { - insensitive = true; - fieldRef = eb.fn('lower', [fieldRef]); + private buildStringFilter( + eb: ExpressionBuilder, + fieldRef: Expression, + payload: StringFilter, + ) { + let mode: 'default' | 'insensitive' | undefined; + if (payload && typeof payload === 'object' && 'mode' in payload) { + mode = payload.mode; } const { conditions, consumedKeys } = this.buildStandardFilter( eb, 'String', payload, - fieldRef, - (value) => this.prepStringCasing(eb, value, insensitive), - (value) => this.buildStringFilter(eb, fieldRef, value as StringFilter), + mode === 'insensitive' ? eb.fn('lower', [fieldRef]) : fieldRef, + (value) => this.prepStringCasing(eb, value, mode), + (value) => this.buildStringFilter(eb, fieldRef, value as StringFilter), ); if (payload && typeof payload === 'object') { @@ -538,19 +545,19 @@ export abstract class BaseCrudDialect { const condition = match(key) .with('contains', () => - insensitive - ? eb(fieldRef, 'ilike', sql.lit(`%${value}%`)) - : eb(fieldRef, 'like', sql.lit(`%${value}%`)), + mode === 'insensitive' + ? eb(fieldRef, 'ilike', sql.val(`%${value}%`)) + : eb(fieldRef, 'like', sql.val(`%${value}%`)), ) .with('startsWith', () => - insensitive - ? eb(fieldRef, 'ilike', sql.lit(`${value}%`)) - : eb(fieldRef, 'like', sql.lit(`${value}%`)), + mode === 'insensitive' + ? eb(fieldRef, 'ilike', sql.val(`${value}%`)) + : eb(fieldRef, 'like', sql.val(`${value}%`)), ) .with('endsWith', () => - insensitive - ? eb(fieldRef, 'ilike', sql.lit(`%${value}`)) - : eb(fieldRef, 'like', sql.lit(`%${value}`)), + mode === 'insensitive' + ? eb(fieldRef, 'ilike', sql.val(`%${value}`)) + : eb(fieldRef, 'like', sql.val(`%${value}`)), ) .otherwise(() => { throw new Error(`Invalid string filter key: ${key}`); @@ -565,13 +572,21 @@ export abstract class BaseCrudDialect { return this.and(eb, ...conditions); } - private prepStringCasing(eb: ExpressionBuilder, value: unknown, toLower: boolean = true): any { + private prepStringCasing( + eb: ExpressionBuilder, + value: unknown, + mode: 'default' | 'insensitive' | undefined, + ): any { + if (!mode || mode === 'default') { + return value === null ? value : sql.val(value); + } + if (typeof value === 'string') { - return toLower ? eb.fn('lower', [sql.lit(value)]) : sql.lit(value); + return eb.fn('lower', [sql.val(value)]); } else if (Array.isArray(value)) { - return value.map((v) => this.prepStringCasing(eb, v, toLower)); + return value.map((v) => this.prepStringCasing(eb, v, mode)); } else { - return value === null ? null : sql.lit(value); + return value === null ? null : sql.val(value); } } @@ -613,7 +628,7 @@ export abstract class BaseCrudDialect { private buildDateTimeFilter( eb: ExpressionBuilder, fieldRef: Expression, - payload: DateTimeFilter, + payload: DateTimeFilter, ) { const { conditions } = this.buildStandardFilter( eb, @@ -621,7 +636,7 @@ export abstract class BaseCrudDialect { payload, fieldRef, (value) => this.transformPrimitive(value, 'DateTime', false), - (value) => this.buildDateTimeFilter(eb, fieldRef, value as DateTimeFilter), + (value) => this.buildDateTimeFilter(eb, fieldRef, value as DateTimeFilter), true, ); return this.and(eb, ...conditions); diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 32ab09cb..c586abfd 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -210,12 +210,12 @@ export class InputValidator { fields['cursor'] = this.makeCursorSchema(model).optional(); if (options.collection) { - fields['skip'] = z.number().int().nonnegative().optional(); - fields['take'] = z.number().int().optional(); + fields['skip'] = this.makeSkipSchema().optional(); + fields['take'] = this.makeTakeSchema().optional(); fields['orderBy'] = this.orArray(this.makeOrderBySchema(model, true, false), true).optional(); } - let result: ZodType = z.object(fields).strict(); + let result: ZodType = z.strictObject(fields); result = this.refineForSelectIncludeMutuallyExclusive(result); result = this.refineForSelectOmitMutuallyExclusive(result); @@ -292,7 +292,7 @@ export class InputValidator { // to-many relation fieldSchema = z.union([ fieldSchema, - z.object({ + z.strictObject({ some: fieldSchema.optional(), every: fieldSchema.optional(), none: fieldSchema.optional(), @@ -302,7 +302,7 @@ export class InputValidator { // to-one relation fieldSchema = z.union([ fieldSchema, - z.object({ + z.strictObject({ is: fieldSchema.optional(), isNot: fieldSchema.optional(), }), @@ -381,7 +381,7 @@ export class InputValidator { true, ).optional(); - const baseWhere = z.object(fields).strict(); + const baseWhere = z.strictObject(fields); let result: ZodType = baseWhere; if (unique) { @@ -414,7 +414,7 @@ export class InputValidator { ); return z.union([ this.nullableIf(baseSchema, optional), - z.object({ + z.strictObject({ equals: components.equals, in: components.in, notIn: components.notIn, @@ -424,7 +424,7 @@ export class InputValidator { } private makeArrayFilterSchema(type: BuiltinType) { - return z.object({ + return z.strictObject({ equals: this.makePrimitiveSchema(type).array().optional(), has: this.makePrimitiveSchema(type).optional(), hasEvery: this.makePrimitiveSchema(type).array().optional(), @@ -468,7 +468,7 @@ export class InputValidator { private makeBooleanFilterSchema(optional: boolean): ZodType { return z.union([ this.nullableIf(z.boolean(), optional), - z.object({ + z.strictObject({ equals: this.nullableIf(z.boolean(), optional).optional(), not: z.lazy(() => this.makeBooleanFilterSchema(optional)).optional(), }), @@ -482,7 +482,7 @@ export class InputValidator { ); return z.union([ this.nullableIf(baseSchema, optional), - z.object({ + z.strictObject({ equals: components.equals, in: components.in, notIn: components.notIn, @@ -508,7 +508,7 @@ export class InputValidator { private makeCommonPrimitiveFilterSchema(baseSchema: ZodType, optional: boolean, makeThis: () => ZodType) { return z.union([ this.nullableIf(baseSchema, optional), - z.object(this.makeCommonPrimitiveFilterComponents(baseSchema, optional, makeThis)), + z.strictObject(this.makeCommonPrimitiveFilterComponents(baseSchema, optional, makeThis)), ]); } @@ -519,9 +519,26 @@ export class InputValidator { } private makeStringFilterSchema(optional: boolean): ZodType { - return this.makeCommonPrimitiveFilterSchema(z.string(), optional, () => - z.lazy(() => this.makeStringFilterSchema(optional)), - ); + return z.union([ + this.nullableIf(z.string(), optional), + z.strictObject({ + ...this.makeCommonPrimitiveFilterComponents(z.string(), optional, () => + z.lazy(() => this.makeStringFilterSchema(optional)), + ), + startsWith: z.string().optional(), + endsWith: z.string().optional(), + contains: z.string().optional(), + ...(this.providerSupportsCaseSensitivity + ? { + mode: this.makeStringModeSchema().optional(), + } + : {}), + }), + ]); + } + + private makeStringModeSchema() { + return z.union([z.literal('default'), z.literal('insensitive')]); } private makeSelectSchema(model: string) { @@ -533,7 +550,7 @@ export class InputValidator { fields[field] = z .union([ z.literal(true), - z.object({ + z.strictObject({ select: z.lazy(() => this.makeSelectSchema(fieldDef.type)).optional(), include: z.lazy(() => this.makeIncludeSchema(fieldDef.type)).optional(), }), @@ -550,27 +567,29 @@ export class InputValidator { fields['_count'] = z .union([ z.literal(true), - z.object( - toManyRelations.reduce( - (acc, fieldDef) => ({ - ...acc, - [fieldDef.name]: z - .union([ - z.boolean(), - z.object({ - where: this.makeWhereSchema(fieldDef.type, false, false), - }), - ]) - .optional(), - }), - {} as Record, + z.strictObject({ + select: z.strictObject( + toManyRelations.reduce( + (acc, fieldDef) => ({ + ...acc, + [fieldDef.name]: z + .union([ + z.boolean(), + z.strictObject({ + where: this.makeWhereSchema(fieldDef.type, false, false), + }), + ]) + .optional(), + }), + {} as Record, + ), ), - ), + }), ]) .optional(); } - return z.object(fields).strict(); + return z.strictObject(fields); } private makeOmitSchema(model: string) { @@ -582,7 +601,7 @@ export class InputValidator { fields[field] = z.boolean().optional(); } } - return z.object(fields).strict(); + return z.strictObject(fields); } private makeIncludeSchema(model: string) { @@ -594,17 +613,22 @@ export class InputValidator { fields[field] = z .union([ z.literal(true), - z.object({ + z.strictObject({ select: z.lazy(() => this.makeSelectSchema(fieldDef.type)).optional(), include: z.lazy(() => this.makeIncludeSchema(fieldDef.type)).optional(), + omit: z.lazy(() => this.makeOmitSchema(fieldDef.type)).optional(), where: z.lazy(() => this.makeWhereSchema(fieldDef.type, false)).optional(), + orderBy: z.lazy(() => this.makeOrderBySchema(fieldDef.type, true, false)).optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), + distinct: this.makeDistinctSchema(fieldDef.type).optional(), }), ]) .optional(); } } - return z.object(fields).strict(); + return z.strictObject(fields); } private makeOrderBySchema(model: string, withRelation: boolean, WithAggregation: boolean) { @@ -616,9 +640,15 @@ export class InputValidator { if (fieldDef.relation) { // relations if (withRelation) { - fields[field] = z.lazy(() => - this.makeOrderBySchema(fieldDef.type, withRelation, WithAggregation).optional(), - ); + fields[field] = z.lazy(() => { + let relationOrderBy = this.makeOrderBySchema(fieldDef.type, withRelation, WithAggregation); + if (fieldDef.array) { + relationOrderBy = relationOrderBy.extend({ + _count: sort, + }); + } + return relationOrderBy.optional(); + }); } } else { // scalars @@ -626,7 +656,7 @@ export class InputValidator { fields[field] = z .union([ sort, - z.object({ + z.strictObject({ sort, nulls: z.union([z.literal('first'), z.literal('last')]), }), @@ -646,7 +676,7 @@ export class InputValidator { } } - return z.object(fields); + return z.strictObject(fields); } private makeDistinctSchema(model: string) { @@ -665,14 +695,12 @@ export class InputValidator { private makeCreateSchema(model: string) { const dataSchema = this.makeCreateDataSchema(model, false); - const schema = z - .object({ - data: dataSchema, - select: this.makeSelectSchema(model).optional(), - include: this.makeIncludeSchema(model).optional(), - omit: this.makeOmitSchema(model).optional(), - }) - .strict(); + const schema = z.object({ + data: dataSchema, + select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), + }); return this.refineForSelectIncludeMutuallyExclusive(schema); } @@ -683,7 +711,7 @@ export class InputValidator { private makeCreateManyAndReturnSchema(model: string) { const base = this.makeCreateManyDataSchema(model, []); const result = base.merge( - z.object({ + z.strictObject({ select: this.makeSelectSchema(model).optional(), omit: this.makeOmitSchema(model).optional(), }), @@ -769,7 +797,7 @@ export class InputValidator { fieldSchema = z .union([ z.array(fieldSchema), - z.object({ + z.strictObject({ set: z.array(fieldSchema), }), ]) @@ -793,13 +821,13 @@ export class InputValidator { }); if (!hasRelation) { - return this.orArray(z.object(uncheckedVariantFields).strict(), canBeArray); + return this.orArray(z.strictObject(uncheckedVariantFields), canBeArray); } else { return z.union([ - z.object(uncheckedVariantFields).strict(), - z.object(checkedVariantFields).strict(), - ...(canBeArray ? [z.array(z.object(uncheckedVariantFields).strict())] : []), - ...(canBeArray ? [z.array(z.object(checkedVariantFields).strict())] : []), + z.strictObject(uncheckedVariantFields), + z.strictObject(checkedVariantFields), + ...(canBeArray ? [z.array(z.strictObject(uncheckedVariantFields))] : []), + ...(canBeArray ? [z.array(z.strictObject(checkedVariantFields))] : []), ]); } } @@ -838,7 +866,7 @@ export class InputValidator { fields['update'] = array ? this.orArray( - z.object({ + z.strictObject({ where: this.makeWhereSchema(fieldType, true), data: this.makeUpdateDataSchema(fieldType, withoutFields), }), @@ -846,7 +874,7 @@ export class InputValidator { ).optional() : z .union([ - z.object({ + z.strictObject({ where: this.makeWhereSchema(fieldType, true), data: this.makeUpdateDataSchema(fieldType, withoutFields), }), @@ -855,7 +883,7 @@ export class InputValidator { .optional(); fields['upsert'] = this.orArray( - z.object({ + z.strictObject({ where: this.makeWhereSchema(fieldType, true), create: this.makeCreateDataSchema(fieldType, false, withoutFields), update: this.makeUpdateDataSchema(fieldType, withoutFields), @@ -868,7 +896,7 @@ export class InputValidator { fields['set'] = this.makeSetDataSchema(fieldType, true).optional(); fields['updateMany'] = this.orArray( - z.object({ + z.strictObject({ where: this.makeWhereSchema(fieldType, false, true), data: this.makeUpdateDataSchema(fieldType, withoutFields), }), @@ -879,7 +907,7 @@ export class InputValidator { } } - return z.object(fields).strict(); + return z.strictObject(fields); } private makeSetDataSchema(model: string, canBeArray: boolean) { @@ -911,23 +939,19 @@ export class InputValidator { const whereSchema = this.makeWhereSchema(model, true); const createSchema = this.makeCreateDataSchema(model, false, withoutFields); return this.orArray( - z - .object({ - where: whereSchema, - create: createSchema, - }) - .strict(), + z.object({ + where: whereSchema, + create: createSchema, + }), canBeArray, ); } private makeCreateManyDataSchema(model: string, withoutFields: string[]) { - return z - .object({ - data: this.makeCreateDataSchema(model, true, withoutFields, true), - skipDuplicates: z.boolean().optional(), - }) - .strict(); + return z.object({ + data: this.makeCreateDataSchema(model, true, withoutFields, true), + skipDuplicates: z.boolean().optional(), + }); } // #endregion @@ -935,33 +959,28 @@ export class InputValidator { // #region Update private makeUpdateSchema(model: string) { - const schema = z - .object({ - where: this.makeWhereSchema(model, true), - data: this.makeUpdateDataSchema(model), - select: this.makeSelectSchema(model).optional(), - include: this.makeIncludeSchema(model).optional(), - omit: this.makeOmitSchema(model).optional(), - }) - .strict(); - + const schema = z.object({ + where: this.makeWhereSchema(model, true), + data: this.makeUpdateDataSchema(model), + select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), + }); return this.refineForSelectIncludeMutuallyExclusive(schema); } private makeUpdateManySchema(model: string) { - return z - .object({ - where: this.makeWhereSchema(model, false).optional(), - data: this.makeUpdateDataSchema(model, [], true), - limit: z.number().int().nonnegative().optional(), - }) - .strict(); + return z.object({ + where: this.makeWhereSchema(model, false).optional(), + data: this.makeUpdateDataSchema(model, [], true), + limit: z.number().int().nonnegative().optional(), + }); } private makeUpdateManyAndReturnSchema(model: string) { const base = this.makeUpdateManySchema(model); const result = base.merge( - z.object({ + z.strictObject({ select: this.makeSelectSchema(model).optional(), omit: this.makeOmitSchema(model).optional(), }), @@ -970,17 +989,14 @@ export class InputValidator { } private makeUpsertSchema(model: string) { - const schema = z - .object({ - where: this.makeWhereSchema(model, true), - create: this.makeCreateDataSchema(model, false), - update: this.makeUpdateDataSchema(model), - select: this.makeSelectSchema(model).optional(), - include: this.makeIncludeSchema(model).optional(), - omit: this.makeOmitSchema(model).optional(), - }) - .strict(); - + const schema = z.object({ + where: this.makeWhereSchema(model, true), + create: this.makeCreateDataSchema(model, false), + update: this.makeUpdateDataSchema(model), + select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), + }); return this.refineForSelectIncludeMutuallyExclusive(schema); } @@ -1074,9 +1090,9 @@ export class InputValidator { }); if (!hasRelation) { - return z.object(uncheckedVariantFields).strict(); + return z.strictObject(uncheckedVariantFields); } else { - return z.union([z.object(uncheckedVariantFields).strict(), z.object(checkedVariantFields).strict()]); + return z.union([z.strictObject(uncheckedVariantFields), z.strictObject(checkedVariantFields)]); } } @@ -1085,13 +1101,11 @@ export class InputValidator { // #region Delete private makeDeleteSchema(model: GetModels) { - const schema = z - .object({ - where: this.makeWhereSchema(model, true), - select: this.makeSelectSchema(model).optional(), - include: this.makeIncludeSchema(model).optional(), - }) - .strict(); + const schema = z.object({ + where: this.makeWhereSchema(model, true), + select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), + }); return this.refineForSelectIncludeMutuallyExclusive(schema); } @@ -1101,7 +1115,7 @@ export class InputValidator { where: this.makeWhereSchema(model, false).optional(), limit: z.number().int().nonnegative().optional(), }) - .strict() + .optional(); } @@ -1113,12 +1127,12 @@ export class InputValidator { return z .object({ where: this.makeWhereSchema(model, false).optional(), - skip: z.number().int().nonnegative().optional(), - take: z.number().int().optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), orderBy: this.orArray(this.makeOrderBySchema(model, true, false), true).optional(), select: this.makeCountAggregateInputSchema(model).optional(), }) - .strict() + .optional(); } @@ -1126,18 +1140,16 @@ export class InputValidator { const modelDef = requireModel(this.schema, model); return z.union([ z.literal(true), - z - .object({ - _all: z.literal(true).optional(), - ...Object.keys(modelDef.fields).reduce( - (acc, field) => { - acc[field] = z.literal(true).optional(); - return acc; - }, - {} as Record, - ), - }) - .strict(), + z.object({ + _all: z.literal(true).optional(), + ...Object.keys(modelDef.fields).reduce( + (acc, field) => { + acc[field] = z.literal(true).optional(); + return acc; + }, + {} as Record, + ), + }), ]); } @@ -1149,8 +1161,8 @@ export class InputValidator { return z .object({ where: this.makeWhereSchema(model, false).optional(), - skip: z.number().int().nonnegative().optional(), - take: z.number().int().optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), orderBy: this.orArray(this.makeOrderBySchema(model, true, false), true).optional(), _count: this.makeCountAggregateInputSchema(model).optional(), _avg: this.makeSumAvgInputSchema(model).optional(), @@ -1158,13 +1170,13 @@ export class InputValidator { _min: this.makeMinMaxInputSchema(model).optional(), _max: this.makeMinMaxInputSchema(model).optional(), }) - .strict() + .optional(); } makeSumAvgInputSchema(model: GetModels) { const modelDef = requireModel(this.schema, model); - return z.object( + return z.strictObject( Object.keys(modelDef.fields).reduce( (acc, field) => { const fieldDef = requireField(this.schema, model, field); @@ -1180,7 +1192,7 @@ export class InputValidator { makeMinMaxInputSchema(model: GetModels) { const modelDef = requireModel(this.schema, model); - return z.object( + return z.strictObject( Object.keys(modelDef.fields).reduce( (acc, field) => { const fieldDef = requireField(this.schema, model, field); @@ -1198,22 +1210,19 @@ export class InputValidator { const modelDef = requireModel(this.schema, model); const nonRelationFields = Object.keys(modelDef.fields).filter((field) => !modelDef.fields[field]?.relation); - let schema = z - .object({ - where: this.makeWhereSchema(model, false).optional(), - orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(), - by: this.orArray(z.enum(nonRelationFields), true), - having: this.makeWhereSchema(model, false, true).optional(), - skip: z.number().int().nonnegative().optional(), - take: z.number().int().optional(), - _count: this.makeCountAggregateInputSchema(model).optional(), - _avg: this.makeSumAvgInputSchema(model).optional(), - _sum: this.makeSumAvgInputSchema(model).optional(), - _min: this.makeMinMaxInputSchema(model).optional(), - _max: this.makeMinMaxInputSchema(model).optional(), - }) - .strict(); - + let schema = z.object({ + where: this.makeWhereSchema(model, false).optional(), + orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(), + by: this.orArray(z.enum(nonRelationFields), true), + having: this.makeWhereSchema(model, false, true).optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), + _count: this.makeCountAggregateInputSchema(model).optional(), + _avg: this.makeSumAvgInputSchema(model).optional(), + _sum: this.makeSumAvgInputSchema(model).optional(), + _min: this.makeMinMaxInputSchema(model).optional(), + _max: this.makeMinMaxInputSchema(model).optional(), + }); schema = schema.refine((value) => { const bys = typeof value.by === 'string' ? [value.by] : value.by; if ( @@ -1249,6 +1258,14 @@ export class InputValidator { // #region Helpers + private makeSkipSchema() { + return z.number().int().nonnegative(); + } + + private makeTakeSchema() { + return z.number().int(); + } + private refineForSelectIncludeMutuallyExclusive(schema: ZodType) { return schema.refine( (value: any) => !(value['select'] && value['include']), @@ -1275,5 +1292,8 @@ export class InputValidator { return NUMERIC_FIELD_TYPES.includes(fieldDef.type) && !fieldDef.array; } + private get providerSupportsCaseSensitivity() { + return this.schema.provider.type === 'postgresql'; + } // #endregion } diff --git a/packages/runtime/test/client-api/filter.test.ts b/packages/runtime/test/client-api/filter.test.ts index e9f49ce1..b7ec82af 100644 --- a/packages/runtime/test/client-api/filter.test.ts +++ b/packages/runtime/test/client-api/filter.test.ts @@ -5,7 +5,7 @@ import { createClientSpecs } from './client-specs'; const PG_DB_NAME = 'client-api-filter-tests'; -describe.each(createClientSpecs(PG_DB_NAME))('Client filter tests for $provider', ({ createClient }) => { +describe.each(createClientSpecs(PG_DB_NAME))('Client filter tests for $provider', ({ createClient, provider }) => { let client: ClientContract; beforeEach(async () => { @@ -76,19 +76,117 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client filter tests for $provider' }), ).toResolveTruthy(); - // case-insensitive - await expect( - client.user.findFirst({ - where: { email: { equals: 'u1@Test.com' } }, - }), - ).toResolveFalsy(); - await expect( - client.user.findFirst({ - where: { - email: { equals: 'u1@Test.com', mode: 'insensitive' }, - }, - }), - ).toResolveTruthy(); + if (provider === 'sqlite') { + // sqlite: equalities are case-sensitive, match is case-insensitive + await expect( + client.user.findFirst({ + where: { email: { equals: 'u1@Test.com' } }, + }), + ).toResolveFalsy(); + + await expect( + client.user.findFirst({ + where: { email: { equals: 'u1@test.com' } }, + }), + ).toResolveTruthy(); + + await expect( + client.user.findFirst({ + where: { email: { contains: 'test' } }, + }), + ).toResolveTruthy(); + await expect( + client.user.findFirst({ + where: { email: { contains: 'Test' } }, + }), + ).toResolveTruthy(); + + await expect( + client.user.findFirst({ + where: { email: { startsWith: 'u1' } }, + }), + ).toResolveTruthy(); + await expect( + client.user.findFirst({ + where: { email: { startsWith: 'U1' } }, + }), + ).toResolveTruthy(); + + await expect( + client.user.findFirst({ + where: { + email: { in: ['u1@Test.com'] }, + }, + }), + ).toResolveFalsy(); + await expect( + client.user.findFirst({ + where: { + email: { in: ['u1@test.com'] }, + }, + }), + ).toResolveTruthy(); + } else if (provider === 'postgresql') { + // postgresql: default is case-sensitive, but can be toggled with "mode" + + await expect( + client.user.findFirst({ + where: { email: { equals: 'u1@Test.com' } }, + }), + ).toResolveFalsy(); + await expect( + client.user.findFirst({ + where: { + email: { equals: 'u1@Test.com', mode: 'insensitive' } as any, + }, + }), + ).toResolveTruthy(); + + await expect( + client.user.findFirst({ + where: { + email: { contains: 'u1@Test.com' }, + }, + }), + ).toResolveFalsy(); + await expect( + client.user.findFirst({ + where: { + email: { contains: 'u1@Test.com', mode: 'insensitive' } as any, + }, + }), + ).toResolveTruthy(); + + await expect( + client.user.findFirst({ + where: { + email: { endsWith: 'Test.com' }, + }, + }), + ).toResolveFalsy(); + await expect( + client.user.findFirst({ + where: { + email: { endsWith: 'Test.com', mode: 'insensitive' } as any, + }, + }), + ).toResolveTruthy(); + + await expect( + client.user.findFirst({ + where: { + email: { in: ['u1@Test.com'] }, + }, + }), + ).toResolveFalsy(); + await expect( + client.user.findFirst({ + where: { + email: { in: ['u1@Test.com'], mode: 'insensitive' } as any, + }, + }), + ).toResolveTruthy(); + } // in await expect( @@ -225,7 +323,9 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client filter tests for $provider' // equals await expect(client.profile.findFirst({ where: { age: 20 } })).resolves.toMatchObject({ id: '1' }); - await expect(client.profile.findFirst({ where: { age: { equals: 20 } } })).resolves.toMatchObject({ id: '1' }); + await expect(client.profile.findFirst({ where: { age: { equals: 20 } } })).resolves.toMatchObject({ + id: '1', + }); await expect(client.profile.findFirst({ where: { age: { equals: 10 } } })).toResolveFalsy(); await expect(client.profile.findFirst({ where: { age: null } })).resolves.toMatchObject({ id: '2' }); await expect(client.profile.findFirst({ where: { age: { equals: null } } })).resolves.toMatchObject({ From f43b2eac572064418acb6c028cffb5e02ad9a3ae Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 6 Aug 2025 16:57:45 +0800 Subject: [PATCH 2/2] update --- packages/runtime/src/client/crud/dialects/base.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index ee316346..d6bb705e 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -560,7 +560,7 @@ export abstract class BaseCrudDialect { : eb(fieldRef, 'like', sql.val(`%${value}`)), ) .otherwise(() => { - throw new Error(`Invalid string filter key: ${key}`); + throw new QueryError(`Invalid string filter key: ${key}`); }); if (condition) {