diff --git a/CHANGELOG.md b/CHANGELOG.md index bb0357e37..65b6d091b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v3.118.0 * Added support for nullable `Date32`, `Datetime64`, `Timestamp64`, and `Interval64` types in the `optional` parameter builder * Added method `query.WithIssuesHandler` to get query issues +* Fixed connection timeout issue in topics writer ## v3.117.2 * Added support for `Result.RowsAffected()` for YDB `database/sql` driver diff --git a/internal/topic/topicwriterinternal/writer_reconnector.go b/internal/topic/topicwriterinternal/writer_reconnector.go index 443024db0..266c12e74 100644 --- a/internal/topic/topicwriterinternal/writer_reconnector.go +++ b/internal/topic/topicwriterinternal/writer_reconnector.go @@ -440,7 +440,7 @@ func (w *WriterReconnector) connectionLoop(ctx context.Context) { attempt, ) - writer, err := w.startWriteStream(ctx, streamCtx) + writer, err := w.startWriteStream(streamCtx) w.onWriterChange(writer) onStreamError := onWriterStarted(err) if err == nil { @@ -485,18 +485,26 @@ func (w *WriterReconnector) handleReconnectRetry( return false } -func (w *WriterReconnector) startWriteStream(ctx, streamCtx context.Context) ( - writer *SingleStreamWriter, - err error, -) { - stream, err := w.connectWithTimeout(streamCtx) +func (w *WriterReconnector) startWriteStream(ctx context.Context) (writer *SingleStreamWriter, err error) { + // connectCtx with timeout applies only to the connection phase, + // allowing the main stream context to remain active after exiting this method + connectCtx, stopConnectCtx := xcontext.WithStoppableTimeoutCause(ctx, w.cfg.connectTimeout, errConnTimeout) + defer func() { + // If the context was cancelled during connection (the stream was cancelled), + // we should return a timeout error + if !stopConnectCtx() && err == nil { + err = context.Cause(connectCtx) + } + }() + + stream, err := w.connectWithTimeout(connectCtx) if err != nil { return nil, err } w.queue.ResetSentProgress() - return NewSingleStreamWriter(ctx, w.createWriterStreamConfig(stream)) + return NewSingleStreamWriter(connectCtx, w.createWriterStreamConfig(stream)) } func (w *WriterReconnector) needReceiveLastSeqNo() bool { @@ -505,45 +513,16 @@ func (w *WriterReconnector) needReceiveLastSeqNo() bool { return res } -func (w *WriterReconnector) connectWithTimeout(streamLifetimeContext context.Context) (RawTopicWriterStream, error) { - connectCtx, connectCancel := xcontext.WithCancel(streamLifetimeContext) - - type resT struct { - stream RawTopicWriterStream - err error - } - resCh := make(chan resT, 1) - - go func() { - defer func() { - p := recover() - if p != nil { - resCh <- resT{ - stream: nil, - err: xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf("ydb: panic while connect to topic writer: %+v", p))), - } - } - }() - - stream, err := w.cfg.Connect(connectCtx, w.cfg.Tracer) - resCh <- resT{stream: stream, err: err} +func (w *WriterReconnector) connectWithTimeout(ctx context.Context) (stream RawTopicWriterStream, err error) { + defer func() { + p := recover() + if p != nil { + stream = nil + err = xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf("ydb: panic while connect to topic writer: %+v", p))) + } }() - timer := time.NewTimer(w.cfg.connectTimeout) - defer timer.Stop() - - select { - case <-timer.C: - connectCancel() - - return nil, xerrors.WithStackTrace(errConnTimeout) - case res := <-resCh: - // force no cancel connect context - because it will break stream - // context will cancel by cancel streamLifetimeContext while reconnect or stop connection - _ = connectCancel - - return res.stream, res.err - } + return w.cfg.Connect(ctx, w.cfg.Tracer) } func (w *WriterReconnector) onAckReceived(count int) { diff --git a/internal/xcontext/context_with_stoppable_timeout.go b/internal/xcontext/context_with_stoppable_timeout.go new file mode 100644 index 000000000..d3b9a940f --- /dev/null +++ b/internal/xcontext/context_with_stoppable_timeout.go @@ -0,0 +1,27 @@ +package xcontext + +import ( + "context" + "time" +) + +// WithStoppableTimeoutCause returns a copy of the parent context that is cancelled with +// the specified cause after timeout elapses, and a stop function. Calling the stop function +// prevents the timeout from canceling the context and releases resources associated with it. +// The cause error will be used when the timeout triggers cancellation. +// +// The returned stop function returns a boolean value: +// - true if the timeout was successfully stopped before it fired (context was not cancelled by timeout) +// - false if the timeout already fired and the context was cancelled with the specified cause +func WithStoppableTimeoutCause(ctx context.Context, timeout time.Duration, cause error) (context.Context, func() bool) { + ctxWithCancel, cancel := context.WithCancelCause(ctx) + timeoutCtx, cancelTimeout := WithTimeout(ctx, timeout) + + stop := context.AfterFunc(timeoutCtx, func() { cancel(cause) }) + + return ctxWithCancel, func() bool { + defer cancelTimeout() + + return stop() + } +} diff --git a/internal/xcontext/context_with_stoppable_timeout_test.go b/internal/xcontext/context_with_stoppable_timeout_test.go new file mode 100644 index 000000000..84ba9de8f --- /dev/null +++ b/internal/xcontext/context_with_stoppable_timeout_test.go @@ -0,0 +1,67 @@ +//go:build go1.25 + +package xcontext_test + +import ( + "context" + "errors" + "testing" + "testing/synctest" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" +) + +func TestWithStoppableTimeoutCause(t *testing.T) { + wantErr := errors.New("some error") + + synctest.Test(t, func(t *testing.T) { + ctx, _ := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr) + select { + case <-time.After(100500 * time.Second): + t.Fatal("context should be done") + case <-ctx.Done(): + assert.ErrorIs(t, context.Cause(ctx), wantErr) + } + }) + + synctest.Test(t, func(t *testing.T) { + ctx, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr) + + stop() + + select { + case <-time.After(100500 * time.Second): + case <-ctx.Done(): + t.Fatal("context shouldn't be canceled") + } + }) + + synctest.Test(t, func(t *testing.T) { + _, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr) + + time.Sleep(1 * time.Second) + + assert.True(t, stop()) + }) + + synctest.Test(t, func(t *testing.T) { + _, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr) + + time.Sleep(1 * time.Second) + + stop() + + assert.False(t, stop()) + }) + + synctest.Test(t, func(t *testing.T) { + _, stop := xcontext.WithStoppableTimeoutCause(context.Background(), 10*time.Second, wantErr) + + time.Sleep(15 * time.Second) + + assert.False(t, stop()) + }) +}