Skip to content

Commit 3f6383b

Browse files
committed
Add concurrent result sets in db.Query.Query(...)
1 parent 8f235f3 commit 3f6383b

File tree

8 files changed

+170
-11
lines changed

8 files changed

+170
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
* Supported `sql.Null*` from `database/sql` as query params in `toValue` func
2+
* Added `WithConcurrentResultSets` option for `db.Query().Query()`
23

34
## v3.118.0
45
* Added support for nullable `Date32`, `Datetime64`, `Timestamp64`, and `Interval64` types in the `optional` parameter builder

internal/query/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,12 @@ func clientQuery(ctx context.Context, pool sessionPool, q string, opts ...option
404404
if err != nil {
405405
return xerrors.WithStackTrace(err)
406406
}
407+
407408
defer func() {
408409
_ = streamResult.Close(ctx)
409410
}()
410411

411-
r, err = resultToMaterializedResult(ctx, streamResult)
412+
r, err = concurrentResultToMaterializedResult(ctx, streamResult)
412413
if err != nil {
413414
return xerrors.WithStackTrace(err)
414415
}

internal/query/client_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,24 @@ func TestClient(t *testing.T) {
922922
Status: Ydb.StatusIds_SUCCESS,
923923
ResultSetIndex: 0,
924924
ResultSet: &Ydb.ResultSet{
925+
Columns: []*Ydb.Column{
926+
{
927+
Name: "a",
928+
Type: &Ydb.Type{
929+
Type: &Ydb.Type_TypeId{
930+
TypeId: Ydb.Type_UINT64,
931+
},
932+
},
933+
},
934+
{
935+
Name: "b",
936+
Type: &Ydb.Type{
937+
Type: &Ydb.Type_TypeId{
938+
TypeId: Ydb.Type_UTF8,
939+
},
940+
},
941+
},
942+
},
925943
Rows: []*Ydb.Value{
926944
{
927945
Items: []*Ydb.Value{{

internal/query/result.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result"
1515
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
1616
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stats"
17+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
1718
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1819
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter"
1920
"github.com/ydb-platform/ydb-go-sdk/v3/query"
@@ -349,6 +350,33 @@ func (r *streamResult) nextPartFunc(
349350
}
350351
}
351352

353+
func (r *streamResult) NextPart(ctx context.Context) (_ result.Part, err error) {
354+
if r.lastPart == nil {
355+
return nil, xerrors.WithStackTrace(io.EOF)
356+
}
357+
358+
select {
359+
case <-r.closer.Done():
360+
return nil, xerrors.WithStackTrace(r.closer.Err())
361+
case <-ctx.Done():
362+
return nil, xerrors.WithStackTrace(ctx.Err())
363+
default:
364+
part, err := r.nextPart(ctx)
365+
if err != nil && !xerrors.Is(err, io.EOF) {
366+
return nil, xerrors.WithStackTrace(err)
367+
}
368+
if part.GetExecStats() != nil && r.statsCallback != nil {
369+
r.statsCallback(stats.FromQueryStats(part.GetExecStats()))
370+
}
371+
defer func() {
372+
r.lastPart = part
373+
r.resultSetIndex = part.GetResultSetIndex()
374+
}()
375+
376+
return newResultPart(r.lastPart), nil
377+
}
378+
}
379+
352380
func (r *streamResult) NextResultSet(ctx context.Context) (_ result.Set, err error) {
353381
if r.trace != nil {
354382
onDone := trace.QueryOnResultNextResultSet(r.trace, &ctx,
@@ -433,11 +461,20 @@ func exactlyOneResultSetFromResult(ctx context.Context, r result.Result) (rs res
433461
return MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows), nil
434462
}
435463

436-
func resultToMaterializedResult(ctx context.Context, r result.Result) (result.Result, error) {
437-
var resultSets []result.Set
464+
func concurrentResultToMaterializedResult(ctx context.Context, r result.ConcurrentResult) (result.Result, error) {
465+
type resultSet struct {
466+
rows []query.Row
467+
columnNames []string
468+
columnTypes []types.Type
469+
}
470+
resultSetByIndex := make(map[int64]resultSet)
438471

439472
for {
440-
rs, err := r.NextResultSet(ctx)
473+
if ctx.Err() != nil {
474+
return nil, xerrors.WithStackTrace(ctx.Err())
475+
}
476+
477+
part, err := r.NextPart(ctx)
441478
if err != nil {
442479
if xerrors.Is(err, io.EOF) {
443480
break
@@ -446,21 +483,32 @@ func resultToMaterializedResult(ctx context.Context, r result.Result) (result.Re
446483
return nil, xerrors.WithStackTrace(err)
447484
}
448485

449-
var rows []query.Row
486+
rs := resultSetByIndex[part.ResultSetIndex()]
487+
if len(rs.columnNames) == 0 {
488+
rs.columnTypes = part.ColumnTypes()
489+
rs.columnNames = part.ColumnNames()
490+
}
491+
492+
rows := make([]query.Row, 0)
450493
for {
451-
row, err := rs.NextRow(ctx)
494+
row, err := part.NextRow(ctx)
452495
if err != nil {
453496
if xerrors.Is(err, io.EOF) {
454497
break
455498
}
456499

457500
return nil, xerrors.WithStackTrace(err)
458501
}
459-
460502
rows = append(rows, row)
461503
}
504+
rs.rows = append(rs.rows, rows...)
505+
506+
resultSetByIndex[part.ResultSetIndex()] = rs
507+
}
462508

463-
resultSets = append(resultSets, MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows))
509+
resultSets := make([]result.Set, len(resultSetByIndex))
510+
for rsIndex, rs := range resultSetByIndex {
511+
resultSets[rsIndex] = MaterializedResultSet(int(rsIndex), rs.columnNames, rs.columnTypes, rs.rows)
464512
}
465513

466514
return &materializedResult{

internal/query/result/result.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ type (
2121
// with Go version 1.23+
2222
ResultSets(ctx context.Context) xiter.Seq2[Set, error]
2323
}
24+
ConcurrentResult interface {
25+
closer.Closer
26+
27+
NextPart(ctx context.Context) (Part, error)
28+
}
2429
Set interface {
2530
Index() int
2631
Columns() []string
@@ -34,6 +39,13 @@ type (
3439
Set
3540
closer.Closer
3641
}
42+
Part interface {
43+
ResultSetIndex() int64
44+
ColumnNames() []string
45+
ColumnTypes() []types.Type
46+
47+
NextRow(ctx context.Context) (Row, error)
48+
}
3749
Row interface {
3850
Values() []value.Value
3951

internal/query/result_part.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package query
2+
3+
import (
4+
"context"
5+
"io"
6+
7+
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
8+
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
9+
10+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
12+
"github.com/ydb-platform/ydb-go-sdk/v3/query"
13+
)
14+
15+
var _ query.Part = (*resultPart)(nil)
16+
17+
type (
18+
resultPart struct {
19+
resultSetIndex int64
20+
columns []*Ydb.Column
21+
rows []*Ydb.Value
22+
columnNames []string
23+
columnTypes []types.Type
24+
rowIndex int
25+
}
26+
)
27+
28+
func (p *resultPart) ResultSetIndex() int64 {
29+
return p.resultSetIndex
30+
}
31+
32+
func (p *resultPart) ColumnNames() []string {
33+
if len(p.columnNames) != 0 {
34+
return p.columnNames
35+
}
36+
names := make([]string, len(p.columns))
37+
for i, col := range p.columns {
38+
names[i] = col.GetName()
39+
}
40+
p.columnNames = names
41+
42+
return names
43+
}
44+
45+
func (p *resultPart) ColumnTypes() []types.Type {
46+
if len(p.columnTypes) != 0 {
47+
return p.columnTypes
48+
}
49+
colTypes := make([]types.Type, len(p.columns))
50+
for i, col := range p.columns {
51+
colTypes[i] = types.TypeFromYDB(col.GetType())
52+
}
53+
p.columnTypes = colTypes
54+
55+
return colTypes
56+
}
57+
58+
func (p *resultPart) NextRow(ctx context.Context) (query.Row, error) {
59+
if p.rowIndex == len(p.rows) {
60+
return nil, xerrors.WithStackTrace(io.EOF)
61+
}
62+
63+
defer func() {
64+
p.rowIndex++
65+
}()
66+
67+
return NewRow(p.columns, p.rows[p.rowIndex]), nil
68+
}
69+
70+
func newResultPart(part *Ydb_Query.ExecuteQueryResponsePart) *resultPart {
71+
return &resultPart{
72+
resultSetIndex: part.GetResultSetIndex(),
73+
columns: part.GetResultSet().GetColumns(),
74+
rows: part.GetResultSet().GetRows(),
75+
rowIndex: 0,
76+
}
77+
}

query/execute_options.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ func WithResponsePartLimitSizeBytes(size int64) ExecuteOption {
7070
return options.WithResponsePartLimitSizeBytes(size)
7171
}
7272

73-
//func WithConcurrentResultSets(isEnabled bool) ExecuteOption {
74-
// return options.WithConcurrentResultSets(isEnabled)
75-
//}
73+
func WithConcurrentResultSets(isEnabled bool) ExecuteOption {
74+
return options.WithConcurrentResultSets(isEnabled)
75+
}
7676

7777
func WithCallOptions(opts ...grpc.CallOption) ExecuteOption {
7878
return options.WithCallOptions(opts...)

query/result.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import (
88

99
type (
1010
Result = result.Result
11+
ConcurrentResult = result.ConcurrentResult
1112
ResultSet = result.Set
1213
ClosableResultSet = result.ClosableResultSet
14+
Part = result.Part
1315
Row = result.Row
1416
Type = types.Type
1517
NamedDestination = scanner.NamedDestination

0 commit comments

Comments
 (0)