Skip to content

Commit 4db8483

Browse files
committed
Add concurrent result sets in db.Query.Query(...)
1 parent 18cd427 commit 4db8483

File tree

4 files changed

+74
-8
lines changed

4 files changed

+74
-8
lines changed

internal/query/client.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,23 @@ func clientQuery(ctx context.Context, pool sessionPool, q string, opts ...option
397397
) {
398398
settings := options.ExecuteSettings(opts...)
399399
err = do(ctx, pool, func(ctx context.Context, s *Session) (err error) {
400-
streamResult, err := s.execute(ctx, q, options.ExecuteSettings(opts...), withStreamResultTrace(s.trace))
400+
request, callOptions, err := executeQueryRequest(s.ID(), q, settings)
401401
if err != nil {
402402
return xerrors.WithStackTrace(err)
403403
}
404+
request.ConcurrentResultSets = settings.ConcurrentResultSets()
405+
406+
executeCtx, executeCancel := xcontext.WithCancel(xcontext.ValueOnly(ctx))
404407
defer func() {
405-
_ = streamResult.Close(ctx)
408+
executeCancel()
406409
}()
407410

408-
r, err = resultToMaterializedResult(ctx, streamResult)
411+
stream, err := s.client.ExecuteQuery(executeCtx, request, callOptions...)
412+
if err != nil {
413+
return xerrors.WithStackTrace(err)
414+
}
415+
416+
r, err = streamToMaterializedResult(ctx, stream)
409417
if err != nil {
410418
return xerrors.WithStackTrace(err)
411419
}

internal/query/execute_query.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func executeQueryRequest(sessionID, q string, cfg executeSettings) (
9393
},
9494
Parameters: params,
9595
StatsMode: Ydb_Query.StatsMode(cfg.StatsMode()),
96-
ConcurrentResultSets: cfg.ConcurrentResultSets(),
96+
ConcurrentResultSets: false,
9797
PoolId: cfg.ResourcePool(),
9898
ResponsePartLimitBytes: cfg.ResponsePartLimitSizeBytes(),
9999
}

internal/query/result.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ import (
88
"time"
99

1010
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
11+
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
1112
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
12-
1313
"github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result"
1414
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
1515
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stats"
16+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
1617
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1718
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter"
1819
"github.com/ydb-platform/ydb-go-sdk/v3/query"
@@ -459,3 +460,60 @@ func resultToMaterializedResult(ctx context.Context, r result.Result) (result.Re
459460
resultSets: resultSets,
460461
}, nil
461462
}
463+
464+
func streamToMaterializedResult(ctx context.Context, stream Ydb_Query_V1.QueryService_ExecuteQueryClient) (result.Result, error) {
465+
type resultSet struct {
466+
rows []query.Row
467+
rawColumns []*Ydb.Column
468+
columnNames []string
469+
columnTypes []types.Type
470+
}
471+
resultSetByIndex := make(map[int64]resultSet)
472+
473+
for {
474+
if ctx.Err() != nil {
475+
return nil, xerrors.WithStackTrace(ctx.Err())
476+
}
477+
478+
part, err := stream.Recv()
479+
if err != nil {
480+
if xerrors.Is(err, io.EOF) {
481+
break
482+
}
483+
return nil, xerrors.WithStackTrace(err)
484+
}
485+
486+
if part.GetResultSetIndex() < 0 {
487+
break
488+
}
489+
490+
rs := resultSetByIndex[part.GetResultSetIndex()]
491+
if len(rs.columnNames) == 0 {
492+
rs.rawColumns = part.GetResultSet().GetColumns()
493+
rs.columnNames = make([]string, 0, len(rs.rawColumns))
494+
rs.columnTypes = make([]types.Type, 0, len(rs.rawColumns))
495+
496+
for _, column := range rs.rawColumns {
497+
rs.columnNames = append(rs.columnNames, column.GetName())
498+
rs.columnTypes = append(rs.columnTypes, types.TypeFromYDB(column.GetType()))
499+
}
500+
}
501+
502+
rows := make([]query.Row, 0, len(part.GetResultSet().GetRows()))
503+
for _, row := range part.GetResultSet().GetRows() {
504+
rows = append(rows, NewRow(rs.rawColumns, row))
505+
}
506+
rs.rows = append(rs.rows, rows...)
507+
508+
resultSetByIndex[part.GetResultSetIndex()] = rs
509+
}
510+
511+
resultSets := make([]result.Set, len(resultSetByIndex))
512+
for rsIndex, rs := range resultSetByIndex {
513+
resultSets[rsIndex] = MaterializedResultSet(int(rsIndex), rs.columnNames, rs.columnTypes, rs.rows)
514+
}
515+
516+
return &materializedResult{
517+
resultSets: resultSets,
518+
}, nil
519+
}

query/execute_options.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ func WithResponsePartLimitSizeBytes(size int64) ExecuteOption {
6363
return options.WithResponsePartLimitSizeBytes(size)
6464
}
6565

66-
//func WithConcurrentResultSets(isEnabled bool) ExecuteOption {
67-
// return options.WithConcurrentResultSets(isEnabled)
68-
//}
66+
func WithConcurrentResultSets(isEnabled bool) ExecuteOption {
67+
return options.WithConcurrentResultSets(isEnabled)
68+
}
6969

7070
func WithCallOptions(opts ...grpc.CallOption) ExecuteOption {
7171
return options.WithCallOptions(opts...)

0 commit comments

Comments
 (0)