@@ -3,10 +3,14 @@ package scanner
33import (
44 "context"
55 "fmt"
6+ "io"
67 "reflect"
78 "testing"
9+ "time"
810
11+ "github.com/stretchr/testify/require"
912 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
13+ "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_TableStats"
1014
1115 "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator"
1216 "github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
@@ -199,3 +203,75 @@ func NewResultSet(a *allocator.Allocator, opts ...ResultSetOption) *Ydb.ResultSe
199203 }
200204 return (* Ydb .ResultSet )(& d )
201205}
206+
207+ func TestNewStreamWithRecvFirstResultSet (t * testing.T ) {
208+ for _ , tt := range []struct {
209+ ctx context.Context
210+ recvCounter int
211+ err error
212+ }{
213+ {
214+ ctx : context .Background (),
215+ err : nil ,
216+ },
217+ {
218+ ctx : func () context.Context {
219+ ctx , cancel := context .WithCancel (context .Background ())
220+ cancel ()
221+ return ctx
222+ }(),
223+ err : context .Canceled ,
224+ },
225+ {
226+ ctx : func () context.Context {
227+ ctx , cancel := context .WithTimeout (context .Background (), 0 )
228+ cancel ()
229+ return ctx
230+ }(),
231+ err : context .DeadlineExceeded ,
232+ },
233+ {
234+ ctx : func () context.Context {
235+ ctx , cancel := context .WithTimeout (context .Background (), time .Hour )
236+ cancel ()
237+ return ctx
238+ }(),
239+ err : context .Canceled ,
240+ },
241+ } {
242+ t .Run ("" , func (t * testing.T ) {
243+ result , err := NewStream (tt .ctx ,
244+ func (ctx context.Context ) (* Ydb.ResultSet , * Ydb_TableStats.QueryStats , error ) {
245+ tt .recvCounter ++
246+ if tt .recvCounter > 1000 {
247+ return nil , nil , io .EOF
248+ }
249+ return & Ydb.ResultSet {}, nil , ctx .Err ()
250+ },
251+ func (err error ) error {
252+ return err
253+ },
254+ )
255+ if tt .err != nil {
256+ require .ErrorIs (t , err , tt .err )
257+ require .Nil (t , result )
258+ } else {
259+ require .NoError (t , err )
260+ require .NotNil (t , result )
261+ require .EqualValues (t , 1 , tt .recvCounter )
262+ require .EqualValues (t , 1 , result .(* streamResult ).nextResultSetCounter .Load ())
263+ for i := range make ([]struct {}, 1000 ) {
264+ err = result .NextResultSetErr (tt .ctx )
265+ require .NoError (t , err )
266+ require .Equal (t , i + 1 , tt .recvCounter )
267+ require .Equal (t , i + 2 , int (result .(* streamResult ).nextResultSetCounter .Load ()))
268+ }
269+ err = result .NextResultSetErr (tt .ctx )
270+ require .ErrorIs (t , err , io .EOF )
271+ require .True (t , err == io .EOF ) //nolint:errorlint
272+ require .Equal (t , 1001 , tt .recvCounter )
273+ require .Equal (t , 1002 , int (result .(* streamResult ).nextResultSetCounter .Load ()))
274+ }
275+ })
276+ }
277+ }
0 commit comments