Skip to content

Commit cd143aa

Browse files
committed
more robust alias handling
1 parent 23ce493 commit cd143aa

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

packages/runtime/src/client/query-utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ export function requireField(schema: SchemaDef, modelOrType: string, field: stri
5656

5757
export function getIdFields<Schema extends SchemaDef>(schema: SchemaDef, model: GetModels<Schema>) {
5858
const modelDef = requireModel(schema, model);
59-
return modelDef?.idFields as GetModels<Schema>[];
59+
return modelDef?.idFields;
6060
}
6161

6262
export function requireIdFields(schema: SchemaDef, model: string) {

packages/runtime/src/plugins/policy/policy-handler.ts

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import type { ClientContract } from '../../client';
3232
import type { CRUD } from '../../client/contract';
3333
import { getCrudDialect } from '../../client/crud/dialects';
3434
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
35-
import { InternalError } from '../../client/errors';
35+
import { InternalError, QueryError } from '../../client/errors';
3636
import type { ProceedKyselyQueryFunction } from '../../client/plugin';
3737
import { getIdFields, requireField, requireModel } from '../../client/query-utils';
3838
import { ExpressionUtils, type BuiltinType, type Expression, type GetModels, type SchemaDef } from '../../schema';
@@ -73,7 +73,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
7373
}
7474

7575
let mutationRequiresTransaction = false;
76-
const mutationModel = this.getMutationModel(node);
76+
const { mutationModel } = this.getMutationModel(node);
7777

7878
if (InsertQueryNode.is(node)) {
7979
// reject create if unconditional deny
@@ -168,7 +168,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
168168
}
169169

170170
// build a nested query with policy filter applied
171-
const filter = this.buildPolicyFilter(table.model, undefined, 'read');
171+
const filter = this.buildPolicyFilter(table.model, table.alias, 'read');
172172
const nestedSelect: SelectQueryNode = {
173173
kind: 'SelectQueryNode',
174174
from: FromNode.create([node.table]),
@@ -188,8 +188,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
188188

189189
if (onConflict?.updates) {
190190
// for "on conflict do update", we need to apply policy filter to the "where" clause
191-
const mutationModel = this.getMutationModel(node);
192-
const filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
191+
const { mutationModel, alias } = this.getMutationModel(node);
192+
const filter = this.buildPolicyFilter(mutationModel, alias, 'update');
193193
if (onConflict.updateWhere) {
194194
onConflict = {
195195
...onConflict,
@@ -216,7 +216,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
216216
return result;
217217
} else {
218218
// only return ID fields, that's enough for reading back the inserted row
219-
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
219+
const { mutationModel } = this.getMutationModel(node);
220+
const idFields = getIdFields(this.client.$schema, mutationModel);
220221
return {
221222
...result,
222223
returning: ReturningNode.create(
@@ -228,8 +229,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
228229

229230
protected override transformUpdateQuery(node: UpdateQueryNode) {
230231
const result = super.transformUpdateQuery(node);
231-
const mutationModel = this.getMutationModel(node);
232-
let filter = this.buildPolicyFilter(mutationModel, undefined, 'update');
232+
const { mutationModel, alias } = this.getMutationModel(node);
233+
let filter = this.buildPolicyFilter(mutationModel, alias, 'update');
233234

234235
if (node.from) {
235236
// for update with from (join), we need to merge join tables' policy filters to the "where" clause
@@ -247,8 +248,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
247248

248249
protected override transformDeleteQuery(node: DeleteQueryNode) {
249250
const result = super.transformDeleteQuery(node);
250-
const mutationModel = this.getMutationModel(node);
251-
let filter = this.buildPolicyFilter(mutationModel, undefined, 'delete');
251+
const { mutationModel, alias } = this.getMutationModel(node);
252+
let filter = this.buildPolicyFilter(mutationModel, alias, 'delete');
252253

253254
if (node.using) {
254255
// for delete with using (join), we need to merge join tables' policy filters to the "where" clause
@@ -272,19 +273,20 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
272273
if (!node.returning) {
273274
return true;
274275
}
275-
const idFields = getIdFields(this.client.$schema, this.getMutationModel(node));
276+
const { mutationModel } = this.getMutationModel(node);
277+
const idFields = getIdFields(this.client.$schema, mutationModel);
276278
const collector = new ColumnCollector();
277279
const selectedColumns = collector.collect(node.returning);
278280
return selectedColumns.every((c) => idFields.includes(c));
279281
}
280282

281283
private async enforcePreCreatePolicy(node: InsertQueryNode, proceed: ProceedKyselyQueryFunction) {
282-
const model = this.getMutationModel(node);
284+
const { mutationModel } = this.getMutationModel(node);
283285
const fields = node.columns?.map((c) => c.column.name) ?? [];
284-
const valueRows = node.values ? this.unwrapCreateValueRows(node.values, model, fields) : [[]];
286+
const valueRows = node.values ? this.unwrapCreateValueRows(node.values, mutationModel, fields) : [[]];
285287
for (const values of valueRows) {
286288
await this.enforcePreCreatePolicyForOne(
287-
model,
289+
mutationModel,
288290
fields,
289291
values.map((v) => v.node),
290292
proceed,
@@ -431,17 +433,13 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
431433
}
432434

433435
// do a select (with policy) in place of returning
434-
const table = this.getMutationModel(node);
435-
if (!table) {
436-
throw new InternalError(`Unable to get table name for query node: ${node}`);
437-
}
438-
439-
const idConditions = this.buildIdConditions(table, result.rows);
440-
const policyFilter = this.buildPolicyFilter(table, undefined, 'read');
436+
const { mutationModel } = this.getMutationModel(node);
437+
const idConditions = this.buildIdConditions(mutationModel, result.rows);
438+
const policyFilter = this.buildPolicyFilter(mutationModel, undefined, 'read');
441439

442440
const select: SelectQueryNode = {
443441
kind: 'SelectQueryNode',
444-
from: FromNode.create([TableNode.create(table)]),
442+
from: FromNode.create([TableNode.create(mutationModel)]),
445443
where: WhereNode.create(conjunction(this.dialect, [idConditions, policyFilter])),
446444
selections: node.returning.selections,
447445
};
@@ -470,13 +468,23 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
470468

471469
private getMutationModel(node: InsertQueryNode | UpdateQueryNode | DeleteQueryNode) {
472470
const r = match(node)
473-
.when(InsertQueryNode.is, (node) => getTableName(node.into) as GetModels<Schema>)
474-
.when(UpdateQueryNode.is, (node) => getTableName(node.table) as GetModels<Schema>)
471+
.when(InsertQueryNode.is, (node) => ({
472+
mutationModel: getTableName(node.into) as GetModels<Schema>,
473+
alias: undefined,
474+
}))
475+
.when(UpdateQueryNode.is, (node) => {
476+
if (!node.table) {
477+
throw new QueryError('Update query must have a table');
478+
}
479+
const r = this.extractTableName(node.table);
480+
return r ? { mutationModel: r.model, alias: r.alias } : undefined;
481+
})
475482
.when(DeleteQueryNode.is, (node) => {
476483
if (node.from.froms.length !== 1) {
477-
throw new InternalError('Only one from table is supported for delete');
484+
throw new QueryError('Only one from table is supported for delete');
478485
}
479-
return getTableName(node.from.froms[0]) as GetModels<Schema>;
486+
const r = this.extractTableName(node.from.froms[0]!);
487+
return r ? { mutationModel: r.model, alias: r.alias } : undefined;
480488
})
481489
.exhaustive();
482490
if (!r) {
@@ -531,18 +539,18 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
531539
return combinedPolicy;
532540
}
533541

534-
private extractTableName(from: OperationNode): { model: GetModels<Schema>; alias?: string } | undefined {
535-
if (TableNode.is(from)) {
536-
return { model: from.table.identifier.name as GetModels<Schema> };
542+
private extractTableName(node: OperationNode): { model: GetModels<Schema>; alias?: string } | undefined {
543+
if (TableNode.is(node)) {
544+
return { model: node.table.identifier.name as GetModels<Schema> };
537545
}
538-
if (AliasNode.is(from)) {
539-
const inner = this.extractTableName(from.node);
546+
if (AliasNode.is(node)) {
547+
const inner = this.extractTableName(node.node);
540548
if (!inner) {
541549
return undefined;
542550
}
543551
return {
544552
model: inner.model,
545-
alias: IdentifierNode.is(from.alias) ? from.alias.name : undefined,
553+
alias: IdentifierNode.is(node.alias) ? node.alias.name : undefined,
546554
};
547555
} else {
548556
// this can happen for subqueries, which will be handled when nested

0 commit comments

Comments
 (0)