diff --git a/README.md b/README.md index 1ce6826e..ed8829bb 100644 --- a/README.md +++ b/README.md @@ -189,10 +189,14 @@ If your application is stateful on either the server or the client, the service ```ts transport.addEventListener('sessionStatus', (evt) => { - if (evt.status === 'connect') { + if (evt.status === 'created') { // do something - } else if (evt.status === 'disconnect') { - // do something else + } else if (evt.status === 'closing') { + // do other things + } else if (evt.status === 'closed') { + // note that evt.session only has id + to + // this is useful for doing things like creating a new session if + // a session just got yanked } }); diff --git a/__tests__/negative.test.ts b/__tests__/negative.test.ts index c4fcb9e9..0d07f3c7 100644 --- a/__tests__/negative.test.ts +++ b/__tests__/negative.test.ts @@ -205,7 +205,7 @@ describe('should handle incompatabilities', async () => { expect(errMock).toHaveBeenCalledTimes(0); expect(spy).toHaveBeenCalledWith( expect.objectContaining({ - status: 'connect', + status: 'created', }), ); diff --git a/package-lock.json b/package-lock.json index 7efba3af..4931f62e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.204.0", + "version": "0.205.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.204.0", + "version": "0.205.0", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", diff --git a/package.json b/package.json index 2b1bd0cf..e6fff93c 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.204.0", + "version": "0.205.0", "type": "module", "exports": { ".": { diff --git a/router/client.ts b/router/client.ts index 6e013109..279eaa12 100644 --- a/router/client.ts +++ b/router/client.ts @@ -456,7 +456,7 @@ function handleProc( function onSessionStatus(evt: EventMap['sessionStatus']) { if ( - evt.status !== 'disconnect' || + evt.status !== 'closing' || evt.session.to !== serverId || session.id !== evt.session.id ) { diff --git a/router/server.ts b/router/server.ts index 7c6e53bf..0a6ff61b 100644 --- a/router/server.ts +++ b/router/server.ts @@ -216,7 +216,7 @@ class RiverServer }; const handleSessionStatus = (evt: EventMap['sessionStatus']) => { - if (evt.status !== 'disconnect') return; + if (evt.status !== 'closing') return; const disconnectedClientId = evt.session.to; this.log?.info( diff --git a/testUtil/fixtures/matrix.ts b/testUtil/fixtures/matrix.ts index bde4ff3b..d89ea20f 100644 --- a/testUtil/fixtures/matrix.ts +++ b/testUtil/fixtures/matrix.ts @@ -17,7 +17,7 @@ interface TestMatrixEntry { /** * Defines a selector type that pairs a valid transport with a valid codec. */ -type Selector = [ValidTransports, ValidCodecs]; +type Selector = [ValidTransports | 'all', ValidCodecs | 'all']; /** * Generates a matrix of test entries for each combination of transport and codec. @@ -26,16 +26,23 @@ type Selector = [ValidTransports, ValidCodecs]; * @param selector An optional tuple specifying a transport and codec to filter the matrix. * @returns An array of TestMatrixEntry objects representing the combinations of transport and codec. */ -export const testMatrix = (selector?: Selector): Array => - transports +export const testMatrix = ( + [transportSelector, codecSelector]: Selector = ['all', 'all'], +): Array => { + const filteredTransports = transports.filter( + (t) => transportSelector === 'all' || t.name === transportSelector, + ); + + const filteredCodecs = codecs.filter( + (c) => codecSelector === 'all' || c.name === codecSelector, + ); + + return filteredTransports .map((transport) => - // If a selector is provided, filter transport + codecs to match the selector; otherwise, use all codecs. - (selector - ? codecs.filter((codec) => selector[1] === codec.name) - : codecs - ).map((codec) => ({ + filteredCodecs.map((codec) => ({ transport, codec, })), ) .flat(); +}; diff --git a/transport/client.ts b/transport/client.ts index 5025fd73..642dede6 100644 --- a/transport/client.ts +++ b/transport/client.ts @@ -372,7 +372,9 @@ export abstract class ClientTransport< * and don't want to wait for the grace period to elapse. */ hardDisconnect() { - for (const session of this.sessions.values()) { + // create a copy of the sessions to avoid modifying the map while iterating + const sessions = Array.from(this.sessions.values()); + for (const session of sessions) { this.deleteSession(session); } } diff --git a/transport/events.ts b/transport/events.ts index 470dbd26..c0c590a6 100644 --- a/transport/events.ts +++ b/transport/events.ts @@ -16,10 +16,15 @@ export type ProtocolErrorType = export interface EventMap { message: OpaqueTransportMessage; - sessionStatus: { - status: 'connect' | 'disconnect'; - session: Session; - }; + sessionStatus: + | { + status: 'created' | 'closing'; + session: Session; + } + | { + status: 'closed'; + session: Pick, 'id' | 'to'>; + }; sessionTransition: | { state: SessionState.Connected } | { state: SessionState.Handshaking } diff --git a/transport/transport.test.ts b/transport/transport.test.ts index e006aa4d..32c8c53f 100644 --- a/transport/transport.test.ts +++ b/transport/transport.test.ts @@ -254,7 +254,7 @@ describe.each(testMatrix())( const msg = createDummyTransportMessage(); const msgPromise = waitForMessage(serverTransport); const sendHandle = (evt: EventMap['sessionStatus']) => { - if (evt.status === 'connect') { + if (evt.status === 'created') { getClientSendFn(clientTransport, serverTransport)(msg); } }; @@ -354,7 +354,7 @@ describe.each(testMatrix())( }); }); - test('both client and server transport get connect/disconnect notifs', async () => { + test('both client and server transport get session created/closing/closed notifs', async () => { const clientTransport = getClientTransport('client'); const serverTransport = getServerTransport(); const clientConnStart = vi.fn(); @@ -367,13 +367,17 @@ describe.each(testMatrix())( }; const clientSessStart = vi.fn(); + const clientSessStopping = vi.fn(); const clientSessStop = vi.fn(); const clientSessHandler = (evt: EventMap['sessionStatus']) => { switch (evt.status) { - case 'connect': + case 'created': clientSessStart(); break; - case 'disconnect': + case 'closing': + clientSessStopping(); + break; + case 'closed': clientSessStop(); break; } @@ -389,13 +393,17 @@ describe.each(testMatrix())( }; const serverSessStart = vi.fn(); + const serverSessStopping = vi.fn(); const serverSessStop = vi.fn(); const serverSessHandler = (evt: EventMap['sessionStatus']) => { switch (evt.status) { - case 'connect': + case 'created': serverSessStart(); break; - case 'disconnect': + case 'closing': + serverSessStopping(); + break; + case 'closed': serverSessStop(); break; } @@ -438,6 +446,8 @@ describe.each(testMatrix())( expect(serverSessStart).toHaveBeenCalledTimes(0); expect(clientSessStop).toHaveBeenCalledTimes(0); expect(serverSessStop).toHaveBeenCalledTimes(0); + expect(clientSessStopping).toHaveBeenCalledTimes(0); + expect(serverSessStopping).toHaveBeenCalledTimes(0); clientTransport.connect(serverTransport.clientId); const clientSendFn = getClientSendFn(clientTransport, serverTransport); @@ -455,6 +465,8 @@ describe.each(testMatrix())( expect(serverSessStart).toHaveBeenCalledTimes(1); expect(clientSessStop).toHaveBeenCalledTimes(0); expect(serverSessStop).toHaveBeenCalledTimes(0); + expect(clientSessStopping).toHaveBeenCalledTimes(0); + expect(serverSessStopping).toHaveBeenCalledTimes(0); // clean disconnect + reconnect within grace period closeAllConnections(clientTransport); @@ -471,6 +483,8 @@ describe.each(testMatrix())( await waitFor(() => expect(serverSessStart).toHaveBeenCalledTimes(1)); await waitFor(() => expect(clientSessStop).toHaveBeenCalledTimes(0)); await waitFor(() => expect(serverSessStop).toHaveBeenCalledTimes(0)); + await waitFor(() => expect(clientSessStopping).toHaveBeenCalledTimes(0)); + await waitFor(() => expect(serverSessStopping).toHaveBeenCalledTimes(0)); // by this point the client should have reconnected // session > c----------| (connected) @@ -486,6 +500,8 @@ describe.each(testMatrix())( expect(clientSessStop).toHaveBeenCalledTimes(0); expect(serverSessStart).toHaveBeenCalledTimes(1); expect(serverSessStop).toHaveBeenCalledTimes(0); + expect(clientSessStopping).toHaveBeenCalledTimes(0); + expect(serverSessStopping).toHaveBeenCalledTimes(0); // disconnect session entirely // session > c------------x | (disconnected) @@ -500,6 +516,8 @@ describe.each(testMatrix())( await waitFor(() => expect(serverSessStart).toHaveBeenCalledTimes(1)); await waitFor(() => expect(clientSessStop).toHaveBeenCalledTimes(1)); await waitFor(() => expect(serverSessStop).toHaveBeenCalledTimes(1)); + await waitFor(() => expect(clientSessStopping).toHaveBeenCalledTimes(1)); + await waitFor(() => expect(serverSessStopping).toHaveBeenCalledTimes(1)); await testFinishesCleanly({ clientTransports: [clientTransport], @@ -591,6 +609,57 @@ describe.each(testMatrix())( serverTransport, }); }); + + test('listening on session disconnect and manually reconnecting works', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + clientTransport.connect(serverTransport.clientId); + + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); + + const onSessionDisconnect = vi.fn(); + const onSessionConnect = vi.fn(); + const sessionStatusListener = (evt: EventMap['sessionStatus']) => { + if (evt.status === 'created') { + onSessionConnect(); + } + + if (evt.status === 'closed') { + onSessionDisconnect(); + clientTransport.connect(serverTransport.clientId); + } + }; + + clientTransport.addEventListener('sessionStatus', sessionStatusListener); + addPostTestCleanup(async () => { + clientTransport.removeEventListener( + 'sessionStatus', + sessionStatusListener, + ); + }); + + expect(onSessionDisconnect).toHaveBeenCalledTimes(0); + expect(onSessionDisconnect).toHaveBeenCalledTimes(0); + + // cause a session disconnect + clientTransport.hardDisconnect(); + + await waitFor(() => expect(onSessionDisconnect).toHaveBeenCalledTimes(1)); + await waitFor(() => expect(onSessionConnect).toHaveBeenCalledTimes(1)); + + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + }); + }); }, ); @@ -883,7 +952,7 @@ describe.each(testMatrix())( const onSessionDisconnect = vi.fn(); const sessionStatusListener = (evt: EventMap['sessionStatus']) => { - if (evt.status === 'disconnect') { + if (evt.status === 'closed') { onSessionDisconnect(); } }; @@ -953,10 +1022,10 @@ describe.each(testMatrix())( const serverSessStop = vi.fn(); const serverSessHandler = (evt: EventMap['sessionStatus']) => { switch (evt.status) { - case 'connect': + case 'created': serverSessStart(); break; - case 'disconnect': + case 'closed': serverSessStop(); break; } @@ -1045,10 +1114,10 @@ describe.each(testMatrix())( const clientSessStop = vi.fn(); const clientSessHandler = (evt: EventMap['sessionStatus']) => { switch (evt.status) { - case 'connect': + case 'created': clientSessStart(); break; - case 'disconnect': + case 'closed': clientSessStop(); break; } @@ -1177,10 +1246,10 @@ describe.each(testMatrix())( const clientSessStop = vi.fn(); const clientSessHandler = (evt: EventMap['sessionStatus']) => { switch (evt.status) { - case 'connect': + case 'created': clientSessStart(); break; - case 'disconnect': + case 'closed': clientSessStop(); break; } @@ -1190,10 +1259,10 @@ describe.each(testMatrix())( const serverSessStop = vi.fn(); const serverSessHandler = (evt: EventMap['sessionStatus']) => { switch (evt.status) { - case 'connect': + case 'created': serverSessStart(); break; - case 'disconnect': + case 'closed': serverSessStop(); break; } diff --git a/transport/transport.ts b/transport/transport.ts index fd230a58..4e4bf8e7 100644 --- a/transport/transport.ts +++ b/transport/transport.ts @@ -165,7 +165,8 @@ export abstract class Transport { close() { this.status = 'closed'; - for (const session of this.sessions.values()) { + const sessions = Array.from(this.sessions.values()); + for (const session of sessions) { this.deleteSession(session); } @@ -195,7 +196,7 @@ export abstract class Transport { this.sessions.set(session.to, session); this.eventDispatcher.dispatchEvent('sessionStatus', { - status: 'connect', + status: 'created', session: session, }); @@ -246,13 +247,17 @@ export abstract class Transport { session.log?.info(`closing session ${session.id}`, loggingMetadata); this.eventDispatcher.dispatchEvent('sessionStatus', { - status: 'disconnect', + status: 'closing', session: session, }); const to = session.to; session.close(); this.sessions.delete(to); + this.eventDispatcher.dispatchEvent('sessionStatus', { + status: 'closed', + session: { id: session.id, to: to }, + }); } // common listeners