Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion packages/runtime/src/client/crud/dialects/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { invariant, isPlainObject } from '@zenstackhq/common-helpers';
import type { Expression, ExpressionBuilder, ExpressionWrapper, SqlBool, ValueNode } from 'kysely';
import { expressionBuilder, sql, type SelectQueryBuilder } from 'kysely';
import { match, P } from 'ts-pattern';
import type { BuiltinType, DataSourceProviderType, FieldDef, GetModels, SchemaDef } from '../../../schema';
import type { BuiltinType, DataSourceProviderType, FieldDef, GetModels, ModelDef, SchemaDef } from '../../../schema';
import { enumerate } from '../../../utils/enumerate';
import type { OrArray } from '../../../utils/type-utils';
import { AGGREGATE_OPERATORS, DELEGATE_JOINED_FIELD_PREFIX, LOGICAL_COMBINATORS } from '../../constants';
Expand Down Expand Up @@ -963,6 +963,31 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return result;
}

protected buildModelSelect(
eb: ExpressionBuilder<any, any>,
model: GetModels<Schema>,
subQueryAlias: string,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
selectAllFields: boolean,
) {
let subQuery = this.buildSelectModel(eb, model, subQueryAlias);

if (selectAllFields) {
subQuery = this.buildSelectAllFields(
model,
subQuery,
typeof payload === 'object' ? payload?.omit : undefined,
subQueryAlias,
);
}

if (payload && typeof payload === 'object') {
subQuery = this.buildFilterSortTake(model, payload, subQuery, subQueryAlias);
}

return subQuery;
}

buildSelectField(
query: SelectQueryBuilder<any, any, any>,
model: string,
Expand Down Expand Up @@ -1115,6 +1140,35 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return buildFieldRef(this.schema, model, field, this.options, eb, modelAlias, inlineComputedField);
}

protected canJoinWithoutNestedSelect(
modelDef: ModelDef,
payload: boolean | FindArgs<Schema, GetModels<Schema>, true>,
) {
if (modelDef.computedFields) {
// computed fields requires explicit select
return false;
}

if (modelDef.baseModel || modelDef.isDelegate) {
// delete models require upward/downward joins
return false;
}

if (
typeof payload === 'object' &&
(payload.orderBy ||
payload.skip !== undefined ||
payload.take !== undefined ||
payload.cursor ||
(payload as any).distinct)
) {
// ordering/pagination/distinct needs to be handled before joining
return false;
}

return true;
}

// #endregion

// #region abstract methods
Expand Down
178 changes: 101 additions & 77 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,127 +58,151 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
parentAlias: string,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
): SelectQueryBuilder<any, any, any> {
const joinedQuery = this.buildRelationJSON(model, query, relationField, parentAlias, payload);

return joinedQuery.select(`${parentAlias}$${relationField}.$t as ${relationField}`);
const relationResultName = `${parentAlias}$${relationField}`;
const joinedQuery = this.buildRelationJSON(
model,
query,
relationField,
parentAlias,
payload,
relationResultName,
);
return joinedQuery.select(`${relationResultName}.$data as ${relationField}`);
}

private buildRelationJSON(
model: string,
qb: SelectQueryBuilder<any, any, any>,
relationField: string,
parentName: string,
parentAlias: string,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
resultName: string,
) {
const relationFieldDef = requireField(this.schema, model, relationField);
const relationModel = relationFieldDef.type as GetModels<Schema>;

return qb.leftJoinLateral(
(eb) => {
const joinTableName = `${parentName}$${relationField}`;

// simple select by default
let result = eb.selectFrom(`${relationModel} as ${joinTableName}`);
const relationSelectName = `${resultName}$sub`;
const relationModelDef = requireModel(this.schema, relationModel);

// however if there're filter/orderBy/take/skip,
// we need to build a subquery to handle them before aggregation
let tbl: SelectQueryBuilder<any, any, any>;

// give sub query an alias to avoid conflict with parent scope
// (e.g., for cases like self-relation)
const subQueryAlias = `${relationModel}$${relationField}$sub`;
if (this.canJoinWithoutNestedSelect(relationModelDef, payload)) {
// build join directly
tbl = this.buildModelSelect(eb, relationModel, relationSelectName, payload, false);

result = eb.selectFrom(() => {
let subQuery = this.buildSelectModel(eb, relationModel, subQueryAlias);
subQuery = this.buildSelectAllFields(
// parent join filter
tbl = this.buildRelationJoinFilter(
tbl,
model,
relationField,
relationModel,
subQuery,
typeof payload === 'object' ? payload?.omit : undefined,
subQueryAlias,
relationSelectName,
parentAlias,
);

if (payload && typeof payload === 'object') {
subQuery = this.buildFilterSortTake(relationModel, payload, subQuery, subQueryAlias);
}

// add join conditions

const m2m = getManyToManyRelation(this.schema, model, relationField);

if (m2m) {
// many-to-many relation
const parentIds = getIdFields(this.schema, model);
const relationIds = getIdFields(this.schema, relationModel);
invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field');
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
subQuery = subQuery.where(
eb(
eb.ref(`${subQueryAlias}.${relationIds[0]}`),
'in',
eb
.selectFrom(m2m.joinTable)
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
.whereRef(
`${parentName}.${parentIds[0]}`,
'=',
`${m2m.joinTable}.${m2m.parentFkName}`,
),
),
} else {
// join with a nested query
tbl = eb.selectFrom(() => {
let subQuery = this.buildModelSelect(
eb,
relationModel,
`${relationSelectName}$t`,
payload,
true,
);
} else {
const joinPairs = buildJoinPairs(this.schema, model, parentName, relationField, subQueryAlias);
subQuery = subQuery.where((eb) =>
this.and(eb, ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),

// parent join filter
subQuery = this.buildRelationJoinFilter(
subQuery,
model,
relationField,
relationModel,
`${relationSelectName}$t`,
parentAlias,
);
}

return subQuery.as(joinTableName);
});
return subQuery.as(relationSelectName);
});
}

result = this.buildRelationObjectSelect(
// select relation result
tbl = this.buildRelationObjectSelect(
relationModel,
joinTableName,
relationField,
relationSelectName,
relationFieldDef,
result,
tbl,
payload,
parentName,
resultName,
);

// add nested joins for each relation
result = this.buildRelationJoins(relationModel, relationField, result, payload, parentName);
tbl = this.buildRelationJoins(tbl, relationModel, relationSelectName, payload, resultName);

// alias the join table
return result.as(joinTableName);
return tbl.as(resultName);
},
(join) => join.onTrue(),
);
}

private buildRelationJoinFilter(
query: SelectQueryBuilder<any, any, {}>,
model: string,
relationField: string,
relationModel: GetModels<Schema>,
relationModelAlias: string,
parentAlias: string,
) {
const m2m = getManyToManyRelation(this.schema, model, relationField);
if (m2m) {
// many-to-many relation
const parentIds = getIdFields(this.schema, model);
const relationIds = getIdFields(this.schema, relationModel);
invariant(parentIds.length === 1, 'many-to-many relation must have exactly one id field');
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
query = query.where((eb) =>
eb(
eb.ref(`${relationModelAlias}.${relationIds[0]}`),
'in',
eb
.selectFrom(m2m.joinTable)
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
.whereRef(`${parentAlias}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`),
),
);
} else {
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, relationField, relationModelAlias);
query = query.where((eb) =>
this.and(eb, ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),
);
}
return query;
}

private buildRelationObjectSelect(
relationModel: string,
relationModelAlias: string,
relationField: string,
relationFieldDef: FieldDef,
qb: SelectQueryBuilder<any, any, any>,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
parentName: string,
parentResultName: string,
) {
qb = qb.select((eb) => {
const objArgs = this.buildRelationObjectArgs(
relationModel,
relationModelAlias,
relationField,
eb,
payload,
parentName,
parentResultName,
);

if (relationFieldDef.array) {
return eb.fn
.coalesce(sql`jsonb_agg(jsonb_build_object(${sql.join(objArgs)}))`, sql`'[]'::jsonb`)
.as('$t');
.as('$data');
} else {
return sql`jsonb_build_object(${sql.join(objArgs)})`.as('$t');
return sql`jsonb_build_object(${sql.join(objArgs)})`.as('$data');
}
});

Expand All @@ -188,10 +212,9 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
private buildRelationObjectArgs(
relationModel: string,
relationModelAlias: string,
relationField: string,
eb: ExpressionBuilder<any, any>,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
parentAlias: string,
parentResultName: string,
) {
const relationModelDef = requireModel(this.schema, relationModel);
const objArgs: Array<
Expand Down Expand Up @@ -234,15 +257,15 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
const subJson = this.buildCountJson(
relationModel as GetModels<Schema>,
eb,
`${parentAlias}$${relationField}`,
relationModelAlias,
value,
);
return [sql.lit(field), subJson];
} else {
const fieldDef = requireField(this.schema, relationModel, field);
const fieldValue = fieldDef.relation
? // reference the synthesized JSON field
eb.ref(`${parentAlias}$${relationField}$${field}.$t`)
eb.ref(`${parentResultName}$${field}.$data`)
: // reference a plain field
this.fieldRef(relationModel, field, eb, undefined, false);
return [sql.lit(field), fieldValue];
Expand All @@ -260,7 +283,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
.map(([field]) => [
sql.lit(field),
// reference the synthesized JSON field
eb.ref(`${parentAlias}$${relationField}$${field}.$t`),
eb.ref(`${parentResultName}$${field}.$data`),
])
.flatMap((v) => v),
);
Expand All @@ -269,13 +292,13 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
}

private buildRelationJoins(
query: SelectQueryBuilder<any, any, any>,
relationModel: string,
relationField: string,
qb: SelectQueryBuilder<any, any, any>,
relationModelAlias: string,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
parentName: string,
parentResultName: string,
) {
let result = qb;
let result = query;
if (typeof payload === 'object') {
const selectInclude = payload.include ?? payload.select;
if (selectInclude && typeof selectInclude === 'object') {
Expand All @@ -287,8 +310,9 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
relationModel,
result,
field,
`${parentName}$${relationField}`,
relationModelAlias,
value,
`${parentResultName}$${field}`,
);
});
}
Expand Down
Loading