Skip to content

Commit 70b6863

Browse files
committed
streamResult.Close works only when context is alive
1 parent 1f0dc26 commit 70b6863

File tree

6 files changed

+43
-36
lines changed

6 files changed

+43
-36
lines changed

internal/query/client_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,7 @@ func TestClient(t *testing.T) {
13861386
},
13871387
},
13881388
}, nil)
1389+
stream.EXPECT().Recv().Return(nil, io.EOF)
13891390
client := NewMockQueryServiceClient(ctrl)
13901391
client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil)
13911392

@@ -1525,6 +1526,7 @@ func TestClient(t *testing.T) {
15251526
},
15261527
},
15271528
}, nil)
1529+
stream.EXPECT().Recv().Return(nil, io.EOF)
15281530
client := NewMockQueryServiceClient(ctrl)
15291531
client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil)
15301532

internal/query/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ var (
1818
ErrOptionNotForTxExecute = errors.New("option is not for execute on transaction")
1919
errExecuteOnCompletedTx = errors.New("execute on completed transaction")
2020
errSessionClosed = errors.New("session is closed")
21+
errStreamResultClosed = errors.New("stream result is closed")
2122
)

internal/query/execute_query.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,7 @@ func execute(
140140

141141
r, err := newResult(ctx, stream, append(opts,
142142
withStatsCallback(settings.StatsCallback()),
143-
withOnClose(func(ctx context.Context) error {
144-
executeCancel()
145-
146-
return nil
147-
}),
143+
withOnClose(executeCancel),
148144
)...)
149145
if err != nil {
150146
return nil, xerrors.WithStackTrace(err)

internal/query/result.go

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"io"
8+
"sync"
89

910
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
1011
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
@@ -14,7 +15,6 @@ import (
1415
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stats"
1516
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1617
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter"
17-
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
1818
"github.com/ydb-platform/ydb-go-sdk/v3/query"
1919
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
2020
)
@@ -33,13 +33,13 @@ type (
3333
}
3434
streamResult struct {
3535
stream Ydb_Query_V1.QueryService_ExecuteQueryClient
36-
closeOnce func(ctx context.Context) error
36+
closeOnce func()
3737
lastPart *Ydb_Query.ExecuteQueryResponsePart
3838
resultSetIndex int64
3939
closed chan struct{}
4040
trace *trace.Query
4141
statsCallback func(queryStats stats.QueryStats)
42-
onClose []func(ctx context.Context) error
42+
onClose []func()
4343
onNextPartErr []func(err error)
4444
onTxMeta []func(txMeta *Ydb_Query.TransactionMeta)
4545
}
@@ -99,7 +99,7 @@ func withStatsCallback(callback func(queryStats stats.QueryStats)) resultOption
9999
}
100100
}
101101

102-
func withOnClose(onClose func(ctx context.Context) error) resultOption {
102+
func withOnClose(onClose func()) resultOption {
103103
return func(s *streamResult) {
104104
s.onClose = append(s.onClose, onClose)
105105
}
@@ -117,7 +117,7 @@ func onTxMeta(callback func(txMeta *Ydb_Query.TransactionMeta)) resultOption {
117117
}
118118
}
119119

120-
func newResult( //nolint:funlen
120+
func newResult(
121121
ctx context.Context,
122122
stream Ydb_Query_V1.QueryService_ExecuteQueryClient,
123123
opts ...resultOption,
@@ -137,20 +137,8 @@ func newResult( //nolint:funlen
137137
}
138138
}
139139

140-
r.closeOnce = xsync.OnceFunc(func(ctx context.Context) error {
141-
defer func() {
142-
close(closed)
143-
144-
r.stream = nil
145-
}()
146-
147-
for i := range r.onClose {
148-
if err := r.onClose[len(r.onClose)-i-1](ctx); err != nil {
149-
return xerrors.WithStackTrace(err)
150-
}
151-
}
152-
153-
return nil
140+
r.closeOnce = sync.OnceFunc(func() {
141+
close(closed)
154142
})
155143

156144
if r.trace != nil {
@@ -197,9 +185,13 @@ func (r *streamResult) nextPart(ctx context.Context) (
197185
case <-r.closed:
198186
return nil, xerrors.WithStackTrace(io.EOF)
199187
default:
188+
if r.stream == nil {
189+
return nil, xerrors.WithStackTrace(io.EOF)
190+
}
191+
200192
part, err = nextPart(r.stream)
201193
if err != nil {
202-
_ = r.closeOnce(ctx)
194+
r.stream = nil
203195

204196
for _, callback := range r.onNextPartErr {
205197
callback(err)
@@ -239,11 +231,31 @@ func (r *streamResult) Close(ctx context.Context) (finalErr error) {
239231
}()
240232
}
241233

242-
if err := r.closeOnce(ctx); err != nil {
243-
return xerrors.WithStackTrace(err)
244-
}
234+
defer func() {
235+
r.closeOnce()
245236

246-
return nil
237+
for i := range r.onClose {
238+
r.onClose[len(r.onClose)-i-1]()
239+
}
240+
}()
241+
242+
for {
243+
select {
244+
case <-ctx.Done():
245+
return xerrors.WithStackTrace(ctx.Err())
246+
case <-r.closed:
247+
return xerrors.WithStackTrace(errStreamResultClosed)
248+
default:
249+
_, err := r.nextPart(ctx)
250+
if err != nil {
251+
if xerrors.Is(err, io.EOF) {
252+
return nil
253+
}
254+
255+
return xerrors.WithStackTrace(err)
256+
}
257+
}
258+
}
247259
}
248260

249261
func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err error) {
@@ -271,7 +283,7 @@ func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err err
271283
r.statsCallback(stats.FromQueryStats(part.GetExecStats()))
272284
}
273285
if part.GetResultSetIndex() < r.resultSetIndex {
274-
_ = r.closeOnce(ctx)
286+
r.stream = nil
275287

276288
if part.GetResultSetIndex() <= 0 && r.resultSetIndex > 0 {
277289
return nil, xerrors.WithStackTrace(io.EOF)

internal/query/result_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ func TestResultNextResultSet(t *testing.T) {
539539
require.EqualValues(t, 1, rs.rowIndex)
540540
}
541541
t.Log("explicit interrupt stream")
542-
_ = r.closeOnce(ctx)
542+
r.closeOnce()
543543
{
544544
t.Log("next (row=3)")
545545
_, err := rs.nextRow(context.Background())

internal/query/session.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,7 @@ func (s *Session) execute(
144144
}
145145
}()
146146

147-
r, err := execute(ctx, s.ID(), s.client, q, settings, append(opts, withOnClose(func(ctx context.Context) error {
148-
cancel()
149-
150-
return nil
151-
}))...)
147+
r, err := execute(ctx, s.ID(), s.client, q, settings, append(opts, withOnClose(cancel))...)
152148
if err != nil {
153149
return nil, xerrors.WithStackTrace(err)
154150
}

0 commit comments

Comments
 (0)