Skip to content

Commit 1afdb1f

Browse files
committed
mcp: several fixes for streamable client reconnection
Do a pass through the streamable client reconnection logic, and fix several bugs. - Establish the initial GET even if MaxRetries is 0 (modelcontextprotocol#256). This was broken because the GET bypasses the initial request and going straight to the SSE GET reconnection logic. - Release the stream ownership when POST requests exit. - Don't reconnect POST requests if we've received the expected response. - Move unexported reconnection config to constants. Otherwise it is too hard to set ReconnectOptions (you have to use the DefaultOptions and mutate). Fixes modelcontextprotocol#256
1 parent 0a8fe40 commit 1afdb1f

File tree

2 files changed

+160
-88
lines changed

2 files changed

+160
-88
lines changed

mcp/streamable.go

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ func (c *streamableServerConn) SessionID() string {
329329
// A stream is a single logical stream of SSE events within a server session.
330330
// A stream begins with a client request, or with a client GET that has
331331
// no Last-Event-ID header.
332+
//
332333
// A stream ends only when its session ends; we cannot determine its end otherwise,
333334
// since a client may send a GET with a Last-Event-ID that references the stream
334335
// at any time.
@@ -529,6 +530,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
529530
}
530531
c.mu.Unlock()
531532
stream.signal.Store(signalChanPtr())
533+
defer stream.signal.Store(nil)
532534
}
533535

534536
// Publish incoming messages.
@@ -857,27 +859,27 @@ type StreamableReconnectOptions struct {
857859
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
858860
// A value of 0 or less means never retry.
859861
MaxRetries int
860-
861-
// growFactor is the multiplicative factor by which the delay increases after each attempt.
862-
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
863-
// It must be 1.0 or greater if MaxRetries is greater than 0.
864-
growFactor float64
865-
866-
// initialDelay is the base delay for the first reconnect attempt.
867-
initialDelay time.Duration
868-
869-
// maxDelay caps the backoff delay, preventing it from growing indefinitely.
870-
maxDelay time.Duration
871862
}
872863

873864
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
874865
var DefaultReconnectOptions = &StreamableReconnectOptions{
875-
MaxRetries: 5,
876-
growFactor: 1.5,
877-
initialDelay: 1 * time.Second,
878-
maxDelay: 30 * time.Second,
866+
MaxRetries: 5,
879867
}
880868

869+
// These settings are not (yet) exposed to the user in
870+
// StreamableReconnectOptions. Since they're invisible, keep them const rather
871+
// than requiring the user to start from DefaultReconnectOptions and mutate.
872+
const (
873+
// reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt.
874+
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
875+
// It must be 1.0 or greater if MaxRetries is greater than 0.
876+
reconnectGrowFactor = 1.5
877+
// reconnectInitialDelay is the base delay for the first reconnect attempt.
878+
reconnectInitialDelay = 1 * time.Second
879+
// reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely.
880+
reconnectMaxDelay = 30 * time.Second
881+
)
882+
881883
// StreamableClientTransportOptions provides options for the
882884
// [NewStreamableClientTransport] constructor.
883885
//
@@ -928,7 +930,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
928930
conn := &streamableClientConn{
929931
url: t.Endpoint,
930932
client: client,
931-
incoming: make(chan []byte, 100),
933+
incoming: make(chan jsonrpc.Message, 10),
932934
done: make(chan struct{}),
933935
ReconnectOptions: reconnOpts,
934936
ctx: connCtx,
@@ -944,7 +946,7 @@ type streamableClientConn struct {
944946
client *http.Client
945947
ctx context.Context
946948
cancel context.CancelFunc
947-
incoming chan []byte
949+
incoming chan jsonrpc.Message
948950

949951
// Guard calls to Close, as it may be called multiple times.
950952
closeOnce sync.Once
@@ -988,7 +990,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
988990
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
989991
// ID at initialization time, by including it in an Mcp-Session-Id header
990992
// on the HTTP response containing the InitializeResult.
991-
go c.handleSSE(nil, true)
993+
go c.handleSSE(nil, true, nil)
992994
}
993995

994996
// fail handles an asynchronous error while reading.
@@ -1031,8 +1033,8 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error
10311033
return nil, c.failure()
10321034
case <-c.done:
10331035
return nil, io.EOF
1034-
case data := <-c.incoming:
1035-
return jsonrpc2.DecodeMessage(data)
1036+
case msg := <-c.incoming:
1037+
return msg, nil
10361038
}
10371039
}
10381040

@@ -1042,7 +1044,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
10421044
return err
10431045
}
10441046

1045-
data, err := jsonrpc2.EncodeMessage(msg)
1047+
data, err := jsonrpc.EncodeMessage(msg)
10461048
if err != nil {
10471049
return err
10481050
}
@@ -1088,7 +1090,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
10881090
go c.handleJSON(resp)
10891091

10901092
case "text/event-stream":
1091-
go c.handleSSE(resp, false)
1093+
jsonReq, _ := msg.(*jsonrpc.Request)
1094+
go c.handleSSE(resp, false, jsonReq)
10921095

10931096
default:
10941097
resp.Body.Close()
@@ -1116,30 +1119,40 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) {
11161119
c.fail(err)
11171120
return
11181121
}
1122+
msg, err := jsonrpc.DecodeMessage(body)
1123+
if err != nil {
1124+
c.fail(fmt.Errorf("failed to decode response: %v", err))
1125+
return
1126+
}
11191127
select {
1120-
case c.incoming <- body:
1128+
case c.incoming <- msg:
11211129
case <-c.done:
11221130
// The connection was closed by the client; exit gracefully.
11231131
}
11241132
}
11251133

11261134
// handleSSE manages the lifecycle of an SSE connection. It can be either
11271135
// persistent (for the main GET listener) or temporary (for a POST response).
1128-
func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) {
1136+
//
1137+
// If forReq is set, it is the request that initiated the stream, and the
1138+
// stream is complete when we receive its response.
1139+
func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
11291140
resp := initialResp
11301141
var lastEventID string
11311142
for {
1132-
eventID, clientClosed := c.processStream(resp)
1133-
lastEventID = eventID
1143+
if resp != nil {
1144+
eventID, clientClosed := c.processStream(resp, forReq)
1145+
lastEventID = eventID
11341146

1135-
// If the connection was closed by the client, we're done.
1136-
if clientClosed {
1137-
return
1138-
}
1139-
// If the stream has ended, then do not reconnect if the stream is
1140-
// temporary (POST initiated SSE).
1141-
if lastEventID == "" && !persistent {
1142-
return
1147+
// If the connection was closed by the client, we're done.
1148+
if clientClosed {
1149+
return
1150+
}
1151+
// If the stream has ended, then do not reconnect if the stream is
1152+
// temporary (POST initiated SSE).
1153+
if lastEventID == "" && !persistent {
1154+
return
1155+
}
11431156
}
11441157

11451158
// The stream was interrupted or ended by the server. Attempt to reconnect.
@@ -1159,12 +1172,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
11591172
// incoming channel. It returns the ID of the last processed event and a flag
11601173
// indicating if the connection was closed by the client. If resp is nil, it
11611174
// returns "", false.
1162-
func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) {
1163-
if resp == nil {
1164-
// TODO(rfindley): avoid this special handling.
1165-
return "", false
1166-
}
1167-
1175+
func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
11681176
defer resp.Body.Close()
11691177
for evt, err := range scanEvents(resp.Body) {
11701178
if err != nil {
@@ -1175,8 +1183,21 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s
11751183
lastEventID = evt.ID
11761184
}
11771185

1186+
msg, err := jsonrpc.DecodeMessage(evt.Data)
1187+
if err != nil {
1188+
c.fail(fmt.Errorf("failed to decode event: %v", err))
1189+
return "", true
1190+
}
1191+
11781192
select {
1179-
case c.incoming <- evt.Data:
1193+
case c.incoming <- msg:
1194+
if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil {
1195+
// TODO: we should never get a response when forReq is nil (the hanging GET).
1196+
// We should detect this case, and eliminate the 'persistent' flag arguments.
1197+
if jsonResp.ID == forReq.ID {
1198+
return "", true
1199+
}
1200+
}
11801201
case <-c.done:
11811202
// The connection was closed by the client; exit gracefully.
11821203
return "", true
@@ -1192,11 +1213,20 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s
11921213
func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
11931214
var finalErr error
11941215

1195-
for attempt := 0; attempt < c.ReconnectOptions.MaxRetries; attempt++ {
1216+
// We can reach the 'reconnect' path through the hanging GET, in which case
1217+
// lastEventID will be "".
1218+
//
1219+
// In this case, we need an initial attempt.
1220+
attempt := 0
1221+
if lastEventID != "" {
1222+
attempt = 1
1223+
}
1224+
1225+
for ; attempt <= c.ReconnectOptions.MaxRetries; attempt++ {
11961226
select {
11971227
case <-c.done:
11981228
return nil, fmt.Errorf("connection closed by client during reconnect")
1199-
case <-time.After(calculateReconnectDelay(c.ReconnectOptions, attempt)):
1229+
case <-time.After(calculateReconnectDelay(attempt)):
12001230
resp, err := c.establishSSE(lastEventID)
12011231
if err != nil {
12021232
finalErr = err // Store the error and try again.
@@ -1267,11 +1297,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
12671297
}
12681298

12691299
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
1270-
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration {
1300+
func calculateReconnectDelay(attempt int) time.Duration {
12711301
// Calculate the exponential backoff using the grow factor.
1272-
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt)))
1302+
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt)))
12731303
// Cap the backoffDuration at maxDelay.
1274-
backoffDuration = min(backoffDuration, opts.maxDelay)
1304+
backoffDuration = min(backoffDuration, reconnectMaxDelay)
12751305

12761306
// Use a full jitter using backoffDuration
12771307
jitter := rand.N(backoffDuration)

0 commit comments

Comments
 (0)