@@ -71,74 +71,46 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
7171 }
7272
7373 if ( ! this . isMutationQueryNode ( node ) ) {
74- // transform and proceed read without transaction
74+ // transform and proceed with read directly
7575 return proceed ( this . transformNode ( node ) ) ;
7676 }
7777
78- let mutationRequiresTransaction = false ;
7978 const { mutationModel } = this . getMutationModel ( node ) ;
8079
81- const isManyToManyJoinTable = this . isManyToManyJoinTable ( mutationModel ) ;
80+ if ( InsertQueryNode . is ( node ) ) {
81+ // pre-create policy evaluation happens before execution of the query
82+ const isManyToManyJoinTable = this . isManyToManyJoinTable ( mutationModel ) ;
83+ let needCheckPreCreate = true ;
84+
85+ // many-to-many join table is not a model so can't have policies on it
86+ if ( ! isManyToManyJoinTable ) {
87+ // check constant policies
88+ const constCondition = this . tryGetConstantPolicy ( mutationModel , 'create' ) ;
89+ if ( constCondition === true ) {
90+ needCheckPreCreate = false ;
91+ } else if ( constCondition === false ) {
92+ throw new RejectedByPolicyError ( mutationModel ) ;
93+ }
94+ }
8295
83- if ( InsertQueryNode . is ( node ) && ! isManyToManyJoinTable ) {
84- // reject create if unconditional deny
85- const constCondition = this . tryGetConstantPolicy ( mutationModel , 'create' ) ;
86- if ( constCondition === false ) {
87- throw new RejectedByPolicyError ( mutationModel ) ;
88- } else if ( constCondition === undefined ) {
89- mutationRequiresTransaction = true ;
96+ if ( needCheckPreCreate ) {
97+ await this . enforcePreCreatePolicy ( node , mutationModel , isManyToManyJoinTable , proceed ) ;
9098 }
9199 }
92100
93- if ( ! mutationRequiresTransaction && ! node . returning ) {
94- // transform and proceed mutation without transaction
95- return proceed ( this . transformNode ( node ) ) ;
96- }
101+ // proceed with query
97102
98- if ( InsertQueryNode . is ( node ) ) {
99- await this . enforcePreCreatePolicy ( node , mutationModel , isManyToManyJoinTable , proceed ) ;
100- }
101- const transformedNode = this . transformNode ( node ) ;
102- const result = await proceed ( transformedNode ) ;
103+ const result = await proceed ( this . transformNode ( node ) ) ;
103104
104- if ( ! this . onlyReturningId ( node ) ) {
105+ if ( ! node . returning || this . onlyReturningId ( node ) ) {
106+ return result ;
107+ } else {
105108 const readBackResult = await this . processReadBack ( node , result , proceed ) ;
106109 if ( readBackResult . rows . length !== result . rows . length ) {
107110 throw new RejectedByPolicyError ( mutationModel , 'result is not allowed to be read back' ) ;
108111 }
109112 return readBackResult ;
110- } else {
111- // reading id fields bypasses policy
112- return result ;
113113 }
114-
115- // TODO: run in transaction
116- // let readBackError = false;
117-
118- // transform and post-process in a transaction
119- // const result = await transaction(async (txProceed) => {
120- // if (InsertQueryNode.is(node)) {
121- // await this.enforcePreCreatePolicy(node, txProceed);
122- // }
123- // const transformedNode = this.transformNode(node);
124- // const result = await txProceed(transformedNode);
125-
126- // if (!this.onlyReturningId(node)) {
127- // const readBackResult = await this.processReadBack(node, result, txProceed);
128- // if (readBackResult.rows.length !== result.rows.length) {
129- // readBackError = true;
130- // }
131- // return readBackResult;
132- // } else {
133- // return result;
134- // }
135- // });
136-
137- // if (readBackError) {
138- // throw new RejectedByPolicyError(mutationModel, 'result is not allowed to be read back');
139- // }
140-
141- // return result;
142114 }
143115
144116 // #region overrides
@@ -296,11 +268,81 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
296268 ? this . unwrapCreateValueRows ( node . values , mutationModel , fields , isManyToManyJoinTable )
297269 : [ [ ] ] ;
298270 for ( const values of valueRows ) {
299- await this . enforcePreCreatePolicyForOne (
300- mutationModel ,
301- fields ,
302- values . map ( ( v ) => v . node ) ,
303- proceed ,
271+ if ( isManyToManyJoinTable ) {
272+ await this . enforcePreCreatePolicyForManyToManyJoinTable (
273+ mutationModel ,
274+ fields ,
275+ values . map ( ( v ) => v . node ) ,
276+ proceed ,
277+ ) ;
278+ } else {
279+ await this . enforcePreCreatePolicyForOne (
280+ mutationModel ,
281+ fields ,
282+ values . map ( ( v ) => v . node ) ,
283+ proceed ,
284+ ) ;
285+ }
286+ }
287+ }
288+
289+ private async enforcePreCreatePolicyForManyToManyJoinTable (
290+ tableName : GetModels < Schema > ,
291+ fields : string [ ] ,
292+ values : OperationNode [ ] ,
293+ proceed : ProceedKyselyQueryFunction ,
294+ ) {
295+ const m2m = this . resolveManyToManyJoinTable ( tableName ) ;
296+ invariant ( m2m ) ;
297+
298+ // m2m create requires both sides to be updatable
299+ invariant ( fields . includes ( 'A' ) && fields . includes ( 'B' ) , 'many-to-many join table must have A and B fk fields' ) ;
300+
301+ const aIndex = fields . indexOf ( 'A' ) ;
302+ const aNode = values [ aIndex ] ! ;
303+ const bIndex = fields . indexOf ( 'B' ) ;
304+ const bNode = values [ bIndex ] ! ;
305+ invariant ( ValueNode . is ( aNode ) && ValueNode . is ( bNode ) , 'A and B values must be ValueNode' ) ;
306+
307+ const aValue = aNode . value ;
308+ const bValue = bNode . value ;
309+ invariant ( aValue !== null && aValue !== undefined , 'A value cannot be null or undefined' ) ;
310+ invariant ( bValue !== null && bValue !== undefined , 'B value cannot be null or undefined' ) ;
311+
312+ const eb = expressionBuilder < any , any > ( ) ;
313+
314+ const filterA = this . buildPolicyFilter ( m2m . firstModel as GetModels < Schema > , undefined , 'update' ) ;
315+ const queryA = eb
316+ . selectFrom ( m2m . firstModel )
317+ . where ( eb ( eb . ref ( `${ m2m . firstModel } .${ m2m . firstIdField } ` ) , '=' , aValue ) )
318+ . select ( ( ) => new ExpressionWrapper ( filterA ) . as ( '$t' ) ) ;
319+
320+ const filterB = this . buildPolicyFilter ( m2m . secondModel as GetModels < Schema > , undefined , 'update' ) ;
321+ const queryB = eb
322+ . selectFrom ( m2m . secondModel )
323+ . where ( eb ( eb . ref ( `${ m2m . secondModel } .${ m2m . secondIdField } ` ) , '=' , bValue ) )
324+ . select ( ( ) => new ExpressionWrapper ( filterB ) . as ( '$t' ) ) ;
325+
326+ // select both conditions in one query
327+ const queryNode : SelectQueryNode = {
328+ kind : 'SelectQueryNode' ,
329+ selections : [
330+ SelectionNode . create ( AliasNode . create ( queryA . toOperationNode ( ) , IdentifierNode . create ( '$conditionA' ) ) ) ,
331+ SelectionNode . create ( AliasNode . create ( queryB . toOperationNode ( ) , IdentifierNode . create ( '$conditionB' ) ) ) ,
332+ ] ,
333+ } ;
334+
335+ const result = await proceed ( queryNode ) ;
336+ if ( ! result . rows [ 0 ] ?. $conditionA ) {
337+ throw new RejectedByPolicyError (
338+ m2m . firstModel as GetModels < Schema > ,
339+ `many-to-many relation participant model "${ m2m . firstModel } " not updatable` ,
340+ ) ;
341+ }
342+ if ( ! result . rows [ 0 ] ?. $conditionB ) {
343+ throw new RejectedByPolicyError (
344+ m2m . secondModel as GetModels < Schema > ,
345+ `many-to-many relation participant model "${ m2m . secondModel } " not updatable` ,
304346 ) ;
305347 }
306348 }
@@ -658,77 +700,100 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
658700 return result ;
659701 }
660702
661- private isManyToManyJoinTable ( tableName : string ) {
662- return Object . values ( this . client . $schema . models ) . some ( ( modelDef ) => {
663- return Object . values ( modelDef . fields ) . some ( ( field ) => {
664- const m2m = getManyToManyRelation ( this . client . $schema , modelDef . name , field . name ) ;
665- return m2m ?. joinTable === tableName ;
666- } ) ;
667- } ) ;
668- }
669-
670- private getModelPolicyFilterForManyToManyJoinTable (
671- tableName : string ,
672- alias : string | undefined ,
673- operation : PolicyOperation ,
674- ) : OperationNode | undefined {
675- // find the m2m relation for this join table
703+ private resolveManyToManyJoinTable ( tableName : string ) {
676704 for ( const model of Object . values ( this . client . $schema . models ) ) {
677705 for ( const field of Object . values ( model . fields ) ) {
678706 const m2m = getManyToManyRelation ( this . client . $schema , model . name , field . name ) ;
679- if ( m2m ?. joinTable !== tableName ) {
680- continue ;
681- }
682-
683- // determine A/B side
684- const sortedRecords = [
685- {
686- model : model . name ,
687- field : field . name ,
688- } ,
689- {
690- model : m2m . otherModel ,
691- field : m2m . otherField ,
692- } ,
693- ] . sort ( ( a , b ) =>
694- // the implicit m2m join table's "A", "B" fk fields' order is determined
695- // by model name's sort order, and when identical (for self-relations),
696- // field name's sort order
697- a . model !== b . model ? a . model . localeCompare ( b . model ) : a . field . localeCompare ( b . field ) ,
698- ) ;
699-
700- // join table's permission:
701- // - read: requires both sides to be readable
702- // - mutation: requires both sides to be updatable
703-
704- const queries : SelectQueryBuilder < any , any , any > [ ] = [ ] ;
705- const eb = expressionBuilder < any , any > ( ) ;
706-
707- for ( const [ fk , entry ] of zip ( [ 'A' , 'B' ] , sortedRecords ) ) {
708- const idFields = requireIdFields ( this . client . $schema , entry . model ) ;
707+ if ( m2m ?. joinTable === tableName ) {
708+ const sortedRecord = [
709+ {
710+ model : model . name ,
711+ field : field . name ,
712+ } ,
713+ {
714+ model : m2m . otherModel ,
715+ field : m2m . otherField ,
716+ } ,
717+ ] . sort ( this . manyToManySorter ) ;
718+
719+ const firstIdFields = requireIdFields ( this . client . $schema , sortedRecord [ 0 ] ! . model ) ;
720+ const secondIdFields = requireIdFields ( this . client . $schema , sortedRecord [ 1 ] ! . model ) ;
709721 invariant (
710- idFields . length === 1 ,
722+ firstIdFields . length === 1 && secondIdFields . length === 1 ,
711723 'only single-field id is supported for implicit many-to-many join table' ,
712724 ) ;
713725
714- const policyFilter = this . buildPolicyFilter (
715- entry . model as GetModels < Schema > ,
716- undefined ,
717- operation === 'read' ? 'read' : 'update' ,
718- ) ;
719- const query = eb
720- . selectFrom ( entry . model )
721- . whereRef ( `${ entry . model } .${ idFields [ 0 ] } ` , '=' , `${ alias ?? tableName } .${ fk } ` )
722- . select ( new ExpressionWrapper ( policyFilter ) . as ( `$condition${ fk } ` ) ) ;
723- queries . push ( query ) ;
726+ return {
727+ firstModel : sortedRecord [ 0 ] ! . model ,
728+ firstField : sortedRecord [ 0 ] ! . field ,
729+ firstIdField : firstIdFields [ 0 ] ! ,
730+ secondModel : sortedRecord [ 1 ] ! . model ,
731+ secondField : sortedRecord [ 1 ] ! . field ,
732+ secondIdField : secondIdFields [ 0 ] ! ,
733+ } ;
724734 }
725-
726- return eb . and ( queries ) . toOperationNode ( ) ;
727735 }
728736 }
729-
730737 return undefined ;
731738 }
732739
740+ private manyToManySorter ( a : { model : string ; field : string } , b : { model : string ; field : string } ) : number {
741+ // the implicit m2m join table's "A", "B" fk fields' order is determined
742+ // by model name's sort order, and when identical (for self-relations),
743+ // field name's sort order
744+ return a . model !== b . model ? a . model . localeCompare ( b . model ) : a . field . localeCompare ( b . field ) ;
745+ }
746+
747+ private isManyToManyJoinTable ( tableName : string ) {
748+ return ! ! this . resolveManyToManyJoinTable ( tableName ) ;
749+ }
750+
751+ private getModelPolicyFilterForManyToManyJoinTable (
752+ tableName : string ,
753+ alias : string | undefined ,
754+ operation : PolicyOperation ,
755+ ) : OperationNode | undefined {
756+ const m2m = this . resolveManyToManyJoinTable ( tableName ) ;
757+ if ( ! m2m ) {
758+ return undefined ;
759+ }
760+
761+ const sortedRecords = [
762+ {
763+ model : m2m . firstModel ,
764+ field : m2m . firstField ,
765+ } ,
766+ {
767+ model : m2m . secondModel ,
768+ field : m2m . secondField ,
769+ } ,
770+ ] ;
771+
772+ // join table's permission:
773+ // - read: requires both sides to be readable
774+ // - mutation: requires both sides to be updatable
775+
776+ const queries : SelectQueryBuilder < any , any , any > [ ] = [ ] ;
777+ const eb = expressionBuilder < any , any > ( ) ;
778+
779+ for ( const [ fk , entry ] of zip ( [ 'A' , 'B' ] , sortedRecords ) ) {
780+ const idFields = requireIdFields ( this . client . $schema , entry . model ) ;
781+ invariant ( idFields . length === 1 , 'only single-field id is supported for implicit many-to-many join table' ) ;
782+
783+ const policyFilter = this . buildPolicyFilter (
784+ entry . model as GetModels < Schema > ,
785+ undefined ,
786+ operation === 'read' ? 'read' : 'update' ,
787+ ) ;
788+ const query = eb
789+ . selectFrom ( entry . model )
790+ . whereRef ( `${ entry . model } .${ idFields [ 0 ] } ` , '=' , `${ alias ?? tableName } .${ fk } ` )
791+ . select ( new ExpressionWrapper ( policyFilter ) . as ( `$condition${ fk } ` ) ) ;
792+ queries . push ( query ) ;
793+ }
794+
795+ return eb . and ( queries ) . toOperationNode ( ) ;
796+ }
797+
733798 // #endregion
734799}
0 commit comments