Skip to content

Commit 85deba3

Browse files
authored
YDB: more proper work with contexts and SDK objects lifetime (#333)
1 parent cf4dcda commit 85deba3

File tree

5 files changed

+169
-140
lines changed

5 files changed

+169
-140
lines changed

app/server/datasource/rdbms/ydb/connection_manager.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ func (c *connectionManager) Make(
103103

104104
formatter := NewSQLFormatter(config.TYdbConfig_MODE_QUERY_SERVICE_NATIVE, c.cfg.Pushdown)
105105
ydbConn = newConnectionNative(
106-
ctx,
107106
logger,
108107
c.QueryLoggerFactory,
109108
dsi,

app/server/datasource/rdbms/ydb/connection_native.go

Lines changed: 165 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ type rowsNative struct {
3030
streamResult ydb_sdk_query.Result
3131
lastResultSet ydb_sdk_query.ResultSet
3232
lastRow ydb_sdk_query.Row
33+
34+
closeChan chan struct{}
3335
}
3436

3537
func (r *rowsNative) Next() bool {
@@ -100,9 +102,7 @@ func (r *rowsNative) Err() error {
100102
}
101103

102104
func (r *rowsNative) Close() error {
103-
if err := r.streamResult.Close(r.ctx); err != nil {
104-
return fmt.Errorf("stream result close: %w", err)
105-
}
105+
close(r.closeChan)
106106

107107
return nil
108108
}
@@ -113,158 +113,185 @@ type connectionNative struct {
113113
dsi *api_common.TGenericDataSourceInstance
114114
logger *zap.Logger
115115
queryLoggerFactory common.QueryLoggerFactory
116-
ctx context.Context
117116
driver *ydb_sdk.Driver
118117
tableName string
119118
formatter rdbms_utils.SQLFormatter
120119
resourcePool string
121120
}
122121

123-
// nolint: gocyclo
122+
// nolint: gocyclo,funlen
124123
func (c *connectionNative) Query(params *rdbms_utils.QueryParams) (rdbms_utils.Rows, error) {
125-
rowsChan := make(chan rdbms_utils.Rows, 1)
126-
127-
finalErr := c.driver.Query().Do(
128-
params.Ctx,
129-
func(ctx context.Context, session ydb_sdk_query.Session) (err error) {
130-
// modify query with args
131-
queryRewritten, err := c.rewriteQuery(params)
132-
if err != nil {
133-
return fmt.Errorf("rewrite query: %w", err)
124+
// prepare parameter list
125+
paramsBuilder := ydb_sdk.ParamsBuilder()
126+
127+
for i, arg := range params.QueryArgs.Values() {
128+
placeholder := c.formatter.GetPlaceholder(i)
129+
130+
switch t := arg.(type) {
131+
case bool:
132+
paramsBuilder = paramsBuilder.Param(placeholder).Bool(t)
133+
case *bool:
134+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Bool(t).EndOptional()
135+
case int8:
136+
paramsBuilder = paramsBuilder.Param(placeholder).Int8(t)
137+
case *int8:
138+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int8(t).EndOptional()
139+
case int16:
140+
paramsBuilder = paramsBuilder.Param(placeholder).Int16(t)
141+
case *int16:
142+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int16(t).EndOptional()
143+
case int32:
144+
paramsBuilder = paramsBuilder.Param(placeholder).Int32(t)
145+
case *int32:
146+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int32(t).EndOptional()
147+
case int64:
148+
paramsBuilder = paramsBuilder.Param(placeholder).Int64(t)
149+
case *int64:
150+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int64(t).EndOptional()
151+
case uint8:
152+
paramsBuilder = paramsBuilder.Param(placeholder).Uint8(t)
153+
case *uint8:
154+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint8(t).EndOptional()
155+
case uint16:
156+
paramsBuilder = paramsBuilder.Param(placeholder).Uint16(t)
157+
case *uint16:
158+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint16(t).EndOptional()
159+
case uint32:
160+
paramsBuilder = paramsBuilder.Param(placeholder).Uint32(t)
161+
case *uint32:
162+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint32(t).EndOptional()
163+
case uint64:
164+
paramsBuilder = paramsBuilder.Param(placeholder).Uint64(t)
165+
case *uint64:
166+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint64(t).EndOptional()
167+
case float32:
168+
paramsBuilder = paramsBuilder.Param(placeholder).Float(t)
169+
case *float32:
170+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Float(t).EndOptional()
171+
case float64:
172+
paramsBuilder = paramsBuilder.Param(placeholder).Double(t)
173+
case *float64:
174+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Double(t).EndOptional()
175+
case string:
176+
paramsBuilder = paramsBuilder.Param(placeholder).Text(t)
177+
case *string:
178+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Text(t).EndOptional()
179+
case []byte:
180+
paramsBuilder = paramsBuilder.Param(placeholder).Bytes(t)
181+
case *[]byte:
182+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Bytes(t).EndOptional()
183+
case time.Time:
184+
switch params.QueryArgs.Get(i).YdbType.GetTypeId() {
185+
case Ydb.Type_TIMESTAMP:
186+
paramsBuilder = paramsBuilder.Param(placeholder).Timestamp(t)
187+
default:
188+
return nil, fmt.Errorf("unsupported type: %v (%T): %w", arg, arg, common.ErrUnimplementedPredicateType)
134189
}
190+
case *time.Time:
191+
switch params.QueryArgs.Get(i).YdbType.GetOptionalType().GetItem().GetTypeId() {
192+
case Ydb.Type_TIMESTAMP:
193+
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Timestamp(t).EndOptional()
194+
default:
195+
return nil, fmt.Errorf("unsupported type: %v (%T): %w", arg, arg, common.ErrUnimplementedPredicateType)
196+
}
197+
default:
198+
return nil, fmt.Errorf("unsupported type: %v (%T): %w", arg, arg, common.ErrUnimplementedPredicateType)
199+
}
200+
}
135201

136-
// prepare parameter list
137-
paramsBuilder := ydb_sdk.ParamsBuilder()
138-
139-
for i, arg := range params.QueryArgs.Values() {
140-
placeholder := c.formatter.GetPlaceholder(i)
141-
142-
switch t := arg.(type) {
143-
case bool:
144-
paramsBuilder = paramsBuilder.Param(placeholder).Bool(t)
145-
case *bool:
146-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Bool(t).EndOptional()
147-
case int8:
148-
paramsBuilder = paramsBuilder.Param(placeholder).Int8(t)
149-
case *int8:
150-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int8(t).EndOptional()
151-
case int16:
152-
paramsBuilder = paramsBuilder.Param(placeholder).Int16(t)
153-
case *int16:
154-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int16(t).EndOptional()
155-
case int32:
156-
paramsBuilder = paramsBuilder.Param(placeholder).Int32(t)
157-
case *int32:
158-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int32(t).EndOptional()
159-
case int64:
160-
paramsBuilder = paramsBuilder.Param(placeholder).Int64(t)
161-
case *int64:
162-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Int64(t).EndOptional()
163-
case uint8:
164-
paramsBuilder = paramsBuilder.Param(placeholder).Uint8(t)
165-
case *uint8:
166-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint8(t).EndOptional()
167-
case uint16:
168-
paramsBuilder = paramsBuilder.Param(placeholder).Uint16(t)
169-
case *uint16:
170-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint16(t).EndOptional()
171-
case uint32:
172-
paramsBuilder = paramsBuilder.Param(placeholder).Uint32(t)
173-
case *uint32:
174-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint32(t).EndOptional()
175-
case uint64:
176-
paramsBuilder = paramsBuilder.Param(placeholder).Uint64(t)
177-
case *uint64:
178-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Uint64(t).EndOptional()
179-
case float32:
180-
paramsBuilder = paramsBuilder.Param(placeholder).Float(t)
181-
case *float32:
182-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Float(t).EndOptional()
183-
case float64:
184-
paramsBuilder = paramsBuilder.Param(placeholder).Double(t)
185-
case *float64:
186-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Double(t).EndOptional()
187-
case string:
188-
paramsBuilder = paramsBuilder.Param(placeholder).Text(t)
189-
case *string:
190-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Text(t).EndOptional()
191-
case []byte:
192-
paramsBuilder = paramsBuilder.Param(placeholder).Bytes(t)
193-
case *[]byte:
194-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Bytes(t).EndOptional()
195-
case time.Time:
196-
switch params.QueryArgs.Get(i).YdbType.GetTypeId() {
197-
case Ydb.Type_TIMESTAMP:
198-
paramsBuilder = paramsBuilder.Param(placeholder).Timestamp(t)
199-
default:
200-
return fmt.Errorf("unsupported type: %v (%T): %w", arg, arg, common.ErrUnimplementedPredicateType)
201-
}
202-
case *time.Time:
203-
switch params.QueryArgs.Get(i).YdbType.GetOptionalType().GetItem().GetTypeId() {
204-
case Ydb.Type_TIMESTAMP:
205-
paramsBuilder = paramsBuilder.Param(placeholder).BeginOptional().Timestamp(t).EndOptional()
206-
default:
207-
return fmt.Errorf("unsupported type: %v (%T): %w", arg, arg, common.ErrUnimplementedPredicateType)
208-
}
209-
default:
210-
return fmt.Errorf("unsupported type: %v (%T): %w", arg, arg, common.ErrUnimplementedPredicateType)
202+
type result struct {
203+
rows rdbms_utils.Rows
204+
err error
205+
}
206+
207+
// We cannot use the results of a query from outside of the SDK callback.
208+
// See https://github.com/ydb-platform/ydb-go-sdk/issues/1862 for details.
209+
resultChan := make(chan result)
210+
// context coming from the connector's clien (federated YDB)
211+
parentCtx := params.Ctx
212+
213+
go func() {
214+
finalErr := c.driver.Query().Do(
215+
parentCtx,
216+
func(ctx context.Context, session ydb_sdk_query.Session) (err error) {
217+
// modify query with args
218+
queryRewritten, err := c.rewriteQuery(params)
219+
if err != nil {
220+
return fmt.Errorf("rewrite query: %w", err)
211221
}
212-
}
213222

214-
queryLogger := c.queryLoggerFactory.Make(params.Logger, zap.String("resource_pool", c.resourcePool))
215-
queryLogger.Dump(queryRewritten, params.QueryArgs.Values()...)
216-
217-
// execute query
218-
streamResult, err := session.Query(
219-
ctx,
220-
queryRewritten,
221-
ydb_sdk_query.WithParameters(paramsBuilder.Build()),
222-
ydb_sdk_query.WithResourcePool(c.resourcePool),
223-
)
224-
if err != nil {
225-
return fmt.Errorf("session query: %w", err)
226-
}
223+
queryLogger := c.queryLoggerFactory.Make(params.Logger, zap.String("resource_pool", c.resourcePool))
224+
queryLogger.Dump(queryRewritten, params.QueryArgs.Values()...)
225+
226+
// execute query
227+
streamResult, err := session.Query(
228+
ctx,
229+
queryRewritten,
230+
ydb_sdk_query.WithParameters(paramsBuilder.Build()),
231+
ydb_sdk_query.WithResourcePool(c.resourcePool),
232+
)
233+
if err != nil {
234+
return fmt.Errorf("session query: %w", err)
235+
}
227236

228-
// obtain first result set because it's necessary
229-
// to create type transformers
230-
resultSet, err := streamResult.NextResultSet(ctx)
231-
if err != nil {
232-
if closeErr := streamResult.Close(ctx); closeErr != nil {
233-
params.Logger.Error("close stream result", zap.Error(closeErr))
237+
defer func() {
238+
if closeErr := streamResult.Close(ctx); closeErr != nil {
239+
params.Logger.Error("close stream result", zap.Error(closeErr))
240+
}
241+
}()
242+
243+
// obtain first result set because it's necessary
244+
// to create type transformers
245+
resultSet, err := streamResult.NextResultSet(ctx)
246+
if err != nil {
247+
return fmt.Errorf("next result set: %w", err)
234248
}
235249

236-
return fmt.Errorf("next result set: %w", err)
237-
}
250+
rows := &rowsNative{
251+
ctx: parentCtx,
252+
streamResult: streamResult,
253+
lastResultSet: resultSet,
254+
closeChan: make(chan struct{}),
255+
}
238256

239-
rows := &rowsNative{
240-
ctx: c.ctx,
241-
streamResult: streamResult,
242-
lastResultSet: resultSet,
243-
}
257+
// push iterator over GRPC stream into the outer space
258+
select {
259+
case resultChan <- result{rows: rows}:
260+
case <-ctx.Done():
261+
return ctx.Err()
262+
}
244263

245-
select {
246-
case rowsChan <- rows:
247-
return nil
248-
case <-ctx.Done():
249-
if closeErr := streamResult.Close(ctx); closeErr != nil {
250-
params.Logger.Error("close stream result", zap.Error(closeErr))
264+
// Keep waiting until the rowsNative object is closed by a caller.
265+
// The context (and the rowsNative object) will be invalidated otherwise.
266+
select {
267+
case <-rows.closeChan:
268+
return nil
269+
case <-ctx.Done():
270+
return ctx.Err()
251271
}
272+
},
273+
ydb_sdk_query.WithIdempotent(),
274+
)
252275

253-
return ctx.Err()
276+
// If the error is not nil, that means that callback didn't return the result via channel,
277+
// so we need to write the error into the channel here.
278+
if finalErr != nil {
279+
select {
280+
case resultChan <- result{err: fmt.Errorf("query do: %w", finalErr)}:
281+
case <-parentCtx.Done():
254282
}
255-
},
256-
ydb_sdk_query.WithIdempotent(),
257-
)
258-
259-
if finalErr != nil {
260-
return nil, fmt.Errorf("query do: %w", finalErr)
261-
}
283+
}
284+
}()
262285

263286
select {
264-
case rows := <-rowsChan:
265-
return rows, nil
266-
case <-params.Ctx.Done():
267-
return nil, params.Ctx.Err()
287+
case r := <-resultChan:
288+
if r.err != nil {
289+
return nil, r.err
290+
}
291+
292+
return r.rows, nil
293+
case <-parentCtx.Done():
294+
return nil, parentCtx.Err()
268295
}
269296
}
270297

@@ -281,7 +308,10 @@ func (c *connectionNative) TableName() string {
281308
}
282309

283310
func (c *connectionNative) Close() error {
284-
if err := c.driver.Close(c.ctx); err != nil {
311+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
312+
defer cancel()
313+
314+
if err := c.driver.Close(ctx); err != nil {
285315
return fmt.Errorf("driver close: %w", err)
286316
}
287317

@@ -331,7 +361,6 @@ func (c *connectionNative) Logger() *zap.Logger {
331361
}
332362

333363
func newConnectionNative(
334-
ctx context.Context,
335364
logger *zap.Logger,
336365
queryLoggerFactory common.QueryLoggerFactory,
337366
dsi *api_common.TGenericDataSourceInstance,
@@ -341,7 +370,6 @@ func newConnectionNative(
341370
resourcePool string,
342371
) Connection {
343372
return &connectionNative{
344-
ctx: ctx,
345373
driver: driver,
346374
logger: logger,
347375
queryLoggerFactory: queryLoggerFactory,

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ require (
3535
github.com/stretchr/testify v1.10.0
3636
github.com/ydb-platform/ydb-go-genproto v0.0.0-20250911135631-b3beddd517d9
3737
// never update to version v3.113.1 or higher: this will break reading from YDB
38-
github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0
38+
github.com/ydb-platform/ydb-go-sdk/v3 v3.113.1
3939
github.com/ydb-platform/ydb-go-yc v0.11.0
4040
go.mongodb.org/mongo-driver v1.17.1
4141
go.uber.org/atomic v1.11.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ github.com/ydb-platform/ydb-go-genproto v0.0.0-20250911135631-b3beddd517d9/go.mo
349349
github.com/ydb-platform/ydb-go-sdk/v3 v3.25.3/go.mod h1:PFizF/vJsdAgEwjK3DVSBD52kdmRkWfSIS2q2pA+e88=
350350
github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 h1:TwWSp3gRMcja/hRpOofncLvgxAXCmzpz5cGtmdaoITw=
351351
github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0/go.mod h1:l5sSv153E18VvYcsmr51hok9Sjc16tEC8AXGbwrk+ho=
352+
github.com/ydb-platform/ydb-go-sdk/v3 v3.113.1 h1:VRRUtl0JlovbiZOEwqpreVYJNixY7IdgGvEkXRO2mK0=
353+
github.com/ydb-platform/ydb-go-sdk/v3 v3.113.1/go.mod h1:Pp1w2xxUoLQ3NCNAwV7pvDq0TVQOdtAqs+ZiC+i8r14=
352354
github.com/ydb-platform/ydb-go-yc v0.11.0 h1:DwrjZ+yCUqWhhCQOHKk4HnIt1CiWKgVYXKMiDNi5QUY=
353355
github.com/ydb-platform/ydb-go-yc v0.11.0/go.mod h1:uZ5l31+K3rnIeJAi6pzSkEQYT83Ozgxvr3UY/AV1L4w=
354356
github.com/ydb-platform/ydb-go-yc-metadata v0.5.2/go.mod h1:82SQ4L3PewiEmFW4oTMc1sfPjODasIYxD/SKGsbK74s=

tools/ydb/dump_tablet_id_data/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func executeQuery(parentCtx context.Context, logger *zap.Logger, ydbDriver *ydb.
207207
}
208208

209209
// Get the first result set to initialize the iterator
210-
resultSet, err := result.NextResultSet(ctx)
210+
resultSet, err := result.NextResultSet(parentCtx)
211211
if err != nil && !errors.Is(err, io.EOF) {
212212
if closeErr := result.Close(ctx); closeErr != nil {
213213
logger.Error("close stream result", zap.Error(closeErr))

0 commit comments

Comments
 (0)