@@ -30,6 +30,8 @@ type rowsNative struct {
3030 streamResult ydb_sdk_query.Result
3131 lastResultSet ydb_sdk_query.ResultSet
3232 lastRow ydb_sdk_query.Row
33+
34+ closeChan chan struct {}
3335}
3436
3537func (r * rowsNative ) Next () bool {
@@ -100,9 +102,7 @@ func (r *rowsNative) Err() error {
100102}
101103
102104func (r * rowsNative ) Close () error {
103- if err := r .streamResult .Close (r .ctx ); err != nil {
104- return fmt .Errorf ("stream result close: %w" , err )
105- }
105+ close (r .closeChan )
106106
107107 return nil
108108}
@@ -113,158 +113,185 @@ type connectionNative struct {
113113 dsi * api_common.TGenericDataSourceInstance
114114 logger * zap.Logger
115115 queryLoggerFactory common.QueryLoggerFactory
116- ctx context.Context
117116 driver * ydb_sdk.Driver
118117 tableName string
119118 formatter rdbms_utils.SQLFormatter
120119 resourcePool string
121120}
122121
123- // nolint: gocyclo
122+ // nolint: gocyclo,funlen
124123func (c * connectionNative ) Query (params * rdbms_utils.QueryParams ) (rdbms_utils.Rows , error ) {
125- rowsChan := make (chan rdbms_utils.Rows , 1 )
126-
127- finalErr := c .driver .Query ().Do (
128- params .Ctx ,
129- func (ctx context.Context , session ydb_sdk_query.Session ) (err error ) {
130- // modify query with args
131- queryRewritten , err := c .rewriteQuery (params )
132- if err != nil {
133- return fmt .Errorf ("rewrite query: %w" , err )
124+ // prepare parameter list
125+ paramsBuilder := ydb_sdk .ParamsBuilder ()
126+
127+ for i , arg := range params .QueryArgs .Values () {
128+ placeholder := c .formatter .GetPlaceholder (i )
129+
130+ switch t := arg .(type ) {
131+ case bool :
132+ paramsBuilder = paramsBuilder .Param (placeholder ).Bool (t )
133+ case * bool :
134+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Bool (t ).EndOptional ()
135+ case int8 :
136+ paramsBuilder = paramsBuilder .Param (placeholder ).Int8 (t )
137+ case * int8 :
138+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int8 (t ).EndOptional ()
139+ case int16 :
140+ paramsBuilder = paramsBuilder .Param (placeholder ).Int16 (t )
141+ case * int16 :
142+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int16 (t ).EndOptional ()
143+ case int32 :
144+ paramsBuilder = paramsBuilder .Param (placeholder ).Int32 (t )
145+ case * int32 :
146+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int32 (t ).EndOptional ()
147+ case int64 :
148+ paramsBuilder = paramsBuilder .Param (placeholder ).Int64 (t )
149+ case * int64 :
150+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int64 (t ).EndOptional ()
151+ case uint8 :
152+ paramsBuilder = paramsBuilder .Param (placeholder ).Uint8 (t )
153+ case * uint8 :
154+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint8 (t ).EndOptional ()
155+ case uint16 :
156+ paramsBuilder = paramsBuilder .Param (placeholder ).Uint16 (t )
157+ case * uint16 :
158+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint16 (t ).EndOptional ()
159+ case uint32 :
160+ paramsBuilder = paramsBuilder .Param (placeholder ).Uint32 (t )
161+ case * uint32 :
162+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint32 (t ).EndOptional ()
163+ case uint64 :
164+ paramsBuilder = paramsBuilder .Param (placeholder ).Uint64 (t )
165+ case * uint64 :
166+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint64 (t ).EndOptional ()
167+ case float32 :
168+ paramsBuilder = paramsBuilder .Param (placeholder ).Float (t )
169+ case * float32 :
170+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Float (t ).EndOptional ()
171+ case float64 :
172+ paramsBuilder = paramsBuilder .Param (placeholder ).Double (t )
173+ case * float64 :
174+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Double (t ).EndOptional ()
175+ case string :
176+ paramsBuilder = paramsBuilder .Param (placeholder ).Text (t )
177+ case * string :
178+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Text (t ).EndOptional ()
179+ case []byte :
180+ paramsBuilder = paramsBuilder .Param (placeholder ).Bytes (t )
181+ case * []byte :
182+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Bytes (t ).EndOptional ()
183+ case time.Time :
184+ switch params .QueryArgs .Get (i ).YdbType .GetTypeId () {
185+ case Ydb .Type_TIMESTAMP :
186+ paramsBuilder = paramsBuilder .Param (placeholder ).Timestamp (t )
187+ default :
188+ return nil , fmt .Errorf ("unsupported type: %v (%T): %w" , arg , arg , common .ErrUnimplementedPredicateType )
134189 }
190+ case * time.Time :
191+ switch params .QueryArgs .Get (i ).YdbType .GetOptionalType ().GetItem ().GetTypeId () {
192+ case Ydb .Type_TIMESTAMP :
193+ paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Timestamp (t ).EndOptional ()
194+ default :
195+ return nil , fmt .Errorf ("unsupported type: %v (%T): %w" , arg , arg , common .ErrUnimplementedPredicateType )
196+ }
197+ default :
198+ return nil , fmt .Errorf ("unsupported type: %v (%T): %w" , arg , arg , common .ErrUnimplementedPredicateType )
199+ }
200+ }
135201
136- // prepare parameter list
137- paramsBuilder := ydb_sdk .ParamsBuilder ()
138-
139- for i , arg := range params .QueryArgs .Values () {
140- placeholder := c .formatter .GetPlaceholder (i )
141-
142- switch t := arg .(type ) {
143- case bool :
144- paramsBuilder = paramsBuilder .Param (placeholder ).Bool (t )
145- case * bool :
146- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Bool (t ).EndOptional ()
147- case int8 :
148- paramsBuilder = paramsBuilder .Param (placeholder ).Int8 (t )
149- case * int8 :
150- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int8 (t ).EndOptional ()
151- case int16 :
152- paramsBuilder = paramsBuilder .Param (placeholder ).Int16 (t )
153- case * int16 :
154- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int16 (t ).EndOptional ()
155- case int32 :
156- paramsBuilder = paramsBuilder .Param (placeholder ).Int32 (t )
157- case * int32 :
158- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int32 (t ).EndOptional ()
159- case int64 :
160- paramsBuilder = paramsBuilder .Param (placeholder ).Int64 (t )
161- case * int64 :
162- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Int64 (t ).EndOptional ()
163- case uint8 :
164- paramsBuilder = paramsBuilder .Param (placeholder ).Uint8 (t )
165- case * uint8 :
166- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint8 (t ).EndOptional ()
167- case uint16 :
168- paramsBuilder = paramsBuilder .Param (placeholder ).Uint16 (t )
169- case * uint16 :
170- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint16 (t ).EndOptional ()
171- case uint32 :
172- paramsBuilder = paramsBuilder .Param (placeholder ).Uint32 (t )
173- case * uint32 :
174- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint32 (t ).EndOptional ()
175- case uint64 :
176- paramsBuilder = paramsBuilder .Param (placeholder ).Uint64 (t )
177- case * uint64 :
178- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Uint64 (t ).EndOptional ()
179- case float32 :
180- paramsBuilder = paramsBuilder .Param (placeholder ).Float (t )
181- case * float32 :
182- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Float (t ).EndOptional ()
183- case float64 :
184- paramsBuilder = paramsBuilder .Param (placeholder ).Double (t )
185- case * float64 :
186- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Double (t ).EndOptional ()
187- case string :
188- paramsBuilder = paramsBuilder .Param (placeholder ).Text (t )
189- case * string :
190- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Text (t ).EndOptional ()
191- case []byte :
192- paramsBuilder = paramsBuilder .Param (placeholder ).Bytes (t )
193- case * []byte :
194- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Bytes (t ).EndOptional ()
195- case time.Time :
196- switch params .QueryArgs .Get (i ).YdbType .GetTypeId () {
197- case Ydb .Type_TIMESTAMP :
198- paramsBuilder = paramsBuilder .Param (placeholder ).Timestamp (t )
199- default :
200- return fmt .Errorf ("unsupported type: %v (%T): %w" , arg , arg , common .ErrUnimplementedPredicateType )
201- }
202- case * time.Time :
203- switch params .QueryArgs .Get (i ).YdbType .GetOptionalType ().GetItem ().GetTypeId () {
204- case Ydb .Type_TIMESTAMP :
205- paramsBuilder = paramsBuilder .Param (placeholder ).BeginOptional ().Timestamp (t ).EndOptional ()
206- default :
207- return fmt .Errorf ("unsupported type: %v (%T): %w" , arg , arg , common .ErrUnimplementedPredicateType )
208- }
209- default :
210- return fmt .Errorf ("unsupported type: %v (%T): %w" , arg , arg , common .ErrUnimplementedPredicateType )
202+ type result struct {
203+ rows rdbms_utils.Rows
204+ err error
205+ }
206+
207+ // We cannot use the results of a query from outside of the SDK callback.
208+ // See https://github.com/ydb-platform/ydb-go-sdk/issues/1862 for details.
209+ resultChan := make (chan result )
210+ // context coming from the connector's clien (federated YDB)
211+ parentCtx := params .Ctx
212+
213+ go func () {
214+ finalErr := c .driver .Query ().Do (
215+ parentCtx ,
216+ func (ctx context.Context , session ydb_sdk_query.Session ) (err error ) {
217+ // modify query with args
218+ queryRewritten , err := c .rewriteQuery (params )
219+ if err != nil {
220+ return fmt .Errorf ("rewrite query: %w" , err )
211221 }
212- }
213222
214- queryLogger := c .queryLoggerFactory .Make (params .Logger , zap .String ("resource_pool" , c .resourcePool ))
215- queryLogger .Dump (queryRewritten , params .QueryArgs .Values ()... )
216-
217- // execute query
218- streamResult , err := session .Query (
219- ctx ,
220- queryRewritten ,
221- ydb_sdk_query .WithParameters (paramsBuilder .Build ()),
222- ydb_sdk_query .WithResourcePool (c .resourcePool ),
223- )
224- if err != nil {
225- return fmt .Errorf ("session query: %w" , err )
226- }
223+ queryLogger := c .queryLoggerFactory .Make (params .Logger , zap .String ("resource_pool" , c .resourcePool ))
224+ queryLogger .Dump (queryRewritten , params .QueryArgs .Values ()... )
225+
226+ // execute query
227+ streamResult , err := session .Query (
228+ ctx ,
229+ queryRewritten ,
230+ ydb_sdk_query .WithParameters (paramsBuilder .Build ()),
231+ ydb_sdk_query .WithResourcePool (c .resourcePool ),
232+ )
233+ if err != nil {
234+ return fmt .Errorf ("session query: %w" , err )
235+ }
227236
228- // obtain first result set because it's necessary
229- // to create type transformers
230- resultSet , err := streamResult .NextResultSet (ctx )
231- if err != nil {
232- if closeErr := streamResult .Close (ctx ); closeErr != nil {
233- params .Logger .Error ("close stream result" , zap .Error (closeErr ))
237+ defer func () {
238+ if closeErr := streamResult .Close (ctx ); closeErr != nil {
239+ params .Logger .Error ("close stream result" , zap .Error (closeErr ))
240+ }
241+ }()
242+
243+ // obtain first result set because it's necessary
244+ // to create type transformers
245+ resultSet , err := streamResult .NextResultSet (ctx )
246+ if err != nil {
247+ return fmt .Errorf ("next result set: %w" , err )
234248 }
235249
236- return fmt .Errorf ("next result set: %w" , err )
237- }
250+ rows := & rowsNative {
251+ ctx : parentCtx ,
252+ streamResult : streamResult ,
253+ lastResultSet : resultSet ,
254+ closeChan : make (chan struct {}),
255+ }
238256
239- rows := & rowsNative {
240- ctx : c .ctx ,
241- streamResult : streamResult ,
242- lastResultSet : resultSet ,
243- }
257+ // push iterator over GRPC stream into the outer space
258+ select {
259+ case resultChan <- result {rows : rows }:
260+ case <- ctx .Done ():
261+ return ctx .Err ()
262+ }
244263
245- select {
246- case rowsChan <- rows :
247- return nil
248- case <- ctx .Done ():
249- if closeErr := streamResult .Close (ctx ); closeErr != nil {
250- params .Logger .Error ("close stream result" , zap .Error (closeErr ))
264+ // Keep waiting until the rowsNative object is closed by a caller.
265+ // The context (and the rowsNative object) will be invalidated otherwise.
266+ select {
267+ case <- rows .closeChan :
268+ return nil
269+ case <- ctx .Done ():
270+ return ctx .Err ()
251271 }
272+ },
273+ ydb_sdk_query .WithIdempotent (),
274+ )
252275
253- return ctx .Err ()
276+ // If the error is not nil, that means that callback didn't return the result via channel,
277+ // so we need to write the error into the channel here.
278+ if finalErr != nil {
279+ select {
280+ case resultChan <- result {err : fmt .Errorf ("query do: %w" , finalErr )}:
281+ case <- parentCtx .Done ():
254282 }
255- },
256- ydb_sdk_query .WithIdempotent (),
257- )
258-
259- if finalErr != nil {
260- return nil , fmt .Errorf ("query do: %w" , finalErr )
261- }
283+ }
284+ }()
262285
263286 select {
264- case rows := <- rowsChan :
265- return rows , nil
266- case <- params .Ctx .Done ():
267- return nil , params .Ctx .Err ()
287+ case r := <- resultChan :
288+ if r .err != nil {
289+ return nil , r .err
290+ }
291+
292+ return r .rows , nil
293+ case <- parentCtx .Done ():
294+ return nil , parentCtx .Err ()
268295 }
269296}
270297
@@ -281,7 +308,10 @@ func (c *connectionNative) TableName() string {
281308}
282309
283310func (c * connectionNative ) Close () error {
284- if err := c .driver .Close (c .ctx ); err != nil {
311+ ctx , cancel := context .WithTimeout (context .Background (), 10 * time .Second )
312+ defer cancel ()
313+
314+ if err := c .driver .Close (ctx ); err != nil {
285315 return fmt .Errorf ("driver close: %w" , err )
286316 }
287317
@@ -331,7 +361,6 @@ func (c *connectionNative) Logger() *zap.Logger {
331361}
332362
333363func newConnectionNative (
334- ctx context.Context ,
335364 logger * zap.Logger ,
336365 queryLoggerFactory common.QueryLoggerFactory ,
337366 dsi * api_common.TGenericDataSourceInstance ,
@@ -341,7 +370,6 @@ func newConnectionNative(
341370 resourcePool string ,
342371) Connection {
343372 return & connectionNative {
344- ctx : ctx ,
345373 driver : driver ,
346374 logger : logger ,
347375 queryLoggerFactory : queryLoggerFactory ,
0 commit comments