Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 187 additions & 10 deletions pkg/transport/stdio.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,53 @@
// Package transport provides utilities for handling different transport modes
// for communication between the client and MCP server, including stdio transport
// with automatic re-attachment on Docker/container restarts.
package transport

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"unicode"

"github.com/cenkalti/backoff/v5"
"golang.org/x/exp/jsonrpc2"

"github.com/stacklok/toolhive/pkg/container"
rt "github.com/stacklok/toolhive/pkg/container/runtime"
"github.com/stacklok/toolhive/pkg/ignore"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/permissions"
"github.com/stacklok/toolhive/pkg/transport/errors"
transporterrors "github.com/stacklok/toolhive/pkg/transport/errors"
"github.com/stacklok/toolhive/pkg/transport/proxy/httpsse"
"github.com/stacklok/toolhive/pkg/transport/proxy/streamable"
"github.com/stacklok/toolhive/pkg/transport/types"
)

const (
// Retry configuration constants
// defaultMaxRetries is the maximum number of re-attachment attempts after a connection loss.
// Set to 10 to allow sufficient time for Docker/Rancher Desktop to restart (~5 minutes with backoff).
defaultMaxRetries = 10

// defaultInitialRetryDelay is the starting delay for exponential backoff.
// Starts at 2 seconds to quickly recover from transient issues without overwhelming the system.
defaultInitialRetryDelay = 2 * time.Second

// defaultMaxRetryDelay caps the maximum delay between retry attempts.
// Set to 30 seconds to balance between responsiveness and resource usage during extended outages.
defaultMaxRetryDelay = 30 * time.Second

// shutdownTimeout is the maximum time to wait for graceful shutdown operations.
shutdownTimeout = 30 * time.Second
)

// StdioTransport implements the Transport interface using standard input/output.
// It acts as a proxy between the MCP client and the container's stdin/stdout.
type StdioTransport struct {
Expand Down Expand Up @@ -53,6 +77,25 @@ type StdioTransport struct {

// Container monitor
monitor rt.Monitor

// Retry configuration (for testing)
retryConfig *retryConfig
}

// retryConfig holds configuration for retry behavior
type retryConfig struct {
maxRetries int
initialDelay time.Duration
maxDelay time.Duration
}

// defaultRetryConfig returns the default retry configuration
func defaultRetryConfig() *retryConfig {
return &retryConfig{
maxRetries: defaultMaxRetries,
initialDelay: defaultInitialRetryDelay,
maxDelay: defaultMaxRetryDelay,
}
}

// NewStdioTransport creates a new stdio transport.
Expand All @@ -75,6 +118,7 @@ func NewStdioTransport(
prometheusHandler: prometheusHandler,
shutdownCh: make(chan struct{}),
proxyMode: types.ProxyModeSSE, // default to SSE for backward compatibility
retryConfig: defaultRetryConfig(),
}
}

Expand Down Expand Up @@ -150,7 +194,7 @@ func (t *StdioTransport) Start(ctx context.Context) error {
defer t.mutex.Unlock()

if t.containerName == "" {
return errors.ErrContainerNameNotSet
return transporterrors.ErrContainerNameNotSet
}

if t.deployer == nil {
Expand Down Expand Up @@ -291,8 +335,34 @@ func (t *StdioTransport) IsRunning(_ context.Context) (bool, error) {
}
}

// isDockerSocketError checks if an error indicates Docker socket unavailability using typed error detection
func isDockerSocketError(err error) bool {
if err == nil {
return false
}

// Check for EOF errors
if errors.Is(err, io.EOF) {
return true
}

// Check for network-related errors
var netErr *net.OpError
if errors.As(err, &netErr) {
// Connection refused typically indicates Docker daemon is not running
return true
}

// Fallback to string matching for errors that don't implement standard interfaces
// This handles Docker SDK errors that may not wrap standard error types
errStr := err.Error()
return strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "Cannot connect to the Docker daemon")
}

// processMessages handles the message exchange between the client and container.
func (t *StdioTransport) processMessages(ctx context.Context, stdin io.WriteCloser, stdout io.ReadCloser) {
func (t *StdioTransport) processMessages(ctx context.Context, _ io.WriteCloser, stdout io.ReadCloser) {
// Create a context that will be canceled when shutdown is signaled
ctx, cancel := context.WithCancel(ctx)
defer cancel()
Expand All @@ -317,15 +387,113 @@ func (t *StdioTransport) processMessages(ctx context.Context, stdin io.WriteClos
case <-ctx.Done():
return
case msg := <-messageCh:
logger.Info("Process incoming messages and sending message to container")
if err := t.sendMessageToContainer(ctx, stdin, msg); err != nil {
logger.Debug("Processing incoming message and sending to container")
// Use t.stdin instead of parameter so it uses the current stdin after re-attachment
t.mutex.Lock()
currentStdin := t.stdin
t.mutex.Unlock()
if err := t.sendMessageToContainer(ctx, currentStdin, msg); err != nil {
logger.Errorf("Error sending message to container: %v", err)
}
logger.Info("Messages processed")
logger.Debug("Message processed")
}
}
}

// attemptReattachment tries to re-attach to a container that has lost its stdout connection.
// Returns true if re-attachment was successful, false otherwise.
func (t *StdioTransport) attemptReattachment(ctx context.Context, stdout io.ReadCloser) bool {
if t.deployer == nil || t.containerName == "" {
return false
}

// Create an exponential backoff with the configured parameters
expBackoff := backoff.NewExponentialBackOff()
expBackoff.InitialInterval = t.retryConfig.initialDelay
expBackoff.MaxInterval = t.retryConfig.maxDelay
// Reset to allow unlimited elapsed time - we control retries via MaxTries
expBackoff.Reset()

var attemptCount int
maxRetries := t.retryConfig.maxRetries

operation := func() (any, error) {
attemptCount++

// Check if context is cancelled
select {
case <-ctx.Done():
return nil, backoff.Permanent(ctx.Err())
default:
}

running, checkErr := t.deployer.IsWorkloadRunning(ctx, t.containerName)
if checkErr != nil {
// Check if error is due to Docker being unavailable
if isDockerSocketError(checkErr) {
logger.Warnf("Docker socket unavailable (attempt %d/%d), will retry: %v", attemptCount, maxRetries, checkErr)
return nil, checkErr // Retry
}
logger.Warnf("Error checking if container is running (attempt %d/%d): %v", attemptCount, maxRetries, checkErr)
return nil, checkErr // Retry
}

if !running {
logger.Infof("Container not running (attempt %d/%d)", attemptCount, maxRetries)
return nil, backoff.Permanent(fmt.Errorf("container not running"))
}

logger.Warn("Container is still running after stdout EOF - attempting to re-attach")

// Try to re-attach to the container
newStdin, newStdout, attachErr := t.deployer.AttachToWorkload(ctx, t.containerName)
if attachErr != nil {
logger.Errorf("Failed to re-attach to container (attempt %d/%d): %v", attemptCount, maxRetries, attachErr)
return nil, attachErr // Retry
}

logger.Info("Successfully re-attached to container - restarting message processing")

// Close old stdout and log any errors
if closeErr := stdout.Close(); closeErr != nil {
logger.Warnf("Error closing old stdout during re-attachment: %v", closeErr)
}

// Update stdio references with proper synchronization
t.mutex.Lock()
t.stdin = newStdin
t.stdout = newStdout
t.mutex.Unlock()

// Start ONLY the stdout reader, not the full processMessages
// The existing processMessages goroutine is still running and handling stdin
go t.processStdout(ctx, newStdout)
logger.Info("Restarted stdout processing with new pipe")
return nil, nil // Success
}

// Execute the operation with retry
// Safe conversion: maxRetries is constrained by defaultMaxRetries constant (10)
_, err := backoff.Retry(ctx, operation,
backoff.WithBackOff(expBackoff),
backoff.WithMaxTries(uint(maxRetries)), // #nosec G115
backoff.WithNotify(func(_ error, duration time.Duration) {
logger.Infof("Retry attempt %d/%d after %v", attemptCount+1, maxRetries, duration)
}),
)

if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logger.Warnf("Re-attachment cancelled or timed out: %v", err)
} else {
logger.Warn("Failed to re-attach after all retry attempts")
}
return false
}

return true
}

// processStdout reads from the container's stdout and processes JSON-RPC messages.
func (t *StdioTransport) processStdout(ctx context.Context, stdout io.ReadCloser) {
// Create a buffer for accumulating data
Expand All @@ -343,7 +511,14 @@ func (t *StdioTransport) processStdout(ctx context.Context, stdout io.ReadCloser
n, err := stdout.Read(readBuffer)
if err != nil {
if err == io.EOF {
logger.Info("Container stdout closed")
logger.Warn("Container stdout closed - checking if container is still running")

// Try to re-attach to the container
if t.attemptReattachment(ctx, stdout) {
return
}

logger.Info("Container stdout closed - exiting read loop")
} else {
logger.Errorf("Error reading from container stdout: %v", err)
}
Expand Down Expand Up @@ -418,11 +593,13 @@ func sanitizeBinaryString(input string) string {
}

// isSpace reports whether r is a space character as defined by JSON.
// These are the valid space characters in this implementation:
// These are the valid space characters in JSON:
// - ' ' (U+0020, SPACE)
// - '\t' (U+0009, HORIZONTAL TAB)
// - '\n' (U+000A, LINE FEED)
// - '\r' (U+000D, CARRIAGE RETURN)
func isSpace(r rune) bool {
return r == ' ' || r == '\n'
return r == ' ' || r == '\t' || r == '\n' || r == '\r'
}

// parseAndForwardJSONRPC parses a JSON-RPC message and forwards it.
Expand Down Expand Up @@ -499,7 +676,7 @@ func (t *StdioTransport) handleContainerExit(ctx context.Context) {
default:
// Transport is still running, stop it
// Create a context with timeout for stopping the transport
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
stopCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()

if stopErr := t.Stop(stopCtx); stopErr != nil {
Expand Down
Loading
Loading