Skip to content

Commit 3103ffc

Browse files
authored
Fix race conditions and improve stability in HTTPSSEProxy (#1580)
1 parent a91f470 commit 3103ffc

File tree

4 files changed

+1386
-21
lines changed

4 files changed

+1386
-21
lines changed

pkg/transport/proxy/httpsse/http_proxy.go

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66
"context"
77
"fmt"
88
"io"
9+
"net"
910
"net/http"
11+
"strconv"
1012
"sync"
1113
"time"
1214

@@ -63,7 +65,7 @@ type HTTPSSEProxy struct {
6365

6466
// SSE clients
6567
sseClients map[string]*ssecommon.SSEClient
66-
sseClientsMutex sync.Mutex
68+
sseClientsMutex sync.RWMutex
6769

6870
// Pending messages for SSE clients
6971
pendingMessages []*ssecommon.PendingSSEMessage
@@ -74,6 +76,10 @@ type HTTPSSEProxy struct {
7476

7577
// Health checker
7678
healthChecker *healthcheck.HealthChecker
79+
80+
// Track closed clients to prevent double-close
81+
closedClients map[string]bool
82+
closedClientsMutex sync.Mutex
7783
}
7884

7985
// NewHTTPSSEProxy creates a new HTTP SSE proxy for transports.
@@ -90,6 +96,7 @@ func NewHTTPSSEProxy(
9096
sseClients: make(map[string]*ssecommon.SSEClient),
9197
pendingMessages: []*ssecommon.PendingSSEMessage{},
9298
prometheusHandler: prometheusHandler,
99+
closedClients: make(map[string]bool),
93100
}
94101

95102
// Create MCP pinger and health checker
@@ -138,24 +145,43 @@ func (p *HTTPSSEProxy) Start(_ context.Context) error {
138145
logger.Info("Prometheus metrics endpoint enabled at /metrics")
139146
}
140147

148+
// Create a listener to get the actual port when using port 0
149+
addr := fmt.Sprintf("%s:%d", p.host, p.port)
150+
listener, err := net.Listen("tcp", addr)
151+
if err != nil {
152+
return fmt.Errorf("failed to create listener: %w", err)
153+
}
154+
155+
// Update the server address with the actual address
156+
actualAddr := listener.Addr().String()
157+
141158
// Create the server
142159
p.server = &http.Server{
143-
Addr: fmt.Sprintf("%s:%d", p.host, p.port),
144160
Handler: mux,
145161
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
146162
}
147163

164+
// Store the actual address
165+
p.server.Addr = actualAddr
166+
148167
// Start the server in a goroutine
149168
go func() {
150-
logger.Infof("HTTP proxy started for container %s on port %d", p.containerName, p.port)
151-
logger.Infof("SSE endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPSSEEndpoint)
152-
logger.Infof("JSON-RPC endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPMessagesEndpoint)
169+
// Parse the actual port for logging
170+
_, portStr, _ := net.SplitHostPort(actualAddr)
171+
actualPort, _ := strconv.Atoi(portStr)
172+
173+
logger.Infof("HTTP proxy started for container %s on port %d", p.containerName, actualPort)
174+
logger.Infof("SSE endpoint: http://%s%s", actualAddr, ssecommon.HTTPSSEEndpoint)
175+
logger.Infof("JSON-RPC endpoint: http://%s%s", actualAddr, ssecommon.HTTPMessagesEndpoint)
153176

154-
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
177+
if err := p.server.Serve(listener); err != nil && err != http.ErrServerClosed {
155178
logger.Errorf("HTTP server error: %v", err)
156179
}
157180
}()
158181

182+
// Give the server a moment to start
183+
time.Sleep(10 * time.Millisecond)
184+
159185
return nil
160186
}
161187

@@ -201,9 +227,9 @@ func (p *HTTPSSEProxy) ForwardResponseToClients(_ context.Context, msg jsonrpc2.
201227
sseMsg := ssecommon.NewSSEMessage("message", string(data))
202228

203229
// Check if there are any connected clients
204-
p.sseClientsMutex.Lock()
230+
p.sseClientsMutex.RLock()
205231
hasClients := len(p.sseClients) > 0
206-
p.sseClientsMutex.Unlock()
232+
p.sseClientsMutex.RUnlock()
207233

208234
if hasClients {
209235
// Send the message to all connected clients
@@ -281,10 +307,7 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques
281307
// Create a goroutine to monitor for client disconnection
282308
go func() {
283309
<-ctx.Done()
284-
p.sseClientsMutex.Lock()
285-
delete(p.sseClients, clientID)
286-
p.sseClientsMutex.Unlock()
287-
close(messageCh)
310+
p.removeClient(clientID)
288311
logger.Infof("Client %s disconnected", clientID)
289312
}()
290313

@@ -324,9 +347,9 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request)
324347
}
325348

326349
// Check if the session exists
327-
p.sseClientsMutex.Lock()
350+
p.sseClientsMutex.RLock()
328351
_, exists := p.sseClients[sessionID]
329-
p.sseClientsMutex.Unlock()
352+
p.sseClientsMutex.RUnlock()
330353

331354
if !exists {
332355
http.Error(w, "Could not find session", http.StatusNotFound)
@@ -368,25 +391,60 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error {
368391
// Convert the message to an SSE-formatted string
369392
sseString := msg.ToSSEString()
370393

371-
// Send to all clients
372-
p.sseClientsMutex.Lock()
373-
defer p.sseClientsMutex.Unlock()
394+
// Hold the lock while sending to ensure channels aren't closed during send
395+
// This is a read lock, so multiple sends can happen concurrently
396+
p.sseClientsMutex.RLock()
397+
defer p.sseClientsMutex.RUnlock()
374398

375399
for clientID, client := range p.sseClients {
376400
select {
377401
case client.MessageCh <- sseString:
378402
// Message sent successfully
379403
default:
380-
// Channel is full or closed, remove the client
381-
delete(p.sseClients, clientID)
382-
close(client.MessageCh)
383-
logger.Infof("Client %s removed (channel full or closed)", clientID)
404+
// Channel is full, skip this client
405+
// Don't remove the client here - let the disconnect monitor handle it
406+
logger.Debugf("Client %s channel full, skipping message", clientID)
384407
}
385408
}
386409

387410
return nil
388411
}
389412

413+
// removeClient safely removes a client and closes its channel
414+
func (p *HTTPSSEProxy) removeClient(clientID string) {
415+
// Check if already closed
416+
p.closedClientsMutex.Lock()
417+
if p.closedClients[clientID] {
418+
p.closedClientsMutex.Unlock()
419+
return
420+
}
421+
p.closedClients[clientID] = true
422+
p.closedClientsMutex.Unlock()
423+
424+
// Remove from clients map and get the client
425+
// Use write lock to ensure no sends happen during removal
426+
p.sseClientsMutex.Lock()
427+
client, exists := p.sseClients[clientID]
428+
if exists {
429+
delete(p.sseClients, clientID)
430+
}
431+
p.sseClientsMutex.Unlock()
432+
433+
// Close the channel after removing from map
434+
// This ensures no goroutine will try to send to it
435+
if exists && client != nil {
436+
close(client.MessageCh)
437+
}
438+
439+
// Clean up closed clients map periodically (prevent memory leak)
440+
p.closedClientsMutex.Lock()
441+
if len(p.closedClients) > 1000 {
442+
// Reset the map when it gets too large
443+
p.closedClients = make(map[string]bool)
444+
}
445+
p.closedClientsMutex.Unlock()
446+
}
447+
390448
// processPendingMessages processes any pending messages for a new client.
391449
func (p *HTTPSSEProxy) processPendingMessages(clientID string, messageCh chan<- string) {
392450
p.pendingMutex.Lock()

0 commit comments

Comments
 (0)