Skip to content

Commit 1f0dc26

Browse files
committed
refactored closing of result stream
1 parent c476038 commit 1f0dc26

File tree

7 files changed

+54
-41
lines changed

7 files changed

+54
-41
lines changed

internal/query/client_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,6 @@ func TestClient(t *testing.T) {
13861386
},
13871387
},
13881388
}, nil)
1389-
stream.EXPECT().Recv().Return(nil, io.EOF)
13901389
client := NewMockQueryServiceClient(ctrl)
13911390
client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil)
13921391

@@ -1526,7 +1525,6 @@ func TestClient(t *testing.T) {
15261525
},
15271526
},
15281527
}, nil)
1529-
stream.EXPECT().Recv().Return(nil, io.EOF)
15301528
client := NewMockQueryServiceClient(ctrl)
15311529
client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil)
15321530

internal/query/execute_query.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ func execute(
140140

141141
r, err := newResult(ctx, stream, append(opts,
142142
withStatsCallback(settings.StatsCallback()),
143-
withOnClose(executeCancel),
143+
withOnClose(func(ctx context.Context) error {
144+
executeCancel()
145+
146+
return nil
147+
}),
144148
)...)
145149
if err != nil {
146150
return nil, xerrors.WithStackTrace(err)

internal/query/result.go

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"sync"
98

109
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
1110
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
@@ -15,6 +14,7 @@ import (
1514
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stats"
1615
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1716
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter"
17+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
1818
"github.com/ydb-platform/ydb-go-sdk/v3/query"
1919
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
2020
)
@@ -33,13 +33,13 @@ type (
3333
}
3434
streamResult struct {
3535
stream Ydb_Query_V1.QueryService_ExecuteQueryClient
36-
closeOnce func()
36+
closeOnce func(ctx context.Context) error
3737
lastPart *Ydb_Query.ExecuteQueryResponsePart
3838
resultSetIndex int64
3939
closed chan struct{}
4040
trace *trace.Query
4141
statsCallback func(queryStats stats.QueryStats)
42-
onClose []func()
42+
onClose []func(ctx context.Context) error
4343
onNextPartErr []func(err error)
4444
onTxMeta []func(txMeta *Ydb_Query.TransactionMeta)
4545
}
@@ -99,7 +99,7 @@ func withStatsCallback(callback func(queryStats stats.QueryStats)) resultOption
9999
}
100100
}
101101

102-
func withOnClose(onClose func()) resultOption {
102+
func withOnClose(onClose func(ctx context.Context) error) resultOption {
103103
return func(s *streamResult) {
104104
s.onClose = append(s.onClose, onClose)
105105
}
@@ -117,20 +117,15 @@ func onTxMeta(callback func(txMeta *Ydb_Query.TransactionMeta)) resultOption {
117117
}
118118
}
119119

120-
func newResult(
120+
func newResult( //nolint:funlen
121121
ctx context.Context,
122122
stream Ydb_Query_V1.QueryService_ExecuteQueryClient,
123123
opts ...resultOption,
124124
) (_ *streamResult, finalErr error) {
125125
var (
126126
closed = make(chan struct{})
127127
r = streamResult{
128-
stream: stream,
129-
onClose: []func(){
130-
func() {
131-
close(closed)
132-
},
133-
},
128+
stream: stream,
134129
closed: closed,
135130
resultSetIndex: -1,
136131
}
@@ -142,11 +137,20 @@ func newResult(
142137
}
143138
}
144139

145-
r.closeOnce = sync.OnceFunc(func() {
146-
for _, onClose := range r.onClose {
147-
onClose()
140+
r.closeOnce = xsync.OnceFunc(func(ctx context.Context) error {
141+
defer func() {
142+
close(closed)
143+
144+
r.stream = nil
145+
}()
146+
147+
for i := range r.onClose {
148+
if err := r.onClose[len(r.onClose)-i-1](ctx); err != nil {
149+
return xerrors.WithStackTrace(err)
150+
}
148151
}
149-
r.stream = nil
152+
153+
return nil
150154
})
151155

152156
if r.trace != nil {
@@ -195,7 +199,7 @@ func (r *streamResult) nextPart(ctx context.Context) (
195199
default:
196200
part, err = nextPart(r.stream)
197201
if err != nil {
198-
r.closeOnce()
202+
_ = r.closeOnce(ctx)
199203

200204
for _, callback := range r.onNextPartErr {
201205
callback(err)
@@ -226,8 +230,6 @@ func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) (
226230
}
227231

228232
func (r *streamResult) Close(ctx context.Context) (finalErr error) {
229-
defer r.closeOnce()
230-
231233
if r.trace != nil {
232234
onDone := trace.QueryOnResultClose(r.trace, &ctx,
233235
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).Close"),
@@ -237,21 +239,11 @@ func (r *streamResult) Close(ctx context.Context) (finalErr error) {
237239
}()
238240
}
239241

240-
for {
241-
select {
242-
case <-r.closed:
243-
return nil
244-
default:
245-
_, err := r.nextPart(ctx)
246-
if err != nil {
247-
if xerrors.Is(err, io.EOF) {
248-
return nil
249-
}
250-
251-
return xerrors.WithStackTrace(err)
252-
}
253-
}
242+
if err := r.closeOnce(ctx); err != nil {
243+
return xerrors.WithStackTrace(err)
254244
}
245+
246+
return nil
255247
}
256248

257249
func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err error) {
@@ -279,7 +271,8 @@ func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err err
279271
r.statsCallback(stats.FromQueryStats(part.GetExecStats()))
280272
}
281273
if part.GetResultSetIndex() < r.resultSetIndex {
282-
r.closeOnce()
274+
_ = r.closeOnce(ctx)
275+
283276
if part.GetResultSetIndex() <= 0 && r.resultSetIndex > 0 {
284277
return nil, xerrors.WithStackTrace(io.EOF)
285278
}

internal/query/result_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ func TestResultNextResultSet(t *testing.T) {
539539
require.EqualValues(t, 1, rs.rowIndex)
540540
}
541541
t.Log("explicit interrupt stream")
542-
r.closeOnce()
542+
_ = r.closeOnce(ctx)
543543
{
544544
t.Log("next (row=3)")
545545
_, err := rs.nextRow(context.Background())

internal/query/session.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ func (s *Session) execute(
144144
}
145145
}()
146146

147-
r, err := execute(ctx, s.ID(), s.client, q, settings, append(opts, withOnClose(cancel))...)
147+
r, err := execute(ctx, s.ID(), s.client, q, settings, append(opts, withOnClose(func(ctx context.Context) error {
148+
cancel()
149+
150+
return nil
151+
}))...)
148152
if err != nil {
149153
return nil, xerrors.WithStackTrace(err)
150154
}

internal/query/transaction_test.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ func TestTxOnCompleted(t *testing.T) {
274274

275275
res, err := tx.Query(sf.Context(e), "")
276276
require.NoError(t, err)
277+
_, err = res.NextResultSet(sf.Context(e))
278+
require.NoError(t, err)
279+
_, err = res.NextResultSet(sf.Context(e))
280+
require.ErrorIs(t, err, io.EOF)
277281
_ = res.Close(sf.Context(e))
278282
time.Sleep(time.Millisecond) // time for reaction for closing channel
279283
require.Empty(t, completed)
@@ -332,7 +336,12 @@ func TestTxOnCompleted(t *testing.T) {
332336
})
333337

334338
res, err := tx.Query(sf.Context(e), "", options.WithCommit())
335-
_ = res.Close(sf.Context(e))
339+
require.NoError(t, err)
340+
_, err = res.NextResultSet(sf.Context(e))
341+
require.NoError(t, err)
342+
_, err = res.NextResultSet(sf.Context(e))
343+
require.ErrorIs(t, err, io.EOF)
344+
err = res.Close(sf.Context(e))
336345
require.NoError(t, err)
337346
xtest.SpinWaitCondition(t, &completedMutex, func() bool {
338347
return len(completed) != 0
@@ -384,7 +393,10 @@ func TestTxOnCompleted(t *testing.T) {
384393
_, err = res.NextResultSet(sf.Context(e))
385394
require.NoError(t, err)
386395

387-
_ = res.Close(sf.Context(e))
396+
_, err = res.NextResultSet(sf.Context(e))
397+
require.ErrorIs(t, err, io.EOF)
398+
399+
err = res.Close(sf.Context(e))
388400
require.NoError(t, err)
389401
xtest.SpinWaitCondition(t, &completedMutex, func() bool {
390402
return len(completed) != 0

internal/xtest/leak.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ func findGoroutinesLeak() error {
6262
unexpectedGoroutines = append(unexpectedGoroutines, g)
6363
}
6464
if l := len(unexpectedGoroutines); l > 0 {
65-
return fmt.Errorf("found %d unexpected goroutines:\n%s", len(goroutines), strings.Join(goroutines, "\n"))
65+
return fmt.Errorf("found %d unexpected goroutines:\n\n%s",
66+
len(goroutines), strings.Join(goroutines, "\n\n"),
67+
)
6668
}
6769

6870
return nil

0 commit comments

Comments
 (0)