55 "errors"
66 "fmt"
77 "io"
8- "sync"
98
109 "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
1110 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
@@ -15,6 +14,7 @@ import (
1514 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stats"
1615 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1716 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter"
17+ "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
1818 "github.com/ydb-platform/ydb-go-sdk/v3/query"
1919 "github.com/ydb-platform/ydb-go-sdk/v3/trace"
2020)
@@ -33,13 +33,13 @@ type (
3333 }
3434 streamResult struct {
3535 stream Ydb_Query_V1.QueryService_ExecuteQueryClient
36- closeOnce func ()
36+ closeOnce func (ctx context. Context ) error
3737 lastPart * Ydb_Query.ExecuteQueryResponsePart
3838 resultSetIndex int64
3939 closed chan struct {}
4040 trace * trace.Query
4141 statsCallback func (queryStats stats.QueryStats )
42- onClose []func ()
42+ onClose []func (ctx context. Context ) error
4343 onNextPartErr []func (err error )
4444 onTxMeta []func (txMeta * Ydb_Query.TransactionMeta )
4545 }
@@ -99,7 +99,7 @@ func withStatsCallback(callback func(queryStats stats.QueryStats)) resultOption
9999 }
100100}
101101
102- func withOnClose (onClose func () ) resultOption {
102+ func withOnClose (onClose func (ctx context. Context ) error ) resultOption {
103103 return func (s * streamResult ) {
104104 s .onClose = append (s .onClose , onClose )
105105 }
@@ -117,20 +117,15 @@ func onTxMeta(callback func(txMeta *Ydb_Query.TransactionMeta)) resultOption {
117117 }
118118}
119119
120- func newResult (
120+ func newResult ( //nolint:funlen
121121 ctx context.Context ,
122122 stream Ydb_Query_V1.QueryService_ExecuteQueryClient ,
123123 opts ... resultOption ,
124124) (_ * streamResult , finalErr error ) {
125125 var (
126126 closed = make (chan struct {})
127127 r = streamResult {
128- stream : stream ,
129- onClose : []func (){
130- func () {
131- close (closed )
132- },
133- },
128+ stream : stream ,
134129 closed : closed ,
135130 resultSetIndex : - 1 ,
136131 }
@@ -142,11 +137,20 @@ func newResult(
142137 }
143138 }
144139
145- r .closeOnce = sync .OnceFunc (func () {
146- for _ , onClose := range r .onClose {
147- onClose ()
140+ r .closeOnce = xsync .OnceFunc (func (ctx context.Context ) error {
141+ defer func () {
142+ close (closed )
143+
144+ r .stream = nil
145+ }()
146+
147+ for i := range r .onClose {
148+ if err := r .onClose [len (r .onClose )- i - 1 ](ctx ); err != nil {
149+ return xerrors .WithStackTrace (err )
150+ }
148151 }
149- r .stream = nil
152+
153+ return nil
150154 })
151155
152156 if r .trace != nil {
@@ -195,7 +199,7 @@ func (r *streamResult) nextPart(ctx context.Context) (
195199 default :
196200 part , err = nextPart (r .stream )
197201 if err != nil {
198- r .closeOnce ()
202+ _ = r .closeOnce (ctx )
199203
200204 for _ , callback := range r .onNextPartErr {
201205 callback (err )
@@ -226,8 +230,6 @@ func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) (
226230}
227231
228232func (r * streamResult ) Close (ctx context.Context ) (finalErr error ) {
229- defer r .closeOnce ()
230-
231233 if r .trace != nil {
232234 onDone := trace .QueryOnResultClose (r .trace , & ctx ,
233235 stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).Close" ),
@@ -237,21 +239,11 @@ func (r *streamResult) Close(ctx context.Context) (finalErr error) {
237239 }()
238240 }
239241
240- for {
241- select {
242- case <- r .closed :
243- return nil
244- default :
245- _ , err := r .nextPart (ctx )
246- if err != nil {
247- if xerrors .Is (err , io .EOF ) {
248- return nil
249- }
250-
251- return xerrors .WithStackTrace (err )
252- }
253- }
242+ if err := r .closeOnce (ctx ); err != nil {
243+ return xerrors .WithStackTrace (err )
254244 }
245+
246+ return nil
255247}
256248
257249func (r * streamResult ) nextResultSet (ctx context.Context ) (_ * resultSet , err error ) {
@@ -279,7 +271,8 @@ func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err err
279271 r .statsCallback (stats .FromQueryStats (part .GetExecStats ()))
280272 }
281273 if part .GetResultSetIndex () < r .resultSetIndex {
282- r .closeOnce ()
274+ _ = r .closeOnce (ctx )
275+
283276 if part .GetResultSetIndex () <= 0 && r .resultSetIndex > 0 {
284277 return nil , xerrors .WithStackTrace (io .EOF )
285278 }
0 commit comments