Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions __tests__/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,43 @@ describe.each(testMatrix())(
});
});

test('upload server cancel', async () => {
// setup
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = {
uploadable: UploadableServiceSchema,
};
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

// test
const { reqWritable, finalize } = client.uploadable.cancellableAdd.upload(
{},
);
reqWritable.write({ n: 9 });
reqWritable.write({ n: 1 });

const result = await finalize();
expect(result).toStrictEqual({
ok: false,
payload: { code: CANCEL_CODE, message: "can't add more than 10" },
});

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('upload with init message', async () => {
// setup
const clientTransport = getClientTransport('client');
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@replit/river",
"description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!",
"version": "0.205.2",
"version": "0.206.0",
"type": "module",
"exports": {
".": {
Expand Down
5 changes: 4 additions & 1 deletion router/context.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import { Span } from '@opentelemetry/api';
import { TransportClientId } from '../transport/message';
import { SessionId } from '../transport/sessionStateMachine/common';
import { ErrResult } from './result';
import { CancelResultSchema } from './errors';
import { Static } from '@sinclair/typebox';

/**
* ServiceContext exist for the purpose of declaration merging
Expand Down Expand Up @@ -75,7 +78,7 @@ export type ProcedureHandlerContext<State> = ServiceContext & {
* Cancelling is not the same as closing procedure calls gracefully, please refer to
* the river documentation to understand the difference between the two concepts.
*/
cancel: () => void;
cancel: (message?: string) => ErrResult<Static<typeof CancelResultSchema>>;
/**
* This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal)
* triggered when the procedure invocation is done. This signal tracks the invocation/request finishing
Expand Down
13 changes: 9 additions & 4 deletions router/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ export function castTypeboxValueErrors(
return result;
}

/**
* A schema for cancel payloads sent from the client
*/
export const CancelResultSchema = Type.Object({
code: Type.Literal(CANCEL_CODE),
message: Type.String(),
});

/**
* {@link ReaderErrorSchema} is the schema for all the built-in river errors that
* can be emitted to a reader (request reader on the server, and response reader
Expand All @@ -96,10 +104,7 @@ export const ReaderErrorSchema = Type.Union([
}),
),
}),
Type.Object({
code: Type.Literal(CANCEL_CODE),
message: Type.String(),
}),
CancelResultSchema,
]) satisfies ProcedureErrorSchemaType;

/**
Expand Down
20 changes: 15 additions & 5 deletions router/procedures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import { Static, TNever, TSchema, Type } from '@sinclair/typebox';
import { ProcedureHandlerContext } from './context';
import { Result } from './result';
import { Readable, Writable } from './streams';
import { ProcedureErrorSchemaType, ReaderErrorSchema } from './errors';
import {
CancelResultSchema,
ProcedureErrorSchemaType,
ReaderErrorSchema,
} from './errors';

/**
* Brands a type to prevent it from being directly constructed.
Expand Down Expand Up @@ -35,6 +39,8 @@ export type ValidProcType =
*/
export type PayloadType = TSchema;

export type Cancellable<T> = T | Static<typeof CancelResultSchema>;

/**
* Procedure for a single message in both directions (1:1).
*
Expand All @@ -57,7 +63,7 @@ export interface RpcProcedure<
handler(param: {
ctx: ProcedureHandlerContext<State>;
reqInit: Static<RequestInit>;
}): Promise<Result<Static<ResponseData>, Static<ResponseErr>>>;
}): Promise<Result<Static<ResponseData>, Cancellable<Static<ResponseErr>>>>;
}

/**
Expand Down Expand Up @@ -90,7 +96,7 @@ export interface UploadProcedure<
Static<RequestData>,
Static<typeof ReaderErrorSchema>
>;
}): Promise<Result<Static<ResponseData>, Static<ResponseErr>>>;
}): Promise<Result<Static<ResponseData>, Cancellable<Static<ResponseErr>>>>;
}

/**
Expand All @@ -115,7 +121,9 @@ export interface SubscriptionProcedure<
handler(param: {
ctx: ProcedureHandlerContext<State>;
reqInit: Static<RequestInit>;
resWritable: Writable<Result<Static<ResponseData>, Static<ResponseErr>>>;
resWritable: Writable<
Result<Static<ResponseData>, Cancellable<Static<ResponseErr>>>
>;
}): Promise<void | undefined>;
}

Expand Down Expand Up @@ -149,7 +157,9 @@ export interface StreamProcedure<
Static<RequestData>,
Static<typeof ReaderErrorSchema>
>;
resWritable: Writable<Result<Static<ResponseData>, Static<ResponseErr>>>;
resWritable: Writable<
Result<Static<ResponseData>, Cancellable<Static<ResponseErr>>>
>;
}): Promise<void | undefined>;
}

Expand Down
30 changes: 12 additions & 18 deletions router/server.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Static, Type } from '@sinclair/typebox';
import { Static } from '@sinclair/typebox';
import { PayloadType, AnyProcedure } from './procedures';
import {
ReaderErrorSchema,
Expand All @@ -7,9 +7,9 @@ import {
CANCEL_CODE,
INVALID_REQUEST_CODE,
BaseErrorSchemaType,
ErrResultSchema,
ValidationErrors,
castTypeboxValueErrors,
CancelResultSchema,
} from './errors';
import {
AnyService,
Expand Down Expand Up @@ -53,16 +53,6 @@ import { SessionBoundSendFn } from '../transport/transport';

type StreamId = string;

/**
* A schema for cancel payloads sent from the client
*/
const CancelResultSchema = ErrResultSchema(
Type.Object({
code: Type.Literal(CANCEL_CODE),
message: Type.String(),
}),
);

/**
* Represents a server with a set of services. Use {@link createServer} to create it.
* @template Services - The type of services provided by the server.
Expand Down Expand Up @@ -293,9 +283,9 @@ class RiverServer<Services extends AnyServiceSchemaMap>
}

if (isStreamCancelBackwardsCompat(msg.controlFlags, protocolVersion)) {
let cancelResult: Static<typeof CancelResultSchema>;
let cancelResult: ErrResult<Static<typeof CancelResultSchema>>;
if (Value.Check(CancelResultSchema, msg.payload)) {
cancelResult = msg.payload;
cancelResult = Err(msg.payload);
} else {
// If the payload is unexpected, then we just construct our own cancel result
cancelResult = Err({
Expand Down Expand Up @@ -576,11 +566,15 @@ class RiverServer<Services extends AnyServiceSchemaMap>
sessionId,
metadata: sessionMetadata,
span,
cancel: () => {
onServerCancel({
cancel: (message?: string) => {
const errRes = {
code: CANCEL_CODE,
message: 'cancelled by server procedure handler',
});
message: message ?? 'cancelled by server procedure handler',
} as const;

onServerCancel(errRes);

return Err(errRes);
},
signal: finishedController.signal,
});
Expand Down
19 changes: 19 additions & 0 deletions testUtil/fixtures/services.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,25 @@ export const UploadableServiceSchema = ServiceSchema.define({
return Ok({ result: `${reqInit.prefix} ${result}` });
},
}),

cancellableAdd: Procedure.upload({
requestInit: Type.Object({}),
requestData: Type.Object({ n: Type.Number() }),
responseData: Type.Object({ result: Type.Number() }),
async handler({ ctx, reqReadable }) {
let result = 0;
for await (const req of reqReadable) {
const n = unwrapOrThrow(req).n;
if (result + n >= 10) {
return ctx.cancel("can't add more than 10");
}

result += n;
}

return Ok({ result: result });
},
}),
});

const RecursivePayload = Type.Recursive((This) =>
Expand Down