@@ -134,6 +134,102 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
134134 // return result;
135135 }
136136
137+ // #region overrides
138+
139+ protected override transformSelectQuery ( node : SelectQueryNode ) {
140+ let whereNode = node . where ;
141+
142+ node . from ?. froms . forEach ( ( from ) => {
143+ const extractResult = this . extractTableName ( from ) ;
144+ if ( extractResult ) {
145+ const { model, alias } = extractResult ;
146+ const filter = this . buildPolicyFilter ( model , alias , 'read' ) ;
147+ whereNode = WhereNode . create (
148+ whereNode ?. where ? conjunction ( this . dialect , [ whereNode . where , filter ] ) : filter ,
149+ ) ;
150+ }
151+ } ) ;
152+
153+ const baseResult = super . transformSelectQuery ( {
154+ ...node ,
155+ where : undefined ,
156+ } ) ;
157+
158+ return {
159+ ...baseResult ,
160+ where : whereNode ,
161+ } ;
162+ }
163+
164+ protected override transformInsertQuery ( node : InsertQueryNode ) {
165+ // pre-insert check is done in `handle()`
166+
167+ let onConflict = node . onConflict ;
168+
169+ if ( onConflict ?. updates ) {
170+ // for "on conflict do update", we need to apply policy filter to the "where" clause
171+ const mutationModel = this . getMutationModel ( node ) ;
172+ const filter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
173+ if ( onConflict . updateWhere ) {
174+ onConflict = {
175+ ...onConflict ,
176+ updateWhere : WhereNode . create ( conjunction ( this . dialect , [ onConflict . updateWhere . where , filter ] ) ) ,
177+ } ;
178+ } else {
179+ onConflict = {
180+ ...onConflict ,
181+ updateWhere : WhereNode . create ( filter ) ,
182+ } ;
183+ }
184+ }
185+
186+ // merge updated onConflict
187+ const processedNode = onConflict ? { ...node , onConflict } : node ;
188+
189+ const result = super . transformInsertQuery ( processedNode ) ;
190+
191+ if ( ! node . returning ) {
192+ return result ;
193+ }
194+
195+ if ( this . onlyReturningId ( node ) ) {
196+ return result ;
197+ } else {
198+ // only return ID fields, that's enough for reading back the inserted row
199+ const idFields = getIdFields ( this . client . $schema , this . getMutationModel ( node ) ) ;
200+ return {
201+ ...result ,
202+ returning : ReturningNode . create (
203+ idFields . map ( ( field ) => SelectionNode . create ( ColumnNode . create ( field ) ) ) ,
204+ ) ,
205+ } ;
206+ }
207+ }
208+
209+ protected override transformUpdateQuery ( node : UpdateQueryNode ) {
210+ const result = super . transformUpdateQuery ( node ) ;
211+ const mutationModel = this . getMutationModel ( node ) ;
212+ const filter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
213+ return {
214+ ...result ,
215+ where : WhereNode . create ( result . where ? conjunction ( this . dialect , [ result . where . where , filter ] ) : filter ) ,
216+ } ;
217+ }
218+
219+ protected override transformDeleteQuery ( node : DeleteQueryNode ) {
220+ const result = super . transformDeleteQuery ( node ) ;
221+ const mutationModel = this . getMutationModel ( node ) ;
222+ const filter = this . buildPolicyFilter ( mutationModel , undefined , 'delete' ) ;
223+ return {
224+ ...result ,
225+ where : WhereNode . create ( result . where ? conjunction ( this . dialect , [ result . where . where , filter ] ) : filter ) ,
226+ } ;
227+ }
228+
229+ // #endregion
230+
231+ // #region helpers
232+
137233 private onlyReturningId ( node : MutationQueryNode ) {
138234 if ( ! node . returning ) {
139235 return true ;
@@ -397,70 +493,6 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
397493 return combinedPolicy ;
398494 }
399495
400- protected override transformSelectQuery ( node : SelectQueryNode ) {
401- let whereNode = node . where ;
402-
403- node . from ?. froms . forEach ( ( from ) => {
404- const extractResult = this . extractTableName ( from ) ;
405- if ( extractResult ) {
406- const { model, alias } = extractResult ;
407- const filter = this . buildPolicyFilter ( model , alias , 'read' ) ;
408- whereNode = WhereNode . create (
409- whereNode ?. where ? conjunction ( this . dialect , [ whereNode . where , filter ] ) : filter ,
410- ) ;
411- }
412- } ) ;
413-
414- const baseResult = super . transformSelectQuery ( {
415- ...node ,
416- where : undefined ,
417- } ) ;
418-
419- return {
420- ...baseResult ,
421- where : whereNode ,
422- } ;
423- }
424-
425- protected override transformInsertQuery ( node : InsertQueryNode ) {
426- const result = super . transformInsertQuery ( node ) ;
427- if ( ! node . returning ) {
428- return result ;
429- }
430- if ( this . onlyReturningId ( node ) ) {
431- return result ;
432- } else {
433- // only return ID fields, that's enough for reading back the inserted row
434- const idFields = getIdFields ( this . client . $schema , this . getMutationModel ( node ) ) ;
435- return {
436- ...result ,
437- returning : ReturningNode . create (
438- idFields . map ( ( field ) => SelectionNode . create ( ColumnNode . create ( field ) ) ) ,
439- ) ,
440- } ;
441- }
442- }
443-
444- protected override transformUpdateQuery ( node : UpdateQueryNode ) {
445- const result = super . transformUpdateQuery ( node ) ;
446- const mutationModel = this . getMutationModel ( node ) ;
447- const filter = this . buildPolicyFilter ( mutationModel , undefined , 'update' ) ;
448- return {
449- ...result ,
450- where : WhereNode . create ( result . where ? conjunction ( this . dialect , [ result . where . where , filter ] ) : filter ) ,
451- } ;
452- }
453-
454- protected override transformDeleteQuery ( node : DeleteQueryNode ) {
455- const result = super . transformDeleteQuery ( node ) ;
456- const mutationModel = this . getMutationModel ( node ) ;
457- const filter = this . buildPolicyFilter ( mutationModel , undefined , 'delete' ) ;
458- return {
459- ...result ,
460- where : WhereNode . create ( result . where ? conjunction ( this . dialect , [ result . where . where , filter ] ) : filter ) ,
461- } ;
462- }
463-
464496 private extractTableName ( from : OperationNode ) : { model : GetModels < Schema > ; alias ?: string } | undefined {
465497 if ( TableNode . is ( from ) ) {
466498 return { model : from . table . identifier . name as GetModels < Schema > } ;
@@ -528,4 +560,6 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
528560 }
529561 return result ;
530562 }
563+
564+ // #endregion
531565}
0 commit comments