Skip to content

Commit 230de7d

Browse files
findleyryasomaru
authored andcommitted
mcp: improvements for 'stateless' streamable servers; 'distributed' mode
Several improvements for the stateless streamable mode, plus support for a 'distributed' (or rather, distributable) version of the stateless server. - Add a 'Stateless' option to StreamableHTTPOptions and StreamableServerTransport, which controls stateless behavior. GetSessionID may still return a non-empty session ID. - Audit validation of stateless mode to allow requests with a session id. Propagate this session ID to the temporary connection. - Peek at requests to allow 'initialize' requests to go through to the session, so that version negotiation can occur (FIXME: add tests). Fixes modelcontextprotocol#284 For modelcontextprotocol#148
1 parent 755f644 commit 230de7d

File tree

5 files changed

+203
-37
lines changed

5 files changed

+203
-37
lines changed

internal/jsonrpc2/conn.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e
739739
// write is used by all things that write outgoing messages, including replies.
740740
// it makes sure that writes are atomic
741741
func (c *Connection) write(ctx context.Context, msg Message) error {
742-
err := c.writer.Write(ctx, msg)
742+
var err error
743+
// Fail writes immediately if the connection is shutting down.
744+
//
745+
// TODO(rfindley): should we allow cancellation notifations through? It could
746+
// be the case that writes can still succeed.
747+
c.updateInFlight(func(s *inFlightState) {
748+
err = s.shuttingDown(ErrServerClosing)
749+
})
750+
if err == nil {
751+
err = c.writer.Write(ctx, msg)
752+
}
753+
754+
// For rejected requests, we don't set the writeErr (which would break the
755+
// connection). They can just be returned to the caller.
756+
if errors.Is(err, ErrRejected) {
757+
return err
758+
}
743759

744760
if err != nil && ctx.Err() == nil {
745761
// The call to Write failed, and since ctx.Err() is nil we can't attribute

internal/jsonrpc2/wire.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ var (
3737
ErrServerClosing = NewError(-32004, "server is closing")
3838
// ErrClientClosing is a dummy error returned for calls initiated while the client is closing.
3939
ErrClientClosing = NewError(-32003, "client is closing")
40+
41+
// The following errors have special semantics for MCP transports
42+
43+
// ErrRejected may be wrapped to return errors from calls to Writer.Write
44+
// that signal that the request was rejected by the transport layer as
45+
// invalid.
46+
//
47+
// Such failures do not indicate that the connection is broken, but rather
48+
// should be returned to the caller to indicate that the specific request is
49+
// invalid in the current context.
50+
ErrRejected = NewError(-32004, "rejected by transport")
4051
)
4152

4253
const wireVersion = "2.0"

mcp/streamable.go

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,29 @@ type StreamableHTTPHandler struct {
3838
getServer func(*http.Request) *Server
3939
opts StreamableHTTPOptions
4040

41-
mu sync.Mutex
41+
mu sync.Mutex
42+
// TODO: we should store the ServerSession along with the transport, because
43+
// we need to cancel keepalive requests when closing the transport.
4244
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
4345
}
4446

4547
// StreamableHTTPOptions configures the StreamableHTTPHandler.
4648
type StreamableHTTPOptions struct {
4749
// GetSessionID provides the next session ID to use for an incoming request.
50+
// If nil, a default randomly generated ID will be used.
4851
//
49-
// If GetSessionID returns an empty string, the session is 'stateless',
50-
// meaning it is not persisted and no session validation is performed.
52+
// As a special case, if GetSessionID returns the empty string, the
53+
// Mcp-Session-Id header will not be set.
5154
GetSessionID func() string
5255

56+
// Stateless controls whether the session is 'stateless'.
57+
//
58+
// A stateless server does not validate the Mcp-Session-Id header, and uses a
59+
// temporary session with default initialization parameters. Any
60+
// server->client request is rejected immediately as there's no way for the
61+
// client to respond.
62+
Stateless bool
63+
5364
// TODO: support session retention (?)
5465

5566
// jsonResponse is forwarded to StreamableServerTransport.jsonResponse.
@@ -118,36 +129,39 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
118129
return
119130
}
120131

132+
sessionID := req.Header.Get(sessionIDHeader)
121133
var transport *StreamableServerTransport
122-
if id := req.Header.Get(sessionIDHeader); id != "" {
134+
if sessionID != "" {
123135
h.mu.Lock()
124-
transport = h.transports[id]
136+
transport, _ = h.transports[sessionID]
125137
h.mu.Unlock()
126-
if transport == nil {
138+
if transport == nil && !h.opts.Stateless {
139+
// In stateless mode we allow a missing transport.
140+
//
141+
// A synthetic transport will be created below for the transient session.
127142
http.Error(w, "session not found", http.StatusNotFound)
128143
return
129144
}
130145
}
131146

132-
// TODO(rfindley): simplify the locking so that each request has only one
133-
// critical section.
134147
if req.Method == http.MethodDelete {
135-
if transport == nil {
136-
// => Mcp-Session-Id was not set; else we'd have returned NotFound above.
148+
if sessionID == "" {
137149
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
138150
return
139151
}
140-
h.mu.Lock()
141-
delete(h.transports, transport.SessionID)
142-
h.mu.Unlock()
143-
transport.connection.Close()
152+
if transport != nil { // transport may be nil in stateless mode
153+
h.mu.Lock()
154+
delete(h.transports, transport.SessionID)
155+
h.mu.Unlock()
156+
transport.connection.Close()
157+
}
144158
w.WriteHeader(http.StatusNoContent)
145159
return
146160
}
147161

148162
switch req.Method {
149163
case http.MethodPost, http.MethodGet:
150-
if req.Method == http.MethodGet && transport == nil {
164+
if req.Method == http.MethodGet && sessionID == "" {
151165
http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed)
152166
return
153167
}
@@ -164,37 +178,83 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
164178
http.Error(w, "no server available", http.StatusBadRequest)
165179
return
166180
}
167-
sessionID := h.opts.GetSessionID()
168-
s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse}
181+
if sessionID == "" {
182+
// In stateless mode, sessionID may be nonempty even if there's no
183+
// existing transport.
184+
sessionID = h.opts.GetSessionID()
185+
}
186+
transport = &StreamableServerTransport{
187+
SessionID: sessionID,
188+
Stateless: h.opts.Stateless,
189+
jsonResponse: h.opts.jsonResponse,
190+
}
169191

170192
// To support stateless mode, we initialize the session with a default
171193
// state, so that it doesn't reject subsequent requests.
172194
var connectOpts *ServerSessionOptions
173-
if sessionID == "" {
195+
if h.opts.Stateless {
196+
// Peek at the body to see if it is initialize or initialized.
197+
// We want those to be handled as usual.
198+
var hasInitialize, hasInitialized bool
199+
{
200+
// TODO: verify that this allows protocol version negotiation for
201+
// stateless servers.
202+
body, err := io.ReadAll(req.Body)
203+
if err != nil {
204+
http.Error(w, "failed to read body", http.StatusBadRequest)
205+
return
206+
}
207+
req.Body.Close()
208+
209+
// Reset the body so that it can be read later.
210+
req.Body = io.NopCloser(bytes.NewBuffer(body))
211+
212+
msgs, _, err := readBatch(body)
213+
if err == nil {
214+
for _, msg := range msgs {
215+
if req, ok := msg.(*jsonrpc.Request); ok {
216+
switch req.Method {
217+
case methodInitialize:
218+
hasInitialize = true
219+
case notificationInitialized:
220+
hasInitialized = true
221+
}
222+
}
223+
}
224+
}
225+
}
226+
227+
// If we don't have InitializeParams or InitializedParams in the request,
228+
// set the initial state to a default value.
229+
state := new(ServerSessionState)
230+
if !hasInitialize {
231+
state.InitializeParams = new(InitializeParams)
232+
}
233+
if !hasInitialized {
234+
state.InitializedParams = new(InitializedParams)
235+
}
174236
connectOpts = &ServerSessionOptions{
175-
State: &ServerSessionState{
176-
InitializeParams: new(InitializeParams),
177-
InitializedParams: new(InitializedParams),
178-
},
237+
State: state,
179238
}
180239
}
240+
181241
// Pass req.Context() here, to allow middleware to add context values.
182242
// The context is detached in the jsonrpc2 library when handling the
183243
// long-running stream.
184-
ss, err := server.Connect(req.Context(), s, connectOpts)
244+
ss, err := server.Connect(req.Context(), transport, connectOpts)
185245
if err != nil {
186246
http.Error(w, "failed connection", http.StatusInternalServerError)
187247
return
188248
}
189-
if sessionID == "" {
249+
if h.opts.Stateless {
190250
// Stateless mode: close the session when the request exits.
191251
defer ss.Close() // close the fake session after handling the request
192252
} else {
253+
// Otherwise, save the transport so that it can be reused
193254
h.mu.Lock()
194-
h.transports[s.SessionID] = s
255+
h.transports[transport.SessionID] = transport
195256
h.mu.Unlock()
196257
}
197-
transport = s
198258
}
199259

200260
transport.ServeHTTP(w, req)
@@ -225,6 +285,13 @@ type StreamableServerTransport struct {
225285
// generator to produce one, as with [crypto/rand.Text].)
226286
SessionID string
227287

288+
// Stateless controls whether the eventstore is 'Stateless'. Servers sessions
289+
// connected to a stateless transport are disallowed from making outgoing
290+
// requests.
291+
//
292+
// See also [StreamableHTTPOptions.Stateless].
293+
Stateless bool
294+
228295
// Storage for events, to enable stream resumption.
229296
// If nil, a [MemoryEventStore] with the default maximum size will be used.
230297
EventStore EventStore
@@ -265,6 +332,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
265332
}
266333
t.connection = &streamableServerConn{
267334
sessionID: t.SessionID,
335+
stateless: t.Stateless,
268336
eventStore: t.EventStore,
269337
jsonResponse: t.jsonResponse,
270338
incoming: make(chan jsonrpc.Message, 10),
@@ -285,6 +353,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
285353

286354
type streamableServerConn struct {
287355
sessionID string
356+
stateless bool
288357
jsonResponse bool
289358
eventStore EventStore
290359

@@ -755,6 +824,10 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error
755824

756825
// Write implements the [Connection] interface.
757826
func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error {
827+
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") {
828+
// Requests aren't possible with stateless servers, or when there's no session ID.
829+
return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected)
830+
}
758831
// Find the incoming request that this write relates to, if any.
759832
var forRequest jsonrpc.ID
760833
isResponse := false

mcp/streamable_test.go

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream
748748
if !request.ignoreResponse {
749749
transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() })
750750
if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" {
751-
t.Errorf("received unexpected messages (-want +got):\n%s", diff)
751+
t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff)
752752
}
753753
}
754754
sessionID.CompareAndSwap("", gotSessionID)
@@ -996,19 +996,18 @@ func TestEventID(t *testing.T) {
996996
}
997997

998998
func TestStreamableStateless(t *testing.T) {
999-
// This version of sayHi doesn't make a ping request (we can't respond to
999+
// This version of sayHi expects
10001000
// that request from our client).
10011001
sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) {
1002+
if err := req.Session.Ping(ctx, nil); err == nil {
1003+
// ping should fail, but not break the connection
1004+
t.Errorf("ping succeeded unexpectedly")
1005+
}
10021006
return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil
10031007
}
10041008
server := NewServer(testImpl, nil)
10051009
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
10061010

1007-
// Test stateless mode.
1008-
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1009-
GetSessionID: func() string { return "" },
1010-
})
1011-
10121011
requests := []streamableRequest{
10131012
{
10141013
method: "POST",
@@ -1028,7 +1027,74 @@ func TestStreamableStateless(t *testing.T) {
10281027
},
10291028
wantSessionID: false,
10301029
},
1030+
{
1031+
method: "POST",
1032+
wantStatusCode: http.StatusOK,
1033+
messages: []jsonrpc.Message{
1034+
req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}),
1035+
},
1036+
wantMessages: []jsonrpc.Message{
1037+
resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi foo"}}}, nil),
1038+
},
1039+
wantSessionID: false,
1040+
},
1041+
}
1042+
1043+
testClientCompatibility := func(t *testing.T, handler http.Handler) {
1044+
ctx := context.Background()
1045+
httpServer := httptest.NewServer(handler)
1046+
defer httpServer.Close()
1047+
cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
1048+
if err != nil {
1049+
t.Fatal(err)
1050+
}
1051+
res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}})
1052+
if err != nil {
1053+
t.Fatal(err)
1054+
}
1055+
if got, want := textContent(t, res), "hi bar"; got != want {
1056+
t.Errorf("Result = %q, want %q", got, want)
1057+
}
10311058
}
10321059

1033-
testStreamableHandler(t, handler, requests)
1060+
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1061+
GetSessionID: func() string { return "" },
1062+
Stateless: true,
1063+
})
1064+
1065+
// Test the default stateless mode.
1066+
t.Run("stateless", func(t *testing.T) {
1067+
testStreamableHandler(t, handler, requests)
1068+
testClientCompatibility(t, handler)
1069+
})
1070+
1071+
// Test a "distributed" variant of stateless mode, where it has non-empty
1072+
// session IDs, but is otherwise stateless.
1073+
//
1074+
// This can be used by tools to look up application state preserved across
1075+
// subsequent requests.
1076+
for i, req := range requests {
1077+
// Now, we want a session for all requests.
1078+
req.wantSessionID = true
1079+
requests[i] = req
1080+
}
1081+
distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1082+
Stateless: true,
1083+
})
1084+
t.Run("distributed", func(t *testing.T) {
1085+
testStreamableHandler(t, distributableHandler, requests)
1086+
testClientCompatibility(t, handler)
1087+
})
1088+
}
1089+
1090+
func textContent(t *testing.T, res *CallToolResult) string {
1091+
t.Helper()
1092+
if len(res.Content) != 1 {
1093+
t.Fatalf("len(Content) = %d, want 1", len(res.Content))
1094+
}
1095+
text, ok := res.Content[0].(*TextContent)
1096+
if !ok {
1097+
t.Fatalf("Content[0] is %T, want *TextContent", res.Content[0])
1098+
}
1099+
return text.Text
10341100
}

mcp/transport.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ type Transport interface {
4040
type Connection interface {
4141
// Read reads the next message to process off the connection.
4242
//
43-
// Read need not be safe for concurrent use: Read is called in a
44-
// concurrency-safe manner by the JSON-RPC library.
43+
// Connections must allow Read to be called concurrently with Close. In
44+
// particular, calling Close should unblock a Read waiting for input.
4545
Read(context.Context) (jsonrpc.Message, error)
4646

4747
// Write writes a new message to the connection.

0 commit comments

Comments
 (0)