diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go index 4f2a91493..599413814 100644 --- a/pkg/transport/stdio.go +++ b/pkg/transport/stdio.go @@ -1,16 +1,22 @@ +// Package transport provides utilities for handling different transport modes +// for communication between the client and MCP server, including stdio transport +// with automatic re-attachment on Docker/container restarts. package transport import ( "bytes" "context" + "errors" "fmt" "io" + "net" "net/http" "strings" "sync" "time" "unicode" + "github.com/cenkalti/backoff/v5" "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/container" @@ -18,12 +24,30 @@ import ( "github.com/stacklok/toolhive/pkg/ignore" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" - "github.com/stacklok/toolhive/pkg/transport/errors" + transporterrors "github.com/stacklok/toolhive/pkg/transport/errors" "github.com/stacklok/toolhive/pkg/transport/proxy/httpsse" "github.com/stacklok/toolhive/pkg/transport/proxy/streamable" "github.com/stacklok/toolhive/pkg/transport/types" ) +const ( + // Retry configuration constants + // defaultMaxRetries is the maximum number of re-attachment attempts after a connection loss. + // Set to 10 to allow sufficient time for Docker/Rancher Desktop to restart (~5 minutes with backoff). + defaultMaxRetries = 10 + + // defaultInitialRetryDelay is the starting delay for exponential backoff. + // Starts at 2 seconds to quickly recover from transient issues without overwhelming the system. + defaultInitialRetryDelay = 2 * time.Second + + // defaultMaxRetryDelay caps the maximum delay between retry attempts. + // Set to 30 seconds to balance between responsiveness and resource usage during extended outages. + defaultMaxRetryDelay = 30 * time.Second + + // shutdownTimeout is the maximum time to wait for graceful shutdown operations. + shutdownTimeout = 30 * time.Second +) + // StdioTransport implements the Transport interface using standard input/output. // It acts as a proxy between the MCP client and the container's stdin/stdout. type StdioTransport struct { @@ -53,6 +77,25 @@ type StdioTransport struct { // Container monitor monitor rt.Monitor + + // Retry configuration (for testing) + retryConfig *retryConfig +} + +// retryConfig holds configuration for retry behavior +type retryConfig struct { + maxRetries int + initialDelay time.Duration + maxDelay time.Duration +} + +// defaultRetryConfig returns the default retry configuration +func defaultRetryConfig() *retryConfig { + return &retryConfig{ + maxRetries: defaultMaxRetries, + initialDelay: defaultInitialRetryDelay, + maxDelay: defaultMaxRetryDelay, + } } // NewStdioTransport creates a new stdio transport. @@ -75,6 +118,7 @@ func NewStdioTransport( prometheusHandler: prometheusHandler, shutdownCh: make(chan struct{}), proxyMode: types.ProxyModeSSE, // default to SSE for backward compatibility + retryConfig: defaultRetryConfig(), } } @@ -150,7 +194,7 @@ func (t *StdioTransport) Start(ctx context.Context) error { defer t.mutex.Unlock() if t.containerName == "" { - return errors.ErrContainerNameNotSet + return transporterrors.ErrContainerNameNotSet } if t.deployer == nil { @@ -291,8 +335,34 @@ func (t *StdioTransport) IsRunning(_ context.Context) (bool, error) { } } +// isDockerSocketError checks if an error indicates Docker socket unavailability using typed error detection +func isDockerSocketError(err error) bool { + if err == nil { + return false + } + + // Check for EOF errors + if errors.Is(err, io.EOF) { + return true + } + + // Check for network-related errors + var netErr *net.OpError + if errors.As(err, &netErr) { + // Connection refused typically indicates Docker daemon is not running + return true + } + + // Fallback to string matching for errors that don't implement standard interfaces + // This handles Docker SDK errors that may not wrap standard error types + errStr := err.Error() + return strings.Contains(errStr, "EOF") || + strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "Cannot connect to the Docker daemon") +} + // processMessages handles the message exchange between the client and container. -func (t *StdioTransport) processMessages(ctx context.Context, stdin io.WriteCloser, stdout io.ReadCloser) { +func (t *StdioTransport) processMessages(ctx context.Context, _ io.WriteCloser, stdout io.ReadCloser) { // Create a context that will be canceled when shutdown is signaled ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -317,15 +387,113 @@ func (t *StdioTransport) processMessages(ctx context.Context, stdin io.WriteClos case <-ctx.Done(): return case msg := <-messageCh: - logger.Info("Process incoming messages and sending message to container") - if err := t.sendMessageToContainer(ctx, stdin, msg); err != nil { + logger.Debug("Processing incoming message and sending to container") + // Use t.stdin instead of parameter so it uses the current stdin after re-attachment + t.mutex.Lock() + currentStdin := t.stdin + t.mutex.Unlock() + if err := t.sendMessageToContainer(ctx, currentStdin, msg); err != nil { logger.Errorf("Error sending message to container: %v", err) } - logger.Info("Messages processed") + logger.Debug("Message processed") } } } +// attemptReattachment tries to re-attach to a container that has lost its stdout connection. +// Returns true if re-attachment was successful, false otherwise. +func (t *StdioTransport) attemptReattachment(ctx context.Context, stdout io.ReadCloser) bool { + if t.deployer == nil || t.containerName == "" { + return false + } + + // Create an exponential backoff with the configured parameters + expBackoff := backoff.NewExponentialBackOff() + expBackoff.InitialInterval = t.retryConfig.initialDelay + expBackoff.MaxInterval = t.retryConfig.maxDelay + // Reset to allow unlimited elapsed time - we control retries via MaxTries + expBackoff.Reset() + + var attemptCount int + maxRetries := t.retryConfig.maxRetries + + operation := func() (any, error) { + attemptCount++ + + // Check if context is cancelled + select { + case <-ctx.Done(): + return nil, backoff.Permanent(ctx.Err()) + default: + } + + running, checkErr := t.deployer.IsWorkloadRunning(ctx, t.containerName) + if checkErr != nil { + // Check if error is due to Docker being unavailable + if isDockerSocketError(checkErr) { + logger.Warnf("Docker socket unavailable (attempt %d/%d), will retry: %v", attemptCount, maxRetries, checkErr) + return nil, checkErr // Retry + } + logger.Warnf("Error checking if container is running (attempt %d/%d): %v", attemptCount, maxRetries, checkErr) + return nil, checkErr // Retry + } + + if !running { + logger.Infof("Container not running (attempt %d/%d)", attemptCount, maxRetries) + return nil, backoff.Permanent(fmt.Errorf("container not running")) + } + + logger.Warn("Container is still running after stdout EOF - attempting to re-attach") + + // Try to re-attach to the container + newStdin, newStdout, attachErr := t.deployer.AttachToWorkload(ctx, t.containerName) + if attachErr != nil { + logger.Errorf("Failed to re-attach to container (attempt %d/%d): %v", attemptCount, maxRetries, attachErr) + return nil, attachErr // Retry + } + + logger.Info("Successfully re-attached to container - restarting message processing") + + // Close old stdout and log any errors + if closeErr := stdout.Close(); closeErr != nil { + logger.Warnf("Error closing old stdout during re-attachment: %v", closeErr) + } + + // Update stdio references with proper synchronization + t.mutex.Lock() + t.stdin = newStdin + t.stdout = newStdout + t.mutex.Unlock() + + // Start ONLY the stdout reader, not the full processMessages + // The existing processMessages goroutine is still running and handling stdin + go t.processStdout(ctx, newStdout) + logger.Info("Restarted stdout processing with new pipe") + return nil, nil // Success + } + + // Execute the operation with retry + // Safe conversion: maxRetries is constrained by defaultMaxRetries constant (10) + _, err := backoff.Retry(ctx, operation, + backoff.WithBackOff(expBackoff), + backoff.WithMaxTries(uint(maxRetries)), // #nosec G115 + backoff.WithNotify(func(_ error, duration time.Duration) { + logger.Infof("Retry attempt %d/%d after %v", attemptCount+1, maxRetries, duration) + }), + ) + + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + logger.Warnf("Re-attachment cancelled or timed out: %v", err) + } else { + logger.Warn("Failed to re-attach after all retry attempts") + } + return false + } + + return true +} + // processStdout reads from the container's stdout and processes JSON-RPC messages. func (t *StdioTransport) processStdout(ctx context.Context, stdout io.ReadCloser) { // Create a buffer for accumulating data @@ -343,7 +511,14 @@ func (t *StdioTransport) processStdout(ctx context.Context, stdout io.ReadCloser n, err := stdout.Read(readBuffer) if err != nil { if err == io.EOF { - logger.Info("Container stdout closed") + logger.Warn("Container stdout closed - checking if container is still running") + + // Try to re-attach to the container + if t.attemptReattachment(ctx, stdout) { + return + } + + logger.Info("Container stdout closed - exiting read loop") } else { logger.Errorf("Error reading from container stdout: %v", err) } @@ -418,11 +593,13 @@ func sanitizeBinaryString(input string) string { } // isSpace reports whether r is a space character as defined by JSON. -// These are the valid space characters in this implementation: +// These are the valid space characters in JSON: // - ' ' (U+0020, SPACE) +// - '\t' (U+0009, HORIZONTAL TAB) // - '\n' (U+000A, LINE FEED) +// - '\r' (U+000D, CARRIAGE RETURN) func isSpace(r rune) bool { - return r == ' ' || r == '\n' + return r == ' ' || r == '\t' || r == '\n' || r == '\r' } // parseAndForwardJSONRPC parses a JSON-RPC message and forwards it. @@ -499,7 +676,7 @@ func (t *StdioTransport) handleContainerExit(ctx context.Context) { default: // Transport is still running, stop it // Create a context with timeout for stopping the transport - stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + stopCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() if stopErr := t.Stop(stopCtx); stopErr != nil { diff --git a/pkg/transport/stdio_test.go b/pkg/transport/stdio_test.go index 9bd83f78f..a1acc3661 100644 --- a/pkg/transport/stdio_test.go +++ b/pkg/transport/stdio_test.go @@ -1,14 +1,22 @@ package transport import ( + "bytes" "context" + "errors" "fmt" + "io" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "golang.org/x/exp/jsonrpc2" + "github.com/stacklok/toolhive/pkg/container/runtime/mocks" "github.com/stacklok/toolhive/pkg/logger" ) @@ -184,12 +192,12 @@ func TestIsSpace(t *testing.T) { { name: "tab character", input: '\t', - expected: false, + expected: true, }, { name: "carriage return", input: '\r', - expected: false, + expected: true, }, { name: "regular character", @@ -216,3 +224,764 @@ func TestIsSpace(t *testing.T) { }) } } + +// mockReadCloser is a mock implementation of io.ReadCloser for testing +type mockReadCloser struct { + data []byte + readIndex int + closed bool + eofAfter int // return EOF after this many reads + readCount int +} + +//nolint:unparam // test helper designed to be flexible +func newMockReadCloser(data string) *mockReadCloser { + return &mockReadCloser{ + data: []byte(data), + eofAfter: -1, // Never EOF by default + } +} + +func newMockReadCloserWithEOF(data string) *mockReadCloser { + return &mockReadCloser{ + data: []byte(data), + eofAfter: 1, // Always EOF after first read for these tests + } +} + +func (m *mockReadCloser) Read(p []byte) (n int, err error) { + m.readCount++ + if m.eofAfter >= 0 && m.readCount > m.eofAfter { + return 0, io.EOF + } + + if m.closed { + return 0, errors.New("read from closed reader") + } + + if m.readIndex >= len(m.data) { + // If eofAfter is set, return EOF + if m.eofAfter >= 0 { + return 0, io.EOF + } + // Otherwise, block until closed + time.Sleep(10 * time.Millisecond) + return 0, nil + } + + n = copy(p, m.data[m.readIndex:]) + m.readIndex += n + return n, nil +} + +func (m *mockReadCloser) Close() error { + m.closed = true + return nil +} + +// mockWriteCloser is a mock implementation of io.WriteCloser for testing +type mockWriteCloser struct { + buffer bytes.Buffer + closed bool +} + +func newMockWriteCloser() *mockWriteCloser { + return &mockWriteCloser{} +} + +func (m *mockWriteCloser) Write(p []byte) (n int, err error) { + if m.closed { + return 0, errors.New("write to closed writer") + } + return m.buffer.Write(p) +} + +func (m *mockWriteCloser) Close() error { + m.closed = true + return nil +} + +// testRetryConfig returns a fast retry configuration for testing +func testRetryConfig() *retryConfig { + return &retryConfig{ + maxRetries: 3, + initialDelay: 10 * time.Millisecond, + maxDelay: 50 * time.Millisecond, + } +} + +func TestProcessStdout_EOFWithSuccessfulReattachment(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create mock stdout that will return EOF after first read + mockStdout := newMockReadCloserWithEOF(`{"jsonrpc": "2.0", "method": "test", "params": {}}`) + + // Create new stdio streams for re-attachment + newStdin := newMockWriteCloser() + newStdout := newMockReadCloser(`{"jsonrpc": "2.0", "method": "test2", "params": {}}`) + + // Set up expectations + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + Return(true, nil). + Times(1) + + mockDeployer.EXPECT(). + AttachToWorkload(gomock.Any(), "test-container"). + Return(newStdin, newStdout, nil). + Times(1) + + // Create mock HTTP proxy + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: newMockWriteCloser(), + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Run processStdout in a goroutine + done := make(chan struct{}) + go func() { + transport.processStdout(ctx, mockStdout) + close(done) + }() + + // Wait for completion or timeout + select { + case <-done: + // Success - processStdout returned + case <-time.After(1 * time.Second): + t.Fatal("Test timed out waiting for processStdout to complete") + } + + // Verify that stdin and stdout were updated + transport.mutex.Lock() + assert.Equal(t, newStdin, transport.stdin) + assert.Equal(t, newStdout, transport.stdout) + transport.mutex.Unlock() +} + +func TestProcessStdout_EOFWithDockerUnavailable(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create mock stdout that will return EOF + mockStdout := newMockReadCloserWithEOF(`{"jsonrpc": "2.0", "method": "test", "params": {}}`) + + // Simulate Docker being unavailable on first check, then available + callCount := 0 + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + DoAndReturn(func(_ context.Context, _ string) (bool, error) { + callCount++ + if callCount == 1 { + // First call: Docker socket unavailable + return false, errors.New("EOF") + } + // Second call: Docker is back, container is running + return true, nil + }). + MinTimes(2) + + // Create new stdio streams for re-attachment + newStdin := newMockWriteCloser() + newStdout := newMockReadCloser(`{"jsonrpc": "2.0", "method": "test2", "params": {}}`) + + mockDeployer.EXPECT(). + AttachToWorkload(gomock.Any(), "test-container"). + Return(newStdin, newStdout, nil). + Times(1) + + // Create mock HTTP proxy + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: newMockWriteCloser(), + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Run processStdout in a goroutine + done := make(chan struct{}) + go func() { + transport.processStdout(ctx, mockStdout) + close(done) + }() + + // Wait for completion or timeout + select { + case <-done: + // Success - processStdout returned + case <-time.After(1 * time.Second): + t.Fatal("Test timed out waiting for processStdout to handle Docker restart") + } + + // Verify that stdin and stdout were updated after re-attachment + transport.mutex.Lock() + assert.Equal(t, newStdin, transport.stdin) + assert.Equal(t, newStdout, transport.stdout) + transport.mutex.Unlock() +} + +func TestProcessStdout_EOFWithContainerNotRunning(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create mock stdout that will return EOF + mockStdout := newMockReadCloserWithEOF(`{"jsonrpc": "2.0", "method": "test", "params": {}}`) + + // Set up expectations - container is not running + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + Return(false, nil). + Times(1) + + // Create mock HTTP proxy + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: newMockWriteCloser(), + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Run processStdout in a goroutine + done := make(chan struct{}) + go func() { + transport.processStdout(ctx, mockStdout) + close(done) + }() + + // Wait for completion or timeout + select { + case <-done: + // Success - processStdout returned + case <-time.After(500 * time.Millisecond): + t.Fatal("Test timed out") + } +} + +func TestProcessStdout_EOFWithFailedReattachment(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Use shorter timeout now that we have fast retries + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create mock stdout that will return EOF + mockStdout := newMockReadCloserWithEOF(`{"jsonrpc": "2.0", "method": "test", "params": {}}`) + + retryCount := 0 + // Set up expectations - container is running but re-attachment fails + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + DoAndReturn(func(_ context.Context, _ string) (bool, error) { + retryCount++ + return true, nil + }). + AnyTimes() + + mockDeployer.EXPECT(). + AttachToWorkload(gomock.Any(), "test-container"). + Return(nil, nil, errors.New("failed to attach")). + AnyTimes() + + // Create mock HTTP proxy + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: newMockWriteCloser(), + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Store original stdin/stdout + originalStdin := transport.stdin + + // Run processStdout in a goroutine + done := make(chan struct{}) + go func() { + transport.processStdout(ctx, mockStdout) + close(done) + }() + + // Wait for completion + select { + case <-done: + // Success - processStdout returned + case <-time.After(1 * time.Second): + t.Fatal("Test timed out waiting for context timeout") + } + + // Verify that we attempted at least one retry + assert.GreaterOrEqual(t, retryCount, 1, "Expected at least 1 retry attempt") + + // Verify that stdin/stdout were NOT updated since re-attachment failed + transport.mutex.Lock() + assert.Equal(t, originalStdin, transport.stdin) + transport.mutex.Unlock() +} + +func TestProcessStdout_EOFWithReattachmentRetryLogic(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create mock stdout that will return EOF + mockStdout := newMockReadCloserWithEOF(`{"jsonrpc": "2.0", "method": "test", "params": {}}`) + + // Track retry attempts + attemptCount := 0 + + // Set up expectations - fail first 2 attempts, succeed on 3rd + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + DoAndReturn(func(_ context.Context, _ string) (bool, error) { + attemptCount++ + if attemptCount <= 2 { + // First 2 attempts: connection refused (Docker restarting) + return false, errors.New("connection refused") + } + // Third attempt: success + return true, nil + }). + MinTimes(3) + + // Create new stdio streams for successful re-attachment + newStdin := newMockWriteCloser() + newStdout := newMockReadCloser(`{"jsonrpc": "2.0", "method": "test2", "params": {}}`) + + mockDeployer.EXPECT(). + AttachToWorkload(gomock.Any(), "test-container"). + Return(newStdin, newStdout, nil). + Times(1) + + // Create mock HTTP proxy + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: newMockWriteCloser(), + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Run processStdout in a goroutine + done := make(chan struct{}) + go func() { + transport.processStdout(ctx, mockStdout) + close(done) + }() + + // Wait for completion + select { + case <-done: + // Success - processStdout returned after retries + case <-time.After(1 * time.Second): + t.Fatal("Test timed out waiting for retry logic to complete") + } + + // Verify that we had multiple retry attempts + require.GreaterOrEqual(t, attemptCount, 3, "Expected at least 3 retry attempts") + + // Verify that stdin and stdout were eventually updated + transport.mutex.Lock() + assert.Equal(t, newStdin, transport.stdin) + assert.Equal(t, newStdout, transport.stdout) + transport.mutex.Unlock() +} + +func TestProcessStdout_EOFCheckErrorTypes(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + tests := []struct { + name string + checkError error + shouldRetry bool + contextTimeout time.Duration + }{ + { + name: "Docker socket EOF error triggers retry", + checkError: errors.New("EOF"), + shouldRetry: true, + contextTimeout: 500 * time.Millisecond, + }, + { + name: "Connection refused triggers retry", + checkError: errors.New("connection refused"), + shouldRetry: true, + contextTimeout: 500 * time.Millisecond, + }, + { + name: "Other errors still retry", + checkError: errors.New("some other error"), + shouldRetry: true, + contextTimeout: 500 * time.Millisecond, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), tt.contextTimeout) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create mock stdout that will return EOF + mockStdout := newMockReadCloserWithEOF(`{"jsonrpc": "2.0", "method": "test"}`) + + // Track how many times IsWorkloadRunning is called + callCount := 0 + + // Set up expectations - allow unlimited calls since we're testing retry behavior + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + DoAndReturn(func(_ context.Context, _ string) (bool, error) { + callCount++ + return false, tt.checkError + }). + AnyTimes() + + // Create mock HTTP proxy + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: newMockWriteCloser(), + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Run processStdout in a goroutine + done := make(chan struct{}) + go func() { + transport.processStdout(ctx, mockStdout) + close(done) + }() + + // Wait for completion + select { + case <-done: + // Success + case <-time.After(tt.contextTimeout + 500*time.Millisecond): + t.Fatal("Test timed out") + } + + // Verify we got at least one retry attempt + assert.GreaterOrEqual(t, callCount, 1, "Expected at least 1 retry attempt") + }) + } +} + +func TestConcurrentReattachment(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create new stdio streams for re-attachment + newStdin := newMockWriteCloser() + newStdout := newMockReadCloser(`{"jsonrpc": "2.0", "method": "test2", "params": {}}`) + + // Track how many times IsWorkloadRunning is called + var workloadCheckCount int + workloadCheckMutex := sync.Mutex{} + + // Set up expectations - container is running + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + DoAndReturn(func(_ context.Context, _ string) (bool, error) { + workloadCheckMutex.Lock() + workloadCheckCount++ + workloadCheckMutex.Unlock() + return true, nil + }). + AnyTimes() + + // Track how many times AttachToWorkload is called + var attachCount int + attachMutex := sync.Mutex{} + + mockDeployer.EXPECT(). + AttachToWorkload(gomock.Any(), "test-container"). + DoAndReturn(func(_ context.Context, _ string) (io.WriteCloser, io.ReadCloser, error) { + attachMutex.Lock() + attachCount++ + count := attachCount + attachMutex.Unlock() + + // Only succeed on the first call, fail subsequent concurrent calls + if count == 1 { + return newStdin, newStdout, nil + } + return nil, nil, errors.New("concurrent attachment in progress") + }). + AnyTimes() + + // Create mock HTTP proxy + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: newMockWriteCloser(), + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Run processStdout in multiple goroutines to simulate concurrent re-attachment attempts + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + // Each goroutine creates its own mock stdout that returns EOF + localStdout := newMockReadCloserWithEOF(fmt.Sprintf(`{"jsonrpc": "2.0", "method": "test%d", "params": {}}`, index)) + transport.processStdout(ctx, localStdout) + }(i) + } + + // Wait for all goroutines to complete + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + // Wait for completion or timeout + select { + case <-done: + // Success - all processStdout goroutines returned + case <-time.After(2 * time.Second): + t.Fatal("Test timed out waiting for concurrent re-attachment attempts") + } + + // Verify that stdin and stdout were updated + transport.mutex.Lock() + finalStdin := transport.stdin + finalStdout := transport.stdout + transport.mutex.Unlock() + + // Check that the transport was updated (at least one re-attachment succeeded) + assert.NotNil(t, finalStdin) + assert.NotNil(t, finalStdout) + + // Verify that multiple checks were made but only one successful attachment + workloadCheckMutex.Lock() + assert.GreaterOrEqual(t, workloadCheckCount, 1, "Expected at least 1 workload check") + workloadCheckMutex.Unlock() + + attachMutex.Lock() + // We expect at least one successful attachment + assert.GreaterOrEqual(t, attachCount, 1, "Expected at least 1 attachment attempt") + attachMutex.Unlock() +} + +func TestStdinRaceCondition(t *testing.T) { + t.Parallel() + + // Initialize logger + logger.Initialize() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Create mock deployer + mockDeployer := mocks.NewMockRuntime(ctrl) + + // Create initial stdin/stdout + initialStdin := newMockWriteCloser() + mockStdout := newMockReadCloserWithEOF(`{"jsonrpc": "2.0", "method": "test", "params": {}}`) + + // Create new stdio streams for re-attachment + newStdin := newMockWriteCloser() + newStdout := newMockReadCloser(`{"jsonrpc": "2.0", "method": "test2", "params": {}}`) + + // Set up expectations + mockDeployer.EXPECT(). + IsWorkloadRunning(gomock.Any(), "test-container"). + Return(true, nil). + AnyTimes() + + var attachCalled bool + mockDeployer.EXPECT(). + AttachToWorkload(gomock.Any(), "test-container"). + DoAndReturn(func(_ context.Context, _ string) (io.WriteCloser, io.ReadCloser, error) { + if attachCalled { + return nil, nil, errors.New("already attached") + } + attachCalled = true + // Add a small delay to increase chance of race condition + time.Sleep(10 * time.Millisecond) + return newStdin, newStdout, nil + }). + AnyTimes() + + // Create mock HTTP proxy with message channel + mockProxy := new(MockHTTPProxy) + mockProxy.On("ForwardResponseToClients", mock.Anything, mock.Anything).Return(nil).Maybe() + + messageCh := make(chan jsonrpc2.Message, 10) + mockProxy.On("GetMessageChannel").Return(messageCh) + + // Create transport with fast retry config for testing + transport := &StdioTransport{ + containerName: "test-container", + deployer: mockDeployer, + httpProxy: mockProxy, + stdin: initialStdin, + shutdownCh: make(chan struct{}), + retryConfig: testRetryConfig(), + } + + // Start processMessages which will handle incoming messages + go transport.processMessages(ctx, initialStdin, mockStdout) + + // Start processStdout which will trigger re-attachment + go transport.processStdout(ctx, mockStdout) + + // Send messages concurrently while re-attachment is happening + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + // Create a test message + msg, err := jsonrpc2.NewCall(jsonrpc2.StringID(fmt.Sprintf("msg-%d", index)), "test.method", nil) + if err != nil { + return + } + select { + case messageCh <- msg: + // Message sent successfully + case <-ctx.Done(): + // Context cancelled + case <-time.After(100 * time.Millisecond): + // Timeout + } + }(i) + } + + // Wait for all messages to be sent + wg.Wait() + + // Give some time for re-attachment to complete + time.Sleep(200 * time.Millisecond) + + // Verify that stdin was updated safely + transport.mutex.Lock() + finalStdin := transport.stdin + transport.mutex.Unlock() + + // The stdin should have been updated to the new one after re-attachment + // We can't directly compare pointers, but we can verify it's not nil + assert.NotNil(t, finalStdin, "stdin should not be nil after re-attachment") + + // Clean up + cancel() +}