Skip to content

Commit 81a5654

Browse files
yroblataskbotCopilotdmjb
authored
feat: properly handle interrupt signal for proxy (#1206)
Co-authored-by: taskbot <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Don Browne <[email protected]>
1 parent 071b049 commit 81a5654

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

cmd/thv/app/proxy.go

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ func init() {
143143
}
144144

145145
func proxyCmdFunc(cmd *cobra.Command, args []string) error {
146-
ctx := cmd.Context()
146+
ctx, stopSignal := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
147+
defer stopSignal()
147148
// Get the server name
148149
serverName := args[0]
149150

@@ -223,21 +224,15 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
223224
serverName, port, proxyTargetURI)
224225
logger.Info("Press Ctrl+C to stop")
225226

226-
// Set up signal handling
227-
sigCh := make(chan os.Signal, 1)
228-
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
227+
<-ctx.Done()
228+
logger.Infof("Interrupt received, proxy is shutting down. Please wait for connections to close...")
229229

230-
// Wait for signal
231-
sig := <-sigCh
232-
logger.Infof("Received signal %s, stopping proxy...", sig)
233-
234-
// Stop the proxy
235-
if err := proxy.Stop(ctx); err != nil {
236-
logger.Warnf("Warning: Failed to stop proxy: %v", err)
230+
if err := proxy.CloseListener(); err != nil {
231+
logger.Warnf("Error closing proxy listener: %v", err)
237232
}
238-
239-
logger.Infof("Proxy for server %s stopped", serverName)
240-
return nil
233+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
234+
defer cancel()
235+
return proxy.Stop(shutdownCtx)
241236
}
242237

243238
// AuthInfo contains authentication information extracted from WWW-Authenticate header

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import (
77
"bytes"
88
"context"
99
"encoding/json"
10+
"errors"
1011
"fmt"
1112
"io"
1213
"mime"
14+
"net"
1315
"net/http"
1416
"net/http/httputil"
1517
"net/url"
@@ -61,6 +63,9 @@ type TransparentProxy struct {
6163

6264
// If mcp server has been initialized
6365
IsServerInitialized bool
66+
67+
// Listener for the HTTP server
68+
listener net.Listener
6469
}
6570

6671
// NewTransparentProxy creates a new transparent proxy with optional middlewares.
@@ -129,10 +134,13 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
129134

130135
resp, err := t.forward(req)
131136
if err != nil {
137+
if errors.Is(err, context.Canceled) {
138+
// Expected during shutdown or client disconnect—silently ignore
139+
return nil, err
140+
}
132141
logger.Errorf("Failed to forward request: %v", err)
133142
return nil, err
134143
}
135-
136144
if resp.StatusCode == http.StatusOK {
137145
// check if we saw a valid mcp header
138146
ct := resp.Header.Get("Mcp-Session-Id")
@@ -286,6 +294,11 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
286294
mux.Handle("/metrics", p.prometheusHandler)
287295
logger.Info("Prometheus metrics endpoint enabled at /metrics")
288296
}
297+
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", p.host, p.port))
298+
if err != nil {
299+
return fmt.Errorf("failed to listen: %w", err)
300+
}
301+
p.listener = ln
289302

290303
// Create the server
291304
p.server = &http.Server{
@@ -294,16 +307,17 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
294307
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
295308
}
296309

297-
// Start the server in a goroutine
298310
go func() {
299-
logger.Infof("Transparent proxy started for container %s on %s:%d -> %s",
300-
p.containerName, p.host, p.port, p.targetURI)
301-
302-
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
311+
err := p.server.Serve(ln)
312+
if err != nil && err != http.ErrServerClosed {
313+
var opErr *net.OpError
314+
if errors.As(err, &opErr) && opErr.Op == "accept" {
315+
// Expected when listener is closed—silently return
316+
return
317+
}
303318
logger.Errorf("Transparent proxy error: %v", err)
304319
}
305320
}()
306-
307321
// Start health-check monitoring only if health checker is enabled
308322
if p.healthChecker != nil {
309323
go p.monitorHealth(ctx)
@@ -312,6 +326,14 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
312326
return nil
313327
}
314328

329+
// CloseListener closes the listener for the transparent proxy.
330+
func (p *TransparentProxy) CloseListener() error {
331+
if p.listener != nil {
332+
return p.listener.Close()
333+
}
334+
return nil
335+
}
336+
315337
func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
316338
ticker := time.NewTicker(10 * time.Second)
317339
defer ticker.Stop()
@@ -352,7 +374,13 @@ func (p *TransparentProxy) Stop(ctx context.Context) error {
352374

353375
// Stop the HTTP server
354376
if p.server != nil {
355-
return p.server.Shutdown(ctx)
377+
err := p.server.Shutdown(ctx)
378+
if err != nil && err != http.ErrServerClosed && err != context.DeadlineExceeded {
379+
logger.Warnf("Error during proxy shutdown: %v", err)
380+
return err
381+
}
382+
logger.Infof("Server for %s stopped successfully", p.containerName)
383+
p.server = nil
356384
}
357385

358386
return nil

0 commit comments

Comments
 (0)