Skip to content

Commit d6bb0c0

Browse files
authored
feat: trigger after mutation hooks after transaction is committed (#123)
* feat: trigger after mutation hooks after transaction is committed * update
1 parent 1667ce1 commit d6bb0c0

File tree

12 files changed

+265
-430
lines changed

12 files changed

+265
-430
lines changed

TODO.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- [ ] format
1212
- [ ] db seed
1313
- [ ] ZModel
14+
- [ ] Import
1415
- [ ] View support
1516
- [ ] ORM
1617
- [x] Create
@@ -80,8 +81,8 @@
8081
- [ ] Strict undefined checks
8182
- [ ] DbNull vs JsonNull
8283
- [ ] Benchmark
83-
- [ ] Plugin
84-
- [ ] Post-mutation hooks should be called after transaction is committed
84+
- [x] Plugin
85+
- [x] Post-mutation hooks should be called after transaction is committed
8586
- [x] TypeDef and mixin
8687
- [ ] Strongly typed JSON
8788
- [x] Polymorphism

packages/runtime/src/client/crud/operations/aggregate.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
112112
}
113113
}
114114

115-
const result = await query.executeTakeFirstOrThrow();
115+
const result = await this.executeQuery(this.kysely, query, 'aggregate');
116116
const ret: any = {};
117117

118118
// postprocess result to convert flat fields into nested objects
119-
for (const [key, value] of Object.entries(result as object)) {
119+
for (const [key, value] of Object.entries(result.rows[0] as object)) {
120120
if (key === '_count') {
121121
ret[key] = value;
122122
continue;

packages/runtime/src/client/crud/operations/base.ts

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import {
66
ExpressionWrapper,
77
sql,
88
UpdateResult,
9+
type Compilable,
910
type IsolationLevel,
1011
type Expression as KyselyExpression,
12+
type QueryResult,
1113
type SelectQueryBuilder,
1214
} from 'kysely';
1315
import { nanoid } from 'nanoid';
@@ -125,7 +127,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
125127
return getField(this.schema, model, field);
126128
}
127129

128-
protected exists(kysely: ToKysely<Schema>, model: GetModels<Schema>, filter: any): Promise<unknown | undefined> {
130+
protected async exists(
131+
kysely: ToKysely<Schema>,
132+
model: GetModels<Schema>,
133+
filter: any,
134+
): Promise<unknown | undefined> {
129135
const idFields = getIdFields(this.schema, model);
130136
const _filter = flattenCompoundUniqueFilters(this.schema, model, filter);
131137
const query = kysely
@@ -134,7 +140,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
134140
.select(idFields.map((f) => kysely.dynamic.ref(f)))
135141
.limit(1)
136142
.modifyEnd(this.makeContextComment({ model, operation: 'read' }));
137-
return query.executeTakeFirst();
143+
return this.executeQueryTakeFirst(kysely, query, 'exists');
138144
}
139145

140146
protected async read(
@@ -444,7 +450,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
444450
operation: 'update',
445451
}),
446452
);
447-
return query.execute();
453+
return this.executeQuery(kysely, query, 'update');
448454
};
449455
}
450456
}
@@ -511,10 +517,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
511517
}),
512518
);
513519

514-
const createdEntity = await query.executeTakeFirst();
520+
const createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create');
515521

516522
// try {
517-
// createdEntity = await query.executeTakeFirst();
523+
// createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create');
518524
// } catch (err) {
519525
// const { sql, parameters } = query.compile();
520526
// throw new QueryError(
@@ -893,8 +899,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
893899
);
894900

895901
if (!returnData) {
896-
const result = await query.executeTakeFirstOrThrow();
897-
return { count: Number(result.numInsertedOrUpdatedRows) } as Result;
902+
const result = await this.executeQuery(kysely, query, 'createMany');
903+
return { count: Number(result.numAffectedRows) } as Result;
898904
} else {
899905
const idFields = getIdFields(this.schema, model);
900906
const result = await query.returning(idFields as any).execute();
@@ -1160,10 +1166,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
11601166
}),
11611167
);
11621168

1163-
const updatedEntity = await query.executeTakeFirst();
1169+
const updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update');
11641170

11651171
// try {
1166-
// updatedEntity = await query.executeTakeFirst();
1172+
// updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update');
11671173
// } catch (err) {
11681174
// const { sql, parameters } = query.compile();
11691175
// throw new QueryError(
@@ -1401,8 +1407,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
14011407
query = query.modifyEnd(this.makeContextComment({ model, operation: 'update' }));
14021408

14031409
if (!returnData) {
1404-
const result = await query.executeTakeFirstOrThrow();
1405-
return { count: Number(result.numUpdatedRows) } as Result;
1410+
const result = await this.executeQuery(kysely, query, 'update');
1411+
return { count: Number(result.numAffectedRows) } as Result;
14061412
} else {
14071413
const idFields = getIdFields(this.schema, model);
14081414
const result = await query.returning(idFields as any).execute();
@@ -1636,7 +1642,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
16361642
fromRelation.model,
16371643
fromRelation.field,
16381644
);
1639-
let updateResult: UpdateResult;
1645+
let updateResult: QueryResult<unknown>;
16401646

16411647
if (ownedByModel) {
16421648
// set parent fk directly
@@ -1665,7 +1671,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
16651671
operation: 'update',
16661672
}),
16671673
);
1668-
updateResult = await query.executeTakeFirstOrThrow();
1674+
updateResult = await this.executeQuery(kysely, query, 'connect');
16691675
} else {
16701676
// disconnect current if it's a one-one relation
16711677
const relationFieldDef = this.requireField(fromRelation.model, fromRelation.field);
@@ -1681,7 +1687,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
16811687
operation: 'update',
16821688
}),
16831689
);
1684-
await query.execute();
1690+
await this.executeQuery(kysely, query, 'disconnect');
16851691
}
16861692

16871693
// connect
@@ -1703,11 +1709,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
17031709
operation: 'update',
17041710
}),
17051711
);
1706-
updateResult = await query.executeTakeFirstOrThrow();
1712+
updateResult = await this.executeQuery(kysely, query, 'connect');
17071713
}
17081714

17091715
// validate connect result
1710-
if (_data.length > updateResult.numUpdatedRows) {
1716+
if (_data.length > updateResult.numAffectedRows!) {
17111717
// some entities were not connected
17121718
throw new NotFoundError(model);
17131719
}
@@ -1821,7 +1827,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
18211827
operation: 'update',
18221828
}),
18231829
);
1824-
await query.executeTakeFirstOrThrow();
1830+
await this.executeQuery(kysely, query, 'disconnect');
18251831
} else {
18261832
// disconnect
18271833
const query = kysely
@@ -1841,7 +1847,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
18411847
operation: 'update',
18421848
}),
18431849
);
1844-
await query.executeTakeFirstOrThrow();
1850+
await this.executeQuery(kysely, query, 'disconnect');
18451851
}
18461852
}
18471853
}
@@ -1920,7 +1926,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
19201926
operation: 'update',
19211927
}),
19221928
);
1923-
await query.execute();
1929+
await this.executeQuery(kysely, query, 'disconnect');
19241930

19251931
// connect
19261932
if (_data.length > 0) {
@@ -1942,10 +1948,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
19421948
operation: 'update',
19431949
}),
19441950
);
1945-
const r = await query.executeTakeFirstOrThrow();
1951+
const r = await this.executeQuery(kysely, query, 'connect');
19461952

19471953
// validate result
1948-
if (_data.length > r.numUpdatedRows!) {
1954+
if (_data.length > r.numAffectedRows!) {
19491955
// some entities were not connected
19501956
throw new NotFoundError(model);
19511957
}
@@ -2109,8 +2115,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
21092115
await this.processDelegateRelationDelete(kysely, modelDef, where, limit);
21102116

21112117
query = query.modifyEnd(this.makeContextComment({ model, operation: 'delete' }));
2112-
const result = await query.executeTakeFirstOrThrow();
2113-
return { count: Number(result.numDeletedRows) };
2118+
const result = await this.executeQuery(kysely, query, 'delete');
2119+
return { count: Number(result.numAffectedRows) };
21142120
}
21152121

21162122
private async processDelegateRelationDelete(
@@ -2240,4 +2246,25 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
22402246
}
22412247
}
22422248
}
2249+
2250+
protected makeQueryId(operation: string) {
2251+
return { queryId: `${operation}-${createId()}` };
2252+
}
2253+
2254+
protected executeQuery(kysely: ToKysely<Schema>, query: Compilable, operation: string) {
2255+
return kysely.executeQuery(query.compile(), this.makeQueryId(operation));
2256+
}
2257+
2258+
protected async executeQueryTakeFirst(kysely: ToKysely<Schema>, query: Compilable, operation: string) {
2259+
const result = await kysely.executeQuery(query.compile(), this.makeQueryId(operation));
2260+
return result.rows[0];
2261+
}
2262+
2263+
protected async executeQueryTakeFirstOrThrow(kysely: ToKysely<Schema>, query: Compilable, operation: string) {
2264+
const result = await kysely.executeQuery(query.compile(), this.makeQueryId(operation));
2265+
if (result.rows.length === 0) {
2266+
throw new QueryError('No rows found');
2267+
}
2268+
return result.rows[0];
2269+
}
22432270
}

packages/runtime/src/client/crud/operations/count.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ export class CountOperationHandler<Schema extends SchemaDef> extends BaseOperati
4444
: eb.cast(eb.fn.count(sql.ref(`${subQueryName}.${key}`)), 'integer').as(key),
4545
),
4646
);
47-
48-
return query.executeTakeFirstOrThrow();
47+
const result = await this.executeQuery(this.kysely, query, 'count');
48+
return result.rows[0];
4949
} else {
5050
// simple count all
5151
query = query.select((eb) => eb.cast(eb.fn.countAll(), 'integer').as('count'));
52-
const result = await query.executeTakeFirstOrThrow();
53-
return (result as any).count as number;
52+
const result = await this.executeQuery(this.kysely, query, 'count');
53+
return (result.rows[0] as any).count as number;
5454
}
5555
}
5656
}

packages/runtime/src/client/crud/operations/delete.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export class DeleteOperationHandler<Schema extends SchemaDef> extends BaseOperat
3030

3131
// TODO: avoid using transaction for simple delete
3232
await this.safeTransaction(async (tx) => {
33-
const result = await this.delete(tx, this.model, args.where, undefined);
33+
const result = await this.delete(tx, this.model, args.where);
3434
if (result.count === 0) {
3535
throw new NotFoundError(this.model);
3636
}

packages/runtime/src/client/crud/operations/group-by.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
108108
}
109109
}
110110

111-
const result = await query.execute();
112-
return result.map((row) => this.postProcessRow(row));
111+
const result = await this.executeQuery(this.kysely, query, 'groupBy');
112+
return result.rows.map((row) => this.postProcessRow(row));
113113
}
114114

115115
private postProcessRow(row: any) {

packages/runtime/src/client/executor/zenstack-driver.ts

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import type { CompiledQuery, DatabaseConnection, Driver, Log, QueryResult, Trans
66
export class ZenStackDriver implements Driver {
77
readonly #driver: Driver;
88
readonly #log: Log;
9-
txConnection: DatabaseConnection | undefined;
109

1110
#initPromise?: Promise<void>;
1211
#initDone: boolean;
1312
#destroyPromise?: Promise<void>;
1413
#connections = new WeakSet<DatabaseConnection>();
14+
#txConnections = new WeakMap<DatabaseConnection, Array<() => Promise<unknown>>>();
1515

1616
constructor(driver: Driver, log: Log) {
1717
this.#initDone = false;
@@ -67,23 +67,33 @@ export class ZenStackDriver implements Driver {
6767

6868
async beginTransaction(connection: DatabaseConnection, settings: TransactionSettings): Promise<void> {
6969
const result = await this.#driver.beginTransaction(connection, settings);
70-
this.txConnection = connection;
70+
this.#txConnections.set(connection, []);
7171
return result;
7272
}
7373

74-
commitTransaction(connection: DatabaseConnection): Promise<void> {
74+
async commitTransaction(connection: DatabaseConnection): Promise<void> {
7575
try {
76-
return this.#driver.commitTransaction(connection);
77-
} finally {
78-
this.txConnection = undefined;
76+
const result = await this.#driver.commitTransaction(connection);
77+
const callbacks = this.#txConnections.get(connection);
78+
// delete from the map immediately to avoid accidental re-triggering
79+
this.#txConnections.delete(connection);
80+
if (callbacks) {
81+
for (const callback of callbacks) {
82+
await callback();
83+
}
84+
}
85+
return result;
86+
} catch (err) {
87+
this.#txConnections.delete(connection);
88+
throw err;
7989
}
8090
}
8191

82-
rollbackTransaction(connection: DatabaseConnection): Promise<void> {
92+
async rollbackTransaction(connection: DatabaseConnection): Promise<void> {
8393
try {
84-
return this.#driver.rollbackTransaction(connection);
94+
return await this.#driver.rollbackTransaction(connection);
8595
} finally {
86-
this.txConnection = undefined;
96+
this.#txConnections.delete(connection);
8797
}
8898
}
8999

@@ -175,6 +185,22 @@ export class ZenStackDriver implements Driver {
175185
#calculateDurationMillis(startTime: number): number {
176186
return performanceNow() - startTime;
177187
}
188+
189+
isTransactionConnection(connection: DatabaseConnection): boolean {
190+
return this.#txConnections.has(connection);
191+
}
192+
193+
registerTransactionCommitCallback(connection: DatabaseConnection, callback: () => Promise<unknown>): void {
194+
if (!this.#txConnections.has(connection)) {
195+
return;
196+
}
197+
const callbacks = this.#txConnections.get(connection);
198+
if (callbacks) {
199+
callbacks.push(callback);
200+
} else {
201+
this.#txConnections.set(connection, [callback]);
202+
}
203+
}
178204
}
179205

180206
export function performanceNow() {

0 commit comments

Comments
 (0)