@@ -329,6 +329,7 @@ func (c *streamableServerConn) SessionID() string {
329329// A stream is a single logical stream of SSE events within a server session.
330330// A stream begins with a client request, or with a client GET that has
331331// no Last-Event-ID header.
332+ //
332333// A stream ends only when its session ends; we cannot determine its end otherwise,
333334// since a client may send a GET with a Last-Event-ID that references the stream
334335// at any time.
@@ -529,6 +530,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
529530 }
530531 c .mu .Unlock ()
531532 stream .signal .Store (signalChanPtr ())
533+ defer stream .signal .Store (nil )
532534 }
533535
534536 // Publish incoming messages.
@@ -857,27 +859,27 @@ type StreamableReconnectOptions struct {
857859 // MaxRetries is the maximum number of times to attempt a reconnect before giving up.
858860 // A value of 0 or less means never retry.
859861 MaxRetries int
860-
861- // growFactor is the multiplicative factor by which the delay increases after each attempt.
862- // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
863- // It must be 1.0 or greater if MaxRetries is greater than 0.
864- growFactor float64
865-
866- // initialDelay is the base delay for the first reconnect attempt.
867- initialDelay time.Duration
868-
869- // maxDelay caps the backoff delay, preventing it from growing indefinitely.
870- maxDelay time.Duration
871862}
872863
873864// DefaultReconnectOptions provides sensible defaults for reconnect logic.
874865var DefaultReconnectOptions = & StreamableReconnectOptions {
875- MaxRetries : 5 ,
876- growFactor : 1.5 ,
877- initialDelay : 1 * time .Second ,
878- maxDelay : 30 * time .Second ,
866+ MaxRetries : 5 ,
879867}
880868
869+ // These settings are not (yet) exposed to the user in
870+ // StreamableReconnectOptions. Since they're invisible, keep them const rather
871+ // than requiring the user to start from DefaultReconnectOptions and mutate.
872+ const (
873+ // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt.
874+ // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
875+ // It must be 1.0 or greater if MaxRetries is greater than 0.
876+ reconnectGrowFactor = 1.5
877+ // reconnectInitialDelay is the base delay for the first reconnect attempt.
878+ reconnectInitialDelay = 1 * time .Second
879+ // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely.
880+ reconnectMaxDelay = 30 * time .Second
881+ )
882+
881883// StreamableClientTransportOptions provides options for the
882884// [NewStreamableClientTransport] constructor.
883885//
@@ -928,7 +930,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
928930 conn := & streamableClientConn {
929931 url : t .Endpoint ,
930932 client : client ,
931- incoming : make (chan [] byte , 100 ),
933+ incoming : make (chan jsonrpc. Message , 10 ),
932934 done : make (chan struct {}),
933935 ReconnectOptions : reconnOpts ,
934936 ctx : connCtx ,
@@ -944,7 +946,7 @@ type streamableClientConn struct {
944946 client * http.Client
945947 ctx context.Context
946948 cancel context.CancelFunc
947- incoming chan [] byte
949+ incoming chan jsonrpc. Message
948950
949951 // Guard calls to Close, as it may be called multiple times.
950952 closeOnce sync.Once
@@ -988,7 +990,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
988990 // § 2.5: A server using the Streamable HTTP transport MAY assign a session
989991 // ID at initialization time, by including it in an Mcp-Session-Id header
990992 // on the HTTP response containing the InitializeResult.
991- go c .handleSSE (nil , true )
993+ go c .handleSSE (nil , true , nil )
992994}
993995
994996// fail handles an asynchronous error while reading.
@@ -1031,8 +1033,8 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error
10311033 return nil , c .failure ()
10321034 case <- c .done :
10331035 return nil , io .EOF
1034- case data := <- c .incoming :
1035- return jsonrpc2 . DecodeMessage ( data )
1036+ case msg := <- c .incoming :
1037+ return msg , nil
10361038 }
10371039}
10381040
@@ -1042,7 +1044,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
10421044 return err
10431045 }
10441046
1045- data , err := jsonrpc2 .EncodeMessage (msg )
1047+ data , err := jsonrpc .EncodeMessage (msg )
10461048 if err != nil {
10471049 return err
10481050 }
@@ -1088,7 +1090,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
10881090 go c .handleJSON (resp )
10891091
10901092 case "text/event-stream" :
1091- go c .handleSSE (resp , false )
1093+ jsonReq , _ := msg .(* jsonrpc.Request )
1094+ go c .handleSSE (resp , false , jsonReq )
10921095
10931096 default :
10941097 resp .Body .Close ()
@@ -1116,30 +1119,40 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) {
11161119 c .fail (err )
11171120 return
11181121 }
1122+ msg , err := jsonrpc .DecodeMessage (body )
1123+ if err != nil {
1124+ c .fail (fmt .Errorf ("failed to decode response: %v" , err ))
1125+ return
1126+ }
11191127 select {
1120- case c .incoming <- body :
1128+ case c .incoming <- msg :
11211129 case <- c .done :
11221130 // The connection was closed by the client; exit gracefully.
11231131 }
11241132}
11251133
11261134// handleSSE manages the lifecycle of an SSE connection. It can be either
11271135// persistent (for the main GET listener) or temporary (for a POST response).
1128- func (c * streamableClientConn ) handleSSE (initialResp * http.Response , persistent bool ) {
1136+ //
1137+ // If forReq is set, it is the request that initiated the stream, and the
1138+ // stream is complete when we receive its response.
1139+ func (c * streamableClientConn ) handleSSE (initialResp * http.Response , persistent bool , forReq * jsonrpc2.Request ) {
11291140 resp := initialResp
11301141 var lastEventID string
11311142 for {
1132- eventID , clientClosed := c .processStream (resp )
1133- lastEventID = eventID
1143+ if resp != nil {
1144+ eventID , clientClosed := c .processStream (resp , forReq )
1145+ lastEventID = eventID
11341146
1135- // If the connection was closed by the client, we're done.
1136- if clientClosed {
1137- return
1138- }
1139- // If the stream has ended, then do not reconnect if the stream is
1140- // temporary (POST initiated SSE).
1141- if lastEventID == "" && ! persistent {
1142- return
1147+ // If the connection was closed by the client, we're done.
1148+ if clientClosed {
1149+ return
1150+ }
1151+ // If the stream has ended, then do not reconnect if the stream is
1152+ // temporary (POST initiated SSE).
1153+ if lastEventID == "" && ! persistent {
1154+ return
1155+ }
11431156 }
11441157
11451158 // The stream was interrupted or ended by the server. Attempt to reconnect.
@@ -1159,12 +1172,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
11591172// incoming channel. It returns the ID of the last processed event and a flag
11601173// indicating if the connection was closed by the client. If resp is nil, it
11611174// returns "", false.
1162- func (c * streamableClientConn ) processStream (resp * http.Response ) (lastEventID string , clientClosed bool ) {
1163- if resp == nil {
1164- // TODO(rfindley): avoid this special handling.
1165- return "" , false
1166- }
1167-
1175+ func (c * streamableClientConn ) processStream (resp * http.Response , forReq * jsonrpc.Request ) (lastEventID string , clientClosed bool ) {
11681176 defer resp .Body .Close ()
11691177 for evt , err := range scanEvents (resp .Body ) {
11701178 if err != nil {
@@ -1175,8 +1183,21 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s
11751183 lastEventID = evt .ID
11761184 }
11771185
1186+ msg , err := jsonrpc .DecodeMessage (evt .Data )
1187+ if err != nil {
1188+ c .fail (fmt .Errorf ("failed to decode event: %v" , err ))
1189+ return "" , true
1190+ }
1191+
11781192 select {
1179- case c .incoming <- evt .Data :
1193+ case c .incoming <- msg :
1194+ if jsonResp , ok := msg .(* jsonrpc.Response ); ok && forReq != nil {
1195+ // TODO: we should never get a response when forReq is nil (the hanging GET).
1196+ // We should detect this case, and eliminate the 'persistent' flag arguments.
1197+ if jsonResp .ID == forReq .ID {
1198+ return "" , true
1199+ }
1200+ }
11801201 case <- c .done :
11811202 // The connection was closed by the client; exit gracefully.
11821203 return "" , true
@@ -1192,11 +1213,20 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s
11921213func (c * streamableClientConn ) reconnect (lastEventID string ) (* http.Response , error ) {
11931214 var finalErr error
11941215
1195- for attempt := 0 ; attempt < c .ReconnectOptions .MaxRetries ; attempt ++ {
1216+ // We can reach the 'reconnect' path through the hanging GET, in which case
1217+ // lastEventID will be "".
1218+ //
1219+ // In this case, we need an initial attempt.
1220+ attempt := 0
1221+ if lastEventID != "" {
1222+ attempt = 1
1223+ }
1224+
1225+ for ; attempt <= c .ReconnectOptions .MaxRetries ; attempt ++ {
11961226 select {
11971227 case <- c .done :
11981228 return nil , fmt .Errorf ("connection closed by client during reconnect" )
1199- case <- time .After (calculateReconnectDelay (c . ReconnectOptions , attempt )):
1229+ case <- time .After (calculateReconnectDelay (attempt )):
12001230 resp , err := c .establishSSE (lastEventID )
12011231 if err != nil {
12021232 finalErr = err // Store the error and try again.
@@ -1267,11 +1297,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
12671297}
12681298
12691299// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
1270- func calculateReconnectDelay (opts * StreamableReconnectOptions , attempt int ) time.Duration {
1300+ func calculateReconnectDelay (attempt int ) time.Duration {
12711301 // Calculate the exponential backoff using the grow factor.
1272- backoffDuration := time .Duration (float64 (opts . initialDelay ) * math .Pow (opts . growFactor , float64 (attempt )))
1302+ backoffDuration := time .Duration (float64 (reconnectInitialDelay ) * math .Pow (reconnectGrowFactor , float64 (attempt )))
12731303 // Cap the backoffDuration at maxDelay.
1274- backoffDuration = min (backoffDuration , opts . maxDelay )
1304+ backoffDuration = min (backoffDuration , reconnectMaxDelay )
12751305
12761306 // Use a full jitter using backoffDuration
12771307 jitter := rand .N (backoffDuration )
0 commit comments