Skip to content

Commit 076baf9

Browse files
findleyryasomaru
authored andcommitted
mcp: fix reconnect semantics for hanging GET
A few problems with reconnection cropped up in the review of PR modelcontextprotocol#307. We should allow for the hanging GET to fail with StatusMethodNotAllowed. This simply means that the server does not support sending notifications or requests over the GET, which is allowed in the spec. Also, we should fix the initial delay of the hanging GET request: it should start with 0 delay. Fix the math for this and subsequent attempts. Incidentally, this makes the tests take 3s on my machine, down from 9s. Also address some comments from modelcontextprotocol#307.
1 parent 230de7d commit 076baf9

File tree

2 files changed

+43
-29
lines changed

2 files changed

+43
-29
lines changed

internal/jsonrpc2/conn.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,8 +742,8 @@ func (c *Connection) write(ctx context.Context, msg Message) error {
742742
var err error
743743
// Fail writes immediately if the connection is shutting down.
744744
//
745-
// TODO(rfindley): should we allow cancellation notifations through? It could
746-
// be the case that writes can still succeed.
745+
// TODO(rfindley): should we allow cancellation notifications through? It
746+
// could be the case that writes can still succeed.
747747
c.updateInFlight(func(s *inFlightState) {
748748
err = s.shuttingDown(ErrServerClosing)
749749
})

mcp/streamable.go

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ type StreamableHTTPOptions struct {
4949
// GetSessionID provides the next session ID to use for an incoming request.
5050
// If nil, a default randomly generated ID will be used.
5151
//
52+
// Session IDs should be globally unique across the scope of the server,
53+
// which may span multiple processes in the case of distributed servers.
54+
//
5255
// As a special case, if GetSessionID returns the empty string, the
5356
// Mcp-Session-Id header will not be set.
5457
GetSessionID func() string
@@ -58,7 +61,9 @@ type StreamableHTTPOptions struct {
5861
// A stateless server does not validate the Mcp-Session-Id header, and uses a
5962
// temporary session with default initialization parameters. Any
6063
// server->client request is rejected immediately as there's no way for the
61-
// client to respond.
64+
// client to respond. Server->Client notifications may reach the client if
65+
// they are made in the context of an incoming request, as described in the
66+
// documentation for [StreamableServerTransport].
6267
Stateless bool
6368

6469
// TODO: support session retention (?)
@@ -133,12 +138,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
133138
var transport *StreamableServerTransport
134139
if sessionID != "" {
135140
h.mu.Lock()
136-
transport, _ = h.transports[sessionID]
141+
transport = h.transports[sessionID]
137142
h.mu.Unlock()
138143
if transport == nil && !h.opts.Stateless {
139-
// In stateless mode we allow a missing transport.
144+
// Unless we're in 'stateless' mode, which doesn't perform any Session-ID
145+
// validation, we require that the session ID matches a known session.
140146
//
141-
// A synthetic transport will be created below for the transient session.
147+
// In stateless mode, a temporary transport is be created below.
142148
http.Error(w, "session not found", http.StatusNotFound)
143149
return
144150
}
@@ -201,7 +207,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
201207
// stateless servers.
202208
body, err := io.ReadAll(req.Body)
203209
if err != nil {
204-
http.Error(w, "failed to read body", http.StatusBadRequest)
210+
http.Error(w, "failed to read body", http.StatusInternalServerError)
205211
return
206212
}
207213
req.Body.Close()
@@ -272,9 +278,22 @@ type StreamableServerTransportOptions struct {
272278
// A StreamableServerTransport implements the server side of the MCP streamable
273279
// transport.
274280
//
275-
// Each StreamableServerTransport may be connected (via [Server.Connect]) at
281+
// Each StreamableServerTransport must be connected (via [Server.Connect]) at
276282
// most once, since [StreamableServerTransport.ServeHTTP] serves messages to
277283
// the connected session.
284+
//
285+
// Reads from the streamable server connection receive messages from http POST
286+
// requests from the client. Writes to the streamable server connection are
287+
// sent either to the hanging POST response, or to the hanging GET, according
288+
// to the following rules:
289+
// - JSON-RPC responses to incoming requests are always routed to the
290+
// appropriate HTTP response.
291+
// - Requests or notifications made with a context.Context value derived from
292+
// an incoming request handler, are routed to the HTTP response
293+
// corresponding to that request, unless it has already terminated, in
294+
// which case they are routed to the hanging GET.
295+
// - Requests or notifications made with a detached context.Context value are
296+
// routed to the hanging GET.
278297
type StreamableServerTransport struct {
279298
// SessionID is the ID of this session.
280299
//
@@ -285,7 +304,7 @@ type StreamableServerTransport struct {
285304
// generator to produce one, as with [crypto/rand.Text].)
286305
SessionID string
287306

288-
// Stateless controls whether the eventstore is 'Stateless'. Servers sessions
307+
// Stateless controls whether the eventstore is 'Stateless'. Server sessions
289308
// connected to a stateless transport are disallowed from making outgoing
290309
// requests.
291310
//
@@ -1225,9 +1244,18 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
12251244
c.fail(err)
12261245
return
12271246
}
1228-
1229-
// Reconnection was successful. Continue the loop with the new response.
12301247
resp = newResp
1248+
if resp.StatusCode == http.StatusMethodNotAllowed && persistent {
1249+
// The server doesn't support the hanging GET.
1250+
resp.Body.Close()
1251+
return
1252+
}
1253+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1254+
resp.Body.Close()
1255+
c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode)))
1256+
return
1257+
}
1258+
// Reconnection was successful. Continue the loop with the new response.
12311259
}
12321260
}
12331261

@@ -1295,13 +1323,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
12951323
finalErr = err // Store the error and try again.
12961324
continue
12971325
}
1298-
1299-
if !isResumable(resp) {
1300-
// The server indicated we should not continue.
1301-
resp.Body.Close()
1302-
return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status)
1303-
}
1304-
13051326
return resp, nil
13061327
}
13071328
}
@@ -1312,16 +1333,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
13121333
return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries)
13131334
}
13141335

1315-
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
1316-
func isResumable(resp *http.Response) bool {
1317-
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
1318-
if resp.StatusCode == http.StatusMethodNotAllowed {
1319-
return false
1320-
}
1321-
1322-
return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
1323-
}
1324-
13251336
// Close implements the [Connection] interface.
13261337
func (c *streamableClientConn) Close() error {
13271338
c.closeOnce.Do(func() {
@@ -1361,8 +1372,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
13611372

13621373
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
13631374
func calculateReconnectDelay(attempt int) time.Duration {
1375+
if attempt == 0 {
1376+
return 0
1377+
}
13641378
// Calculate the exponential backoff using the grow factor.
1365-
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt)))
1379+
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1)))
13661380
// Cap the backoffDuration at maxDelay.
13671381
backoffDuration = min(backoffDuration, reconnectMaxDelay)
13681382

0 commit comments

Comments
 (0)