11import { createId } from '@paralleldrive/cuid2' ;
2- import { invariant } from '@zenstackhq/common-helpers' ;
2+ import { invariant , isPlainObject } from '@zenstackhq/common-helpers' ;
33import {
44 DeleteResult ,
55 expressionBuilder ,
66 ExpressionWrapper ,
77 sql ,
88 UpdateResult ,
9- type ExpressionBuilder ,
109 type Expression as KyselyExpression ,
1110 type SelectQueryBuilder ,
1211} from 'kysely' ;
@@ -292,45 +291,36 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
292291 for ( const [ field , value ] of Object . entries ( selections . select ) ) {
293292 const fieldDef = requireField ( this . schema , model , field ) ;
294293 const fieldModel = fieldDef . type ;
295- const jointTable = `${ parentAlias } $${ field } $count` ;
296- const joinPairs = buildJoinPairs ( this . schema , model , parentAlias , field , jointTable ) ;
297-
298- query = query . leftJoin (
299- ( eb ) => {
300- let result = eb . selectFrom ( fieldModel ) . selectAll ( ) ;
301- if (
302- value &&
303- typeof value === 'object' &&
304- 'where' in value &&
305- value . where &&
306- typeof value . where === 'object'
307- ) {
308- const filter = this . dialect . buildFilter ( eb , fieldModel , fieldModel , value . where ) ;
309- result = result . where ( filter ) ;
310- }
311- return result . as ( jointTable ) ;
312- } ,
313- ( join ) => {
314- for ( const [ left , right ] of joinPairs ) {
315- join = join . onRef ( left , '=' , right ) ;
316- }
317- return join ;
318- } ,
319- ) ;
294+ const joinPairs = buildJoinPairs ( this . schema , model , parentAlias , field , fieldModel ) ;
295+
296+ // build a nested query to count the number of records in the relation
297+ let fieldCountQuery = eb . selectFrom ( fieldModel ) . select ( eb . fn . countAll ( ) . as ( `_count$${ field } ` ) ) ;
298+
299+ // join conditions
300+ for ( const [ left , right ] of joinPairs ) {
301+ fieldCountQuery = fieldCountQuery . whereRef ( left , '=' , right ) ;
302+ }
303+
304+ // merge _count filter
305+ if (
306+ value &&
307+ typeof value === 'object' &&
308+ 'where' in value &&
309+ value . where &&
310+ typeof value . where === 'object'
311+ ) {
312+ const filter = this . dialect . buildFilter ( eb , fieldModel , fieldModel , value . where ) ;
313+ fieldCountQuery = fieldCountQuery . where ( filter ) ;
314+ }
320315
321- jsonObject [ field ] = this . countIdDistinct ( eb , fieldDef . type , jointTable ) ;
316+ jsonObject [ field ] = fieldCountQuery ;
322317 }
323318
324319 query = query . select ( ( eb ) => this . dialect . buildJsonObject ( eb , jsonObject ) . as ( '_count' ) ) ;
325320
326321 return query ;
327322 }
328323
329- private countIdDistinct ( eb : ExpressionBuilder < any , any > , model : string , table : string ) {
330- const idFields = getIdFields ( this . schema , model ) ;
331- return eb . fn . count ( sql . join ( idFields . map ( ( f ) => sql . ref ( `${ table } .${ f } ` ) ) ) ) . distinct ( ) ;
332- }
333-
334324 private buildSelectAllScalarFields (
335325 model : string ,
336326 query : SelectQueryBuilder < any , any , any > ,
@@ -479,7 +469,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
479469 } else {
480470 const subM2M = getManyToManyRelation ( this . schema , model , field ) ;
481471 if ( ! subM2M && fieldDef . relation ?. fields && fieldDef . relation ?. references ) {
482- const fkValues = await this . processOwnedRelation ( kysely , fieldDef , value ) ;
472+ const fkValues = await this . processOwnedRelationForCreate ( kysely , fieldDef , value ) ;
483473 for ( let i = 0 ; i < fieldDef . relation . fields . length ; i ++ ) {
484474 createFields [ fieldDef . relation . fields [ i ] ! ] = fkValues [ fieldDef . relation . references [ i ] ! ] ;
485475 }
@@ -519,7 +509,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
519509 if ( Object . keys ( postCreateRelations ) . length > 0 ) {
520510 // process nested creates that need to happen after the current entity is created
521511 const relationPromises = Object . entries ( postCreateRelations ) . map ( ( [ field , subPayload ] ) => {
522- return this . processNoneOwnedRelation ( kysely , model , field , subPayload , createdEntity ) ;
512+ return this . processNoneOwnedRelationForCreate ( kysely , model , field , subPayload , createdEntity ) ;
523513 } ) ;
524514
525515 // await relation creation
@@ -633,7 +623,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
633623 . execute ( ) ;
634624 }
635625
636- private async processOwnedRelation ( kysely : ToKysely < Schema > , relationField : FieldDef , payload : any ) {
626+ private async processOwnedRelationForCreate ( kysely : ToKysely < Schema > , relationField : FieldDef , payload : any ) {
637627 if ( ! payload ) {
638628 return ;
639629 }
@@ -696,7 +686,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
696686 return result ;
697687 }
698688
699- private processNoneOwnedRelation (
689+ private processNoneOwnedRelationForCreate (
700690 kysely : ToKysely < Schema > ,
701691 contextModel : GetModels < Schema > ,
702692 relationFieldName : string ,
@@ -706,6 +696,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
706696 const relationFieldDef = this . requireField ( contextModel , relationFieldName ) ;
707697 const relationModel = relationFieldDef . type as GetModels < Schema > ;
708698 const tasks : Promise < unknown > [ ] = [ ] ;
699+ const fromRelationContext = {
700+ model : contextModel ,
701+ field : relationFieldName ,
702+ ids : parentEntity ,
703+ } ;
709704
710705 for ( const [ action , subPayload ] of Object . entries < any > ( payload ) ) {
711706 if ( ! subPayload ) {
@@ -716,11 +711,21 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
716711 // create with a parent entity
717712 tasks . push (
718713 ...enumerate ( subPayload ) . map ( ( item ) =>
719- this . create ( kysely , relationModel , item , {
720- model : contextModel ,
721- field : relationFieldName ,
722- ids : parentEntity ,
723- } ) ,
714+ this . create ( kysely , relationModel , item , fromRelationContext ) ,
715+ ) ,
716+ ) ;
717+ break ;
718+ }
719+
720+ case 'createMany' : {
721+ invariant ( relationFieldDef . array , 'relation must be an array for createMany' ) ;
722+ tasks . push (
723+ this . createMany (
724+ kysely ,
725+ relationModel ,
726+ subPayload as { data : any ; skipDuplicates : boolean } ,
727+ false ,
728+ fromRelationContext ,
724729 ) ,
725730 ) ;
726731 break ;
@@ -776,6 +781,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
776781 returnData : ReturnData ,
777782 fromRelation ?: FromRelationContext < Schema > ,
778783 ) : Promise < Result > {
784+ if ( ! input . data || ( Array . isArray ( input . data ) && input . data . length === 0 ) ) {
785+ // nothing todo
786+ return returnData ? ( [ ] as Result ) : ( { count : 0 } as Result ) ;
787+ }
788+
779789 const modelDef = this . requireModel ( model ) ;
780790
781791 let relationKeyPairs : { fk : string ; pk : string } [ ] = [ ] ;
@@ -1916,4 +1926,28 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
19161926 where : uniqueFilter ,
19171927 } ) ;
19181928 }
1929+
1930+ /**
1931+ * Normalize input args to strip `undefined` fields
1932+ */
1933+ protected normalizeArgs ( args : unknown ) {
1934+ if ( ! args ) {
1935+ return ;
1936+ }
1937+ const newArgs = clone ( args ) ;
1938+ this . doNormalizeArgs ( newArgs ) ;
1939+ return newArgs ;
1940+ }
1941+
1942+ private doNormalizeArgs ( args : unknown ) {
1943+ if ( args && typeof args === 'object' ) {
1944+ for ( const [ key , value ] of Object . entries ( args ) ) {
1945+ if ( value === undefined ) {
1946+ delete args [ key as keyof typeof args ] ;
1947+ } else if ( value && isPlainObject ( value ) ) {
1948+ this . doNormalizeArgs ( value ) ;
1949+ }
1950+ }
1951+ }
1952+ }
19191953}
0 commit comments