Skip to content

Commit 332b1db

Browse files
authored
fix(delegate): column name mapping issue when delegates are involved (#296)
* fix(delegate): column name mapping issue when delegates are involved * fix build * fix tests
1 parent 0fa87c1 commit 332b1db

26 files changed

+998
-29
lines changed

packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
102102
if ('distinct' in args && (args as any).distinct) {
103103
const distinct = ensureArray((args as any).distinct) as string[];
104104
if (this.supportsDistinctOn) {
105-
result = result.distinctOn(distinct.map((f) => sql.ref(`${modelAlias}.${f}`)));
105+
result = result.distinctOn(distinct.map((f) => this.eb.ref(`${modelAlias}.${f}`)));
106106
} else {
107107
throw new QueryError(`"distinct" is not supported by "${this.schema.provider.type}" provider`);
108108
}
@@ -248,7 +248,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
248248

249249
if (ownedByModel && !fieldDef.originModel) {
250250
// can be short-circuited to FK null check
251-
return this.and(...keyPairs.map(({ fk }) => this.eb(sql.ref(`${modelAlias}.${fk}`), 'is', null)));
251+
return this.and(...keyPairs.map(({ fk }) => this.eb(this.eb.ref(`${modelAlias}.${fk}`), 'is', null)));
252252
} else {
253253
// translate it to `{ is: null }` filter
254254
return this.buildToOneRelationFilter(model, modelAlias, field, fieldDef, { is: null });
@@ -268,7 +268,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
268268

269269
const joinSelect = this.eb
270270
.selectFrom(`${fieldDef.type} as ${joinAlias}`)
271-
.where(() => this.and(...joinPairs.map(([left, right]) => this.eb(sql.ref(left), '=', sql.ref(right)))))
271+
.where(() =>
272+
this.and(...joinPairs.map(([left, right]) => this.eb(this.eb.ref(left), '=', this.eb.ref(right)))),
273+
)
272274
.select(() => this.eb.fn.count(this.eb.lit(1)).as(filterResultField));
273275

274276
const conditions: Expression<SqlBool>[] = [];
@@ -331,7 +333,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
331333
) {
332334
// null check needs to be converted to fk "is null" checks
333335
if (payload === null) {
334-
return this.eb(sql.ref(`${modelAlias}.${field}`), 'is', null);
336+
return this.eb(this.eb.ref(`${modelAlias}.${field}`), 'is', null);
335337
}
336338

337339
const relationModel = fieldDef.type;
@@ -351,15 +353,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
351353
invariant(relationIdFields.length === 1, 'many-to-many relation must have exactly one id field');
352354

353355
return eb(
354-
sql.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`),
356+
this.eb.ref(`${relationFilterSelectAlias}.${relationIdFields[0]}`),
355357
'in',
356358
eb
357359
.selectFrom(m2m.joinTable)
358360
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
359361
.whereRef(
360-
sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
362+
this.eb.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
361363
'=',
362-
sql.ref(`${modelAlias}.${modelIdFields[0]}`),
364+
this.eb.ref(`${modelAlias}.${modelIdFields[0]}`),
363365
),
364366
);
365367
} else {
@@ -370,12 +372,20 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
370372
if (relationKeyPairs.ownedByModel) {
371373
result = this.and(
372374
result,
373-
eb(sql.ref(`${modelAlias}.${fk}`), '=', sql.ref(`${relationFilterSelectAlias}.${pk}`)),
375+
eb(
376+
this.eb.ref(`${modelAlias}.${fk}`),
377+
'=',
378+
this.eb.ref(`${relationFilterSelectAlias}.${pk}`),
379+
),
374380
);
375381
} else {
376382
result = this.and(
377383
result,
378-
eb(sql.ref(`${modelAlias}.${pk}`), '=', sql.ref(`${relationFilterSelectAlias}.${fk}`)),
384+
eb(
385+
this.eb.ref(`${modelAlias}.${pk}`),
386+
'=',
387+
this.eb.ref(`${relationFilterSelectAlias}.${fk}`),
388+
),
379389
);
380390
}
381391
}
@@ -833,7 +843,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
833843
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, subQueryAlias);
834844
subQuery = subQuery.where(() =>
835845
this.and(
836-
...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right))),
846+
...joinPairs.map(([left, right]) =>
847+
eb(this.eb.ref(left), '=', this.eb.ref(right)),
848+
),
837849
),
838850
);
839851
subQuery = subQuery.select(() => eb.fn.count(eb.lit(1)).as('_count'));
@@ -845,7 +857,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
845857
result = result.leftJoin(relationModel, (join) => {
846858
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, relationModel);
847859
return join.on((eb) =>
848-
this.and(...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),
860+
this.and(
861+
...joinPairs.map(([left, right]) => eb(this.eb.ref(left), '=', this.eb.ref(right))),
862+
),
849863
);
850864
});
851865
result = this.buildOrderBy(result, fieldDef.type, relationModel, value, false, negated);
@@ -934,7 +948,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
934948
return query.select(() => this.fieldRef(model, field, modelAlias).as(field));
935949
} else if (!fieldDef.originModel) {
936950
// regular field
937-
return query.select(sql.ref(`${modelAlias}.${field}`).as(field));
951+
return query.select(this.eb.ref(`${modelAlias}.${field}`).as(field));
938952
} else {
939953
return this.buildSelectField(query, fieldDef.originModel, fieldDef.originModel, field);
940954
}

packages/runtime/src/client/crud/dialects/postgresql.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
231231
} else {
232232
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, relationField, relationModelAlias);
233233
query = query.where((eb) =>
234-
this.and(...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),
234+
this.and(...joinPairs.map(([left, right]) => eb(this.eb.ref(left), '=', this.eb.ref(right)))),
235235
);
236236
}
237237
return query;

packages/runtime/src/client/crud/operations/aggregate.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { sql } from 'kysely';
21
import { match } from 'ts-pattern';
32
import type { SchemaDef } from '../../../schema';
43
import { getField } from '../../query-utils';
@@ -80,7 +79,9 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
8079
);
8180
} else {
8281
query = query.select((eb) =>
83-
eb.cast(eb.fn.count(sql.ref(`$sub.${field}`)), 'integer').as(`${key}.${field}`),
82+
eb
83+
.cast(eb.fn.count(eb.ref(`$sub.${field}` as any)), 'integer')
84+
.as(`${key}.${field}`),
8485
);
8586
}
8687
}
@@ -102,7 +103,7 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
102103
.with('_max', () => eb.fn.max)
103104
.with('_min', () => eb.fn.min)
104105
.exhaustive();
105-
return fn(sql.ref(`$sub.${field}`)).as(`${key}.${field}`);
106+
return fn(eb.ref(`$sub.${field}` as any)).as(`${key}.${field}`);
106107
});
107108
}
108109
});

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1540,7 +1540,9 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
15401540
if (!relationFieldDef.array) {
15411541
const query = kysely
15421542
.updateTable(model)
1543-
.where((eb) => eb.and(keyPairs.map(({ fk, pk }) => eb(sql.ref(fk), '=', fromRelation.ids[pk]))))
1543+
.where((eb) =>
1544+
eb.and(keyPairs.map(({ fk, pk }) => eb(eb.ref(fk as any), '=', fromRelation.ids[pk]))),
1545+
)
15441546
.set(keyPairs.reduce((acc, { fk }) => ({ ...acc, [fk]: null }), {} as any))
15451547
.modifyEnd(
15461548
this.makeContextComment({

packages/runtime/src/client/crud/operations/count.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { sql } from 'kysely';
21
import type { SchemaDef } from '../../../schema';
32
import { BaseOperationHandler } from './base';
43

@@ -40,7 +39,7 @@ export class CountOperationHandler<Schema extends SchemaDef> extends BaseOperati
4039
Object.keys(parsedArgs.select!).map((key) =>
4140
key === '_all'
4241
? eb.cast(eb.fn.countAll(), 'integer').as('_all')
43-
: eb.cast(eb.fn.count(sql.ref(`${subQueryName}.${key}`)), 'integer').as(key),
42+
: eb.cast(eb.fn.count(eb.ref(`${subQueryName}.${key}` as any)), 'integer').as(key),
4443
),
4544
);
4645
const result = await this.executeQuery(this.kysely, query, 'count');

packages/runtime/src/client/executor/name-mapper.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ export class QueryNameMapper extends OperationNodeTransformer {
3838
this.modelToTableMap.set(modelName, mappedName);
3939
}
4040

41-
for (const [fieldName, fieldDef] of Object.entries(modelDef.fields)) {
41+
for (const fieldDef of this.getModelFields(modelDef)) {
4242
const mappedName = this.getMappedName(fieldDef);
4343
if (mappedName) {
44-
this.fieldToColumnMap.set(`${modelName}.${fieldName}`, mappedName);
44+
this.fieldToColumnMap.set(`${modelName}.${fieldDef.name}`, mappedName);
4545
}
4646
}
4747
}
@@ -72,11 +72,14 @@ export class QueryNameMapper extends OperationNodeTransformer {
7272
on: this.transformNode(join.on),
7373
}))
7474
: undefined;
75+
const selections = this.processSelectQuerySelections(node);
76+
const baseResult = super.transformSelectQuery(node);
77+
7578
return {
76-
...super.transformSelectQuery(node),
79+
...baseResult,
7780
from: FromNode.create(processedFroms.map((f) => f.node)),
7881
joins,
79-
selections: this.processSelectQuerySelections(node),
82+
selections,
8083
};
8184
});
8285
}
@@ -132,7 +135,8 @@ export class QueryNameMapper extends OperationNodeTransformer {
132135
mappedTableName ? TableNode.create(mappedTableName) : undefined,
133136
);
134137
} else {
135-
return super.transformReference(node);
138+
// no name mapping needed
139+
return node;
136140
}
137141
}
138142

@@ -270,7 +274,7 @@ export class QueryNameMapper extends OperationNodeTransformer {
270274
if (!modelDef) {
271275
continue;
272276
}
273-
if (modelDef.fields[name]) {
277+
if (this.getModelFields(modelDef).some((f) => f.name === name)) {
274278
return scope;
275279
}
276280
}

tests/e2e/orm/client-api/computed-fields.test.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { sql } from '@zenstackhq/runtime/helpers';
21
import { createTestClient } from '@zenstackhq/testtools';
32
import { afterEach, describe, expect, it } from 'vitest';
43

@@ -226,7 +225,7 @@ model Post {
226225
postCount: (eb: any, context: { modelAlias: string }) =>
227226
eb
228227
.selectFrom('Post')
229-
.whereRef('Post.authorId', '=', sql.ref(`${context.modelAlias}.id`))
228+
.whereRef('Post.authorId', '=', eb.ref(`${context.modelAlias}.id`))
230229
.select(() => eb.fn.countAll().as('count')),
231230
},
232231
},

tests/regression/test/v2-migrated/issue-1235.test.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { createPolicyTestClient, testLogger } from '@zenstackhq/testtools';
1+
import { createPolicyTestClient } from '@zenstackhq/testtools';
22
import { describe, expect, it } from 'vitest';
33

44
describe('Regression for issue 1235', () => {
@@ -11,7 +11,6 @@ model Post {
1111
@@allow('all', true)
1212
}
1313
`,
14-
{ log: testLogger },
1514
);
1615

1716
const post = await db.post.create({ data: {} });
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { createPolicyTestClient } from '@zenstackhq/testtools';
2+
import { it } from 'vitest';
3+
4+
it('verifies issue 1506', async () => {
5+
await createPolicyTestClient(
6+
`
7+
model A {
8+
id Int @id @default(autoincrement())
9+
value Int
10+
b B @relation(fields: [bId], references: [id])
11+
bId Int @unique
12+
13+
@@allow('read', true)
14+
}
15+
16+
model B {
17+
id Int @id @default(autoincrement())
18+
value Int
19+
a A?
20+
c C @relation(fields: [cId], references: [id])
21+
cId Int @unique
22+
23+
@@allow('read', value > c.value)
24+
}
25+
26+
model C {
27+
id Int @id @default(autoincrement())
28+
value Int
29+
b B?
30+
31+
@@allow('read', true)
32+
}
33+
`,
34+
);
35+
});
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { createPolicyTestClient } from '@zenstackhq/testtools';
2+
import { expect, it } from 'vitest';
3+
4+
it('verifies issue 1507', async () => {
5+
const db = await createPolicyTestClient(
6+
`
7+
model User {
8+
id Int @id @default(autoincrement())
9+
age Int
10+
}
11+
12+
model Profile {
13+
id Int @id @default(autoincrement())
14+
age Int
15+
16+
@@allow('read', auth().age == age)
17+
}
18+
`,
19+
);
20+
21+
await db.$unuseAll().profile.create({ data: { age: 18 } });
22+
await db.$unuseAll().profile.create({ data: { age: 20 } });
23+
await expect(db.$setAuth({ id: 1, age: 18 }).profile.findMany()).resolves.toHaveLength(1);
24+
await expect(db.$setAuth({ id: 1, age: 18 }).profile.count()).resolves.toBe(1);
25+
});

0 commit comments

Comments
 (0)