11/* eslint-disable @typescript-eslint/no-explicit-any */
22
33import { PrismaClientKnownRequestError , PrismaClientUnknownRequestError } from '@prisma/client/runtime' ;
4- import { AUXILIARY_FIELDS , CrudFailureReason , TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk' ;
4+ import { AUXILIARY_FIELDS , CrudFailureReason , GUARD_FIELD_NAME , TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk' ;
55import { camelCase } from 'change-case' ;
66import cuid from 'cuid' ;
77import deepcopy from 'deepcopy' ;
@@ -42,8 +42,7 @@ export class PolicyUtil {
4242 and ( ...conditions : ( boolean | object ) [ ] ) : any {
4343 if ( conditions . includes ( false ) ) {
4444 // always false
45- // TODO: custom id field
46- return { id : { in : [ ] } } ;
45+ return { [ GUARD_FIELD_NAME ] : false } ;
4746 }
4847
4948 const filtered = conditions . filter (
@@ -64,7 +63,7 @@ export class PolicyUtil {
6463 or ( ...conditions : ( boolean | object ) [ ] ) : any {
6564 if ( conditions . includes ( true ) ) {
6665 // always true
67- return { id : { notIn : [ ] } } ;
66+ return { [ GUARD_FIELD_NAME ] : true } ;
6867 }
6968
7069 const filtered = conditions . filter ( ( c ) : c is object => typeof c === 'object' && ! ! c ) ;
@@ -276,7 +275,7 @@ export class PolicyUtil {
276275 return ;
277276 }
278277
279- const idField = this . getIdField ( model ) ;
278+ const idFields = this . getIdFields ( model ) ;
280279 for ( const field of getModelFields ( injectTarget ) ) {
281280 const fieldInfo = resolveField ( this . modelMeta , model , field ) ;
282281 if ( ! fieldInfo || ! fieldInfo . isDataModel ) {
@@ -292,10 +291,16 @@ export class PolicyUtil {
292291
293292 await this . injectAuthGuard ( injectTarget [ field ] , fieldInfo . type , 'read' ) ;
294293 } else {
295- // there's no way of injecting condition for to-one relation, so we
296- // make sure 'id' field is selected and check them against query result
297- if ( injectTarget [ field ] ?. select && injectTarget [ field ] ?. select ?. [ idField . name ] !== true ) {
298- injectTarget [ field ] . select [ idField . name ] = true ;
294+ // there's no way of injecting condition for to-one relation, so if there's
295+ // "select" clause we make sure 'id' fields are selected and check them against
296+ // query result; nothing needs to be done for "include" clause because all
297+ // fields are already selected
298+ if ( injectTarget [ field ] ?. select ) {
299+ for ( const idField of idFields ) {
300+ if ( injectTarget [ field ] . select [ idField . name ] !== true ) {
301+ injectTarget [ field ] . select [ idField . name ] = true ;
302+ }
303+ }
299304 }
300305 }
301306
@@ -310,7 +315,8 @@ export class PolicyUtil {
310315 * omitted.
311316 */
312317 async postProcessForRead ( entityData : any , model : string , args : any , operation : PolicyOperationKind ) {
313- if ( ! this . getEntityId ( model , entityData ) ) {
318+ const ids = this . getEntityIds ( model , entityData ) ;
319+ if ( Object . keys ( ids ) . length === 0 ) {
314320 return ;
315321 }
316322
@@ -330,21 +336,23 @@ export class PolicyUtil {
330336 // post-check them
331337
332338 for ( const field of getModelFields ( injectTarget ) ) {
339+ if ( ! entityData ?. [ field ] ) {
340+ continue ;
341+ }
342+
333343 const fieldInfo = resolveField ( this . modelMeta , model , field ) ;
334344 if ( ! fieldInfo || ! fieldInfo . isDataModel || fieldInfo . isArray ) {
335345 continue ;
336346 }
337347
338- const idField = this . getIdField ( fieldInfo . type ) ;
339- const relatedEntityId = entityData ?. [ field ] ?. [ idField . name ] ;
348+ const ids = this . getEntityIds ( fieldInfo . type , entityData [ field ] ) ;
340349
341- if ( ! relatedEntityId ) {
350+ if ( Object . keys ( ids ) . length === 0 ) {
342351 continue ;
343352 }
344353
345- this . logger . info ( `Validating read of to-one relation: ${ fieldInfo . type } #${ relatedEntityId } ` ) ;
346-
347- await this . checkPolicyForFilter ( fieldInfo . type , { [ idField . name ] : relatedEntityId } , operation , this . db ) ;
354+ this . logger . info ( `Validating read of to-one relation: ${ fieldInfo . type } #${ formatObject ( ids ) } ` ) ;
355+ await this . checkPolicyForFilter ( fieldInfo . type , ids , operation , this . db ) ;
348356
349357 // recurse
350358 await this . postProcessForRead ( entityData [ field ] , fieldInfo . type , injectTarget [ field ] , operation ) ;
@@ -366,14 +374,18 @@ export class PolicyUtil {
366374
367375 // record model entities that are updated, together with their
368376 // values before update, so we can post-check if they satisfy
369- // model => id => entity value
370- const updatedModels = new Map < string , Map < string , any > > ( ) ;
377+ // model => { ids, entity value }
378+ const updatedModels = new Map < string , Array < { ids : Record < string , unknown > ; value : any } > > ( ) ;
371379
372- const idField = this . getIdField ( model ) ;
373- if ( args . select && ! args . select [ idField . name ] ) {
380+ const idFields = this . getIdFields ( model ) ;
381+ if ( args . select ) {
374382 // make sure 'id' field is selected, we need it to
375383 // read back the updated entity
376- args . select [ idField . name ] = true ;
384+ for ( const idField of idFields ) {
385+ if ( ! args . select [ idField . name ] ) {
386+ args . select [ idField . name ] = true ;
387+ }
388+ }
377389 }
378390
379391 // use a transaction to conduct write, so in case any create or nested create
@@ -496,7 +508,7 @@ export class PolicyUtil {
496508 if ( postGuard !== true || schema ) {
497509 let modelEntities = updatedModels . get ( model ) ;
498510 if ( ! modelEntities ) {
499- modelEntities = new Map < string , any > ( ) ;
511+ modelEntities = [ ] ;
500512 updatedModels . set ( model , modelEntities ) ;
501513 }
502514
@@ -509,11 +521,19 @@ export class PolicyUtil {
509521 // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' }
510522 await this . flattenGeneratedUniqueField ( model , filter ) ;
511523
512- const idField = this . getIdField ( model ) ;
513- const query = { where : filter , select : { ...preValueSelect , [ idField . name ] : true } } ;
524+ const idFields = this . getIdFields ( model ) ;
525+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
526+ const select : any = { ...preValueSelect } ;
527+ for ( const idField of idFields ) {
528+ select [ idField . name ] = true ;
529+ }
530+
531+ const query = { where : filter , select } ;
514532 this . logger . info ( `fetching pre-update entities for ${ model } : ${ formatObject ( query ) } )}` ) ;
515533 const entities = await this . db [ model ] . findMany ( query ) ;
516- entities . forEach ( ( entity ) => modelEntities ?. set ( this . getEntityId ( model , entity ) , entity ) ) ;
534+ entities . forEach ( ( entity ) =>
535+ modelEntities ?. push ( { ids : this . getEntityIds ( model , entity ) , value : entity } )
536+ ) ;
517537 }
518538 } ;
519539
@@ -622,8 +642,8 @@ export class PolicyUtil {
622642 await Promise . all (
623643 [ ...updatedModels . entries ( ) ]
624644 . map ( ( [ model , modelEntities ] ) =>
625- [ ... modelEntities . entries ( ) ] . map ( async ( [ id , preValue ] ) =>
626- this . checkPostUpdate ( model , id , tx , preValue )
645+ modelEntities . map ( async ( { ids , value : preValue } ) =>
646+ this . checkPostUpdate ( model , ids , tx , preValue )
627647 )
628648 )
629649 . flat ( )
@@ -716,14 +736,18 @@ export class PolicyUtil {
716736 }
717737 }
718738
719- private async checkPostUpdate ( model : string , id : any , db : Record < string , DbOperations > , preValue : any ) {
720- this . logger . info ( `Checking post-update policy for ${ model } #${ id } , preValue: ${ formatObject ( preValue ) } ` ) ;
739+ private async checkPostUpdate (
740+ model : string ,
741+ ids : Record < string , unknown > ,
742+ db : Record < string , DbOperations > ,
743+ preValue : any
744+ ) {
745+ this . logger . info ( `Checking post-update policy for ${ model } #${ ids } , preValue: ${ formatObject ( preValue ) } ` ) ;
721746
722747 const guard = await this . getAuthGuard ( model , 'postUpdate' , preValue ) ;
723748
724749 // build a query condition with policy injected
725- const idField = this . getIdField ( model ) ;
726- const guardedQuery = { where : this . and ( { [ idField . name ] : id } , guard ) } ;
750+ const guardedQuery = { where : this . and ( ids , guard ) } ;
727751
728752 // query with policy injected
729753 const entity = await db [ model ] . findFirst ( guardedQuery ) ;
@@ -760,13 +784,13 @@ export class PolicyUtil {
760784 /**
761785 * Gets "id" field for a given model.
762786 */
763- getIdField ( model : string ) {
787+ getIdFields ( model : string ) {
764788 const fields = this . modelMeta . fields [ camelCase ( model ) ] ;
765789 if ( ! fields ) {
766790 throw this . unknownError ( `Unable to load fields for ${ model } ` ) ;
767791 }
768- const result = Object . values ( fields ) . find ( ( f ) => f . isId ) ;
769- if ( ! result ) {
792+ const result = Object . values ( fields ) . filter ( ( f ) => f . isId ) ;
793+ if ( result . length === 0 ) {
770794 throw this . unknownError ( `model ${ model } does not have an id field` ) ;
771795 }
772796 return result ;
@@ -775,8 +799,12 @@ export class PolicyUtil {
775799 /**
776800 * Gets id field value from an entity.
777801 */
778- getEntityId ( model : string , entityData : any ) {
779- const idField = this . getIdField ( model ) ;
780- return entityData [ idField . name ] ;
802+ getEntityIds ( model : string , entityData : any ) {
803+ const idFields = this . getIdFields ( model ) ;
804+ const result : Record < string , unknown > = { } ;
805+ for ( const idField of idFields ) {
806+ result [ idField . name ] = entityData [ idField . name ] ;
807+ }
808+ return result ;
781809 }
782810}
0 commit comments