Skip to content

Commit 3a49973

Browse files
authored
fix: tighten up query input validation, fixed case-sensitivity compatibility with Prisma (#147)
* fix: tighten up query input validation, fixed case-sensitivity compatibility with Prisma * update
1 parent 6a62a2c commit 3a49973

File tree

4 files changed

+343
-201
lines changed

4 files changed

+343
-201
lines changed

packages/runtime/src/client/crud-types.ts

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ export type WhereInput<
223223
: FieldIsArray<Schema, Model, Key> extends true
224224
? ArrayFilter<GetModelFieldType<Schema, Model, Key>>
225225
: // primitive
226-
PrimitiveFilter<GetModelFieldType<Schema, Model, Key>, ModelFieldIsOptional<Schema, Model, Key>>;
226+
PrimitiveFilter<Schema, GetModelFieldType<Schema, Model, Key>, ModelFieldIsOptional<Schema, Model, Key>>;
227227
} & {
228228
$expr?: (eb: ExpressionBuilder<ToKyselySchema<Schema>, Model>) => OperandExpression<SqlBool>;
229229
} & {
@@ -249,47 +249,52 @@ type ArrayFilter<T extends string> = {
249249
isEmpty?: boolean;
250250
};
251251

252-
type PrimitiveFilter<T extends string, Nullable extends boolean> = T extends 'String'
253-
? StringFilter<Nullable>
252+
type PrimitiveFilter<Schema extends SchemaDef, T extends string, Nullable extends boolean> = T extends 'String'
253+
? StringFilter<Schema, Nullable>
254254
: T extends 'Int' | 'Float' | 'Decimal' | 'BigInt'
255-
? NumberFilter<T, Nullable>
255+
? NumberFilter<Schema, T, Nullable>
256256
: T extends 'Boolean'
257257
? BooleanFilter<Nullable>
258258
: T extends 'DateTime'
259-
? DateTimeFilter<Nullable>
259+
? DateTimeFilter<Schema, Nullable>
260260
: T extends 'Bytes'
261261
? BytesFilter<Nullable>
262262
: T extends 'Json'
263263
? 'Not implemented yet' // TODO: Json filter
264264
: never;
265265

266-
type CommonPrimitiveFilter<DataType, T extends BuiltinType, Nullable extends boolean> = {
266+
type CommonPrimitiveFilter<Schema extends SchemaDef, DataType, T extends BuiltinType, Nullable extends boolean> = {
267267
equals?: NullableIf<DataType, Nullable>;
268268
in?: DataType[];
269269
notIn?: DataType[];
270270
lt?: DataType;
271271
lte?: DataType;
272272
gt?: DataType;
273273
gte?: DataType;
274-
not?: PrimitiveFilter<T, Nullable>;
274+
not?: PrimitiveFilter<Schema, T, Nullable>;
275275
};
276276

277-
export type StringFilter<Nullable extends boolean> =
277+
export type StringFilter<Schema extends SchemaDef, Nullable extends boolean> =
278278
| NullableIf<string, Nullable>
279-
| (CommonPrimitiveFilter<string, 'String', Nullable> & {
279+
| (CommonPrimitiveFilter<Schema, string, 'String', Nullable> & {
280280
contains?: string;
281281
startsWith?: string;
282282
endsWith?: string;
283-
mode?: 'default' | 'insensitive';
284-
});
283+
} & (ProviderSupportsCaseSensitivity<Schema> extends true
284+
? {
285+
mode?: 'default' | 'insensitive';
286+
}
287+
: {}));
285288

286-
export type NumberFilter<T extends 'Int' | 'Float' | 'Decimal' | 'BigInt', Nullable extends boolean> =
287-
| NullableIf<number | bigint, Nullable>
288-
| CommonPrimitiveFilter<number, T, Nullable>;
289+
export type NumberFilter<
290+
Schema extends SchemaDef,
291+
T extends 'Int' | 'Float' | 'Decimal' | 'BigInt',
292+
Nullable extends boolean,
293+
> = NullableIf<number | bigint, Nullable> | CommonPrimitiveFilter<Schema, number, T, Nullable>;
289294

290-
export type DateTimeFilter<Nullable extends boolean> =
295+
export type DateTimeFilter<Schema extends SchemaDef, Nullable extends boolean> =
291296
| NullableIf<Date | string, Nullable>
292-
| CommonPrimitiveFilter<Date | string, 'DateTime', Nullable>;
297+
| CommonPrimitiveFilter<Schema, Date | string, 'DateTime', Nullable>;
293298

294299
export type BytesFilter<Nullable extends boolean> =
295300
| NullableIf<Uint8Array | Buffer, Nullable>
@@ -1192,4 +1197,6 @@ type HasToManyRelations<Schema extends SchemaDef, Model extends GetModels<Schema
11921197
? false
11931198
: true;
11941199

1200+
type ProviderSupportsCaseSensitivity<Schema extends SchemaDef> = Schema['provider'] extends 'postgresql' ? true : false;
1201+
11951202
// #endregion

packages/runtime/src/client/crud/dialects/base.ts

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
457457
recurse: (value: unknown) => Expression<SqlBool>,
458458
throwIfInvalid = false,
459459
onlyForKeys: string[] | undefined = undefined,
460+
excludeKeys: string[] = [],
460461
) {
461462
if (payload === null || !isPlainObject(payload)) {
462463
return {
@@ -472,6 +473,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
472473
if (onlyForKeys && !onlyForKeys.includes(op)) {
473474
continue;
474475
}
476+
if (excludeKeys.includes(op)) {
477+
continue;
478+
}
475479
const rhs = Array.isArray(value) ? value.map(getRhs) : getRhs(value);
476480
const condition = match(op)
477481
.with('equals', () => (rhs === null ? eb(lhs, 'is', null) : eb(lhs, '=', rhs)))
@@ -513,20 +517,23 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
513517
return { conditions, consumedKeys };
514518
}
515519

516-
private buildStringFilter(eb: ExpressionBuilder<any, any>, fieldRef: Expression<any>, payload: StringFilter<true>) {
517-
let insensitive = false;
518-
if (payload && typeof payload === 'object' && 'mode' in payload && payload.mode === 'insensitive') {
519-
insensitive = true;
520-
fieldRef = eb.fn('lower', [fieldRef]);
520+
private buildStringFilter(
521+
eb: ExpressionBuilder<any, any>,
522+
fieldRef: Expression<any>,
523+
payload: StringFilter<Schema, true>,
524+
) {
525+
let mode: 'default' | 'insensitive' | undefined;
526+
if (payload && typeof payload === 'object' && 'mode' in payload) {
527+
mode = payload.mode;
521528
}
522529

523530
const { conditions, consumedKeys } = this.buildStandardFilter(
524531
eb,
525532
'String',
526533
payload,
527-
fieldRef,
528-
(value) => this.prepStringCasing(eb, value, insensitive),
529-
(value) => this.buildStringFilter(eb, fieldRef, value as StringFilter<true>),
534+
mode === 'insensitive' ? eb.fn('lower', [fieldRef]) : fieldRef,
535+
(value) => this.prepStringCasing(eb, value, mode),
536+
(value) => this.buildStringFilter(eb, fieldRef, value as StringFilter<Schema, true>),
530537
);
531538

532539
if (payload && typeof payload === 'object') {
@@ -538,22 +545,22 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
538545

539546
const condition = match(key)
540547
.with('contains', () =>
541-
insensitive
542-
? eb(fieldRef, 'ilike', sql.lit(`%${value}%`))
543-
: eb(fieldRef, 'like', sql.lit(`%${value}%`)),
548+
mode === 'insensitive'
549+
? eb(fieldRef, 'ilike', sql.val(`%${value}%`))
550+
: eb(fieldRef, 'like', sql.val(`%${value}%`)),
544551
)
545552
.with('startsWith', () =>
546-
insensitive
547-
? eb(fieldRef, 'ilike', sql.lit(`${value}%`))
548-
: eb(fieldRef, 'like', sql.lit(`${value}%`)),
553+
mode === 'insensitive'
554+
? eb(fieldRef, 'ilike', sql.val(`${value}%`))
555+
: eb(fieldRef, 'like', sql.val(`${value}%`)),
549556
)
550557
.with('endsWith', () =>
551-
insensitive
552-
? eb(fieldRef, 'ilike', sql.lit(`%${value}`))
553-
: eb(fieldRef, 'like', sql.lit(`%${value}`)),
558+
mode === 'insensitive'
559+
? eb(fieldRef, 'ilike', sql.val(`%${value}`))
560+
: eb(fieldRef, 'like', sql.val(`%${value}`)),
554561
)
555562
.otherwise(() => {
556-
throw new Error(`Invalid string filter key: ${key}`);
563+
throw new QueryError(`Invalid string filter key: ${key}`);
557564
});
558565

559566
if (condition) {
@@ -565,13 +572,21 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
565572
return this.and(eb, ...conditions);
566573
}
567574

568-
private prepStringCasing(eb: ExpressionBuilder<any, any>, value: unknown, toLower: boolean = true): any {
575+
private prepStringCasing(
576+
eb: ExpressionBuilder<any, any>,
577+
value: unknown,
578+
mode: 'default' | 'insensitive' | undefined,
579+
): any {
580+
if (!mode || mode === 'default') {
581+
return value === null ? value : sql.val(value);
582+
}
583+
569584
if (typeof value === 'string') {
570-
return toLower ? eb.fn('lower', [sql.lit(value)]) : sql.lit(value);
585+
return eb.fn('lower', [sql.val(value)]);
571586
} else if (Array.isArray(value)) {
572-
return value.map((v) => this.prepStringCasing(eb, v, toLower));
587+
return value.map((v) => this.prepStringCasing(eb, v, mode));
573588
} else {
574-
return value === null ? null : sql.lit(value);
589+
return value === null ? null : sql.val(value);
575590
}
576591
}
577592

@@ -613,15 +628,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
613628
private buildDateTimeFilter(
614629
eb: ExpressionBuilder<any, any>,
615630
fieldRef: Expression<any>,
616-
payload: DateTimeFilter<true>,
631+
payload: DateTimeFilter<Schema, true>,
617632
) {
618633
const { conditions } = this.buildStandardFilter(
619634
eb,
620635
'DateTime',
621636
payload,
622637
fieldRef,
623638
(value) => this.transformPrimitive(value, 'DateTime', false),
624-
(value) => this.buildDateTimeFilter(eb, fieldRef, value as DateTimeFilter<true>),
639+
(value) => this.buildDateTimeFilter(eb, fieldRef, value as DateTimeFilter<Schema, true>),
625640
true,
626641
);
627642
return this.and(eb, ...conditions);

0 commit comments

Comments
 (0)