Skip to content

Commit 080114b

Browse files
authored
fix(orm): properly cast values for array filters for postgres (#662)
* fix(orm): properly cast values for array filters for postgres fixes #651 * add missing Array conversion * refactor
1 parent 274871b commit 080114b

File tree

6 files changed

+128
-5
lines changed

6 files changed

+128
-5
lines changed

packages/orm/src/client/crud/dialects/base-dialect.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
513513
}
514514

515515
case 'has': {
516-
clauses.push(this.buildArrayContains(receiver, this.eb.val(value)));
516+
clauses.push(this.buildArrayContains(receiver, this.eb.val(value), fieldType));
517517
break;
518518
}
519519

@@ -1442,7 +1442,11 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
14421442
/**
14431443
* Builds an expression that checks if an array contains a single value.
14441444
*/
1445-
abstract buildArrayContains(field: Expression<unknown>, value: Expression<unknown>): AliasableExpression<SqlBool>;
1445+
abstract buildArrayContains(
1446+
field: Expression<unknown>,
1447+
value: Expression<unknown>,
1448+
elemType?: string,
1449+
): AliasableExpression<SqlBool>;
14461450

14471451
/**
14481452
* Builds an expression that checks if an array contains all values from another array.

packages/orm/src/client/crud/dialects/mysql.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ export class MySqlCrudDialect<Schema extends SchemaDef> extends LateralJoinDiale
231231
override buildArrayContains(
232232
_field: Expression<unknown>,
233233
_value: Expression<unknown>,
234+
_elemType?: string,
234235
): AliasableExpression<SqlBool> {
235236
throw createNotSupportedError('MySQL does not support native array operations');
236237
}

packages/orm/src/client/crud/dialects/postgresql.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,20 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
284284
return this.eb.cast(arr, sql`${sql.raw(mappedType)}[]`);
285285
}
286286

287-
override buildArrayContains(field: Expression<unknown>, value: Expression<unknown>): AliasableExpression<SqlBool> {
288-
// PostgreSQL @> operator expects array on both sides, so wrap single value in array
289-
return this.eb(field, '@>', sql`ARRAY[${value}]`);
287+
override buildArrayContains(
288+
field: Expression<unknown>,
289+
value: Expression<unknown>,
290+
elemType?: string,
291+
): AliasableExpression<SqlBool> {
292+
// PostgreSQL @> operator expects array on both sides, so wrap single value in a typed array
293+
const arrayExpr = sql`ARRAY[${value}]`;
294+
if (elemType) {
295+
const mappedType = this.getSqlType(elemType);
296+
const typedArray = this.eb.cast(arrayExpr, sql`${sql.raw(mappedType)}[]`);
297+
return this.eb(field, '@>', typedArray);
298+
} else {
299+
return this.eb(field, '@>', arrayExpr);
300+
}
290301
}
291302

292303
override buildArrayHasEvery(field: Expression<unknown>, values: Expression<unknown>): AliasableExpression<SqlBool> {

packages/orm/src/client/crud/dialects/sqlite.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
456456
override buildArrayContains(
457457
_field: Expression<unknown>,
458458
_value: Expression<unknown>,
459+
_elemType?: string,
459460
): AliasableExpression<SqlBool> {
460461
throw createNotSupportedError('SQLite does not support native array operations');
461462
}

packages/orm/src/client/helpers/schema-db-pusher.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ export class SchemaDbPusher<Schema extends SchemaDef> {
440440

441441
private get floatType() {
442442
return match<string, ColumnDataType | RawBuilder<unknown>>(this.schema.provider.type)
443+
.with('postgresql', () => 'double precision')
443444
.with('mysql', () => sql.raw('double'))
444445
.otherwise(() => 'real');
445446
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import { createTestClient } from '@zenstackhq/testtools';
2+
import { describe, expect, it } from 'vitest';
3+
4+
describe('Regression for issue 651', () => {
5+
it('float array queries should work with all operators on PostgreSQL', async () => {
6+
const db = await createTestClient(
7+
`
8+
model User {
9+
id Int @id @default(autoincrement())
10+
email String @unique
11+
floatArray Float[]
12+
}
13+
`,
14+
{ provider: 'postgresql', usePrismaPush: true },
15+
);
16+
17+
// Create test users with different float arrays
18+
const user1 = await db.user.create({
19+
data: {
20+
email: 'user1@example.com',
21+
floatArray: [1.1, 2.2, 3.3],
22+
},
23+
});
24+
25+
const user2 = await db.user.create({
26+
data: {
27+
email: 'user2@example.com',
28+
floatArray: [1.1, 2.2, 3.3, 4.4, 5.5],
29+
},
30+
});
31+
32+
const user3 = await db.user.create({
33+
data: {
34+
email: 'user3@example.com',
35+
floatArray: [],
36+
},
37+
});
38+
39+
// Test 'equals' operator
40+
const equalsResult = await db.user.findMany({
41+
where: {
42+
floatArray: {
43+
equals: [1.1, 2.2, 3.3],
44+
},
45+
},
46+
});
47+
expect(equalsResult).toHaveLength(1);
48+
expect(equalsResult[0].id).toBe(user1.id);
49+
50+
// Test 'has' operator - contains single value
51+
const hasResult = await db.user.findMany({
52+
where: {
53+
floatArray: {
54+
has: 4.4,
55+
},
56+
},
57+
});
58+
expect(hasResult).toHaveLength(1);
59+
expect(hasResult[0].id).toBe(user2.id);
60+
61+
// Test 'hasSome' operator - contains any of the values
62+
const hasSomeResult = await db.user.findMany({
63+
where: {
64+
floatArray: {
65+
hasSome: [3.3, 6.6, 7.7],
66+
},
67+
},
68+
});
69+
expect(hasSomeResult).toHaveLength(2);
70+
expect(hasSomeResult.map((u: any) => u.id).sort()).toEqual([user1.id, user2.id].sort());
71+
72+
// Test 'hasEvery' operator - contains all values
73+
const hasEveryResult = await db.user.findMany({
74+
where: {
75+
floatArray: {
76+
hasEvery: [1.1, 2.2],
77+
},
78+
},
79+
});
80+
expect(hasEveryResult).toHaveLength(2);
81+
expect(hasEveryResult.map((u: any) => u.id).sort()).toEqual([user1.id, user2.id].sort());
82+
83+
// Test 'isEmpty' operator
84+
const isEmptyResult = await db.user.findMany({
85+
where: {
86+
floatArray: {
87+
isEmpty: true,
88+
},
89+
},
90+
});
91+
expect(isEmptyResult).toHaveLength(1);
92+
expect(isEmptyResult[0].id).toBe(user3.id);
93+
94+
// Test 'isEmpty: false'
95+
const notEmptyResult = await db.user.findMany({
96+
where: {
97+
floatArray: {
98+
isEmpty: false,
99+
},
100+
},
101+
});
102+
expect(notEmptyResult).toHaveLength(2);
103+
expect(notEmptyResult.map((u: any) => u.id).sort()).toEqual([user1.id, user2.id].sort());
104+
});
105+
});

0 commit comments

Comments
 (0)