Skip to content

Commit 9ef7e18

Browse files
committed
optimize nested relation manipulation
1 parent ca06169 commit 9ef7e18

File tree

10 files changed

+106
-126
lines changed

10 files changed

+106
-126
lines changed

packages/runtime/src/client/crud/dialects/base.ts renamed to packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
11041104
return (node as ValueNode).value === false || (node as ValueNode).value === 0;
11051105
}
11061106

1107-
protected and(eb: ExpressionBuilder<any, any>, ...args: Expression<SqlBool>[]) {
1107+
and(eb: ExpressionBuilder<any, any>, ...args: Expression<SqlBool>[]) {
11081108
const nonTrueArgs = args.filter((arg) => !this.isTrue(arg));
11091109
if (nonTrueArgs.length === 0) {
11101110
return this.true(eb);
@@ -1115,7 +1115,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
11151115
}
11161116
}
11171117

1118-
protected or(eb: ExpressionBuilder<any, any>, ...args: Expression<SqlBool>[]) {
1118+
or(eb: ExpressionBuilder<any, any>, ...args: Expression<SqlBool>[]) {
11191119
const nonFalseArgs = args.filter((arg) => !this.isFalse(arg));
11201120
if (nonFalseArgs.length === 0) {
11211121
return this.false(eb);
@@ -1126,7 +1126,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
11261126
}
11271127
}
11281128

1129-
protected not(eb: ExpressionBuilder<any, any>, ...args: Expression<SqlBool>[]) {
1129+
not(eb: ExpressionBuilder<any, any>, ...args: Expression<SqlBool>[]) {
11301130
return eb.not(this.and(eb, ...args));
11311131
}
11321132

packages/runtime/src/client/crud/dialects/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { match } from 'ts-pattern';
22
import type { SchemaDef } from '../../../schema';
33
import type { ClientOptions } from '../../options';
4-
import type { BaseCrudDialect } from './base';
4+
import type { BaseCrudDialect } from './base-dialect';
55
import { PostgresCrudDialect } from './postgresql';
66
import { SqliteCrudDialect } from './sqlite';
77

packages/runtime/src/client/crud/dialects/postgresql.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {
2020
requireField,
2121
requireModel,
2222
} from '../../query-utils';
23-
import { BaseCrudDialect } from './base';
23+
import { BaseCrudDialect } from './base-dialect';
2424

2525
export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect<Schema> {
2626
override get provider() {

packages/runtime/src/client/crud/dialects/sqlite.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {
2020
requireField,
2121
requireModel,
2222
} from '../../query-utils';
23-
import { BaseCrudDialect } from './base';
23+
import { BaseCrudDialect } from './base-dialect';
2424

2525
export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect<Schema> {
2626
override get provider() {

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 74 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import {
77
UpdateResult,
88
type Compilable,
99
type IsolationLevel,
10-
type QueryResult,
1110
type SelectQueryBuilder,
1211
} from 'kysely';
1312
import { nanoid } from 'nanoid';
@@ -44,7 +43,7 @@ import {
4443
requireModel,
4544
} from '../../query-utils';
4645
import { getCrudDialect } from '../dialects';
47-
import type { BaseCrudDialect } from '../dialects/base';
46+
import type { BaseCrudDialect } from '../dialects/base-dialect';
4847
import { InputValidator } from '../validator';
4948

5049
export type CoreCrudOperation =
@@ -66,10 +65,16 @@ export type CoreCrudOperation =
6665

6766
export type AllCrudOperation = CoreCrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow';
6867

68+
// context for nested relation operations
6969
export type FromRelationContext<Schema extends SchemaDef> = {
70+
// the model where the relation field is defined
7071
model: GetModels<Schema>;
72+
// the relation field name
7173
field: string;
74+
// the parent entity's id fields and values
7275
ids: any;
76+
// for relations owned by model, record the parent updates needed after the relation is processed
77+
parentUpdates: Record<string, unknown>;
7378
};
7479

7580
export abstract class BaseOperationHandler<Schema extends SchemaDef> {
@@ -258,7 +263,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
258263
}
259264

260265
let createFields: any = {};
261-
let parentUpdateTask: ((entity: any) => Promise<unknown>) | undefined = undefined;
266+
let updateParent: ((entity: any) => void) | undefined = undefined;
262267

263268
let m2m: ReturnType<typeof getManyToManyRelation> = undefined;
264269

@@ -281,28 +286,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
281286
);
282287
Object.assign(createFields, parentFkFields);
283288
} else {
284-
parentUpdateTask = async (entity) => {
285-
const query = kysely
286-
.updateTable(fromRelation.model)
287-
.set(
288-
keyPairs.reduce(
289-
(acc, { fk, pk }) => ({
290-
...acc,
291-
[fk]: entity[pk],
292-
}),
293-
{} as any,
294-
),
295-
)
296-
.where((eb) => eb.and(fromRelation.ids))
297-
.modifyEnd(
298-
this.makeContextComment({
299-
model: fromRelation.model,
300-
operation: 'update',
301-
}),
302-
);
303-
const result = await this.executeQuery(kysely, query, 'update');
304-
if (!result.numAffectedRows) {
305-
throw new NotFoundError(fromRelation.model);
289+
// record parent fk update after entity is created
290+
updateParent = (entity) => {
291+
for (const { fk, pk } of keyPairs) {
292+
fromRelation.parentUpdates[fk] = entity[pk];
306293
}
307294
};
308295
}
@@ -406,8 +393,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
406393
}
407394

408395
// finally update parent if needed
409-
if (parentUpdateTask) {
410-
await parentUpdateTask(createdEntity);
396+
if (updateParent) {
397+
updateParent(createdEntity);
411398
}
412399

413400
return createdEntity;
@@ -611,10 +598,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
611598
const relationFieldDef = this.requireField(contextModel, relationFieldName);
612599
const relationModel = relationFieldDef.type as GetModels<Schema>;
613600
const tasks: Promise<unknown>[] = [];
614-
const fromRelationContext = {
601+
const fromRelationContext: FromRelationContext<Schema> = {
615602
model: contextModel,
616603
field: relationFieldName,
617604
ids: parentEntity,
605+
parentUpdates: {},
618606
};
619607

620608
for (const [action, subPayload] of Object.entries<any>(payload)) {
@@ -647,13 +635,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
647635
}
648636

649637
case 'connect': {
650-
tasks.push(
651-
this.connectRelation(kysely, relationModel, subPayload, {
652-
model: contextModel,
653-
field: relationFieldName,
654-
ids: parentEntity,
655-
}),
656-
);
638+
tasks.push(this.connectRelation(kysely, relationModel, subPayload, fromRelationContext));
657639
break;
658640
}
659641

@@ -662,16 +644,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
662644
...enumerate(subPayload).map((item) =>
663645
this.exists(kysely, relationModel, item.where).then((found) =>
664646
!found
665-
? this.create(kysely, relationModel, item.create, {
666-
model: contextModel,
667-
field: relationFieldName,
668-
ids: parentEntity,
669-
})
670-
: this.connectRelation(kysely, relationModel, found, {
671-
model: contextModel,
672-
field: relationFieldName,
673-
ids: parentEntity,
674-
}),
647+
? this.create(kysely, relationModel, item.create, fromRelationContext)
648+
: this.connectRelation(kysely, relationModel, found, fromRelationContext),
675649
),
676650
),
677651
);
@@ -1047,7 +1021,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
10471021
}
10481022
}
10491023
}
1050-
await this.processRelationUpdates(
1024+
const parentUpdates = await this.processRelationUpdates(
10511025
kysely,
10521026
model,
10531027
field,
@@ -1056,6 +1030,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
10561030
finalData[field],
10571031
throwIfNotFound,
10581032
);
1033+
1034+
if (Object.keys(parentUpdates).length > 0) {
1035+
// merge field updates propagated from nested relation processing
1036+
Object.assign(updateFields, parentUpdates);
1037+
}
10591038
}
10601039
}
10611040

@@ -1375,10 +1354,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
13751354
) {
13761355
const tasks: Promise<unknown>[] = [];
13771356
const fieldModel = fieldDef.type as GetModels<Schema>;
1378-
const fromRelationContext = {
1357+
const fromRelationContext: FromRelationContext<Schema> = {
13791358
model,
13801359
field,
13811360
ids: parentIds,
1361+
parentUpdates: {},
13821362
};
13831363

13841364
for (const [key, value] of Object.entries(args)) {
@@ -1509,6 +1489,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
15091489
}
15101490

15111491
await Promise.all(tasks);
1492+
1493+
return fromRelationContext.parentUpdates;
15121494
}
15131495

15141496
// #region relation manipulation
@@ -1553,42 +1535,21 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
15531535
fromRelation.model,
15541536
fromRelation.field,
15551537
);
1556-
let updateResult: QueryResult<unknown>;
1557-
let updateModel: GetModels<Schema>;
15581538

15591539
if (ownedByModel) {
1560-
updateModel = fromRelation.model;
1561-
1562-
// set parent fk directly
1540+
// record parent fk update
15631541
invariant(_data.length === 1, 'only one entity can be connected');
15641542
const target = await this.readUnique(kysely, model, {
15651543
where: _data[0],
15661544
});
15671545
if (!target) {
15681546
throw new NotFoundError(model);
15691547
}
1570-
const query = kysely
1571-
.updateTable(fromRelation.model)
1572-
.where((eb) => eb.and(fromRelation.ids))
1573-
.set(
1574-
keyPairs.reduce(
1575-
(acc, { fk, pk }) => ({
1576-
...acc,
1577-
[fk]: target[pk],
1578-
}),
1579-
{} as any,
1580-
),
1581-
)
1582-
.modifyEnd(
1583-
this.makeContextComment({
1584-
model: fromRelation.model,
1585-
operation: 'update',
1586-
}),
1587-
);
1588-
updateResult = await this.executeQuery(kysely, query, 'connect');
1589-
} else {
1590-
updateModel = model;
15911548

1549+
for (const { fk, pk } of keyPairs) {
1550+
fromRelation.parentUpdates[fk] = target[pk];
1551+
}
1552+
} else {
15921553
// disconnect current if it's a one-one relation
15931554
const relationFieldDef = this.requireField(fromRelation.model, fromRelation.field);
15941555

@@ -1625,13 +1586,13 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
16251586
operation: 'update',
16261587
}),
16271588
);
1628-
updateResult = await this.executeQuery(kysely, query, 'connect');
1629-
}
1589+
const updateResult = await this.executeQuery(kysely, query, 'connect');
16301590

1631-
// validate connect result
1632-
if (!updateResult.numAffectedRows || _data.length > updateResult.numAffectedRows) {
1633-
// some entities were not connected
1634-
throw new NotFoundError(updateModel);
1591+
// validate connect result
1592+
if (!updateResult.numAffectedRows || _data.length > updateResult.numAffectedRows) {
1593+
// some entities were not connected
1594+
throw new NotFoundError(model);
1595+
}
16351596
}
16361597
}
16371598
}
@@ -1715,42 +1676,42 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
17151676

17161677
const eb = expressionBuilder<any, any>();
17171678
if (ownedByModel) {
1718-
// set parent fk directly
1679+
// record parent fk update
17191680
invariant(disconnectConditions.length === 1, 'only one entity can be disconnected');
17201681
const condition = disconnectConditions[0];
1721-
const query = kysely
1722-
.updateTable(fromRelation.model)
1723-
// id filter
1724-
.where(eb.and(fromRelation.ids))
1725-
// merge extra disconnect conditions
1726-
.$if(condition !== true, (qb) =>
1727-
qb.where(
1728-
eb(
1729-
// @ts-ignore
1730-
eb.refTuple(...keyPairs.map(({ fk }) => fk)),
1731-
'in',
1732-
eb
1733-
.selectFrom(model)
1734-
.select(keyPairs.map(({ pk }) => pk))
1735-
.where(this.dialect.buildFilter(eb, model, model, condition)),
1736-
),
1737-
),
1738-
)
1739-
.set(keyPairs.reduce((acc, { fk }) => ({ ...acc, [fk]: null }), {} as any))
1740-
.modifyEnd(
1741-
this.makeContextComment({
1742-
model: fromRelation.model,
1743-
operation: 'update',
1744-
}),
1745-
);
1746-
const result = await this.executeQuery(kysely, query, 'disconnect');
1747-
if (!result.numAffectedRows) {
1748-
// determine if the parent entity doesn't exist, or the relation entity to be disconnected doesn't exist
1749-
const parentExists = await this.exists(kysely, fromRelation.model, fromRelation.ids);
1750-
if (!parentExists) {
1751-
throw new NotFoundError(fromRelation.model);
1752-
} else {
1753-
// silently ignore
1682+
1683+
if (condition === true) {
1684+
// just disconnect, record parent fk update
1685+
for (const { fk } of keyPairs) {
1686+
fromRelation.parentUpdates[fk] = null;
1687+
}
1688+
} else {
1689+
// disconnect with a filter
1690+
1691+
// read parent's fk
1692+
const fromEntity = await this.readUnique(kysely, fromRelation.model, {
1693+
where: fromRelation.ids,
1694+
select: fieldsToSelectObject(keyPairs.map(({ fk }) => fk)),
1695+
});
1696+
if (!fromEntity || keyPairs.some(({ fk }) => fromEntity[fk] == null)) {
1697+
return;
1698+
}
1699+
1700+
// check if the disconnect target exists under parent fk and the filter condition
1701+
const relationFilter = {
1702+
AND: [condition, Object.fromEntries(keyPairs.map(({ fk, pk }) => [pk, fromEntity[fk]]))],
1703+
};
1704+
1705+
// if the target exists, record parent fk update, otherwise do nothing
1706+
const targetExists = await this.read(kysely, model, {
1707+
where: relationFilter,
1708+
take: 1,
1709+
select: this.makeIdSelect(model),
1710+
} as any);
1711+
if (targetExists.length > 0) {
1712+
for (const { fk } of keyPairs) {
1713+
fromRelation.parentUpdates[fk] = null;
1714+
}
17541715
}
17551716
}
17561717
} else {

packages/runtime/src/client/options.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysel
22
import type { GetModel, GetModels, ProcedureDef, SchemaDef } from '../schema';
33
import type { PrependParameter } from '../utils/type-utils';
44
import type { ClientContract, CRUD, ProcedureFunc } from './contract';
5-
import type { BaseCrudDialect } from './crud/dialects/base';
5+
import type { BaseCrudDialect } from './crud/dialects/base-dialect';
66
import type { RuntimePlugin } from './plugin';
77
import type { ToKyselySchema } from './query-builder';
88

packages/runtime/src/plugins/policy/expression-transformer.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {
2222
import { match } from 'ts-pattern';
2323
import type { CRUD } from '../../client/contract';
2424
import { getCrudDialect } from '../../client/crud/dialects';
25-
import type { BaseCrudDialect } from '../../client/crud/dialects/base';
25+
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
2626
import { InternalError, QueryError } from '../../client/errors';
2727
import type { ClientOptions } from '../../client/options';
2828
import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils';

0 commit comments

Comments
 (0)