Skip to content

Commit cd6ccb4

Browse files
authored
fix: issues and tests about self relations (#189)
1 parent 286b14f commit cd6ccb4

File tree

17 files changed

+1696
-825
lines changed

17 files changed

+1696
-825
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ ZenStack v3 allows you to define database-evaluated computed fields with the fol
249249
postCount: (eb) =>
250250
eb
251251
.selectFrom('Post')
252-
.whereRef('Post.authorId', '=', 'User.id')
252+
.whereRef('Post.authorId', '=', 'id')
253253
.select(({ fn }) =>
254254
fn.countAll<number>().as('postCount')
255255
),

packages/runtime/src/client/crud/dialects/base.ts

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
4747

4848
// #region common query builders
4949

50-
buildSelectModel(eb: ExpressionBuilder<any, any>, model: string) {
50+
buildSelectModel(eb: ExpressionBuilder<any, any>, model: string, modelAlias: string) {
5151
const modelDef = requireModel(this.schema, model);
52-
let result = eb.selectFrom(model);
52+
let result = eb.selectFrom(model === modelAlias ? model : `${model} as ${modelAlias}`);
5353
// join all delegate bases
5454
let joinBase = modelDef.baseModel;
5555
while (joinBase) {
56-
result = this.buildDelegateJoin(model, joinBase, result);
56+
result = this.buildDelegateJoin(model, modelAlias, joinBase, result);
5757
joinBase = requireModel(this.schema, joinBase).baseModel;
5858
}
5959
return result;
@@ -63,12 +63,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
6363
model: GetModels<Schema>,
6464
args: FindArgs<Schema, GetModels<Schema>, true>,
6565
query: SelectQueryBuilder<any, any, {}>,
66+
modelAlias: string,
6667
) {
6768
let result = query;
6869

6970
// where
7071
if (args.where) {
71-
result = result.where((eb) => this.buildFilter(eb, model, model, args?.where));
72+
result = result.where((eb) => this.buildFilter(eb, model, modelAlias, args?.where));
7273
}
7374

7475
// skip && take
@@ -85,7 +86,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
8586
result = this.buildOrderBy(
8687
result,
8788
model,
88-
model,
89+
modelAlias,
8990
args.orderBy,
9091
skip !== undefined || take !== undefined,
9192
negateOrderBy,
@@ -95,14 +96,14 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
9596
if ('distinct' in args && (args as any).distinct) {
9697
const distinct = ensureArray((args as any).distinct) as string[];
9798
if (this.supportsDistinctOn) {
98-
result = result.distinctOn(distinct.map((f) => sql.ref(`${model}.${f}`)));
99+
result = result.distinctOn(distinct.map((f) => sql.ref(`${modelAlias}.${f}`)));
99100
} else {
100101
throw new QueryError(`"distinct" is not supported by "${this.schema.provider.type}" provider`);
101102
}
102103
}
103104

104105
if (args.cursor) {
105-
result = this.buildCursorFilter(model, result, args.cursor, args.orderBy, negateOrderBy);
106+
result = this.buildCursorFilter(model, result, args.cursor, args.orderBy, negateOrderBy, modelAlias);
106107
}
107108
return result;
108109
}
@@ -172,13 +173,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
172173
cursor: FindArgs<Schema, GetModels<Schema>, true>['cursor'],
173174
orderBy: FindArgs<Schema, GetModels<Schema>, true>['orderBy'],
174175
negateOrderBy: boolean,
176+
modelAlias: string,
175177
) {
176178
const _orderBy = orderBy ?? makeDefaultOrderBy(this.schema, model);
177179

178180
const orderByItems = ensureArray(_orderBy).flatMap((obj) => Object.entries<SortOrder>(obj));
179181

180182
const eb = expressionBuilder<any, any>();
181-
const cursorFilter = this.buildFilter(eb, model, model, cursor);
183+
const subQueryAlias = `${model}$cursor$sub`;
184+
const cursorFilter = this.buildFilter(eb, model, subQueryAlias, cursor);
182185

183186
let result = query;
184187
const filters: ExpressionWrapper<any, any, any>[] = [];
@@ -192,9 +195,11 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
192195
const op = j === i ? (_order === 'asc' ? '>=' : '<=') : '=';
193196
andFilters.push(
194197
eb(
195-
eb.ref(`${model}.${field}`),
198+
eb.ref(`${modelAlias}.${field}`),
196199
op,
197-
eb.selectFrom(model).select(`${model}.${field}`).where(cursorFilter),
200+
this.buildSelectModel(eb, model, subQueryAlias)
201+
.select(`${subQueryAlias}.${field}`)
202+
.where(cursorFilter),
198203
),
199204
);
200205
}
@@ -341,34 +346,38 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
341346
private buildToManyRelationFilter(
342347
eb: ExpressionBuilder<any, any>,
343348
model: string,
344-
table: string,
349+
modelAlias: string,
345350
field: string,
346351
fieldDef: FieldDef,
347352
payload: any,
348353
) {
349354
// null check needs to be converted to fk "is null" checks
350355
if (payload === null) {
351-
return eb(sql.ref(`${table}.${field}`), 'is', null);
356+
return eb(sql.ref(`${modelAlias}.${field}`), 'is', null);
352357
}
353358

354359
const relationModel = fieldDef.type;
355360

361+
// evaluating the filter involves creating an inner select,
362+
// give it an alias to avoid conflict
363+
const relationFilterSelectAlias = `${modelAlias}$${field}$filter`;
364+
356365
const buildPkFkWhereRefs = (eb: ExpressionBuilder<any, any>) => {
357366
const m2m = getManyToManyRelation(this.schema, model, field);
358367
if (m2m) {
359368
// many-to-many relation
360369
const modelIdField = getIdFields(this.schema, model)[0]!;
361370
const relationIdField = getIdFields(this.schema, relationModel)[0]!;
362371
return eb(
363-
sql.ref(`${relationModel}.${relationIdField}`),
372+
sql.ref(`${relationFilterSelectAlias}.${relationIdField}`),
364373
'in',
365374
eb
366375
.selectFrom(m2m.joinTable)
367376
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
368377
.whereRef(
369378
sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
370379
'=',
371-
sql.ref(`${table}.${modelIdField}`),
380+
sql.ref(`${modelAlias}.${modelIdField}`),
372381
),
373382
);
374383
} else {
@@ -380,13 +389,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
380389
result = this.and(
381390
eb,
382391
result,
383-
eb(sql.ref(`${table}.${fk}`), '=', sql.ref(`${relationModel}.${pk}`)),
392+
eb(sql.ref(`${modelAlias}.${fk}`), '=', sql.ref(`${relationFilterSelectAlias}.${pk}`)),
384393
);
385394
} else {
386395
result = this.and(
387396
eb,
388397
result,
389-
eb(sql.ref(`${table}.${pk}`), '=', sql.ref(`${relationModel}.${fk}`)),
398+
eb(sql.ref(`${modelAlias}.${pk}`), '=', sql.ref(`${relationFilterSelectAlias}.${fk}`)),
390399
);
391400
}
392401
}
@@ -407,10 +416,12 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
407416
eb,
408417
result,
409418
eb(
410-
this.buildSelectModel(eb, relationModel)
419+
this.buildSelectModel(eb, relationModel, relationFilterSelectAlias)
411420
.select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count'))
412421
.where(buildPkFkWhereRefs(eb))
413-
.where((eb1) => this.buildFilter(eb1, relationModel, relationModel, subPayload)),
422+
.where((eb1) =>
423+
this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload),
424+
),
414425
'>',
415426
0,
416427
),
@@ -423,11 +434,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
423434
eb,
424435
result,
425436
eb(
426-
this.buildSelectModel(eb, relationModel)
437+
this.buildSelectModel(eb, relationModel, relationFilterSelectAlias)
427438
.select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count'))
428439
.where(buildPkFkWhereRefs(eb))
429440
.where((eb1) =>
430-
eb1.not(this.buildFilter(eb1, relationModel, relationModel, subPayload)),
441+
eb1.not(
442+
this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload),
443+
),
431444
),
432445
'=',
433446
0,
@@ -441,10 +454,12 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
441454
eb,
442455
result,
443456
eb(
444-
this.buildSelectModel(eb, relationModel)
457+
this.buildSelectModel(eb, relationModel, relationFilterSelectAlias)
445458
.select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count'))
446459
.where(buildPkFkWhereRefs(eb))
447-
.where((eb1) => this.buildFilter(eb1, relationModel, relationModel, subPayload)),
460+
.where((eb1) =>
461+
this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload),
462+
),
448463
'=',
449464
0,
450465
),
@@ -874,8 +889,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
874889
);
875890
const sort = this.negateSort(value._count, negated);
876891
result = result.orderBy((eb) => {
877-
let subQuery = this.buildSelectModel(eb, relationModel);
878-
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, relationModel);
892+
const subQueryAlias = `${modelAlias}$orderBy$${field}$count`;
893+
let subQuery = this.buildSelectModel(eb, relationModel, subQueryAlias);
894+
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, subQueryAlias);
879895
subQuery = subQuery.where(() =>
880896
this.and(
881897
eb,
@@ -909,7 +925,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
909925
buildSelectAllFields(
910926
model: string,
911927
query: SelectQueryBuilder<any, any, any>,
912-
omit?: Record<string, boolean | undefined>,
928+
omit: Record<string, boolean | undefined> | undefined,
929+
modelAlias: string,
913930
) {
914931
const modelDef = requireModel(this.schema, model);
915932
let result = query;
@@ -921,13 +938,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
921938
if (omit?.[field] === true) {
922939
continue;
923940
}
924-
result = this.buildSelectField(result, model, model, field);
941+
result = this.buildSelectField(result, model, modelAlias, field);
925942
}
926943

927944
// select all fields from delegate descendants and pack into a JSON field `$delegate$Model`
928945
const descendants = getDelegateDescendantModels(this.schema, model);
929946
for (const subModel of descendants) {
930-
result = this.buildDelegateJoin(model, subModel.name, result);
947+
result = this.buildDelegateJoin(model, modelAlias, subModel.name, result);
931948
result = result.select((eb) => {
932949
const jsonObject: Record<string, Expression<any>> = {};
933950
for (const field of Object.keys(subModel.fields)) {
@@ -964,11 +981,16 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
964981
}
965982
}
966983

967-
buildDelegateJoin(thisModel: string, otherModel: string, query: SelectQueryBuilder<any, any, any>) {
984+
buildDelegateJoin(
985+
thisModel: string,
986+
thisModelAlias: string,
987+
otherModelAlias: string,
988+
query: SelectQueryBuilder<any, any, any>,
989+
) {
968990
const idFields = getIdFields(this.schema, thisModel);
969-
query = query.leftJoin(otherModel, (qb) => {
991+
query = query.leftJoin(otherModelAlias, (qb) => {
970992
for (const idField of idFields) {
971-
qb = qb.onRef(`${thisModel}.${idField}`, '=', `${otherModel}.${idField}`);
993+
qb = qb.onRef(`${thisModelAlias}.${idField}`, '=', `${otherModelAlias}.${idField}`);
972994
}
973995
return qb;
974996
});

packages/runtime/src/client/crud/dialects/postgresql.ts

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,22 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
8282

8383
// however if there're filter/orderBy/take/skip,
8484
// we need to build a subquery to handle them before aggregation
85+
86+
// give sub query an alias to avoid conflict with parent scope
87+
// (e.g., for cases like self-relation)
88+
const subQueryAlias = `${relationModel}$${relationField}$sub`;
89+
8590
result = eb.selectFrom(() => {
86-
let subQuery = this.buildSelectModel(eb, relationModel);
91+
let subQuery = this.buildSelectModel(eb, relationModel, subQueryAlias);
8792
subQuery = this.buildSelectAllFields(
8893
relationModel,
8994
subQuery,
9095
typeof payload === 'object' ? payload?.omit : undefined,
96+
subQueryAlias,
9197
);
9298

9399
if (payload && typeof payload === 'object') {
94-
subQuery = this.buildFilterSortTake(relationModel, payload, subQuery);
100+
subQuery = this.buildFilterSortTake(relationModel, payload, subQuery, subQueryAlias);
95101
}
96102

97103
// add join conditions
@@ -106,7 +112,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
106112
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
107113
subQuery = subQuery.where(
108114
eb(
109-
eb.ref(`${relationModel}.${relationIds[0]}`),
115+
eb.ref(`${subQueryAlias}.${relationIds[0]}`),
110116
'in',
111117
eb
112118
.selectFrom(m2m.joinTable)
@@ -119,7 +125,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
119125
),
120126
);
121127
} else {
122-
const joinPairs = buildJoinPairs(this.schema, model, parentName, relationField, relationModel);
128+
const joinPairs = buildJoinPairs(this.schema, model, parentName, relationField, subQueryAlias);
123129
subQuery = subQuery.where((eb) =>
124130
this.and(eb, ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),
125131
);
@@ -130,6 +136,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
130136

131137
result = this.buildRelationObjectSelect(
132138
relationModel,
139+
joinTableName,
133140
relationField,
134141
relationFieldDef,
135142
result,
@@ -149,14 +156,22 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
149156

150157
private buildRelationObjectSelect(
151158
relationModel: string,
159+
relationModelAlias: string,
152160
relationField: string,
153161
relationFieldDef: FieldDef,
154162
qb: SelectQueryBuilder<any, any, any>,
155163
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
156164
parentName: string,
157165
) {
158166
qb = qb.select((eb) => {
159-
const objArgs = this.buildRelationObjectArgs(relationModel, relationField, eb, payload, parentName);
167+
const objArgs = this.buildRelationObjectArgs(
168+
relationModel,
169+
relationModelAlias,
170+
relationField,
171+
eb,
172+
payload,
173+
parentName,
174+
);
160175

161176
if (relationFieldDef.array) {
162177
return eb.fn
@@ -172,6 +187,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
172187

173188
private buildRelationObjectArgs(
174189
relationModel: string,
190+
relationModelAlias: string,
175191
relationField: string,
176192
eb: ExpressionBuilder<any, any>,
177193
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
@@ -202,7 +218,10 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
202218
...Object.entries(relationModelDef.fields)
203219
.filter(([, value]) => !value.relation)
204220
.filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true))
205-
.map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb, undefined, false)])
221+
.map(([field]) => [
222+
sql.lit(field),
223+
this.fieldRef(relationModel, field, eb, relationModelAlias, false),
224+
])
206225
.flatMap((v) => v),
207226
);
208227
} else if (payload.select) {

0 commit comments

Comments
 (0)