diff --git a/package-lock.json b/package-lock.json index 6cf9fcb2..8abad19b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.205.1", + "version": "0.205.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.205.1", + "version": "0.205.2", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", diff --git a/package.json b/package.json index 323378ea..6e3d2be2 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.205.1", + "version": "0.205.2", "type": "module", "exports": { ".": { diff --git a/testUtil/fixtures/mockTransport.ts b/testUtil/fixtures/mockTransport.ts index 73501716..af053d62 100644 --- a/testUtil/fixtures/mockTransport.ts +++ b/testUtil/fixtures/mockTransport.ts @@ -15,6 +15,7 @@ export class InMemoryConnection extends Connection { constructor(pipe: Duplex) { super(); this.conn = pipe; + this.conn.allowHalfOpen = false; this.conn.on('data', (data: Uint8Array) => { for (const cb of this.dataListeners) { @@ -22,7 +23,7 @@ export class InMemoryConnection extends Connection { } }); - this.conn.on('end', () => { + this.conn.on('close', () => { for (const cb of this.closeListeners) { cb(); } @@ -46,6 +47,7 @@ export class InMemoryConnection extends Connection { close(): void { setImmediate(() => { this.conn.end(); + this.conn.emit('close'); }); } } @@ -153,6 +155,7 @@ export function createMockTransportNetwork( simulatePhantomDisconnect() { for (const conn of Object.values(connections.get())) { conn.serverToClient.pause(); + conn.clientToServer.pause(); } }, async restartServer() { diff --git a/transport/events.ts b/transport/events.ts index c0c590a6..0dc1bf6d 100644 --- a/transport/events.ts +++ b/transport/events.ts @@ -2,6 +2,7 @@ import { type Static } from '@sinclair/typebox'; import { Connection } from './connection'; import { OpaqueTransportMessage, HandshakeErrorResponseCodes } from './message'; import { Session, SessionState } from './sessionStateMachine'; +import { SessionId } from './sessionStateMachine/common'; import { TransportStatus } from './transport'; export const ProtocolError = { @@ -26,11 +27,11 @@ export interface EventMap { session: Pick, 'id' | 'to'>; }; sessionTransition: - | { state: SessionState.Connected } - | { state: SessionState.Handshaking } - | { state: SessionState.Connecting } - | { state: SessionState.BackingOff } - | { state: SessionState.NoConnection }; + | { state: SessionState.Connected; id: SessionId } + | { state: SessionState.Handshaking; id: SessionId } + | { state: SessionState.Connecting; id: SessionId } + | { state: SessionState.BackingOff; id: SessionId } + | { state: SessionState.NoConnection; id: SessionId }; protocolError: | { type: (typeof ProtocolError)['HandshakeFailed']; diff --git a/transport/server.ts b/transport/server.ts index e26980cf..9e21011f 100644 --- a/transport/server.ts +++ b/transport/server.ts @@ -395,9 +395,8 @@ export abstract class ServerTransport< ); oldSession = noConnectionSession; + this.updateSession(oldSession); } - - this.updateSession(oldSession); } else if (oldSession) { connectCase = 'hard reconnection'; diff --git a/transport/sessionStateMachine/SessionConnected.ts b/transport/sessionStateMachine/SessionConnected.ts index 0de225f1..f224585f 100644 --- a/transport/sessionStateMachine/SessionConnected.ts +++ b/transport/sessionStateMachine/SessionConnected.ts @@ -4,6 +4,7 @@ import { ControlMessageAckSchema, OpaqueTransportMessage, PartialTransportMessage, + TransportMessage, isAck, } from '../message'; import { @@ -48,10 +49,25 @@ export class SessionConnected< this.heartbeatMisses = 0; } + private assertSendOrdering(constructedMsg: TransportMessage) { + if (constructedMsg.seq > this.seqSent + 1) { + const msg = `invariant violation: would have sent out of order msg (seq: ${constructedMsg.seq}, expected: ${this.seqSent} + 1)`; + this.log?.error(msg, { + ...this.loggingMetadata, + transportMessage: constructedMsg, + tags: ['invariant-violation'], + }); + + throw new Error(msg); + } + } + send(msg: PartialTransportMessage): string { const constructedMsg = this.constructMsg(msg); + this.assertSendOrdering(constructedMsg); this.sendBuffer.push(constructedMsg); this.conn.send(this.options.codec.toBuffer(constructedMsg)); + this.seqSent = constructedMsg.seq; return constructedMsg.id; } @@ -75,7 +91,9 @@ export class SessionConnected< ); for (const msg of this.sendBuffer) { + this.assertSendOrdering(msg); this.conn.send(this.options.codec.toBuffer(msg)); + this.seqSent = msg.seq; } } @@ -165,7 +183,7 @@ export class SessionConnected< ); } else { const reason = `received out-of-order msg, closing connection (got seq: ${parsedMsg.seq}, wanted seq: ${this.ack})`; - this.log?.warn(reason, { + this.log?.error(reason, { ...this.loggingMetadata, transportMessage: parsedMsg, tags: ['invariant-violation'], diff --git a/transport/sessionStateMachine/common.ts b/transport/sessionStateMachine/common.ts index 54019ea9..8a53d981 100644 --- a/transport/sessionStateMachine/common.ts +++ b/transport/sessionStateMachine/common.ts @@ -208,6 +208,7 @@ export interface IdentifiedSessionProps extends CommonSessionProps { to: TransportClientId; seq: number; ack: number; + seqSent: number; sendBuffer: Array; telemetry: TelemetryInfo; protocolVersion: ProtocolVersion; @@ -224,6 +225,11 @@ export abstract class IdentifiedSession extends CommonSession { */ seq: number; + /** + * Last seq we sent over the wire this session (excluding handshake) and retransmissions + */ + seqSent: number; + /** * Number of unique messages we've received this session (excluding handshake) */ @@ -231,8 +237,17 @@ export abstract class IdentifiedSession extends CommonSession { sendBuffer: Array; constructor(props: IdentifiedSessionProps) { - const { id, to, seq, ack, sendBuffer, telemetry, log, protocolVersion } = - props; + const { + id, + to, + seq, + ack, + sendBuffer, + telemetry, + log, + protocolVersion, + seqSent: messagesSent, + } = props; super(props); this.id = id; this.to = to; @@ -242,6 +257,7 @@ export abstract class IdentifiedSession extends CommonSession { this.telemetry = telemetry; this.log = log; this.protocolVersion = protocolVersion; + this.seqSent = messagesSent; } get loggingMetadata(): MessageMetadata { diff --git a/transport/sessionStateMachine/stateMachine.test.ts b/transport/sessionStateMachine/stateMachine.test.ts index f381edf5..1792d000 100644 --- a/transport/sessionStateMachine/stateMachine.test.ts +++ b/transport/sessionStateMachine/stateMachine.test.ts @@ -1942,16 +1942,19 @@ describe('session state machine', () => { expect(conn.send).toHaveBeenCalledTimes(0); // send a heartbeat - conn.emitData( - session.options.codec.toBuffer( - session.constructMsg({ - streamId: 'heartbeat', - controlFlags: ControlFlags.AckBit, - payload: { - type: 'ACK', - } satisfies Static, - }), - ), + conn.onData( + session.options.codec.toBuffer({ + id: 'msgid', + to: 'SERVER', + from: 'client', + seq: 0, + ack: 0, + streamId: 'heartbeat', + controlFlags: ControlFlags.AckBit, + payload: { + type: 'ACK', + } satisfies Static, + }), ); // make sure the session acks the heartbeat @@ -1964,16 +1967,19 @@ describe('session state machine', () => { const conn = session.conn; // send a heartbeat - conn.emitData( - session.options.codec.toBuffer( - session.constructMsg({ - streamId: 'heartbeat', - controlFlags: ControlFlags.AckBit, - payload: { - type: 'ACK', - } satisfies Static, - }), - ), + conn.onData( + session.options.codec.toBuffer({ + id: 'msgid', + to: 'SERVER', + from: 'client', + seq: 0, + ack: 0, + streamId: 'heartbeat', + controlFlags: ControlFlags.AckBit, + payload: { + type: 'ACK', + } satisfies Static, + }), ); expect(sessionHandle.onMessage).not.toHaveBeenCalled(); diff --git a/transport/sessionStateMachine/transitions.ts b/transport/sessionStateMachine/transitions.ts index 5cc03590..bf46b459 100644 --- a/transport/sessionStateMachine/transitions.ts +++ b/transport/sessionStateMachine/transitions.ts @@ -50,6 +50,7 @@ function inheritSharedSession( to: session.to, seq: session.seq, ack: session.ack, + seqSent: session.seqSent, sendBuffer: session.sendBuffer, telemetry: session.telemetry, options: session.options, @@ -90,6 +91,7 @@ export const SessionStateGraph = { to, seq: 0, ack: 0, + seqSent: 0, graceExpiryTime: Date.now() + options.sessionDisconnectGraceMs, sendBuffer, telemetry, @@ -251,12 +253,13 @@ export const SessionStateGraph = { ? // old session exists, inherit state inheritSharedSession(oldSession) : // old session does not exist, create new state - { + ({ id: sessionId, from, to, seq: 0, ack: 0, + seqSent: 0, sendBuffer: [], telemetry: createSessionTelemetryInfo( pendingSession.tracer, @@ -269,7 +272,7 @@ export const SessionStateGraph = { tracer: pendingSession.tracer, log: pendingSession.log, protocolVersion, - }; + } satisfies IdentifiedSessionProps); pendingSession._handleStateExit(); oldSession?._handleStateExit(); diff --git a/transport/transport.test.ts b/transport/transport.test.ts index 32c8c53f..7850e8ef 100644 --- a/transport/transport.test.ts +++ b/transport/transport.test.ts @@ -3,7 +3,6 @@ import { createDummyTransportMessage, payloadToTransportMessage, waitForMessage, - getTransportConnections, closeAllConnections, numberOfConnections, testingClientSessionOptions, @@ -314,9 +313,7 @@ describe.each(testMatrix())( ).resolves.toStrictEqual(first90.slice(0, 30).map((msg) => msg.payload)); clientTransport.reconnectOnConnectionDrop = false; - for (const conn of getTransportConnections(clientTransport)) { - conn.close(); - } + closeAllConnections(clientTransport); await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(0)); await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(0)); @@ -354,6 +351,125 @@ describe.each(testMatrix())( }); }); + test('buffering messages during reconnect doesnt cause a crash', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + clientTransport.connect(serverTransport.clientId); + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); + + const clientSendFn = getClientSendFn(clientTransport, serverTransport); + const serverSendFn = getServerSendFn(serverTransport, clientTransport); + + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + function sendToServer(num: number) { + const msgs = Array.from({ length: num }, () => + createDummyTransportMessage(), + ); + const ids = msgs.map((msg) => clientSendFn(msg)); + const promises = ids.map((id) => + waitForMessage(serverTransport, (recv) => recv.id === id), + ); + + return { msgs, promise: Promise.all(promises) }; + } + + function sendToClient(num: number) { + const msgs = Array.from({ length: num }, () => + createDummyTransportMessage(), + ); + const ids = msgs.map((msg) => serverSendFn(msg)); + const promises = ids.map((id) => + waitForMessage(clientTransport, (recv) => recv.id === id), + ); + + return { msgs, promise: Promise.all(promises) }; + } + + // Send initial messages to establish sequence numbers + const { promise: initialToServer } = sendToServer(5); + const { promise: initialToClient } = sendToClient(5); + await Promise.all([initialToServer, initialToClient]); + + // wait for one heartbeat to elapse + await advanceFakeTimersByHeartbeat(); + + // Disconnect client and prevent auto-reconnect + clientTransport.reconnectOnConnectionDrop = false; + closeAllConnections(clientTransport); + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(0)); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(0)); + + // Buffer some messages while disconnected in both directions + const { + msgs: bufferedMsgsToServer, + promise: bufferedMsgsToServerPromises, + } = sendToServer(5); + const { + msgs: bufferedMsgsToClient, + promise: bufferedMsgsToClientPromises, + } = sendToClient(5); + + // Reconnect client + clientTransport.reconnectOnConnectionDrop = true; + clientTransport.connect(serverTransport.clientId); + + // Send some while still connecting in both directions + const { + msgs: connectingMsgsToServer, + promise: connectingMsgsToServerPromises, + } = sendToServer(5); + const { + msgs: connectingMsgsToClient, + promise: connectingMsgsToClientPromises, + } = sendToClient(5); + + // Wait for reconnection + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); + + // Send some new messages after reconnection in both directions + const { msgs: newMsgsToServer, promise: newMsgPromiseToServer } = + sendToServer(5); + const { msgs: newMsgsToClient, promise: newMsgPromiseToClient } = + sendToClient(5); + + // Wait for all messages to be received in correct order + // First verify client->server messages + await expect( + Promise.all([ + bufferedMsgsToServerPromises, + connectingMsgsToServerPromises, + newMsgPromiseToServer, + ]), + ).resolves.toStrictEqual([ + bufferedMsgsToServer.map((msg) => msg.payload), + connectingMsgsToServer.map((msg) => msg.payload), + newMsgsToServer.map((msg) => msg.payload), + ]); + + // Then verify server->client messages + await expect( + Promise.all([ + bufferedMsgsToClientPromises, + connectingMsgsToClientPromises, + newMsgPromiseToClient, + ]), + ).resolves.toStrictEqual([ + bufferedMsgsToClient.map((msg) => msg.payload), + connectingMsgsToClient.map((msg) => msg.payload), + newMsgsToClient.map((msg) => msg.payload), + ]); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + }); + }); + test('both client and server transport get session created/closing/closed notifs', async () => { const clientTransport = getClientTransport('client'); const serverTransport = getServerTransport(); @@ -1307,8 +1423,20 @@ describe.each(testMatrix())( expect(clientSessStop).toHaveBeenCalledTimes(0); expect(serverSessStop).toHaveBeenCalledTimes(0); + // wait for one heartbeat to elapse + await advanceFakeTimersByHeartbeat(); + // now, let's wait until the connection is considered dead testHelpers.simulatePhantomDisconnect(); + + // sending messages here should eventually be received after recovering from disconnect grace + const msg2 = createDummyTransportMessage(); + const msg2Id = clientSendFn(msg2); + const msg2Promise = waitForMessage( + serverTransport, + (recv) => recv.id === msg2Id, + ); + await advanceFakeTimersByDisconnectGrace(); // should have reconnected by now @@ -1320,12 +1448,17 @@ describe.each(testMatrix())( await waitFor(() => expect(clientSessStop).toHaveBeenCalledTimes(0)); await waitFor(() => expect(serverSessStop).toHaveBeenCalledTimes(0)); + // we finally get the message + expect(await msg2Promise).toStrictEqual(msg2.payload); + // ensure sending across the connection still works - const msg2 = createDummyTransportMessage(); - const msg2Id = clientSendFn(msg2); - await expect( - waitForMessage(serverTransport, (recv) => recv.id === msg2Id), - ).resolves.toStrictEqual(msg2.payload); + const msg3 = createDummyTransportMessage(); + const msg3Id = clientSendFn(msg3); + const msg3Promise = waitForMessage( + serverTransport, + (recv) => recv.id === msg3Id, + ); + await expect(msg3Promise).resolves.toStrictEqual(msg3.payload); await testFinishesCleanly({ clientTransports: [clientTransport], diff --git a/transport/transport.ts b/transport/transport.ts index 4e4bf8e7..832a10f3 100644 --- a/transport/transport.ts +++ b/transport/transport.ts @@ -202,7 +202,7 @@ export abstract class Transport { this.eventDispatcher.dispatchEvent('sessionTransition', { state: session.state, - session: session, + id: session.id, } as EventMap['sessionTransition']); } @@ -229,7 +229,7 @@ export abstract class Transport { this.sessions.set(session.to, session); this.eventDispatcher.dispatchEvent('sessionTransition', { state: session.state, - session: session, + id: session.id, } as EventMap['sessionTransition']); }