@@ -22,17 +22,21 @@ package worker
2222import (
2323 "context"
2424 "errors"
25+ "sync"
2526 "testing"
2627 "time"
2728
2829 "github.com/aws/aws-sdk-go-v2/aws"
2930 "github.com/aws/aws-sdk-go-v2/service/kinesis"
31+ "github.com/aws/aws-sdk-go-v2/service/kinesis/types"
3032 "github.com/stretchr/testify/assert"
3133 "github.com/stretchr/testify/mock"
3234 chk "github.com/vmware/vmware-go-kcl-v2/clientlibrary/checkpoint"
3335 "github.com/vmware/vmware-go-kcl-v2/clientlibrary/config"
36+ kcl "github.com/vmware/vmware-go-kcl-v2/clientlibrary/interfaces"
3437 "github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
3538 par "github.com/vmware/vmware-go-kcl-v2/clientlibrary/partition"
39+ "github.com/vmware/vmware-go-kcl-v2/logger"
3640)
3741
3842var (
@@ -199,7 +203,8 @@ func (m *MockKinesisSubscriberGetter) GetRecords(ctx context.Context, params *ki
199203}
200204
201205func (m * MockKinesisSubscriberGetter ) GetShardIterator (ctx context.Context , params * kinesis.GetShardIteratorInput , optFns ... func (* kinesis.Options )) (* kinesis.GetShardIteratorOutput , error ) {
202- return nil , nil
206+ ret := m .Called (ctx , params , optFns )
207+ return ret .Get (0 ).(* kinesis.GetShardIteratorOutput ), ret .Error (1 )
203208}
204209
205210func (m * MockKinesisSubscriberGetter ) SubscribeToShard (ctx context.Context , params * kinesis.SubscribeToShardInput , optFns ... func (* kinesis.Options )) (* kinesis.SubscribeToShardOutput , error ) {
@@ -455,6 +460,155 @@ func TestPollingShardConsumer_renewLease(t *testing.T) {
455460 }
456461}
457462
463+ func TestPollingShardConsumer_getRecordsRenewLease (t * testing.T ) {
464+ log := logger .GetDefaultLogger ()
465+ type fields struct {
466+ checkpointer chk.Checkpointer
467+ kclConfig * config.KinesisClientLibConfiguration
468+ mService metrics.MonitoringService
469+ }
470+ tests := []struct {
471+ name string
472+ fields fields
473+
474+ // testMillis must be at least 200ms or you'll trigger the localTPSExceededError
475+ testMillis time.Duration
476+ expRenewalCalls int
477+ expRenewals int
478+ shardClosed bool
479+ expErr error
480+ }{
481+ {
482+ "renew once" ,
483+ fields {
484+ & mockCheckpointer {},
485+ & config.KinesisClientLibConfiguration {
486+ LeaseRefreshWaitTime : 200 ,
487+ Logger : log ,
488+ InitialPositionInStream : config .LATEST ,
489+ },
490+ & mockMetrics {},
491+ },
492+ 250 ,
493+ 1 ,
494+ 1 ,
495+ false ,
496+ nil ,
497+ },
498+ {
499+ "renew some" ,
500+ fields {
501+ & mockCheckpointer {},
502+ & config.KinesisClientLibConfiguration {
503+ LeaseRefreshWaitTime : 50 ,
504+ Logger : log ,
505+ InitialPositionInStream : config .LATEST ,
506+ },
507+ & mockMetrics {},
508+ },
509+ 50 * 5 + 10 ,
510+ 5 ,
511+ 5 ,
512+ false ,
513+ nil ,
514+ },
515+ {
516+ "renew twice every 2.5 seconds" ,
517+ fields {
518+ & mockCheckpointer {},
519+ & config.KinesisClientLibConfiguration {
520+ LeaseRefreshWaitTime : 2500 ,
521+ Logger : log ,
522+ InitialPositionInStream : config .LATEST ,
523+ },
524+ & mockMetrics {},
525+ },
526+ 5100 ,
527+ 2 ,
528+ 2 ,
529+ false ,
530+ nil ,
531+ },
532+ {
533+ "lease error" ,
534+ fields {
535+ & mockCheckpointer {fail : true },
536+ & config.KinesisClientLibConfiguration {
537+ LeaseRefreshWaitTime : 500 ,
538+ Logger : log ,
539+ InitialPositionInStream : config .LATEST ,
540+ },
541+ & mockMetrics {},
542+ },
543+ 1100 ,
544+ 1 ,
545+ 0 ,
546+ false ,
547+ getLeaseTestFailure ,
548+ },
549+ }
550+ iterator := "test-iterator"
551+ nextIt := "test-next-iterator"
552+ millisBehind := int64 (0 )
553+ stopChan := make (chan struct {})
554+ for _ , tt := range tests {
555+ t .Run (tt .name , func (t * testing.T ) {
556+ mk := MockKinesisSubscriberGetter {}
557+ gro := kinesis.GetRecordsOutput {
558+ Records : []types.Record {
559+ {
560+ Data : []byte {},
561+ PartitionKey : new (string ),
562+ SequenceNumber : new (string ),
563+ ApproximateArrivalTimestamp : & time.Time {},
564+ EncryptionType : "" ,
565+ },
566+ },
567+ MillisBehindLatest : & millisBehind ,
568+ }
569+ if ! tt .shardClosed {
570+ gro .NextShardIterator = & nextIt
571+ }
572+ mk .On ("GetRecords" , mock .Anything , mock .Anything , mock .Anything ).Return (& gro , nil )
573+ mk .On ("GetShardIterator" , mock .Anything , mock .Anything , mock .Anything ).Return (& kinesis.GetShardIteratorOutput {ShardIterator : & iterator }, nil )
574+ rp := mockRecordProcessor {
575+ processDurationMillis : tt .testMillis ,
576+ }
577+ sc := & PollingShardConsumer {
578+ commonShardConsumer : commonShardConsumer {
579+ shard : & par.ShardStatus {
580+ ID : "test-shard-id" ,
581+ Mux : & sync.RWMutex {},
582+ },
583+ checkpointer : tt .fields .checkpointer ,
584+ kclConfig : tt .fields .kclConfig ,
585+ kc : & mk ,
586+ recordProcessor : & rp ,
587+ mService : tt .fields .mService ,
588+ },
589+ stop : & stopChan ,
590+ mService : tt .fields .mService ,
591+ }
592+
593+ // Send the stop signal a little before the total time it should
594+ // take to get records and process them. This prevents test time
595+ // errors due to the threads running longer than the test case
596+ // expects.
597+ go func () {
598+ time .Sleep ((tt .testMillis - 1 ) * time .Millisecond )
599+ stopChan <- struct {}{}
600+ }()
601+
602+ err := sc .getRecords ()
603+
604+ assert .Equal (t , tt .expErr , err )
605+ assert .Equal (t , tt .expRenewalCalls , sc .checkpointer .(* mockCheckpointer ).getLeaseCalledTimes )
606+ assert .Equal (t , tt .expRenewals , sc .mService .(* mockMetrics ).leaseRenewedCalledTimes )
607+ mk .AssertExpectations (t )
608+ })
609+ }
610+ }
611+
458612type mockCheckpointer struct {
459613 getLeaseCalledTimes int
460614 fail bool
@@ -478,6 +632,16 @@ func (m mockCheckpointer) ListActiveWorkers(map[string]*par.ShardStatus) (map[st
478632}
479633func (m mockCheckpointer ) ClaimShard (* par.ShardStatus , string ) error { return nil }
480634
635+ type mockRecordProcessor struct {
636+ processDurationMillis time.Duration
637+ }
638+
639+ func (m mockRecordProcessor ) Initialize (initializationInput * kcl.InitializationInput ) {}
640+ func (m mockRecordProcessor ) ProcessRecords (processRecordsInput * kcl.ProcessRecordsInput ) {
641+ time .Sleep (time .Millisecond * m .processDurationMillis )
642+ }
643+ func (m mockRecordProcessor ) Shutdown (shutdownInput * kcl.ShutdownInput ) {}
644+
481645type mockMetrics struct {
482646 leaseRenewedCalledTimes int
483647}
0 commit comments