Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions __tests__/unserializable.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import { beforeEach, describe, expect, test } from 'vitest';
import { Type } from '@sinclair/typebox';
import {
Procedure,
createServiceSchema,
Ok,
createClient,
createServer,
UNEXPECTED_DISCONNECT_CODE,
} from '../router';
import { testMatrix } from '../testUtil/fixtures/matrix';
import {
advanceFakeTimersBySessionGrace,
cleanupTransports,
createPostTestCleanups,
} from '../testUtil/fixtures/cleanup';
import { TestSetupHelpers } from '../testUtil/fixtures/transports';
import { readNextResult } from '../testUtil';

const ServiceSchema = createServiceSchema();

const UnserializableServiceSchema = ServiceSchema.define({
returnSymbol: Procedure.rpc({
requestInit: Type.Object({}),
responseData: Type.Object({ id: Type.String() }),
async handler() {
return Ok({ id: 'test', extra: Symbol('unserializable') });
},
}),
streamSymbol: Procedure.subscription({
requestInit: Type.Object({}),
responseData: Type.Object({ id: Type.String() }),
async handler({ resWritable }) {
resWritable.write(Ok({ id: 'test', extra: Symbol('unserializable') }));
resWritable.close();
},
}),
});

describe('unserializable values in procedure handlers', () => {
// binary codec (msgpack) throws on Symbol, causing encode failure
// which kills the session -- only test with ws transport since mock
// transport's setImmediate chains conflict with fake timer flushing
describe.each(testMatrix(['ws', 'binary']))(
'binary codec ($transport.name transport)',
({ transport, codec }) => {
const opts = { codec: codec.codec };
const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups();
let getClientTransport: TestSetupHelpers['getClientTransport'];
let getServerTransport: TestSetupHelpers['getServerTransport'];

beforeEach(async () => {
const setup = await transport.setup({ client: opts, server: opts });
getClientTransport = setup.getClientTransport;
getServerTransport = setup.getServerTransport;

return async () => {
await postTestCleanup();
await setup.cleanup();
};
});

test('rpc handler returning symbol causes client disconnect', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { svc: UnserializableServiceSchema };
createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);
addPostTestCleanup(() =>
cleanupTransports([clientTransport, serverTransport]),
);

const resultPromise = client.svc.returnSymbol.rpc({});
await advanceFakeTimersBySessionGrace();

const result = await resultPromise;
expect(result).toMatchObject({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
},
});
});

test('subscription handler writing symbol causes client disconnect', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { svc: UnserializableServiceSchema };
createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);
addPostTestCleanup(() =>
cleanupTransports([clientTransport, serverTransport]),
);

const { resReadable } = client.svc.streamSymbol.subscribe({});
await advanceFakeTimersBySessionGrace();

const result = await readNextResult(resReadable);
expect(result).toMatchObject({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
},
});
});
},
);

// json codec silently drops Symbol values via JSON.stringify
describe.each(testMatrix(['all', 'naive']))(
'json codec ($transport.name transport)',
({ transport, codec }) => {
const opts = { codec: codec.codec };
const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups();
let getClientTransport: TestSetupHelpers['getClientTransport'];
let getServerTransport: TestSetupHelpers['getServerTransport'];

beforeEach(async () => {
const setup = await transport.setup({ client: opts, server: opts });
getClientTransport = setup.getClientTransport;
getServerTransport = setup.getServerTransport;

return async () => {
await postTestCleanup();
await setup.cleanup();
};
});

test('rpc handler returning symbol silently drops the value', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { svc: UnserializableServiceSchema };
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);
addPostTestCleanup(() =>
cleanupTransports([clientTransport, serverTransport]),
);

const result = await client.svc.returnSymbol.rpc({});
// JSON.stringify silently drops Symbol values, so the
// response arrives with the extra symbol field missing
expect(result).toStrictEqual({
ok: true,
payload: { id: 'test' },
});

await server.close();
});
},
);
});
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@replit/river",
"description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!",
"version": "0.214.0",
"version": "0.214.1",
"type": "module",
"exports": {
".": {
Expand Down
18 changes: 10 additions & 8 deletions testUtil/fixtures/cleanup.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { expect, vi } from 'vitest';
import { assert, expect, vi } from 'vitest';
import {
ClientTransport,
Connection,
OpaqueTransportMessage,
ServerTransport,
Transport,
} from '../../transport';
Expand Down Expand Up @@ -68,14 +67,17 @@ export async function ensureTransportBuffersAreEventuallyEmpty(
[...t.sessions]
.map(([client, sess]) => {
// get all messages that are not heartbeats
const buff = sess.sendBuffer.filter((msg) => {
return !Value.Check(ControlMessageAckSchema, msg.payload);
const buff = sess.sendBuffer.filter((encodedMsg) => {
const decoded = sess.codec.fromBuffer(encodedMsg.data);
assert(decoded.ok);

return !Value.Check(
ControlMessageAckSchema,
decoded.value.payload,
);
});

return [client, buff] as [
string,
ReadonlyArray<OpaqueTransportMessage>,
];
return [client, buff] as const;
})
.filter((entry) => entry[1].length > 0),
),
Expand Down
11 changes: 11 additions & 0 deletions transport/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,17 @@ export function cancelMessage(
export type OpaqueTransportMessage = TransportMessage;
export type TransportClientId = string;

/**
* An encoded message that is ready to be sent over the transport.
* The seq number is kept to track which messages have been
* acked by the peer and can be dropped from the send buffer.
*/
export interface EncodedTransportMessage {
id: string;
seq: number;
data: Uint8Array;
}

/**
* Checks if the given control flag (usually found in msg.controlFlag) is an ack message.
* @param controlFlag - The control flag to check.
Expand Down
3 changes: 2 additions & 1 deletion transport/results.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { OpaqueTransportMessage } from './message';
import { EncodedTransportMessage, OpaqueTransportMessage } from './message';

// internal use only, not to be used in public API
type SessionApiResult<T> =
Expand All @@ -13,5 +13,6 @@ type SessionApiResult<T> =

export type SendResult = SessionApiResult<string>;
export type SendBufferResult = SessionApiResult<undefined>;
export type EncodeResult = SessionApiResult<EncodedTransportMessage>;
export type SerializeResult = SessionApiResult<Uint8Array>;
export type DeserializeResult = SessionApiResult<OpaqueTransportMessage>;
63 changes: 42 additions & 21 deletions transport/sessionStateMachine/SessionConnected.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ import { Static } from '@sinclair/typebox';
import {
ControlFlags,
ControlMessageAckSchema,
EncodedTransportMessage,
OpaqueTransportMessage,
PartialTransportMessage,
TransportMessage,
isAck,
} from '../message';
import {
IdentifiedSession,
IdentifiedSessionProps,
sendMessage,
SessionState,
} from './common';
import { Connection } from '../connection';
Expand All @@ -21,7 +20,10 @@ export interface SessionConnectedListeners {
onConnectionErrored: (err: unknown) => void;
onConnectionClosed: () => void;
onMessage: (msg: OpaqueTransportMessage) => void;
onMessageSendFailure: (msg: PartialTransportMessage, reason: string) => void;
onMessageSendFailure: (
msg: Omit<EncodedTransportMessage, 'data'>,
reason: string,
) => void;
onInvalidMessage: (reason: string) => void;
}

Expand Down Expand Up @@ -57,12 +59,11 @@ export class SessionConnected<
this.startMissingHeartbeatTimeout();
}

private assertSendOrdering(constructedMsg: TransportMessage) {
if (constructedMsg.seq > this.seqSent + 1) {
const msg = `invariant violation: would have sent out of order msg (seq: ${constructedMsg.seq}, expected: ${this.seqSent} + 1)`;
private assertSendOrdering(encodedMsg: EncodedTransportMessage) {
if (encodedMsg.seq > this.seqSent + 1) {
const msg = `invariant violation: would have sent out of order msg (seq: ${encodedMsg.seq}, expected: ${this.seqSent} + 1)`;
this.log?.error(msg, {
...this.loggingMetadata,
transportMessage: constructedMsg,
tags: ['invariant-violation'],
});

Expand All @@ -71,19 +72,34 @@ export class SessionConnected<
}

send(msg: PartialTransportMessage): SendResult {
const constructedMsg = this.constructMsg(msg);
this.assertSendOrdering(constructedMsg);
this.sendBuffer.push(constructedMsg);
const res = sendMessage(this.conn, this.codec, constructedMsg);
if (!res.ok) {
this.listeners.onMessageSendFailure(constructedMsg, res.reason);

return res;
const encodeResult = this.encodeMsg(msg);
if (!encodeResult.ok) {
this.listeners.onMessageSendFailure(
{ id: 'unknown', seq: this.seq },
encodeResult.reason,
);

return encodeResult;
}

const encodedMsg = encodeResult.value;
this.assertSendOrdering(encodedMsg);
this.sendBuffer.push(encodedMsg);

const sent = this.conn.send(encodedMsg.data);
if (!sent) {
const reason = 'failed to send message';
this.listeners.onMessageSendFailure(
{ id: encodedMsg.id, seq: encodedMsg.seq },
reason,
);

return { ok: false, reason };
}

this.seqSent = constructedMsg.seq;
this.seqSent = encodedMsg.seq;

return res;
return { ok: true, value: encodedMsg.id };
}

constructor(props: SessionConnectedProps<ConnType>) {
Expand All @@ -110,11 +126,16 @@ export class SessionConnected<

for (const msg of this.sendBuffer) {
this.assertSendOrdering(msg);
const res = sendMessage(this.conn, this.codec, msg);
if (!res.ok) {
this.listeners.onMessageSendFailure(msg, res.reason);

return res;
const sent = this.conn.send(msg.data);
if (!sent) {
const reason = 'failed to send buffered message';
this.listeners.onMessageSendFailure(
{ id: msg.id, seq: msg.seq },
reason,
);

return { ok: false, reason };
}

this.seqSent = msg.seq;
Expand Down
Loading
Loading