Skip to content

Commit 8074186

Browse files
Add read-only middleware support (#308)
## Why This PR adds support of read-only middleware support. Before a procedure is called, river will go through a chain of middleware synchronously and the middlewares are read-only so they shouldn't modify the context. <!-- Describe what you are trying to accomplish with this pull request --> ## What changed - Add `middlewares` in server options - Chain and run middlewares before procedure calls - Add middleware tests <!-- Describe the changes you made in this pull request or pointers for the reviewer --> ## Versioning - [ ] Breaking protocol change - [ ] Breaking ts/js API change <!-- Kind reminder to add tests and updated documentation if needed --> --------- Co-authored-by: masad-frost <farismasad@gmail.com>
1 parent c555a53 commit 8074186

File tree

5 files changed

+527
-152
lines changed

5 files changed

+527
-152
lines changed

__tests__/middleware.test.ts

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
2+
import { AsyncLocalStorage } from 'async_hooks';
3+
import { isReadableDone, readNextResult } from '../testUtil';
4+
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
5+
import {
6+
TestServiceSchema,
7+
SubscribableServiceSchema,
8+
UploadableServiceSchema,
9+
} from '../testUtil/fixtures/services';
10+
import {
11+
createClient,
12+
createServer,
13+
Ok,
14+
Procedure,
15+
ServiceSchema,
16+
Middleware,
17+
} from '../router';
18+
import { createMockTransportNetwork } from '../testUtil/fixtures/mockTransport';
19+
import { Type } from '@sinclair/typebox';
20+
21+
describe('middleware test', () => {
22+
let mockTransportNetwork: ReturnType<typeof createMockTransportNetwork>;
23+
24+
beforeEach(async () => {
25+
mockTransportNetwork = createMockTransportNetwork();
26+
});
27+
28+
afterEach(async () => {
29+
await mockTransportNetwork.cleanup();
30+
});
31+
32+
test('apply read-only middleware to rpc', async () => {
33+
const services = { test: TestServiceSchema };
34+
const middleware = vi.fn<Middleware>(({ next }) => next());
35+
createServer(mockTransportNetwork.getServerTransport(), services, {
36+
middlewares: [middleware],
37+
});
38+
const client = createClient<typeof services>(
39+
mockTransportNetwork.getClientTransport('client'),
40+
'SERVER',
41+
);
42+
43+
const result = await client.test.add.rpc({ n: 3 });
44+
expect(middleware).toHaveBeenCalledOnce();
45+
expect(middleware).toHaveBeenCalledWith(
46+
expect.objectContaining({
47+
ctx: expect.objectContaining({
48+
serviceName: 'test',
49+
procedureName: 'add',
50+
sessionId: expect.stringContaining('session-'),
51+
span: expect.objectContaining({}),
52+
streamId: expect.stringContaining(''),
53+
signal: expect.objectContaining({}),
54+
state: expect.objectContaining({}),
55+
}),
56+
reqInit: {
57+
n: 3,
58+
},
59+
}),
60+
);
61+
62+
expect(result).toStrictEqual({ ok: true, payload: { result: 3 } });
63+
});
64+
65+
test('apply read-only middleware to stream', async () => {
66+
const services = { test: TestServiceSchema };
67+
const middleware = vi.fn<Middleware>(({ next }) => next());
68+
createServer(mockTransportNetwork.getServerTransport(), services, {
69+
middlewares: [middleware],
70+
});
71+
const client = createClient<typeof services>(
72+
mockTransportNetwork.getClientTransport('client'),
73+
'SERVER',
74+
);
75+
76+
const { reqWritable, resReadable } = client.test.echo.stream({});
77+
78+
reqWritable.write({ msg: 'abc', ignore: false });
79+
reqWritable.write({ msg: 'def', ignore: true });
80+
reqWritable.write({ msg: 'ghi', ignore: false });
81+
reqWritable.close();
82+
83+
const result1 = await readNextResult(resReadable);
84+
expect(result1).toStrictEqual({ ok: true, payload: { response: 'abc' } });
85+
86+
const result2 = await readNextResult(resReadable);
87+
expect(result2).toStrictEqual({ ok: true, payload: { response: 'ghi' } });
88+
89+
expect(await isReadableDone(resReadable)).toEqual(true);
90+
91+
expect(middleware).toHaveBeenCalledOnce();
92+
expect(middleware).toHaveBeenCalledWith(
93+
expect.objectContaining({
94+
ctx: expect.objectContaining({
95+
serviceName: 'test',
96+
procedureName: 'echo',
97+
sessionId: expect.stringContaining('session-'),
98+
span: expect.objectContaining({}),
99+
streamId: expect.stringContaining(''),
100+
signal: expect.objectContaining({}),
101+
state: expect.objectContaining({}),
102+
}),
103+
reqInit: {},
104+
}),
105+
);
106+
});
107+
108+
test('apply read-only middleware to subscriptions', async () => {
109+
const services = { test: SubscribableServiceSchema };
110+
const middleware = vi.fn<Middleware>(({ next }) => next());
111+
createServer(mockTransportNetwork.getServerTransport(), services, {
112+
middlewares: [middleware],
113+
});
114+
const client = createClient<typeof services>(
115+
mockTransportNetwork.getClientTransport('client'),
116+
'SERVER',
117+
);
118+
119+
const { resReadable } = client.test.value.subscribe({});
120+
121+
const streamResult1 = await readNextResult(resReadable);
122+
expect(streamResult1).toStrictEqual({ ok: true, payload: { result: 0 } });
123+
124+
expect(middleware).toHaveBeenCalledOnce();
125+
expect(middleware).toHaveBeenCalledWith(
126+
expect.objectContaining({
127+
ctx: expect.objectContaining({
128+
serviceName: 'test',
129+
procedureName: 'value',
130+
sessionId: expect.stringContaining('session-'),
131+
span: expect.objectContaining({}),
132+
streamId: expect.stringContaining(''),
133+
signal: expect.objectContaining({}),
134+
state: expect.objectContaining({}),
135+
}),
136+
reqInit: {},
137+
}),
138+
);
139+
140+
const result = await client.test.add.rpc({ n: 3 });
141+
expect(result).toStrictEqual({ ok: true, payload: { result: 3 } });
142+
143+
expect(middleware).toHaveBeenCalledTimes(2);
144+
expect(middleware).toHaveBeenCalledWith(
145+
expect.objectContaining({
146+
ctx: expect.objectContaining({
147+
serviceName: 'test',
148+
procedureName: 'add',
149+
sessionId: expect.stringContaining('session-'),
150+
span: expect.objectContaining({}),
151+
streamId: expect.stringContaining(''),
152+
signal: expect.objectContaining({}),
153+
state: expect.objectContaining({}),
154+
}),
155+
reqInit: {
156+
n: 3,
157+
},
158+
}),
159+
);
160+
161+
const streamResult2 = await readNextResult(resReadable);
162+
expect(streamResult2).toStrictEqual({ ok: true, payload: { result: 3 } });
163+
});
164+
165+
test('apply read-only middleware to uploads', async () => {
166+
const services = { test: UploadableServiceSchema };
167+
const middleware = vi.fn<Middleware>(({ next }) => next());
168+
createServer(mockTransportNetwork.getServerTransport(), services, {
169+
middlewares: [middleware],
170+
});
171+
const client = createClient<typeof services>(
172+
mockTransportNetwork.getClientTransport('client'),
173+
'SERVER',
174+
);
175+
176+
const { reqWritable, finalize } = client.test.addMultiple.upload({});
177+
178+
reqWritable.write({ n: 1 });
179+
reqWritable.write({ n: 2 });
180+
reqWritable.close();
181+
expect(await finalize()).toStrictEqual({
182+
ok: true,
183+
payload: { result: 3 },
184+
});
185+
186+
expect(middleware).toHaveBeenCalledOnce();
187+
expect(middleware).toHaveBeenCalledWith(
188+
expect.objectContaining({
189+
ctx: expect.objectContaining({
190+
serviceName: 'test',
191+
procedureName: 'addMultiple',
192+
sessionId: expect.stringContaining('session-'),
193+
span: expect.objectContaining({}),
194+
streamId: expect.stringContaining(''),
195+
signal: expect.objectContaining({}),
196+
state: expect.objectContaining({}),
197+
}),
198+
reqInit: {},
199+
}),
200+
);
201+
});
202+
203+
test('apply multiple middlewares in order', async () => {
204+
const services = { test: TestServiceSchema };
205+
const middleware1 = vi.fn<Middleware>(({ next }) => {
206+
next();
207+
});
208+
const middleware2 = vi.fn<Middleware>(({ next }) => {
209+
next();
210+
});
211+
const middleware3 = vi.fn<Middleware>(({ next }) => {
212+
next();
213+
});
214+
createServer(mockTransportNetwork.getServerTransport(), services, {
215+
middlewares: [middleware1, middleware2, middleware3],
216+
});
217+
const client = createClient<typeof services>(
218+
mockTransportNetwork.getClientTransport('client'),
219+
'SERVER',
220+
);
221+
222+
const result = await client.test.add.rpc({ n: 3 });
223+
224+
expect(middleware1.mock.invocationCallOrder[0]).toBeLessThan(
225+
middleware2.mock.invocationCallOrder[0],
226+
);
227+
expect(middleware2.mock.invocationCallOrder[0]).toBeLessThan(
228+
middleware3.mock.invocationCallOrder[0],
229+
);
230+
expect(result).toStrictEqual({ ok: true, payload: { result: 3 } });
231+
});
232+
233+
// The reason we have a test for AsyncLocalStorage is that it depends on the
234+
// details of how the server applies middlewares; they have to be called within
235+
// callbacks so that the context is preserved.
236+
// Unfortunately we put our selves in a tough situation where vitest doesn't support
237+
// async hooks when using fake timers, and we have fake timers in our global setup, so
238+
// we only rely on Promise.resolve().then(() => {}) to test context propagation.
239+
test('AsyncLocalStorage context is propagated via AsyncLocalStorage.run', async () => {
240+
const storage = new AsyncLocalStorage<{
241+
readByHandler: boolean;
242+
readByHandlerSignal: boolean;
243+
readByOtherMiddleware: boolean;
244+
readByMiddlewareSignal: boolean;
245+
}>();
246+
247+
const AsyncStorageSchemas = ServiceSchema.define({
248+
gimmeStore: Procedure.rpc({
249+
requestInit: Type.Object({}),
250+
responseData: Type.Object({}),
251+
async handler({ ctx }) {
252+
ctx.signal.addEventListener('abort', () => {
253+
const s = storage.getStore();
254+
if (s) {
255+
s.readByHandlerSignal = true;
256+
}
257+
});
258+
259+
return Promise.resolve().then(() => {
260+
const s = storage.getStore();
261+
if (s) {
262+
s.readByHandler = true;
263+
}
264+
265+
return Ok({});
266+
});
267+
},
268+
}),
269+
});
270+
271+
// Kind of a funky AsyncLocalStorage set up where the store is
272+
// actually accessible everywhere but we promise to always get it from
273+
// the storage instance and only use store in our tests.
274+
const store = {
275+
readByHandler: false,
276+
readByHandlerSignal: false,
277+
readByOtherMiddleware: false,
278+
readByMiddlewareSignal: false,
279+
};
280+
281+
const middleware = vi.fn<Middleware>(({ ctx, next }) => {
282+
ctx.signal.addEventListener('abort', () => {
283+
const s = storage.getStore();
284+
if (s) {
285+
s.readByMiddlewareSignal = true;
286+
}
287+
});
288+
289+
storage.run(store, () => {
290+
next();
291+
});
292+
});
293+
// testing that middlewares in the chain inheret context from the previous
294+
const middlewarThatReadsFromStorage = vi.fn<Middleware>(({ next }) => {
295+
const s = storage.getStore();
296+
if (s) {
297+
s.readByOtherMiddleware = true;
298+
}
299+
300+
next();
301+
});
302+
// these extraneous looking middlewares are to make sure that different shapes of
303+
// middlewares running in the same context don't interfere with each other.
304+
const timeoutMiddleware = vi.fn<Middleware>(({ next }) => {
305+
void Promise.resolve().then(() => {
306+
next();
307+
});
308+
});
309+
const promiseMiddleware = vi.fn<Middleware>(async ({ next }) => {
310+
await Promise.resolve();
311+
312+
next();
313+
314+
await Promise.resolve();
315+
});
316+
317+
const services = { test: AsyncStorageSchemas };
318+
319+
createServer(mockTransportNetwork.getServerTransport(), services, {
320+
middlewares: [
321+
timeoutMiddleware,
322+
promiseMiddleware,
323+
middleware,
324+
middlewarThatReadsFromStorage,
325+
timeoutMiddleware,
326+
promiseMiddleware,
327+
],
328+
});
329+
const client = createClient<typeof services>(
330+
mockTransportNetwork.getClientTransport('client'),
331+
'SERVER',
332+
);
333+
334+
await client.test.gimmeStore.rpc({});
335+
336+
expect(middleware).toHaveBeenCalledOnce();
337+
expect(store).toStrictEqual({
338+
readByHandler: true,
339+
readByHandlerSignal: true,
340+
readByOtherMiddleware: true,
341+
readByMiddlewareSignal: true,
342+
});
343+
});
344+
});

0 commit comments

Comments
 (0)