diff --git a/__tests__/e2e.test.ts b/__tests__/e2e.test.ts index 427f33d7..33e875a6 100644 --- a/__tests__/e2e.test.ts +++ b/__tests__/e2e.test.ts @@ -530,6 +530,43 @@ describe.each(testMatrix())( }); }); + test('upload server cancel', async () => { + // setup + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { + uploadable: UploadableServiceSchema, + }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + // test + const { reqWritable, finalize } = client.uploadable.cancellableAdd.upload( + {}, + ); + reqWritable.write({ n: 9 }); + reqWritable.write({ n: 1 }); + + const result = await finalize(); + expect(result).toStrictEqual({ + ok: false, + payload: { code: CANCEL_CODE, message: "can't add more than 10" }, + }); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + test('upload with init message', async () => { // setup const clientTransport = getClientTransport('client'); diff --git a/package-lock.json b/package-lock.json index 8abad19b..55646f08 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.205.2", + "version": "0.206.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.205.2", + "version": "0.206.0", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", diff --git a/package.json b/package.json index 6e3d2be2..874ed901 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.205.2", + "version": "0.206.0", "type": "module", "exports": { ".": { diff --git a/router/context.ts b/router/context.ts index 6b48e29e..69a8f770 100644 --- a/router/context.ts +++ b/router/context.ts @@ -1,6 +1,9 @@ import { Span } from '@opentelemetry/api'; import { TransportClientId } from '../transport/message'; import { SessionId } from '../transport/sessionStateMachine/common'; +import { ErrResult } from './result'; +import { CancelResultSchema } from './errors'; +import { Static } from '@sinclair/typebox'; /** * ServiceContext exist for the purpose of declaration merging @@ -75,7 +78,7 @@ export type ProcedureHandlerContext = ServiceContext & { * Cancelling is not the same as closing procedure calls gracefully, please refer to * the river documentation to understand the difference between the two concepts. */ - cancel: () => void; + cancel: (message?: string) => ErrResult>; /** * This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) * triggered when the procedure invocation is done. This signal tracks the invocation/request finishing diff --git a/router/errors.ts b/router/errors.ts index b8dac347..ce6be85a 100644 --- a/router/errors.ts +++ b/router/errors.ts @@ -72,6 +72,14 @@ export function castTypeboxValueErrors( return result; } +/** + * A schema for cancel payloads sent from the client + */ +export const CancelResultSchema = Type.Object({ + code: Type.Literal(CANCEL_CODE), + message: Type.String(), +}); + /** * {@link ReaderErrorSchema} is the schema for all the built-in river errors that * can be emitted to a reader (request reader on the server, and response reader @@ -96,10 +104,7 @@ export const ReaderErrorSchema = Type.Union([ }), ), }), - Type.Object({ - code: Type.Literal(CANCEL_CODE), - message: Type.String(), - }), + CancelResultSchema, ]) satisfies ProcedureErrorSchemaType; /** diff --git a/router/procedures.ts b/router/procedures.ts index 7e49a706..60f25a79 100644 --- a/router/procedures.ts +++ b/router/procedures.ts @@ -3,7 +3,11 @@ import { Static, TNever, TSchema, Type } from '@sinclair/typebox'; import { ProcedureHandlerContext } from './context'; import { Result } from './result'; import { Readable, Writable } from './streams'; -import { ProcedureErrorSchemaType, ReaderErrorSchema } from './errors'; +import { + CancelResultSchema, + ProcedureErrorSchemaType, + ReaderErrorSchema, +} from './errors'; /** * Brands a type to prevent it from being directly constructed. @@ -35,6 +39,8 @@ export type ValidProcType = */ export type PayloadType = TSchema; +export type Cancellable = T | Static; + /** * Procedure for a single message in both directions (1:1). * @@ -57,7 +63,7 @@ export interface RpcProcedure< handler(param: { ctx: ProcedureHandlerContext; reqInit: Static; - }): Promise, Static>>; + }): Promise, Cancellable>>>; } /** @@ -90,7 +96,7 @@ export interface UploadProcedure< Static, Static >; - }): Promise, Static>>; + }): Promise, Cancellable>>>; } /** @@ -115,7 +121,9 @@ export interface SubscriptionProcedure< handler(param: { ctx: ProcedureHandlerContext; reqInit: Static; - resWritable: Writable, Static>>; + resWritable: Writable< + Result, Cancellable>> + >; }): Promise; } @@ -149,7 +157,9 @@ export interface StreamProcedure< Static, Static >; - resWritable: Writable, Static>>; + resWritable: Writable< + Result, Cancellable>> + >; }): Promise; } diff --git a/router/server.ts b/router/server.ts index 37efa7a1..90e4e098 100644 --- a/router/server.ts +++ b/router/server.ts @@ -1,4 +1,4 @@ -import { Static, Type } from '@sinclair/typebox'; +import { Static } from '@sinclair/typebox'; import { PayloadType, AnyProcedure } from './procedures'; import { ReaderErrorSchema, @@ -7,9 +7,9 @@ import { CANCEL_CODE, INVALID_REQUEST_CODE, BaseErrorSchemaType, - ErrResultSchema, ValidationErrors, castTypeboxValueErrors, + CancelResultSchema, } from './errors'; import { AnyService, @@ -53,16 +53,6 @@ import { SessionBoundSendFn } from '../transport/transport'; type StreamId = string; -/** - * A schema for cancel payloads sent from the client - */ -const CancelResultSchema = ErrResultSchema( - Type.Object({ - code: Type.Literal(CANCEL_CODE), - message: Type.String(), - }), -); - /** * Represents a server with a set of services. Use {@link createServer} to create it. * @template Services - The type of services provided by the server. @@ -293,9 +283,9 @@ class RiverServer } if (isStreamCancelBackwardsCompat(msg.controlFlags, protocolVersion)) { - let cancelResult: Static; + let cancelResult: ErrResult>; if (Value.Check(CancelResultSchema, msg.payload)) { - cancelResult = msg.payload; + cancelResult = Err(msg.payload); } else { // If the payload is unexpected, then we just construct our own cancel result cancelResult = Err({ @@ -576,11 +566,15 @@ class RiverServer sessionId, metadata: sessionMetadata, span, - cancel: () => { - onServerCancel({ + cancel: (message?: string) => { + const errRes = { code: CANCEL_CODE, - message: 'cancelled by server procedure handler', - }); + message: message ?? 'cancelled by server procedure handler', + } as const; + + onServerCancel(errRes); + + return Err(errRes); }, signal: finishedController.signal, }); diff --git a/testUtil/fixtures/services.ts b/testUtil/fixtures/services.ts index 40933a29..971ac964 100644 --- a/testUtil/fixtures/services.ts +++ b/testUtil/fixtures/services.ts @@ -287,6 +287,25 @@ export const UploadableServiceSchema = ServiceSchema.define({ return Ok({ result: `${reqInit.prefix} ${result}` }); }, }), + + cancellableAdd: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ result: Type.Number() }), + async handler({ ctx, reqReadable }) { + let result = 0; + for await (const req of reqReadable) { + const n = unwrapOrThrow(req).n; + if (result + n >= 10) { + return ctx.cancel("can't add more than 10"); + } + + result += n; + } + + return Ok({ result: result }); + }, + }), }); const RecursivePayload = Type.Recursive((This) =>