diff --git a/internal/pool/dial_conn_retry_test.go b/internal/pool/dial_conn_retry_test.go new file mode 100644 index 000000000..61eba5b05 --- /dev/null +++ b/internal/pool/dial_conn_retry_test.go @@ -0,0 +1,208 @@ +package pool + +import ( + "context" + "errors" + "net" + "sync/atomic" + "testing" + "time" +) + +func TestDialConn_HangingDial_RetriesWithPerAttemptTimeout(t *testing.T) { + var calls atomic.Int32 + var sawDeadline atomic.Int32 + + const ( + dialTimeout = 50 * time.Millisecond + backoff = 10 * time.Millisecond + retries = 3 + ) + + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + calls.Add(1) + + // Ensure each attempt has a deadline (pool applies DialTimeout per attempt). + if dl, ok := ctx.Deadline(); ok { + rem := time.Until(dl) + // Very generous bounds to avoid flakes. + if rem > 5*time.Millisecond && rem <= 2*dialTimeout { + sawDeadline.Add(1) + } + } + + // Simulate a TCP connect hang: block until the context cancels. + <-ctx.Done() + return nil, ctx.Err() + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: dialTimeout, + DialerRetries: retries, + DialerRetryTimeout: backoff, + }) + defer p.Close() + + // Use a parent context with a bounded timeout so this test fails fast (instead of hanging) + // when dialConn does not apply per-attempt DialTimeout via context. + parentBudget := dialTimeout*time.Duration(retries) + backoff*time.Duration(retries-1) + 250*time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), parentBudget) + defer cancel() + + start := time.Now() + _, err := p.dialConn(ctx, true) + elapsed := time.Since(start) + + if err == nil { + t.Fatalf("expected error") + } + if got := calls.Load(); got != retries { + t.Fatalf("expected %d dial attempts, got %d", retries, got) + } + if got := sawDeadline.Load(); got != retries { + t.Fatalf("expected deadline on all attempts, got %d/%d", got, retries) + } + + // Each attempt should wait ~dialTimeout, plus backoff between attempts. + // Allow wide bounds for CI noise. + min := dialTimeout*time.Duration(retries) + backoff*time.Duration(retries-1) + if elapsed < min/2 { + t.Fatalf("dialConn returned too quickly (%v < %v), retries/backoff may not have occurred", elapsed, min/2) + } + if elapsed > 5*min { + t.Fatalf("dialConn took too long (%v > %v), likely hung beyond expected timeouts", elapsed, 5*min) + } +} + +func TestDialConn_DoesNotExtendEarlierParentDeadline(t *testing.T) { + var calls atomic.Int32 + + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + calls.Add(1) + dl, ok := ctx.Deadline() + if !ok { + return nil, errors.New("expected deadline") + } + // Parent deadline should win (be soon). + if time.Until(dl) > 100*time.Millisecond { + return nil, errors.New("deadline was unexpectedly extended") + } + <-ctx.Done() + return nil, ctx.Err() + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 500 * time.Millisecond, + DialerRetries: 1, + }) + defer p.Close() + + parent, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + _, err := p.dialConn(parent, true) + if err == nil { + t.Fatalf("expected error") + } + if got := calls.Load(); got != 1 { + t.Fatalf("expected 1 dial attempt, got %d", got) + } +} + +func TestDialConn_ContextCancelStopsFurtherRetries(t *testing.T) { + var calls atomic.Int32 + + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + n := calls.Add(1) + if n == 1 { + // First attempt fails immediately; test cancels context to stop retries. + return nil, errors.New("dial failed") + } + return nil, errors.New("unexpected extra attempt") + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 5 * time.Second, + DialerRetries: 5, + DialerRetryTimeout: 5 * time.Second, + }) + defer p.Close() + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel immediately after the first attempt fails by wrapping dialConn call. + // We do it via a goroutine so dialConn has a chance to enter the backoff select. + go func() { + // Give dialConn a moment to start. This avoids a race where ctx is already canceled + // before the first attempt and we wouldn't be testing the retry stop path. + time.Sleep(5 * time.Millisecond) + cancel() + }() + + _, _ = p.dialConn(ctx, true) + + if got := calls.Load(); got != 1 { + t.Fatalf("expected dialer to be called once after cancel, got %d", got) + } +} + +func TestDialConn_DialTimeoutDisabled_DoesNotSetDeadline(t *testing.T) { + var calls atomic.Int32 + + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + calls.Add(1) + if _, ok := ctx.Deadline(); ok { + return nil, errors.New("unexpected deadline when DialTimeout disabled") + } + return nil, errors.New("dial failed") + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 0, + DialerRetries: 1, + }) + defer p.Close() + + _, err := p.dialConn(context.Background(), true) + if err == nil { + t.Fatalf("expected error") + } + if got := calls.Load(); got != 1 { + t.Fatalf("expected 1 dial attempt, got %d", got) + } +} + +func TestDialConn_NoBackoffAfterLastAttempt(t *testing.T) { + var calls atomic.Int32 + backoff := 300 * time.Millisecond + + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + calls.Add(1) + return nil, errors.New("dial failed") + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 5 * time.Second, + DialerRetries: 1, + DialerRetryTimeout: backoff, + }) + defer p.Close() + + start := time.Now() + _, err := p.dialConn(context.Background(), true) + elapsed := time.Since(start) + if err == nil { + t.Fatalf("expected error") + } + if got := calls.Load(); got != 1 { + t.Fatalf("expected 1 dial attempt, got %d", got) + } + // If we slept after the last attempt, this will be ~backoff. + if elapsed >= backoff/2 { + t.Fatalf("dialConn took too long (%v); likely slept after last attempt (backoff=%v)", elapsed, backoff) + } +} diff --git a/internal/pool/dial_context_timeout_test.go b/internal/pool/dial_context_timeout_test.go new file mode 100644 index 000000000..3c5e736ae --- /dev/null +++ b/internal/pool/dial_context_timeout_test.go @@ -0,0 +1,40 @@ +package pool + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +// Ensures ConnPool applies DialTimeout per attempt via context (so dialing doesn't hang +// when a custom dialer ignores timeouts). +func TestDialConn_AppliesDialTimeoutPerAttemptViaContext(t *testing.T) { + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + // Pool should apply DialTimeout per attempt via context. + dl, ok := ctx.Deadline() + if !ok { + return nil, errors.New("expected context deadline") + } + remaining := time.Until(dl) + // Allow slack for scheduling jitter. + if remaining <= 50*time.Millisecond || remaining > 250*time.Millisecond { + return nil, errors.New("unexpected context deadline duration") + } + return nil, errors.New("dial failed") + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 200 * time.Millisecond, + PoolTimeout: 10 * time.Millisecond, + DialerRetries: 1, + }) + defer p.Close() + + _, err := p.newConn(context.Background(), true) + if err == nil { + t.Fatalf("expected error") + } +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index aaca530cb..d7d317954 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -456,10 +456,8 @@ func (p *ConnPool) checkMinIdleConns() { } func (p *ConnPool) addIdleConn() error { - ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) - defer cancel() - - cn, err := p.dialConn(ctx, true) + // Do not apply DialTimeout via context here; dialConn applies DialTimeout per attempt. + cn, err := p.dialConn(context.Background(), true) if err != nil { return err } @@ -505,9 +503,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { ctx = context.Background() } - dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout) - defer cancel() - cn, err := p.dialConn(dialCtx, pooled) + // Do not apply DialTimeout via context here; dialConn applies DialTimeout per attempt. + // We still propagate ctx so callers can cancel explicitly. + cn, err := p.dialConn(ctx, pooled) if err != nil { return nil, err } @@ -569,8 +567,8 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { } // Retry dialing with backoff - // the context timeout is already handled by the context passed in - // so we may never reach the max retries, higher values don't hurt + // Dial timeout is applied per attempt (so retries/backoff don't eat into the next + // attempt's dial budget), while still honoring caller cancellation via ctx. maxRetries := p.cfg.DialerRetries if maxRetries <= 0 { maxRetries = 5 // Default value @@ -587,16 +585,31 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { // instead of a generic context deadline exceeded error attempt := 0 for attempt = 0; (attempt < maxRetries) && shouldLoop; attempt++ { - netConn, err := p.cfg.Dialer(ctx) + attemptCtx := ctx + var cancel context.CancelFunc + if p.cfg.DialTimeout > 0 { + // Apply DialTimeout per attempt, but never extend an existing earlier deadline. + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > p.cfg.DialTimeout { + attemptCtx, cancel = context.WithTimeout(ctx, p.cfg.DialTimeout) + } + } + + netConn, err := p.cfg.Dialer(attemptCtx) + if cancel != nil { + cancel() + } if err != nil { lastErr = err // Add backoff delay for retry attempts // (not for the first attempt, do at least one) - select { - case <-ctx.Done(): - shouldLoop = false - case <-time.After(backoffDuration): - // Continue with retry + // Do not sleep after the last attempt. + if attempt+1 < maxRetries { + select { + case <-ctx.Done(): + shouldLoop = false + case <-time.After(backoffDuration): + // Continue with retry + } } continue } @@ -648,19 +661,26 @@ func (p *ConnPool) tryDial() { return } - ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) + // Probe dialing even when dialErrorsNum is saturated. Apply DialTimeout per probe + // attempt so custom dialers can't hang indefinitely. + ctx := context.Background() + var cancel context.CancelFunc + if p.cfg.DialTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, p.cfg.DialTimeout) + } conn, err := p.cfg.Dialer(ctx) + if cancel != nil { + cancel() + } if err != nil { p.setLastDialError(err) time.Sleep(time.Second) - cancel() continue } atomic.StoreUint32(&p.dialErrorsNum, 0) _ = conn.Close() - cancel() return } } @@ -835,7 +855,8 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { return nil, ctx.Err() } - dialCtx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) + // Don't apply DialTimeout via context here; dialConn applies DialTimeout per attempt. + dialCtx, cancel := context.WithCancel(context.Background()) w := &wantConn{ ctx: dialCtx, diff --git a/internal/pool/try_dial_test.go b/internal/pool/try_dial_test.go new file mode 100644 index 000000000..b20b62f1e --- /dev/null +++ b/internal/pool/try_dial_test.go @@ -0,0 +1,81 @@ +package pool + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +func TestTryDial_AppliesDialTimeoutWhenSet(t *testing.T) { + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + if _, ok := ctx.Deadline(); !ok { + return nil, errors.New("expected deadline in tryDial") + } + c1, c2 := net.Pipe() + _ = c2.Close() + return c1, nil + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 200 * time.Millisecond, + }) + defer p.Close() + + p.tryDial() +} + +func TestTryDial_DoesNotApplyDialTimeoutWhenDisabled(t *testing.T) { + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + if _, ok := ctx.Deadline(); ok { + return nil, errors.New("unexpected deadline in tryDial when DialTimeout disabled") + } + // Ensure context is still a real context. + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + c1, c2 := net.Pipe() + _ = c2.Close() + return c1, nil + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 0, + }) + defer p.Close() + + p.tryDial() +} + +func TestTryDial_RespectsPoolClose(t *testing.T) { + // If Dialer keeps failing, tryDial should exit once the pool is closed. + p := NewConnPool(&Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 10 * time.Millisecond, + }) + + done := make(chan struct{}) + go func() { + p.tryDial() + close(done) + }() + + time.Sleep(20 * time.Millisecond) + _ = p.Close() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("tryDial did not exit after pool close") + } +}