Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
buildJoinPairs,
getIdFields,
getManyToManyRelation,
isRelationField,
requireField,
requireModel,
} from '../../query-utils';
Expand Down Expand Up @@ -216,10 +217,15 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
objArgs.push(
...Object.entries(payload.select)
.filter(([, value]) => value)
.map(([field]) => [
sql.lit(field),
buildFieldRef(this.schema, relationModel, field, this.options, eb),
])
.map(([field]) => {
const fieldDef = requireField(this.schema, relationModel, field);
const fieldValue = fieldDef.relation
? // reference the synthesized JSON field
eb.ref(`${parentName}$${relationField}$${field}.$j`)
: // reference a plain field
buildFieldRef(this.schema, relationModel, field, this.options, eb);
return [sql.lit(field), fieldValue];
})
.flatMap((v) => v),
);
}
Expand All @@ -229,27 +235,41 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
objArgs.push(
...Object.entries<any>(payload.include)
.filter(([, value]) => value)
.map(([field]) => [sql.lit(field), eb.ref(`${parentName}$${relationField}$${field}.$j`)])
.map(([field]) => [
sql.lit(field),
// reference the synthesized JSON field
eb.ref(`${parentName}$${relationField}$${field}.$j`),
])
.flatMap((v) => v),
);
}
return objArgs;
}

private buildRelationJoins(
model: string,
relationModel: string,
relationField: string,
qb: SelectQueryBuilder<any, any, any>,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
parentName: string,
) {
let result = qb;
if (typeof payload === 'object' && payload.include && typeof payload.include === 'object') {
Object.entries<any>(payload.include)
.filter(([, value]) => value)
.forEach(([field, value]) => {
result = this.buildRelationJSON(model, result, field, `${parentName}$${relationField}`, value);
});
if (typeof payload === 'object') {
const selectInclude = payload.include ?? payload.select;
if (selectInclude && typeof selectInclude === 'object') {
Object.entries<any>(selectInclude)
.filter(([, value]) => value)
.filter(([field]) => isRelationField(this.schema, relationModel, field))
.forEach(([field, value]) => {
result = this.buildRelationJSON(
relationModel,
result,
field,
`${parentName}$${relationField}`,
value,
);
});
}
}
return result;
}
Expand Down
16 changes: 10 additions & 6 deletions packages/runtime/src/client/crud/operations/aggregate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import { BaseOperationHandler } from './base';

export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
async handle(_operation: 'aggregate', args: unknown | undefined) {
const validatedArgs = this.inputValidator.validateAggregateArgs(this.model, args);
// normalize args to strip `undefined` fields
const normalizeArgs = this.normalizeArgs(args);

// parse args
const parsedArgs = this.inputValidator.validateAggregateArgs(this.model, normalizeArgs);

let query = this.kysely.selectFrom((eb) => {
// nested query for filtering and pagination
Expand All @@ -15,11 +19,11 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
let subQuery = eb
.selectFrom(this.model)
.selectAll(this.model as any) // TODO: check typing
.where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, validatedArgs?.where));
.where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where));

// skip & take
const skip = validatedArgs?.skip;
let take = validatedArgs?.take;
const skip = parsedArgs?.skip;
let take = parsedArgs?.take;
let negateOrderBy = false;
if (take !== undefined && take < 0) {
negateOrderBy = true;
Expand All @@ -32,7 +36,7 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
subQuery,
this.model,
this.model,
validatedArgs.orderBy,
parsedArgs.orderBy,
skip !== undefined || take !== undefined,
negateOrderBy,
);
Expand All @@ -41,7 +45,7 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
});

// aggregations
for (const [key, value] of Object.entries(validatedArgs)) {
for (const [key, value] of Object.entries(parsedArgs)) {
switch (key) {
case '_count': {
if (value === true) {
Expand Down
118 changes: 76 additions & 42 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import { createId } from '@paralleldrive/cuid2';
import { invariant } from '@zenstackhq/common-helpers';
import { invariant, isPlainObject } from '@zenstackhq/common-helpers';
import {
DeleteResult,
expressionBuilder,
ExpressionWrapper,
sql,
UpdateResult,
type ExpressionBuilder,
type Expression as KyselyExpression,
type SelectQueryBuilder,
} from 'kysely';
Expand Down Expand Up @@ -292,45 +291,36 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
for (const [field, value] of Object.entries(selections.select)) {
const fieldDef = requireField(this.schema, model, field);
const fieldModel = fieldDef.type;
const jointTable = `${parentAlias}$${field}$count`;
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, jointTable);

query = query.leftJoin(
(eb) => {
let result = eb.selectFrom(fieldModel).selectAll();
if (
value &&
typeof value === 'object' &&
'where' in value &&
value.where &&
typeof value.where === 'object'
) {
const filter = this.dialect.buildFilter(eb, fieldModel, fieldModel, value.where);
result = result.where(filter);
}
return result.as(jointTable);
},
(join) => {
for (const [left, right] of joinPairs) {
join = join.onRef(left, '=', right);
}
return join;
},
);
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);

// build a nested query to count the number of records in the relation
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));

// join conditions
for (const [left, right] of joinPairs) {
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
}

// merge _count filter
if (
value &&
typeof value === 'object' &&
'where' in value &&
value.where &&
typeof value.where === 'object'
) {
const filter = this.dialect.buildFilter(eb, fieldModel, fieldModel, value.where);
fieldCountQuery = fieldCountQuery.where(filter);
}

jsonObject[field] = this.countIdDistinct(eb, fieldDef.type, jointTable);
jsonObject[field] = fieldCountQuery;
}

query = query.select((eb) => this.dialect.buildJsonObject(eb, jsonObject).as('_count'));

return query;
}

private countIdDistinct(eb: ExpressionBuilder<any, any>, model: string, table: string) {
const idFields = getIdFields(this.schema, model);
return eb.fn.count(sql.join(idFields.map((f) => sql.ref(`${table}.${f}`)))).distinct();
}

private buildSelectAllScalarFields(
model: string,
query: SelectQueryBuilder<any, any, any>,
Expand Down Expand Up @@ -479,7 +469,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
} else {
const subM2M = getManyToManyRelation(this.schema, model, field);
if (!subM2M && fieldDef.relation?.fields && fieldDef.relation?.references) {
const fkValues = await this.processOwnedRelation(kysely, fieldDef, value);
const fkValues = await this.processOwnedRelationForCreate(kysely, fieldDef, value);
for (let i = 0; i < fieldDef.relation.fields.length; i++) {
createFields[fieldDef.relation.fields[i]!] = fkValues[fieldDef.relation.references[i]!];
}
Expand Down Expand Up @@ -519,7 +509,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
if (Object.keys(postCreateRelations).length > 0) {
// process nested creates that need to happen after the current entity is created
const relationPromises = Object.entries(postCreateRelations).map(([field, subPayload]) => {
return this.processNoneOwnedRelation(kysely, model, field, subPayload, createdEntity);
return this.processNoneOwnedRelationForCreate(kysely, model, field, subPayload, createdEntity);
});

// await relation creation
Expand Down Expand Up @@ -633,7 +623,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
.execute();
}

private async processOwnedRelation(kysely: ToKysely<Schema>, relationField: FieldDef, payload: any) {
private async processOwnedRelationForCreate(kysely: ToKysely<Schema>, relationField: FieldDef, payload: any) {
if (!payload) {
return;
}
Expand Down Expand Up @@ -696,7 +686,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return result;
}

private processNoneOwnedRelation(
private processNoneOwnedRelationForCreate(
kysely: ToKysely<Schema>,
contextModel: GetModels<Schema>,
relationFieldName: string,
Expand All @@ -706,6 +696,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
const relationFieldDef = this.requireField(contextModel, relationFieldName);
const relationModel = relationFieldDef.type as GetModels<Schema>;
const tasks: Promise<unknown>[] = [];
const fromRelationContext = {
model: contextModel,
field: relationFieldName,
ids: parentEntity,
};

for (const [action, subPayload] of Object.entries<any>(payload)) {
if (!subPayload) {
Expand All @@ -716,11 +711,21 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
// create with a parent entity
tasks.push(
...enumerate(subPayload).map((item) =>
this.create(kysely, relationModel, item, {
model: contextModel,
field: relationFieldName,
ids: parentEntity,
}),
this.create(kysely, relationModel, item, fromRelationContext),
),
);
break;
}

case 'createMany': {
invariant(relationFieldDef.array, 'relation must be an array for createMany');
tasks.push(
this.createMany(
kysely,
relationModel,
subPayload as { data: any; skipDuplicates: boolean },
false,
fromRelationContext,
),
);
break;
Expand Down Expand Up @@ -776,6 +781,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
returnData: ReturnData,
fromRelation?: FromRelationContext<Schema>,
): Promise<Result> {
if (!input.data || (Array.isArray(input.data) && input.data.length === 0)) {
// nothing todo
return returnData ? ([] as Result) : ({ count: 0 } as Result);
}

const modelDef = this.requireModel(model);

let relationKeyPairs: { fk: string; pk: string }[] = [];
Expand Down Expand Up @@ -1916,4 +1926,28 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
where: uniqueFilter,
});
}

/**
* Normalize input args to strip `undefined` fields
*/
protected normalizeArgs(args: unknown) {
if (!args) {
return;
}
const newArgs = clone(args);
this.doNormalizeArgs(newArgs);
return newArgs;
}

private doNormalizeArgs(args: unknown) {
if (args && typeof args === 'object') {
for (const [key, value] of Object.entries(args)) {
if (value === undefined) {
delete args[key as keyof typeof args];
} else if (value && isPlainObject(value)) {
this.doNormalizeArgs(value);
}
}
}
}
}
14 changes: 9 additions & 5 deletions packages/runtime/src/client/crud/operations/count.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,26 @@ import { BaseOperationHandler } from './base';

export class CountOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
async handle(_operation: 'count', args: unknown | undefined) {
const validatedArgs = this.inputValidator.validateCountArgs(this.model, args);
// normalize args to strip `undefined` fields
const normalizeArgs = this.normalizeArgs(args);

// parse args
const parsedArgs = this.inputValidator.validateCountArgs(this.model, normalizeArgs);

let query = this.kysely.selectFrom((eb) => {
// nested query for filtering and pagination
let subQuery = eb
.selectFrom(this.model)
.selectAll()
.where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, validatedArgs?.where));
subQuery = this.dialect.buildSkipTake(subQuery, validatedArgs?.skip, validatedArgs?.take);
.where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where));
subQuery = this.dialect.buildSkipTake(subQuery, parsedArgs?.skip, parsedArgs?.take);
return subQuery.as('$sub');
});

if (validatedArgs?.select && typeof validatedArgs.select === 'object') {
if (parsedArgs?.select && typeof parsedArgs.select === 'object') {
// count with field selection
query = query.select((eb) =>
Object.keys(validatedArgs.select!).map((key) =>
Object.keys(parsedArgs.select!).map((key) =>
key === '_all'
? eb.cast(eb.fn.countAll(), 'integer').as('_all')
: eb.cast(eb.fn.count(sql.ref(`$sub.${key}`)), 'integer').as(key),
Expand Down
9 changes: 6 additions & 3 deletions packages/runtime/src/client/crud/operations/create.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ import { BaseOperationHandler } from './base';

export class CreateOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
async handle(operation: 'create' | 'createMany' | 'createManyAndReturn', args: unknown | undefined) {
// normalize args to strip `undefined` fields
const normalizeArgs = this.normalizeArgs(args);

return match(operation)
.with('create', () => this.runCreate(this.inputValidator.validateCreateArgs(this.model, args)))
.with('create', () => this.runCreate(this.inputValidator.validateCreateArgs(this.model, normalizeArgs)))
.with('createMany', () => {
return this.runCreateMany(this.inputValidator.validateCreateManyArgs(this.model, args));
return this.runCreateMany(this.inputValidator.validateCreateManyArgs(this.model, normalizeArgs));
})
.with('createManyAndReturn', () => {
return this.runCreateManyAndReturn(
this.inputValidator.validateCreateManyAndReturnArgs(this.model, args),
this.inputValidator.validateCreateManyAndReturnArgs(this.model, normalizeArgs),
);
})
.exhaustive();
Expand Down
9 changes: 7 additions & 2 deletions packages/runtime/src/client/crud/operations/delete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ import { BaseOperationHandler } from './base';

export class DeleteOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
async handle(operation: 'delete' | 'deleteMany', args: unknown | undefined) {
// normalize args to strip `undefined` fields
const normalizeArgs = this.normalizeArgs(args);

return match(operation)
.with('delete', () => this.runDelete(this.inputValidator.validateDeleteArgs(this.model, args)))
.with('deleteMany', () => this.runDeleteMany(this.inputValidator.validateDeleteManyArgs(this.model, args)))
.with('delete', () => this.runDelete(this.inputValidator.validateDeleteArgs(this.model, normalizeArgs)))
.with('deleteMany', () =>
this.runDeleteMany(this.inputValidator.validateDeleteManyArgs(this.model, normalizeArgs)),
)
.exhaustive();
}

Expand Down
Loading
Loading