@@ -32,7 +32,7 @@ import type { ClientContract } from '../../client';
3232import type { CRUD } from '../../client/contract' ;
3333import { getCrudDialect } from '../../client/crud/dialects' ;
3434import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect' ;
35- import { InternalError } from '../../client/errors' ;
35+ import { InternalError , QueryError } from '../../client/errors' ;
3636import type { ProceedKyselyQueryFunction } from '../../client/plugin' ;
3737import { getIdFields , requireField , requireModel } from '../../client/query-utils' ;
3838import { ExpressionUtils , type BuiltinType , type Expression , type GetModels , type SchemaDef } from '../../schema' ;
@@ -73,7 +73,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
7373 }
7474
7575 let mutationRequiresTransaction = false ;
76- const mutationModel = this . getMutationModel ( node ) ;
76+ const { mutationModel } = this . getMutationModel ( node ) ;
7777
7878 if ( InsertQueryNode . is ( node ) ) {
7979 // reject create if unconditional deny
@@ -168,7 +168,7 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
168168 }
169169
170170 // build a nested query with policy filter applied
171- const filter = this . buildPolicyFilter ( table . model , undefined , 'read' ) ;
171+ const filter = this . buildPolicyFilter ( table . model , table . alias , 'read' ) ;
172172 const nestedSelect : SelectQueryNode = {
173173 kind : 'SelectQueryNode' ,
174174 from : FromNode . create ( [ node . table ] ) ,
@@ -188,8 +188,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
188188
189189 if ( onConflict ?. updates ) {
190190 // for "on conflict do update", we need to apply policy filter to the "where" clause
191- const mutationModel = this . getMutationModel ( node ) ;
192- const filter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
191+ const { mutationModel, alias } = this . getMutationModel ( node ) ;
192+ const filter = this . buildPolicyFilter ( mutationModel , alias , 'update' ) ;
193193 if ( onConflict . updateWhere ) {
194194 onConflict = {
195195 ...onConflict ,
@@ -216,7 +216,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
216216 return result ;
217217 } else {
218218 // only return ID fields, that's enough for reading back the inserted row
219- const idFields = getIdFields ( this . client . $schema , this . getMutationModel ( node ) ) ;
219+ const { mutationModel } = this . getMutationModel ( node ) ;
220+ const idFields = getIdFields ( this . client . $schema , mutationModel ) ;
220221 return {
221222 ...result ,
222223 returning : ReturningNode . create (
@@ -228,8 +229,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
228229
229230 protected override transformUpdateQuery ( node : UpdateQueryNode ) {
230231 const result = super . transformUpdateQuery ( node ) ;
231- const mutationModel = this . getMutationModel ( node ) ;
232- let filter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
232+ const { mutationModel, alias } = this . getMutationModel ( node ) ;
233+ let filter = this . buildPolicyFilter ( mutationModel , alias , 'update' ) ;
233234
234235 if ( node . from ) {
235236 // for update with from (join), we need to merge join tables' policy filters to the "where" clause
@@ -247,8 +248,8 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
247248
248249 protected override transformDeleteQuery ( node : DeleteQueryNode ) {
249250 const result = super . transformDeleteQuery ( node ) ;
250- const mutationModel = this . getMutationModel ( node ) ;
251- let filter = this . buildPolicyFilter ( mutationModel , undefined , 'delete' ) ;
251+ const { mutationModel, alias } = this . getMutationModel ( node ) ;
252+ let filter = this . buildPolicyFilter ( mutationModel , alias , 'delete' ) ;
252253
253254 if ( node . using ) {
254255 // for delete with using (join), we need to merge join tables' policy filters to the "where" clause
@@ -272,19 +273,20 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
272273 if ( ! node . returning ) {
273274 return true ;
274275 }
275- const idFields = getIdFields ( this . client . $schema , this . getMutationModel ( node ) ) ;
276+ const { mutationModel } = this . getMutationModel ( node ) ;
277+ const idFields = getIdFields ( this . client . $schema , mutationModel ) ;
276278 const collector = new ColumnCollector ( ) ;
277279 const selectedColumns = collector . collect ( node . returning ) ;
278280 return selectedColumns . every ( ( c ) => idFields . includes ( c ) ) ;
279281 }
280282
281283 private async enforcePreCreatePolicy ( node : InsertQueryNode , proceed : ProceedKyselyQueryFunction ) {
282- const model = this . getMutationModel ( node ) ;
284+ const { mutationModel } = this . getMutationModel ( node ) ;
283285 const fields = node . columns ?. map ( ( c ) => c . column . name ) ?? [ ] ;
284- const valueRows = node . values ? this . unwrapCreateValueRows ( node . values , model , fields ) : [ [ ] ] ;
286+ const valueRows = node . values ? this . unwrapCreateValueRows ( node . values , mutationModel , fields ) : [ [ ] ] ;
285287 for ( const values of valueRows ) {
286288 await this . enforcePreCreatePolicyForOne (
287- model ,
289+ mutationModel ,
288290 fields ,
289291 values . map ( ( v ) => v . node ) ,
290292 proceed ,
@@ -431,17 +433,13 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
431433 }
432434
433435 // do a select (with policy) in place of returning
434- const table = this . getMutationModel ( node ) ;
435- if ( ! table ) {
436- throw new InternalError ( `Unable to get table name for query node: ${ node } ` ) ;
437- }
438-
439- const idConditions = this . buildIdConditions ( table , result . rows ) ;
440- const policyFilter = this . buildPolicyFilter ( table , undefined , 'read' ) ;
436+ const { mutationModel } = this . getMutationModel ( node ) ;
437+ const idConditions = this . buildIdConditions ( mutationModel , result . rows ) ;
438+ const policyFilter = this . buildPolicyFilter ( mutationModel , undefined , 'read' ) ;
441439
442440 const select : SelectQueryNode = {
443441 kind : 'SelectQueryNode' ,
444- from : FromNode . create ( [ TableNode . create ( table ) ] ) ,
442+ from : FromNode . create ( [ TableNode . create ( mutationModel ) ] ) ,
445443 where : WhereNode . create ( conjunction ( this . dialect , [ idConditions , policyFilter ] ) ) ,
446444 selections : node . returning . selections ,
447445 } ;
@@ -470,13 +468,23 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
470468
471469 private getMutationModel ( node : InsertQueryNode | UpdateQueryNode | DeleteQueryNode ) {
472470 const r = match ( node )
473- . when ( InsertQueryNode . is , ( node ) => getTableName ( node . into ) as GetModels < Schema > )
474- . when ( UpdateQueryNode . is , ( node ) => getTableName ( node . table ) as GetModels < Schema > )
471+ . when ( InsertQueryNode . is , ( node ) => ( {
472+ mutationModel : getTableName ( node . into ) as GetModels < Schema > ,
473+ alias : undefined ,
474+ } ) )
475+ . when ( UpdateQueryNode . is , ( node ) => {
476+ if ( ! node . table ) {
477+ throw new QueryError ( 'Update query must have a table' ) ;
478+ }
479+ const r = this . extractTableName ( node . table ) ;
480+ return r ? { mutationModel : r . model , alias : r . alias } : undefined ;
481+ } )
475482 . when ( DeleteQueryNode . is , ( node ) => {
476483 if ( node . from . froms . length !== 1 ) {
477- throw new InternalError ( 'Only one from table is supported for delete' ) ;
484+ throw new QueryError ( 'Only one from table is supported for delete' ) ;
478485 }
479- return getTableName ( node . from . froms [ 0 ] ) as GetModels < Schema > ;
486+ const r = this . extractTableName ( node . from . froms [ 0 ] ! ) ;
487+ return r ? { mutationModel : r . model , alias : r . alias } : undefined ;
480488 } )
481489 . exhaustive ( ) ;
482490 if ( ! r ) {
@@ -531,18 +539,18 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
531539 return combinedPolicy ;
532540 }
533541
534- private extractTableName ( from : OperationNode ) : { model : GetModels < Schema > ; alias ?: string } | undefined {
535- if ( TableNode . is ( from ) ) {
536- return { model : from . table . identifier . name as GetModels < Schema > } ;
542+ private extractTableName ( node : OperationNode ) : { model : GetModels < Schema > ; alias ?: string } | undefined {
543+ if ( TableNode . is ( node ) ) {
544+ return { model : node . table . identifier . name as GetModels < Schema > } ;
537545 }
538- if ( AliasNode . is ( from ) ) {
539- const inner = this . extractTableName ( from . node ) ;
546+ if ( AliasNode . is ( node ) ) {
547+ const inner = this . extractTableName ( node . node ) ;
540548 if ( ! inner ) {
541549 return undefined ;
542550 }
543551 return {
544552 model : inner . model ,
545- alias : IdentifierNode . is ( from . alias ) ? from . alias . name : undefined ,
553+ alias : IdentifierNode . is ( node . alias ) ? node . alias . name : undefined ,
546554 } ;
547555 } else {
548556 // this can happen for subqueries, which will be handled when nested
0 commit comments