Skip to content

Commit e815415

Browse files
authored
Topics: extend connectTimeout to cover InitRequest + InitResponse (#1922)
1 parent 98f0612 commit e815415

File tree

4 files changed

+118
-44
lines changed

4 files changed

+118
-44
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Fixed connection timeout issue in topics writer
12
* Supported `sql.Null*` from `database/sql` as query params in `toValue` func
23

34
## v3.118.0

internal/topic/topicwriterinternal/writer_reconnector.go

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ func (w *WriterReconnector) connectionLoop(ctx context.Context) {
440440
attempt,
441441
)
442442

443-
writer, err := w.startWriteStream(ctx, streamCtx)
443+
writer, err := w.startWriteStream(streamCtx)
444444
w.onWriterChange(writer)
445445
onStreamError := onWriterStarted(err)
446446
if err == nil {
@@ -485,18 +485,26 @@ func (w *WriterReconnector) handleReconnectRetry(
485485
return false
486486
}
487487

488-
func (w *WriterReconnector) startWriteStream(ctx, streamCtx context.Context) (
489-
writer *SingleStreamWriter,
490-
err error,
491-
) {
492-
stream, err := w.connectWithTimeout(streamCtx)
488+
func (w *WriterReconnector) startWriteStream(ctx context.Context) (writer *SingleStreamWriter, err error) {
489+
// connectCtx with timeout applies only to the connection phase,
490+
// allowing the main stream context to remain active after exiting this method
491+
connectCtx, stopConnectCtx := xcontext.WithStoppableTimeoutCause(ctx, w.cfg.connectTimeout, errConnTimeout)
492+
defer func() {
493+
// If the context was cancelled during connection (the stream was cancelled),
494+
// we should return a timeout error
495+
if !stopConnectCtx() && err == nil {
496+
err = context.Cause(connectCtx)
497+
}
498+
}()
499+
500+
stream, err := w.connectWithTimeout(connectCtx)
493501
if err != nil {
494502
return nil, err
495503
}
496504

497505
w.queue.ResetSentProgress()
498506

499-
return NewSingleStreamWriter(ctx, w.createWriterStreamConfig(stream))
507+
return NewSingleStreamWriter(connectCtx, w.createWriterStreamConfig(stream))
500508
}
501509

502510
func (w *WriterReconnector) needReceiveLastSeqNo() bool {
@@ -505,45 +513,16 @@ func (w *WriterReconnector) needReceiveLastSeqNo() bool {
505513
return res
506514
}
507515

508-
func (w *WriterReconnector) connectWithTimeout(streamLifetimeContext context.Context) (RawTopicWriterStream, error) {
509-
connectCtx, connectCancel := xcontext.WithCancel(streamLifetimeContext)
510-
511-
type resT struct {
512-
stream RawTopicWriterStream
513-
err error
514-
}
515-
resCh := make(chan resT, 1)
516-
517-
go func() {
518-
defer func() {
519-
p := recover()
520-
if p != nil {
521-
resCh <- resT{
522-
stream: nil,
523-
err: xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf("ydb: panic while connect to topic writer: %+v", p))),
524-
}
525-
}
526-
}()
527-
528-
stream, err := w.cfg.Connect(connectCtx, w.cfg.Tracer)
529-
resCh <- resT{stream: stream, err: err}
516+
func (w *WriterReconnector) connectWithTimeout(ctx context.Context) (stream RawTopicWriterStream, err error) {
517+
defer func() {
518+
p := recover()
519+
if p != nil {
520+
stream = nil
521+
err = xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf("ydb: panic while connect to topic writer: %+v", p)))
522+
}
530523
}()
531524

532-
timer := time.NewTimer(w.cfg.connectTimeout)
533-
defer timer.Stop()
534-
535-
select {
536-
case <-timer.C:
537-
connectCancel()
538-
539-
return nil, xerrors.WithStackTrace(errConnTimeout)
540-
case res := <-resCh:
541-
// force no cancel connect context - because it will break stream
542-
// context will cancel by cancel streamLifetimeContext while reconnect or stop connection
543-
_ = connectCancel
544-
545-
return res.stream, res.err
546-
}
525+
return w.cfg.Connect(ctx, w.cfg.Tracer)
547526
}
548527

549528
func (w *WriterReconnector) onAckReceived(count int) {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package xcontext
2+
3+
import (
4+
"context"
5+
"time"
6+
)
7+
8+
// WithStoppableTimeoutCause returns a copy of the parent context that is cancelled with
9+
// the specified cause after timeout elapses, and a stop function. Calling the stop function
10+
// prevents the timeout from canceling the context and releases resources associated with it.
11+
// The cause error will be used when the timeout triggers cancellation.
12+
//
13+
// The returned stop function returns a boolean value:
14+
// - true if the timeout was successfully stopped before it fired (context was not cancelled by timeout)
15+
// - false if the timeout already fired and the context was cancelled with the specified cause
16+
func WithStoppableTimeoutCause(ctx context.Context, timeout time.Duration, cause error) (context.Context, func() bool) {
17+
ctxWithCancel, cancel := context.WithCancelCause(ctx)
18+
timeoutCtx, cancelTimeout := WithTimeout(ctx, timeout)
19+
20+
stop := context.AfterFunc(timeoutCtx, func() { cancel(cause) })
21+
22+
return ctxWithCancel, func() bool {
23+
defer cancelTimeout()
24+
25+
return stop()
26+
}
27+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//go:build go1.25
2+
3+
package xcontext_test
4+
5+
import (
6+
"context"
7+
"errors"
8+
"testing"
9+
"testing/synctest"
10+
"time"
11+
12+
"github.com/stretchr/testify/assert"
13+
14+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
15+
)
16+
17+
func TestWithStoppableTimeoutCause(t *testing.T) {
18+
wantErr := errors.New("some error")
19+
20+
synctest.Test(t, func(t *testing.T) {
21+
ctx, _ := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr)
22+
select {
23+
case <-time.After(100500 * time.Second):
24+
t.Fatal("context should be done")
25+
case <-ctx.Done():
26+
assert.ErrorIs(t, context.Cause(ctx), wantErr)
27+
}
28+
})
29+
30+
synctest.Test(t, func(t *testing.T) {
31+
ctx, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr)
32+
33+
stop()
34+
35+
select {
36+
case <-time.After(100500 * time.Second):
37+
case <-ctx.Done():
38+
t.Fatal("context shouldn't be canceled")
39+
}
40+
})
41+
42+
synctest.Test(t, func(t *testing.T) {
43+
_, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr)
44+
45+
time.Sleep(1 * time.Second)
46+
47+
assert.True(t, stop())
48+
})
49+
50+
synctest.Test(t, func(t *testing.T) {
51+
_, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr)
52+
53+
time.Sleep(1 * time.Second)
54+
55+
stop()
56+
57+
assert.False(t, stop())
58+
})
59+
60+
synctest.Test(t, func(t *testing.T) {
61+
_, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr)
62+
63+
time.Sleep(15 * time.Second)
64+
65+
assert.False(t, stop())
66+
})
67+
}

0 commit comments

Comments
 (0)