diff --git a/__tests__/allocation.test.ts b/__tests__/allocation.test.ts new file mode 100644 index 00000000..03c78e64 --- /dev/null +++ b/__tests__/allocation.test.ts @@ -0,0 +1,177 @@ +import { beforeEach, describe, test, expect, vi, assert } from 'vitest'; +import { TestSetupHelpers, transports } from '../testUtil/fixtures/transports'; +import { BinaryCodec, Codec } from '../codec'; +import { + advanceFakeTimersByHeartbeat, + createPostTestCleanups, +} from '../testUtil/fixtures/cleanup'; +import { createServer } from '../router/server'; +import { createClient } from '../router/client'; +import { TestServiceSchema } from '../testUtil/fixtures/services'; +import { waitFor } from '../testUtil/fixtures/cleanup'; +import { numberOfConnections, closeAllConnections } from '../testUtil'; +import { cleanupTransports } from '../testUtil/fixtures/cleanup'; +import { testFinishesCleanly } from '../testUtil/fixtures/cleanup'; +import { ProtocolError } from '../transport/events'; + +let isOom = false; +// simulate RangeError: Array buffer allocation failed +const OomableCodec: Codec = { + toBuffer(obj) { + if (isOom) { + throw new RangeError('failed allocation'); + } + + return BinaryCodec.toBuffer(obj); + }, + fromBuffer: (buff: Uint8Array) => { + return BinaryCodec.fromBuffer(buff); + }, +}; + +describe.each(transports)( + 'failed allocation test ($name transport)', + async (transport) => { + const clientOpts = { codec: OomableCodec }; + const serverOpts = { codec: BinaryCodec }; + + const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); + let getClientTransport: TestSetupHelpers['getClientTransport']; + let getServerTransport: TestSetupHelpers['getServerTransport']; + beforeEach(async () => { + // only allow client to oom, server has sane oom handling already + const setup = await transport.setup({ + client: clientOpts, + server: serverOpts, + }); + getClientTransport = setup.getClientTransport; + getServerTransport = setup.getServerTransport; + isOom = false; + + return async () => { + await postTestCleanup(); + await setup.cleanup(); + }; + }); + + test('oom during heartbeat kills the session, client starts new session', async () => { + // setup + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { test: TestServiceSchema }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const errMock = vi.fn(); + clientTransport.addEventListener('protocolError', errMock); + addPostTestCleanup(async () => { + clientTransport.removeEventListener('protocolError', errMock); + await cleanupTransports([clientTransport, serverTransport]); + }); + + // establish initial connection + const result = await client.test.add.rpc({ n: 1 }); + expect(result).toStrictEqual({ ok: true, payload: { result: 1 } }); + + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); + const oldClientSession = serverTransport.sessions.get('client'); + const oldServerSession = clientTransport.sessions.get('SERVER'); + assert(oldClientSession); + assert(oldServerSession); + + // simulate some OOM during heartbeat + for (let i = 0; i < 5; i++) { + isOom = i % 2 === 0; + await advanceFakeTimersByHeartbeat(); + } + + // verify session on client is dead + await waitFor(() => expect(clientTransport.sessions.size).toBe(0)); + + // verify we got MessageSendFailure errors + await waitFor(() => { + expect(errMock).toHaveBeenCalledWith( + expect.objectContaining({ + type: ProtocolError.MessageSendFailure, + }), + ); + }); + + // client should be able to reconnect and make new calls + isOom = false; + const result2 = await client.test.add.rpc({ n: 2 }); + expect(result2).toStrictEqual({ ok: true, payload: { result: 3 } }); + + // verify new session IDs are different from old ones + const newClientSession = serverTransport.sessions.get('client'); + const newServerSession = clientTransport.sessions.get('SERVER'); + assert(newClientSession); + assert(newServerSession); + expect(newClientSession.id).not.toBe(oldClientSession.id); + expect(newServerSession.id).not.toBe(oldServerSession.id); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('oom during handshake kills the session, client starts new session', async () => { + // setup + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { test: TestServiceSchema }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + const errMock = vi.fn(); + clientTransport.addEventListener('protocolError', errMock); + addPostTestCleanup(async () => { + clientTransport.removeEventListener('protocolError', errMock); + await cleanupTransports([clientTransport, serverTransport]); + }); + + // establish initial connection + await client.test.add.rpc({ n: 1 }); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); + + // close connection to force reconnection + closeAllConnections(clientTransport); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(0)); + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(0)); + + // simulate OOM during handshake + isOom = true; + clientTransport.connect('SERVER'); + await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(0)); + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(0)); + + await waitFor(() => { + expect(errMock).toHaveBeenCalledWith( + expect.objectContaining({ + type: ProtocolError.MessageSendFailure, + }), + ); + }); + + // client should be able to reconnect and make new calls + isOom = false; + const result = await client.test.add.rpc({ n: 2 }); + expect(result).toStrictEqual({ ok: true, payload: { result: 3 } }); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + }, +); diff --git a/codec/adapter.ts b/codec/adapter.ts new file mode 100644 index 00000000..450a371b --- /dev/null +++ b/codec/adapter.ts @@ -0,0 +1,53 @@ +import { Value } from '@sinclair/typebox/value'; +import { + OpaqueTransportMessage, + OpaqueTransportMessageSchema, +} from '../transport'; +import { Codec } from './types'; +import { DeserializeResult, SerializeResult } from '../transport/results'; +import { coerceErrorString } from '../transport/stringifyError'; + +/** + * Adapts a {@link Codec} to the {@link OpaqueTransportMessage} format, + * accounting for fallibility of toBuffer and fromBuffer and wrapping + * it with a Result type. + */ +export class CodecMessageAdapter { + constructor(private readonly codec: Codec) {} + + toBuffer(msg: OpaqueTransportMessage): SerializeResult { + try { + return { + ok: true, + value: this.codec.toBuffer(msg), + }; + } catch (e) { + return { + ok: false, + reason: coerceErrorString(e), + }; + } + } + + fromBuffer(buf: Uint8Array): DeserializeResult { + try { + const parsedMsg = this.codec.fromBuffer(buf); + if (!Value.Check(OpaqueTransportMessageSchema, parsedMsg)) { + return { + ok: false, + reason: 'transport message schema mismatch', + }; + } + + return { + ok: true, + value: parsedMsg, + }; + } catch (e) { + return { + ok: false, + reason: coerceErrorString(e), + }; + } + } +} diff --git a/codec/binary.ts b/codec/binary.ts index 69c9818e..09c6e339 100644 --- a/codec/binary.ts +++ b/codec/binary.ts @@ -10,15 +10,11 @@ export const BinaryCodec: Codec = { return encode(obj, { ignoreUndefined: true }); }, fromBuffer: (buff: Uint8Array) => { - try { - const res = decode(buff); - if (typeof res !== 'object') { - return null; - } - - return res; - } catch { - return null; + const res = decode(buff); + if (typeof res !== 'object' || res === null) { + throw new Error('unpacked msg is not an object'); } + + return res; }, }; diff --git a/codec/codec.test.ts b/codec/codec.test.ts index 763b8774..958cfe1d 100644 --- a/codec/codec.test.ts +++ b/codec/codec.test.ts @@ -41,10 +41,10 @@ describe.each(codecs)('codec -- $name', ({ codec }) => { expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual(msg); }); - test('invalid json returns null', () => { - expect(codec.fromBuffer(Buffer.from(''))).toBeNull(); - expect(codec.fromBuffer(Buffer.from('['))).toBeNull(); - expect(codec.fromBuffer(Buffer.from('[{}'))).toBeNull(); - expect(codec.fromBuffer(Buffer.from('{"a":1}[]'))).toBeNull(); + test('invalid json throws', () => { + expect(() => codec.fromBuffer(Buffer.from(''))).toThrow(); + expect(() => codec.fromBuffer(Buffer.from('['))).toThrow(); + expect(() => codec.fromBuffer(Buffer.from('[{}'))).toThrow(); + expect(() => codec.fromBuffer(Buffer.from('{"a":1}[]'))).toThrow(); }); }); diff --git a/codec/index.ts b/codec/index.ts index f6455901..f85c24c2 100644 --- a/codec/index.ts +++ b/codec/index.ts @@ -1,3 +1,4 @@ export { BinaryCodec } from './binary'; export { NaiveJsonCodec } from './json'; export type { Codec } from './types'; +export { CodecMessageAdapter } from './adapter'; diff --git a/codec/json.ts b/codec/json.ts index 78973989..9c6cb813 100644 --- a/codec/json.ts +++ b/codec/json.ts @@ -48,23 +48,21 @@ export const NaiveJsonCodec: Codec = { ); }, fromBuffer: (buff: Uint8Array) => { - try { - const parsed = JSON.parse( - decoder.decode(buff), - function reviver(_key, val: unknown) { - if ((val as Base64EncodedValue | undefined)?.$t) { - return base64ToUint8Array((val as Base64EncodedValue).$t); - } else { - return val; - } - }, - ) as unknown; - - if (typeof parsed === 'object') return parsed; + const parsed = JSON.parse( + decoder.decode(buff), + function reviver(_key, val: unknown) { + if ((val as Base64EncodedValue | undefined)?.$t) { + return base64ToUint8Array((val as Base64EncodedValue).$t); + } else { + return val; + } + }, + ) as unknown; - return null; - } catch { - return null; + if (typeof parsed !== 'object' || parsed === null) { + throw new Error('unpacked msg is not an object'); } + + return parsed; }, }; diff --git a/codec/types.ts b/codec/types.ts index 819e0dcb..0d74e812 100644 --- a/codec/types.ts +++ b/codec/types.ts @@ -14,5 +14,5 @@ export interface Codec { * @param buf - The Uint8 buffer to decode. * @returns The decoded object, or null if decoding failed. */ - fromBuffer(buf: Uint8Array): object | null; + fromBuffer(buf: Uint8Array): object; } diff --git a/package-lock.json b/package-lock.json index 9d105090..98157ecc 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.207.2", + "version": "0.207.3", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.207.2", + "version": "0.207.3", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", diff --git a/package.json b/package.json index 6866a298..18f78d15 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.207.2", + "version": "0.207.3", "type": "module", "exports": { ".": { diff --git a/transport/client.ts b/transport/client.ts index 642dede6..ae4c12be 100644 --- a/transport/client.ts +++ b/transport/client.ts @@ -292,13 +292,46 @@ export abstract class ClientTransport< this.handleMsg(msg); }, onInvalidMessage: (reason) => { - this.deleteSession(connectedSession, { unhealthy: true }); + this.log?.error(`invalid message: ${reason}`, { + ...connectedSession.loggingMetadata, + transportMessage: msg, + }); + this.protocolError({ type: ProtocolError.InvalidMessage, message: reason, }); + this.deleteSession(connectedSession, { unhealthy: true }); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...connectedSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(connectedSession, { unhealthy: true }); + }, + }); + + const res = connectedSession.sendBufferedMessages(); + if (!res.ok) { + this.log?.error(`failed to send buffered messages: ${res.reason}`, { + ...connectedSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: res.reason, }); + this.deleteSession(connectedSession, { unhealthy: true }); + + return; + } this.updateSession(connectedSession); this.retryBudget.startRestoringBudget(); @@ -471,7 +504,19 @@ export abstract class ClientTransport< transportMessage: requestMsg, }); - session.sendHandshake(requestMsg); + const res = session.sendHandshake(requestMsg); + if (!res.ok) { + this.log?.error(`failed to send handshake request: ${res.reason}`, { + ...session.loggingMetadata, + transportMessage: requestMsg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: res.reason, + }); + this.deleteSession(session, { unhealthy: true }); + } } close() { diff --git a/transport/events.ts b/transport/events.ts index 0dc1bf6d..41e832c4 100644 --- a/transport/events.ts +++ b/transport/events.ts @@ -10,6 +10,7 @@ export const ProtocolError = { HandshakeFailed: 'handshake_failed', MessageOrderingViolated: 'message_ordering_violated', InvalidMessage: 'invalid_message', + MessageSendFailure: 'message_send_failure', } as const; export type ProtocolErrorType = diff --git a/transport/impls/ws/connection.ts b/transport/impls/ws/connection.ts index c436c3ac..c1da4fbd 100644 --- a/transport/impls/ws/connection.ts +++ b/transport/impls/ws/connection.ts @@ -52,13 +52,13 @@ export class WebSocketConnection extends Connection { } send(payload: Uint8Array) { - if (this.ws.readyState !== this.ws.OPEN) { + try { + this.ws.send(payload); + + return true; + } catch { return false; } - - this.ws.send(payload); - - return true; } close() { diff --git a/transport/results.ts b/transport/results.ts new file mode 100644 index 00000000..ff8c12da --- /dev/null +++ b/transport/results.ts @@ -0,0 +1,17 @@ +import { OpaqueTransportMessage } from './message'; + +// internal use only, not to be used in public API +type SessionApiResult = + | { + ok: true; + value: T; + } + | { + ok: false; + reason: string; + }; + +export type SendResult = SessionApiResult; +export type SendBufferResult = SessionApiResult; +export type SerializeResult = SessionApiResult; +export type DeserializeResult = SessionApiResult; diff --git a/transport/server.ts b/transport/server.ts index 9e21011f..cf5a7d3d 100644 --- a/transport/server.ts +++ b/transport/server.ts @@ -185,17 +185,31 @@ export abstract class ServerTransport< this.log?.warn(reason, metadata); - session.sendHandshake( - handshakeResponseMessage({ - from: this.clientId, - to, - status: { - ok: false, - code, - reason, - }, - }), - ); + const responseMsg = handshakeResponseMessage({ + from: this.clientId, + to, + status: { + ok: false, + code, + reason, + }, + }); + + const res = session.sendHandshake(responseMsg); + if (!res.ok) { + this.log?.error(`failed to send handshake response: ${res.reason}`, { + ...session.loggingMetadata, + transportMessage: responseMsg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: res.reason, + }); + this.deletePendingSession(session); + + return; + } this.protocolError({ type: ProtocolError.HandshakeFailed, @@ -456,9 +470,24 @@ export abstract class ServerTransport< }, }); - session.sendHandshake(responseMsg); + const res = session.sendHandshake(responseMsg); + if (!res.ok) { + this.log?.error(`failed to send handshake response: ${res.reason}`, { + ...session.loggingMetadata, + transportMessage: responseMsg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: res.reason, + }); + this.deletePendingSession(session); + + return; + } // transition + this.pendingSessions.delete(session); const connectedSession = ServerSessionStateGraph.transition.WaitingForHandshakeToConnected( session, @@ -487,16 +516,52 @@ export abstract class ServerTransport< this.handleMsg(msg); }, onInvalidMessage: (reason) => { + this.log?.error(`invalid message: ${reason}`, { + ...connectedSession.loggingMetadata, + transportMessage: msg, + }); + this.protocolError({ type: ProtocolError.InvalidMessage, message: reason, }); this.deleteSession(connectedSession, { unhealthy: true }); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...connectedSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(connectedSession, { unhealthy: true }); + }, }, gotVersion, ); + const bufferSendRes = connectedSession.sendBufferedMessages(); + if (!bufferSendRes.ok) { + this.log?.error( + `failed to send buffered messages: ${bufferSendRes.reason}`, + { + ...connectedSession.loggingMetadata, + transportMessage: msg, + }, + ); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: bufferSendRes.reason, + }); + this.deleteSession(connectedSession, { unhealthy: true }); + + return; + } + this.sessionHandshakeMetadata.set(connectedSession.to, parsedMetadata); if (oldSession) { this.updateSession(connectedSession); @@ -504,7 +569,6 @@ export abstract class ServerTransport< this.createSession(connectedSession); } - this.pendingSessions.delete(session); connectedSession.startActiveHeartbeat(); } } diff --git a/transport/sessionStateMachine/SessionConnected.ts b/transport/sessionStateMachine/SessionConnected.ts index 83729f0a..1000d3a1 100644 --- a/transport/sessionStateMachine/SessionConnected.ts +++ b/transport/sessionStateMachine/SessionConnected.ts @@ -10,15 +10,18 @@ import { import { IdentifiedSession, IdentifiedSessionProps, + sendMessage, SessionState, } from './common'; import { Connection } from '../connection'; import { SpanStatusCode } from '@opentelemetry/api'; +import { SendBufferResult, SendResult } from '../results'; export interface SessionConnectedListeners { onConnectionErrored: (err: unknown) => void; onConnectionClosed: () => void; onMessage: (msg: OpaqueTransportMessage) => void; + onMessageSendFailure: (msg: PartialTransportMessage, reason: string) => void; onInvalidMessage: (reason: string) => void; } @@ -67,14 +70,20 @@ export class SessionConnected< } } - send(msg: PartialTransportMessage): string { + send(msg: PartialTransportMessage): SendResult { const constructedMsg = this.constructMsg(msg); this.assertSendOrdering(constructedMsg); this.sendBuffer.push(constructedMsg); - this.conn.send(this.options.codec.toBuffer(constructedMsg)); + const res = sendMessage(this.conn, this.codec, constructedMsg); + if (!res.ok) { + this.listeners.onMessageSendFailure(constructedMsg, res.reason); + + return res; + } + this.seqSent = constructedMsg.seq; - return constructedMsg.id; + return res; } constructor(props: SessionConnectedProps) { @@ -85,7 +94,9 @@ export class SessionConnected< this.conn.addDataListener(this.onMessageData); this.conn.addCloseListener(this.listeners.onConnectionClosed); this.conn.addErrorListener(this.listeners.onConnectionErrored); + } + sendBufferedMessages(): SendBufferResult { // send any buffered messages // dont explicity clear the buffer, we'll just filter out old messages // when we receive an ack @@ -99,12 +110,18 @@ export class SessionConnected< for (const msg of this.sendBuffer) { this.assertSendOrdering(msg); - this.conn.send(this.options.codec.toBuffer(msg)); + const res = sendMessage(this.conn, this.codec, msg); + if (!res.ok) { + this.listeners.onMessageSendFailure(msg, res.reason); + + return res; + } + this.seqSent = msg.seq; } } - this.startMissingHeartbeatTimeout(); + return { ok: true, value: undefined }; } get loggingMetadata() { @@ -137,25 +154,31 @@ export class SessionConnected< }, this.options.heartbeatIntervalMs); } - private sendHeartbeat() { + private sendHeartbeat(): void { this.log?.debug('sending heartbeat', this.loggingMetadata); - this.send({ + const heartbeat = { streamId: 'heartbeat', controlFlags: ControlFlags.AckBit, payload: { type: 'ACK', } satisfies Static, - }); + } satisfies PartialTransportMessage; + + this.send(heartbeat); } onMessageData = (msg: Uint8Array) => { - const parsedMsg = this.parseMsg(msg); - if (parsedMsg === null) { - this.listeners.onInvalidMessage('could not parse message'); + const parsedMsgRes = this.codec.fromBuffer(msg); + if (!parsedMsgRes.ok) { + this.listeners.onInvalidMessage( + `could not parse message: ${parsedMsgRes.reason}`, + ); return; } + const parsedMsg = parsedMsgRes.value; + // check message ordering here if (parsedMsg.seq !== this.ack) { if (parsedMsg.seq < this.ack) { diff --git a/transport/sessionStateMachine/SessionHandshaking.ts b/transport/sessionStateMachine/SessionHandshaking.ts index 3fe2e099..dbe3cef0 100644 --- a/transport/sessionStateMachine/SessionHandshaking.ts +++ b/transport/sessionStateMachine/SessionHandshaking.ts @@ -9,8 +9,10 @@ import { IdentifiedSessionWithGracePeriod, IdentifiedSessionWithGracePeriodListeners, IdentifiedSessionWithGracePeriodProps, + sendMessage, SessionState, } from './common'; +import { SendResult } from '../results'; export interface SessionHandshakingListeners extends IdentifiedSessionWithGracePeriodListeners { @@ -67,21 +69,21 @@ export class SessionHandshaking< } onHandshakeData = (msg: Uint8Array) => { - const parsedMsg = this.parseMsg(msg); - if (parsedMsg === null) { + const parsedMsgRes = this.codec.fromBuffer(msg); + if (!parsedMsgRes.ok) { this.listeners.onInvalidHandshake( - 'could not parse message', + `could not parse handshake message: ${parsedMsgRes.reason}`, 'MALFORMED_HANDSHAKE', ); return; } - this.listeners.onHandshake(parsedMsg); + this.listeners.onHandshake(parsedMsgRes.value); }; - sendHandshake(msg: TransportMessage): boolean { - return this.conn.send(this.options.codec.toBuffer(msg)); + sendHandshake(msg: TransportMessage): SendResult { + return sendMessage(this.conn, this.codec, msg); } _handleStateExit(): void { diff --git a/transport/sessionStateMachine/SessionWaitingForHandshake.ts b/transport/sessionStateMachine/SessionWaitingForHandshake.ts index 851f372c..0f5a2047 100644 --- a/transport/sessionStateMachine/SessionWaitingForHandshake.ts +++ b/transport/sessionStateMachine/SessionWaitingForHandshake.ts @@ -5,7 +5,13 @@ import { OpaqueTransportMessage, TransportMessage, } from '../message'; -import { CommonSession, CommonSessionProps, SessionState } from './common'; +import { + CommonSession, + CommonSessionProps, + sendMessage, + SessionState, +} from './common'; +import { SendResult } from '../results'; export interface SessionWaitingForHandshakeListeners { onConnectionErrored: (err: unknown) => void; @@ -62,10 +68,10 @@ export class SessionWaitingForHandshake< } onHandshakeData = (msg: Uint8Array) => { - const parsedMsg = this.parseMsg(msg); - if (parsedMsg === null) { + const parsedMsgRes = this.codec.fromBuffer(msg); + if (!parsedMsgRes.ok) { this.listeners.onInvalidHandshake( - 'could not parse message', + `could not parse handshake message: ${parsedMsgRes.reason}`, 'MALFORMED_HANDSHAKE', ); @@ -74,11 +80,11 @@ export class SessionWaitingForHandshake< // after this fires, the listener is responsible for transitioning the session // and thus removing the handshake timeout - this.listeners.onHandshake(parsedMsg); + this.listeners.onHandshake(parsedMsgRes.value); }; - sendHandshake(msg: TransportMessage): boolean { - return this.conn.send(this.options.codec.toBuffer(msg)); + sendHandshake(msg: TransportMessage): SendResult { + return sendMessage(this.conn, this.codec, msg); } _handleStateExit(): void { diff --git a/transport/sessionStateMachine/common.ts b/transport/sessionStateMachine/common.ts index 8a53d981..2db1499b 100644 --- a/transport/sessionStateMachine/common.ts +++ b/transport/sessionStateMachine/common.ts @@ -2,16 +2,16 @@ import { Logger, MessageMetadata } from '../../logging'; import { TelemetryInfo } from '../../tracing'; import { OpaqueTransportMessage, - OpaqueTransportMessageSchema, PartialTransportMessage, ProtocolVersion, TransportClientId, TransportMessage, } from '../message'; -import { Value } from '@sinclair/typebox/value'; -import { Codec } from '../../codec'; +import { Codec, CodecMessageAdapter } from '../../codec'; import { generateId } from '../id'; import { Tracer } from '@opentelemetry/api'; +import { SendResult } from '../results'; +import { Connection } from '../connection'; export const enum SessionState { NoConnection = 'NoConnection', @@ -148,6 +148,7 @@ export interface SessionOptions { export interface CommonSessionProps { from: TransportClientId; options: SessionOptions; + codec: CodecMessageAdapter; tracer: Tracer; log: Logger | undefined; } @@ -156,42 +157,18 @@ export abstract class CommonSession extends StateMachineState { readonly from: TransportClientId; readonly options: SessionOptions; + readonly codec: CodecMessageAdapter; tracer: Tracer; log?: Logger; abstract get loggingMetadata(): MessageMetadata; - constructor({ from, options, log, tracer }: CommonSessionProps) { + constructor({ from, options, log, tracer, codec }: CommonSessionProps) { super(); this.from = from; this.options = options; this.log = log; this.tracer = tracer; - } - - parseMsg(msg: Uint8Array): OpaqueTransportMessage | null { - const parsedMsg = this.options.codec.fromBuffer(msg); - - if (parsedMsg === null) { - this.log?.error( - `received malformed msg: ${Buffer.from(msg).toString('base64')}`, - this.loggingMetadata, - ); - - return null; - } - - if (!Value.Check(OpaqueTransportMessageSchema, parsedMsg)) { - this.log?.error(`received invalid msg: ${JSON.stringify(parsedMsg)}`, { - ...this.loggingMetadata, - validationErrors: [ - ...Value.Errors(OpaqueTransportMessageSchema, parsedMsg), - ], - }); - - return null; - } - - return parsedMsg; + this.codec = codec; } } @@ -299,11 +276,14 @@ export abstract class IdentifiedSession extends CommonSession { return this.sendBuffer.length > 0 ? this.sendBuffer[0].seq : this.seq; } - send(msg: PartialTransportMessage): string { + send(msg: PartialTransportMessage): SendResult { const constructedMsg = this.constructMsg(msg); this.sendBuffer.push(constructedMsg); - return constructedMsg.id; + return { + ok: true, + value: constructedMsg.id, + }; } _handleStateExit(): void { @@ -356,3 +336,27 @@ export abstract class IdentifiedSessionWithGracePeriod extends IdentifiedSession super._handleClose(); } } + +export function sendMessage( + conn: Connection, + codec: CodecMessageAdapter, + msg: TransportMessage, +): SendResult { + const buff = codec.toBuffer(msg); + if (!buff.ok) { + return buff; + } + + const sent = conn.send(buff.value); + if (!sent) { + return { + ok: false, + reason: 'failed to send message', + }; + } + + return { + ok: true, + value: msg.id, + }; +} diff --git a/transport/sessionStateMachine/stateMachine.test.ts b/transport/sessionStateMachine/stateMachine.test.ts index 894f41ab..cc94c831 100644 --- a/transport/sessionStateMachine/stateMachine.test.ts +++ b/transport/sessionStateMachine/stateMachine.test.ts @@ -51,7 +51,7 @@ function persistedSessionState(session: IdentifiedSession) { class MockConnection extends Connection { status: 'open' | 'closed' = 'open'; - send = vi.fn(); + send = vi.fn(() => true); close(): void { this.status = 'closed'; @@ -135,6 +135,7 @@ function createSessionConnectedListeners(): SessionConnectedListeners { onConnectionClosed: vi.fn(), onConnectionErrored: vi.fn(), onInvalidMessage: vi.fn(), + onMessageSendFailure: vi.fn(), }; } @@ -483,6 +484,9 @@ describe('session state machine', () => { currentProtocolVersion, ); + const res = session.sendBufferedMessages(); + expect(res.ok).toBe(true); + expect(session.sendBuffer.length).toBe(2); session.send(payloadToTransportMessage('foo')); expect(session.sendBuffer.length).toBe(3); expect(session.seq).toBe(3); diff --git a/transport/sessionStateMachine/transitions.ts b/transport/sessionStateMachine/transitions.ts index bf46b459..b14a9c34 100644 --- a/transport/sessionStateMachine/transitions.ts +++ b/transport/sessionStateMachine/transitions.ts @@ -40,6 +40,7 @@ import { } from './SessionBackingOff'; import { ProtocolVersion } from '../message'; import { Tracer } from '@opentelemetry/api'; +import { CodecMessageAdapter } from '../../codec'; function inheritSharedSession( session: IdentifiedSession, @@ -57,6 +58,7 @@ function inheritSharedSession( log: session.log, tracer: session.tracer, protocolVersion: session.protocolVersion, + codec: session.codec, }; } @@ -99,6 +101,7 @@ export const SessionStateGraph = { protocolVersion, tracer, log, + codec: new CodecMessageAdapter(options.codec), }); session.log?.info(`session ${session.id} created in NoConnection state`, { @@ -123,6 +126,7 @@ export const SessionStateGraph = { options, tracer, log, + codec: new CodecMessageAdapter(options.codec), }); session.log?.info(`session created in WaitingForHandshake state`, { @@ -228,6 +232,8 @@ export const SessionStateGraph = { ...carriedState, }); + session.startMissingHeartbeatTimeout(); + session.log?.info( `session ${session.id} transition from Handshaking to Connected`, { @@ -272,6 +278,7 @@ export const SessionStateGraph = { tracer: pendingSession.tracer, log: pendingSession.log, protocolVersion, + codec: new CodecMessageAdapter(options.codec), } satisfies IdentifiedSessionProps); pendingSession._handleStateExit(); @@ -283,6 +290,8 @@ export const SessionStateGraph = { ...carriedState, }); + session.startMissingHeartbeatTimeout(); + conn.telemetry = createConnectionTelemetryInfo( session.tracer, conn, diff --git a/transport/transforms/messageFraming.test.ts b/transport/transforms/messageFraming.test.ts deleted file mode 100644 index dde913ce..00000000 --- a/transport/transforms/messageFraming.test.ts +++ /dev/null @@ -1,124 +0,0 @@ -import { MessageFramer } from './messageFraming'; -import { describe, test, expect, vi } from 'vitest'; - -describe('MessageFramer', () => { - const encodeMessage = (message: string) => { - return MessageFramer.write(Buffer.from(message)); - }; - - test('basic transform', () => { - const spy = vi.fn(); - const parser = MessageFramer.createFramedStream(); - - parser.on('data', spy); - parser.write(encodeMessage('content 1')); - parser.write(encodeMessage('content 2')); - parser.write(encodeMessage('content 3')); - parser.write(encodeMessage('content 4')); - parser.end(); - - expect(spy).toHaveBeenNthCalledWith(1, Buffer.from('content 1')); - expect(spy).toHaveBeenNthCalledWith(2, Buffer.from('content 2')); - expect(spy).toHaveBeenNthCalledWith(3, Buffer.from('content 3')); - expect(spy).toHaveBeenNthCalledWith(4, Buffer.from('content 4')); - expect(spy).toHaveBeenCalledTimes(4); - }); - - test('handles partial messages across chunks', () => { - const spy = vi.fn(); - const parser = MessageFramer.createFramedStream(); - - const msg = encodeMessage('content 1'); - const part1 = msg.subarray(0, 5); // Split the encoded message - const part2 = msg.subarray(5); - - parser.on('data', spy); - parser.write(part1); - parser.write(part2); // Complete the first message - parser.write(encodeMessage('content 2')); // Second message - parser.end(); - - expect(spy).toHaveBeenNthCalledWith(1, Buffer.from('content 1')); - expect(spy).toHaveBeenNthCalledWith(2, Buffer.from('content 2')); - expect(spy).toHaveBeenCalledTimes(2); - }); - - test('multiple messages in a single chunk', () => { - const spy = vi.fn(); - const parser = MessageFramer.createFramedStream(); - - const message1 = encodeMessage('first message'); - const message2 = encodeMessage('second message'); - const combinedMessages = Buffer.concat([message1, message2]); - - parser.on('data', spy); - parser.write(combinedMessages); // Writing both messages in a single write operation - parser.end(); - - expect(spy).toHaveBeenNthCalledWith(1, Buffer.from('first message')); - expect(spy).toHaveBeenNthCalledWith(2, Buffer.from('second message')); - expect(spy).toHaveBeenCalledTimes(2); - }); - - test('max buffer size exceeded', () => { - const parser = MessageFramer.createFramedStream({ - maxBufferSizeBytes: 8, // Set a small max buffer size - }); - - const spy = vi.fn(); - const err = vi.fn(); - parser.on('data', spy); - parser.on('error', err); - - const msg = encodeMessage('long content'); - expect(msg.byteLength > 10); - parser.write(msg); - expect(spy).toHaveBeenCalledTimes(0); - expect(err).toHaveBeenCalledTimes(1); - parser.end(); - }); - - test('incomplete message at stream end', () => { - const spy = vi.fn(); - const err = vi.fn(); - const parser = MessageFramer.createFramedStream(); - - parser.on('data', spy); - parser.on('error', err); - - // say this is a 256B message - const lengthPrefix = Buffer.alloc(4); - lengthPrefix.writeUInt32BE(256, 0); - - // write a message that is clearly not 256B - const incompleteMessage = Buffer.concat([ - lengthPrefix, - Buffer.from('incomplete'), - ]); - parser.write(incompleteMessage); - - expect(spy).toHaveBeenCalledTimes(0); - expect(err).toHaveBeenCalledTimes(0); - - parser.end(); - expect(spy).toHaveBeenCalledTimes(0); - expect(err).toHaveBeenCalledTimes(0); - }); - - test('consistent byte length calculation with emojis and unicode', () => { - const parser = MessageFramer.createFramedStream(); - const spy = vi.fn(); - parser.on('data', spy); - - const emojiMessage = 'πŸ‡§πŸ‡ͺπŸ‡¨πŸ‡¦πŸ‘¨β€πŸ‘©β€πŸ‘§β€πŸ‘¦'; - const unicodeMessage = 'δ½ ε₯½οΌŒδΈ–η•Œ'; // "Hello, World" in Chinese - - parser.write(encodeMessage(emojiMessage)); - parser.write(encodeMessage(unicodeMessage)); - parser.end(); - - expect(spy).toHaveBeenNthCalledWith(1, Buffer.from(emojiMessage)); - expect(spy).toHaveBeenNthCalledWith(2, Buffer.from(unicodeMessage)); - expect(spy).toHaveBeenCalledTimes(2); - }); -}); diff --git a/transport/transforms/messageFraming.ts b/transport/transforms/messageFraming.ts deleted file mode 100644 index 93decf2e..00000000 --- a/transport/transforms/messageFraming.ts +++ /dev/null @@ -1,83 +0,0 @@ -import { Transform, TransformCallback, TransformOptions } from 'node:stream'; - -export interface LengthEncodedOptions extends TransformOptions { - /** Maximum in-memory buffer size before we throw */ - maxBufferSizeBytes: number; -} - -/** - * A transform stream that emits data each time a message with a network/BigEndian uint32 length prefix is received. - * @extends Transform - */ -export class Uint32LengthPrefixFraming extends Transform { - receivedBuffer: Buffer; - maxBufferSizeBytes: number; - - constructor({ maxBufferSizeBytes, ...options }: LengthEncodedOptions) { - super(options); - this.maxBufferSizeBytes = maxBufferSizeBytes; - this.receivedBuffer = Buffer.alloc(0); - } - - _transform(chunk: Buffer, _encoding: BufferEncoding, cb: TransformCallback) { - if ( - this.receivedBuffer.byteLength + chunk.byteLength > - this.maxBufferSizeBytes - ) { - const err = new Error( - `buffer overflow: ${this.receivedBuffer.byteLength}B > ${this.maxBufferSizeBytes}B`, - ); - - this.emit('error', err); - cb(err); - - return; - } - - this.receivedBuffer = Buffer.concat([this.receivedBuffer, chunk]); - - // ensure there's enough for a length prefix - while (this.receivedBuffer.length > 4) { - // read length from buffer (accounting for uint32 prefix) - const claimedMessageLength = this.receivedBuffer.readUInt32BE(0) + 4; - if (this.receivedBuffer.length >= claimedMessageLength) { - // slice the buffer to extract the message - const message = this.receivedBuffer.subarray(4, claimedMessageLength); - this.push(message); - this.receivedBuffer = - this.receivedBuffer.subarray(claimedMessageLength); - } else { - // not enough data for a complete message, wait for more data - break; - } - } - - cb(); - } - - _flush(cb: TransformCallback) { - this.receivedBuffer = Buffer.alloc(0); - cb(); - } - - _destroy(error: Error | null, callback: (error: Error | null) => void): void { - this.receivedBuffer = Buffer.alloc(0); - super._destroy(error, callback); - } -} - -function createLengthEncodedStream(options?: Partial) { - return new Uint32LengthPrefixFraming({ - maxBufferSizeBytes: options?.maxBufferSizeBytes ?? 16 * 1024 * 1024, // 16MB - }); -} - -export const MessageFramer = { - createFramedStream: createLengthEncodedStream, - write: (buf: Uint8Array) => { - const lengthPrefix = Buffer.alloc(4); - lengthPrefix.writeUInt32BE(buf.length, 0); - - return Buffer.concat([lengthPrefix, buf]); - }, -}; diff --git a/transport/transport.ts b/transport/transport.ts index 832a10f3..97ff8b01 100644 --- a/transport/transport.ts +++ b/transport/transport.ts @@ -40,9 +40,7 @@ export interface DeleteSessionOptions { unhealthy: boolean; } -export type SessionBoundSendFn = ( - msg: PartialTransportMessage, -) => string | undefined; +export type SessionBoundSendFn = (msg: PartialTransportMessage) => string; /** * Transports manage the lifecycle (creation/deletion) of sessions @@ -337,13 +335,18 @@ export abstract class Transport { } const sameSession = session.id === sessionId; - if (!sameSession) { + if (!sameSession || session._isConsumed) { throw new Error( `session scope for ${sessionId} has ended (transition), can't send`, ); } - return session.send(msg); + const res = session.send(msg); + if (!res.ok) { + throw new Error(res.reason); + } + + return res.value; }; } }