Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions internal/pool/dial_conn_retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package pool

import (
"context"
"errors"
"net"
"sync/atomic"
"testing"
"time"
)

func TestDialConn_HangingDial_RetriesWithPerAttemptTimeout(t *testing.T) {
var calls atomic.Int32
var sawDeadline atomic.Int32

const (
dialTimeout = 50 * time.Millisecond
backoff = 10 * time.Millisecond
retries = 3
)

p := NewConnPool(&Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
calls.Add(1)

// Ensure each attempt has a deadline (pool applies DialTimeout per attempt).
if dl, ok := ctx.Deadline(); ok {
rem := time.Until(dl)
// Very generous bounds to avoid flakes.
if rem > 5*time.Millisecond && rem <= 2*dialTimeout {
sawDeadline.Add(1)
}
}

// Simulate a TCP connect hang: block until the context cancels.
<-ctx.Done()
return nil, ctx.Err()
},
PoolSize: 1,
MaxConcurrentDials: 1,
DialTimeout: dialTimeout,
DialerRetries: retries,
DialerRetryTimeout: backoff,
})
defer p.Close()

// Use a parent context with a bounded timeout so this test fails fast (instead of hanging)
// when dialConn does not apply per-attempt DialTimeout via context.
parentBudget := dialTimeout*time.Duration(retries) + backoff*time.Duration(retries-1) + 250*time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), parentBudget)
defer cancel()

start := time.Now()
_, err := p.dialConn(ctx, true)
elapsed := time.Since(start)

if err == nil {
t.Fatalf("expected error")
}
if got := calls.Load(); got != retries {
t.Fatalf("expected %d dial attempts, got %d", retries, got)
}
if got := sawDeadline.Load(); got != retries {
t.Fatalf("expected deadline on all attempts, got %d/%d", got, retries)
}

// Each attempt should wait ~dialTimeout, plus backoff between attempts.
// Allow wide bounds for CI noise.
min := dialTimeout*time.Duration(retries) + backoff*time.Duration(retries-1)
if elapsed < min/2 {
t.Fatalf("dialConn returned too quickly (%v < %v), retries/backoff may not have occurred", elapsed, min/2)
}
if elapsed > 5*min {
t.Fatalf("dialConn took too long (%v > %v), likely hung beyond expected timeouts", elapsed, 5*min)
}
}

func TestDialConn_DoesNotExtendEarlierParentDeadline(t *testing.T) {
var calls atomic.Int32

p := NewConnPool(&Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
calls.Add(1)
dl, ok := ctx.Deadline()
if !ok {
return nil, errors.New("expected deadline")
}
// Parent deadline should win (be soon).
if time.Until(dl) > 100*time.Millisecond {
return nil, errors.New("deadline was unexpectedly extended")
}
<-ctx.Done()
return nil, ctx.Err()
},
PoolSize: 1,
MaxConcurrentDials: 1,
DialTimeout: 500 * time.Millisecond,
DialerRetries: 1,
})
defer p.Close()

parent, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()

_, err := p.dialConn(parent, true)
if err == nil {
t.Fatalf("expected error")
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 dial attempt, got %d", got)
}
}

func TestDialConn_ContextCancelStopsFurtherRetries(t *testing.T) {
var calls atomic.Int32

p := NewConnPool(&Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
n := calls.Add(1)
if n == 1 {
// First attempt fails immediately; test cancels context to stop retries.
return nil, errors.New("dial failed")
}
return nil, errors.New("unexpected extra attempt")
},
PoolSize: 1,
MaxConcurrentDials: 1,
DialTimeout: 5 * time.Second,
DialerRetries: 5,
DialerRetryTimeout: 5 * time.Second,
})
defer p.Close()

ctx, cancel := context.WithCancel(context.Background())
// Cancel immediately after the first attempt fails by wrapping dialConn call.
// We do it via a goroutine so dialConn has a chance to enter the backoff select.
go func() {
// Give dialConn a moment to start. This avoids a race where ctx is already canceled
// before the first attempt and we wouldn't be testing the retry stop path.
time.Sleep(5 * time.Millisecond)
cancel()
}()

_, _ = p.dialConn(ctx, true)

if got := calls.Load(); got != 1 {
t.Fatalf("expected dialer to be called once after cancel, got %d", got)
}
}

func TestDialConn_DialTimeoutDisabled_DoesNotSetDeadline(t *testing.T) {
var calls atomic.Int32

p := NewConnPool(&Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
calls.Add(1)
if _, ok := ctx.Deadline(); ok {
return nil, errors.New("unexpected deadline when DialTimeout disabled")
}
return nil, errors.New("dial failed")
},
PoolSize: 1,
MaxConcurrentDials: 1,
DialTimeout: 0,
DialerRetries: 1,
})
defer p.Close()

_, err := p.dialConn(context.Background(), true)
if err == nil {
t.Fatalf("expected error")
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 dial attempt, got %d", got)
}
}

func TestDialConn_NoBackoffAfterLastAttempt(t *testing.T) {
var calls atomic.Int32
backoff := 300 * time.Millisecond

p := NewConnPool(&Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
calls.Add(1)
return nil, errors.New("dial failed")
},
PoolSize: 1,
MaxConcurrentDials: 1,
DialTimeout: 5 * time.Second,
DialerRetries: 1,
DialerRetryTimeout: backoff,
})
defer p.Close()

start := time.Now()
_, err := p.dialConn(context.Background(), true)
elapsed := time.Since(start)
if err == nil {
t.Fatalf("expected error")
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 dial attempt, got %d", got)
}
// If we slept after the last attempt, this will be ~backoff.
if elapsed >= backoff/2 {
t.Fatalf("dialConn took too long (%v); likely slept after last attempt (backoff=%v)", elapsed, backoff)
}
}
40 changes: 40 additions & 0 deletions internal/pool/dial_context_timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package pool

import (
"context"
"errors"
"net"
"testing"
"time"
)

// Ensures ConnPool applies DialTimeout per attempt via context (so dialing doesn't hang
// when a custom dialer ignores timeouts).
func TestDialConn_AppliesDialTimeoutPerAttemptViaContext(t *testing.T) {
p := NewConnPool(&Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
// Pool should apply DialTimeout per attempt via context.
dl, ok := ctx.Deadline()
if !ok {
return nil, errors.New("expected context deadline")
}
remaining := time.Until(dl)
// Allow slack for scheduling jitter.
if remaining <= 50*time.Millisecond || remaining > 250*time.Millisecond {
return nil, errors.New("unexpected context deadline duration")
}
return nil, errors.New("dial failed")
},
PoolSize: 1,
MaxConcurrentDials: 1,
DialTimeout: 200 * time.Millisecond,
PoolTimeout: 10 * time.Millisecond,
DialerRetries: 1,
})
defer p.Close()

_, err := p.newConn(context.Background(), true)
if err == nil {
t.Fatalf("expected error")
}
}
59 changes: 40 additions & 19 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,8 @@ func (p *ConnPool) checkMinIdleConns() {
}

func (p *ConnPool) addIdleConn() error {
ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)
defer cancel()

cn, err := p.dialConn(ctx, true)
// Do not apply DialTimeout via context here; dialConn applies DialTimeout per attempt.
cn, err := p.dialConn(context.Background(), true)
if err != nil {
return err
}
Expand Down Expand Up @@ -505,9 +503,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
ctx = context.Background()
}

dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout)
defer cancel()
cn, err := p.dialConn(dialCtx, pooled)
// Do not apply DialTimeout via context here; dialConn applies DialTimeout per attempt.
// We still propagate ctx so callers can cancel explicitly.
cn, err := p.dialConn(ctx, pooled)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -569,8 +567,8 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
}

// Retry dialing with backoff
// the context timeout is already handled by the context passed in
// so we may never reach the max retries, higher values don't hurt
// Dial timeout is applied per attempt (so retries/backoff don't eat into the next
// attempt's dial budget), while still honoring caller cancellation via ctx.
maxRetries := p.cfg.DialerRetries
if maxRetries <= 0 {
maxRetries = 5 // Default value
Expand All @@ -587,16 +585,31 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
// instead of a generic context deadline exceeded error
attempt := 0
for attempt = 0; (attempt < maxRetries) && shouldLoop; attempt++ {
netConn, err := p.cfg.Dialer(ctx)
attemptCtx := ctx
var cancel context.CancelFunc
if p.cfg.DialTimeout > 0 {
// Apply DialTimeout per attempt, but never extend an existing earlier deadline.
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > p.cfg.DialTimeout {
attemptCtx, cancel = context.WithTimeout(ctx, p.cfg.DialTimeout)
}
}

netConn, err := p.cfg.Dialer(attemptCtx)
if cancel != nil {
cancel()
}
if err != nil {
lastErr = err
// Add backoff delay for retry attempts
// (not for the first attempt, do at least one)
select {
case <-ctx.Done():
shouldLoop = false
case <-time.After(backoffDuration):
// Continue with retry
// Do not sleep after the last attempt.
if attempt+1 < maxRetries {
select {
case <-ctx.Done():
shouldLoop = false
case <-time.After(backoffDuration):
// Continue with retry
}
}
continue
}
Expand Down Expand Up @@ -648,19 +661,26 @@ func (p *ConnPool) tryDial() {
return
}

ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)
// Probe dialing even when dialErrorsNum is saturated. Apply DialTimeout per probe
// attempt so custom dialers can't hang indefinitely.
ctx := context.Background()
var cancel context.CancelFunc
if p.cfg.DialTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, p.cfg.DialTimeout)
}

conn, err := p.cfg.Dialer(ctx)
if cancel != nil {
cancel()
}
if err != nil {
p.setLastDialError(err)
time.Sleep(time.Second)
cancel()
continue
}

atomic.StoreUint32(&p.dialErrorsNum, 0)
_ = conn.Close()
cancel()
return
}
}
Expand Down Expand Up @@ -835,7 +855,8 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) {
return nil, ctx.Err()
}

dialCtx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)
// Don't apply DialTimeout via context here; dialConn applies DialTimeout per attempt.
dialCtx, cancel := context.WithCancel(context.Background())

w := &wantConn{
ctx: dialCtx,
Expand Down
Loading
Loading