88 FunctionNode ,
99 IdentifierNode ,
1010 InsertQueryNode ,
11+ JoinNode ,
1112 OperationNodeTransformer ,
1213 OperatorNode ,
1314 ParensNode ,
@@ -138,18 +139,15 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
138139 // #region overrides
139140
140141 protected override transformSelectQuery ( node : SelectQueryNode ) {
141- let whereNode = node . where ;
142+ let whereNode = this . transformNode ( node . where ) ;
142143
143- node . from ?. froms . forEach ( ( from ) => {
144- const extractResult = this . extractTableName ( from ) ;
145- if ( extractResult ) {
146- const { model, alias } = extractResult ;
147- const filter = this . buildPolicyFilter ( model , alias , 'read' ) ;
148- whereNode = WhereNode . create (
149- whereNode ?. where ? conjunction ( this . dialect , [ whereNode . where , filter ] ) : filter ,
150- ) ;
151- }
152- } ) ;
144+ // get combined policy filter for all froms, and merge into where clause
145+ const policyFilter = this . createPolicyFilterForFrom ( node . from ) ;
146+ if ( policyFilter ) {
147+ whereNode = WhereNode . create (
148+ whereNode ?. where ? conjunction ( this . dialect , [ whereNode . where , policyFilter ] ) : policyFilter ,
149+ ) ;
150+ }
153151
154152 const baseResult = super . transformSelectQuery ( {
155153 ...node ,
@@ -162,6 +160,27 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
162160 } ;
163161 }
164162
163+ protected override transformJoin ( node : JoinNode ) {
164+ const table = this . extractTableName ( node . table ) ;
165+ if ( ! table ) {
166+ // unable to extract table name, can be a subquery, which will be handled when nested transformation happens
167+ return super . transformJoin ( node ) ;
168+ }
169+
170+ // build a nested query with policy filter applied
171+ const filter = this . buildPolicyFilter ( table . model , undefined , 'read' ) ;
172+ const nestedSelect : SelectQueryNode = {
173+ kind : 'SelectQueryNode' ,
174+ from : FromNode . create ( [ node . table ] ) ,
175+ selections : [ SelectionNode . createSelectAll ( ) ] ,
176+ where : WhereNode . create ( filter ) ,
177+ } ;
178+ return {
179+ ...node ,
180+ table : AliasNode . create ( ParensNode . create ( nestedSelect ) , IdentifierNode . create ( table . alias ?? table . model ) ) ,
181+ } ;
182+ }
183+
165184 protected override transformInsertQuery ( node : InsertQueryNode ) {
166185 // pre-insert check is done in `handle()`
167186
@@ -210,7 +229,16 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
210229 protected override transformUpdateQuery ( node : UpdateQueryNode ) {
211230 const result = super . transformUpdateQuery ( node ) ;
212231 const mutationModel = this . getMutationModel ( node ) ;
213- const filter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
232+ let filter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
233+
234+ if ( node . from ) {
235+ // for update with from (join), we need to merge join tables' policy filters to the "where" clause
236+ const joinFilter = this . createPolicyFilterForFrom ( node . from ) ;
237+ if ( joinFilter ) {
238+ filter = conjunction ( this . dialect , [ filter , joinFilter ] ) ;
239+ }
240+ }
241+
214242 return {
215243 ...result ,
216244 where : WhereNode . create ( result . where ? conjunction ( this . dialect , [ result . where . where , filter ] ) : filter ) ,
@@ -220,7 +248,16 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
220248 protected override transformDeleteQuery ( node : DeleteQueryNode ) {
221249 const result = super . transformDeleteQuery ( node ) ;
222250 const mutationModel = this . getMutationModel ( node ) ;
223- const filter = this . buildPolicyFilter ( mutationModel , undefined , 'delete' ) ;
251+ let filter = this . buildPolicyFilter ( mutationModel , undefined , 'delete' ) ;
252+
253+ if ( node . using ) {
254+ // for delete with using (join), we need to merge join tables' policy filters to the "where" clause
255+ const joinFilter = this . createPolicyFilterForTables ( node . using . tables ) ;
256+ if ( joinFilter ) {
257+ filter = conjunction ( this . dialect , [ filter , joinFilter ] ) ;
258+ }
259+ }
260+
224261 return {
225262 ...result ,
226263 where : WhereNode . create ( result . where ? conjunction ( this . dialect , [ result . where . where , filter ] ) : filter ) ,
@@ -466,11 +503,11 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
466503
467504 const allows = policies
468505 . filter ( ( policy ) => policy . kind === 'allow' )
469- . map ( ( policy ) => this . transformPolicyCondition ( model , alias , operation , policy ) ) ;
506+ . map ( ( policy ) => this . compilePolicyCondition ( model , alias , operation , policy ) ) ;
470507
471508 const denies = policies
472509 . filter ( ( policy ) => policy . kind === 'deny' )
473- . map ( ( policy ) => this . transformPolicyCondition ( model , alias , operation , policy ) ) ;
510+ . map ( ( policy ) => this . compilePolicyCondition ( model , alias , operation , policy ) ) ;
474511
475512 let combinedPolicy : OperationNode ;
476513
@@ -514,7 +551,26 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
514551 }
515552 }
516553
517- private transformPolicyCondition (
554+ private createPolicyFilterForFrom ( node : FromNode | undefined ) {
555+ if ( ! node ) {
556+ return undefined ;
557+ }
558+ return this . createPolicyFilterForTables ( node . froms ) ;
559+ }
560+
561+ private createPolicyFilterForTables ( tables : readonly OperationNode [ ] ) {
562+ return tables . reduce < OperationNode | undefined > ( ( acc , table ) => {
563+ const extractResult = this . extractTableName ( table ) ;
564+ if ( extractResult ) {
565+ const { model, alias } = extractResult ;
566+ const filter = this . buildPolicyFilter ( model , alias , 'read' ) ;
567+ return acc ? conjunction ( this . dialect , [ acc , filter ] ) : filter ;
568+ }
569+ return acc ;
570+ } , undefined ) ;
571+ }
572+
573+ private compilePolicyCondition (
518574 model : GetModels < Schema > ,
519575 alias : string | undefined ,
520576 operation : CRUD ,
0 commit comments