Skip to content
This repository was archived by the owner on Jan 20, 2026. It is now read-only.

Commit c434f27

Browse files
committed
fix: Move getRecords stop condition checking to top of loop
Added tests to validate coordination of the lease renewal with the record getting and processing routines. This brought to light the value of checking the loop's "stop conditions" early in the loop instead of at the end.
1 parent ec7c0b6 commit c434f27

File tree

2 files changed

+175
-10
lines changed

2 files changed

+175
-10
lines changed

clientlibrary/worker/polling-shard-consumer.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ func (sc *PollingShardConsumer) getRecords() error {
145145
leaseRenewalErrChan <- sc.renewLease(ctx)
146146
}()
147147
for {
148+
select {
149+
case <-*sc.stop:
150+
shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer}
151+
sc.recordProcessor.Shutdown(shutdownInput)
152+
return nil
153+
case leaseRenewalErr := <-leaseRenewalErrChan:
154+
return leaseRenewalErr
155+
default:
156+
}
157+
148158
getRecordsStartTime := time.Now()
149159

150160
log.Debugf("Trying to read %d record from iterator: %v", sc.kclConfig.MaxRecords, aws.ToString(shardIterator))
@@ -226,15 +236,6 @@ func (sc *PollingShardConsumer) getRecords() error {
226236
time.Sleep(time.Duration(sc.kclConfig.IdleTimeBetweenReadsInMillis) * time.Millisecond)
227237
}
228238

229-
select {
230-
case <-*sc.stop:
231-
shutdownInput := &kcl.ShutdownInput{ShutdownReason: kcl.REQUESTED, Checkpointer: recordCheckpointer}
232-
sc.recordProcessor.Shutdown(shutdownInput)
233-
return nil
234-
case leaseRenewalErr := <-leaseRenewalErrChan:
235-
return leaseRenewalErr
236-
default:
237-
}
238239
}
239240
}
240241

clientlibrary/worker/polling-shard-consumer_test.go

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,21 @@ package worker
2222
import (
2323
"context"
2424
"errors"
25+
"sync"
2526
"testing"
2627
"time"
2728

2829
"github.com/aws/aws-sdk-go-v2/aws"
2930
"github.com/aws/aws-sdk-go-v2/service/kinesis"
31+
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
3032
"github.com/stretchr/testify/assert"
3133
"github.com/stretchr/testify/mock"
3234
chk "github.com/vmware/vmware-go-kcl-v2/clientlibrary/checkpoint"
3335
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/config"
36+
kcl "github.com/vmware/vmware-go-kcl-v2/clientlibrary/interfaces"
3437
"github.com/vmware/vmware-go-kcl-v2/clientlibrary/metrics"
3538
par "github.com/vmware/vmware-go-kcl-v2/clientlibrary/partition"
39+
"github.com/vmware/vmware-go-kcl-v2/logger"
3640
)
3741

3842
var (
@@ -199,7 +203,8 @@ func (m *MockKinesisSubscriberGetter) GetRecords(ctx context.Context, params *ki
199203
}
200204

201205
func (m *MockKinesisSubscriberGetter) GetShardIterator(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
202-
return nil, nil
206+
ret := m.Called(ctx, params, optFns)
207+
return ret.Get(0).(*kinesis.GetShardIteratorOutput), ret.Error(1)
203208
}
204209

205210
func (m *MockKinesisSubscriberGetter) SubscribeToShard(ctx context.Context, params *kinesis.SubscribeToShardInput, optFns ...func(*kinesis.Options)) (*kinesis.SubscribeToShardOutput, error) {
@@ -455,6 +460,155 @@ func TestPollingShardConsumer_renewLease(t *testing.T) {
455460
}
456461
}
457462

463+
func TestPollingShardConsumer_getRecordsRenewLease(t *testing.T) {
464+
log := logger.GetDefaultLogger()
465+
type fields struct {
466+
checkpointer chk.Checkpointer
467+
kclConfig *config.KinesisClientLibConfiguration
468+
mService metrics.MonitoringService
469+
}
470+
tests := []struct {
471+
name string
472+
fields fields
473+
474+
// testMillis must be at least 200ms or you'll trigger the localTPSExceededError
475+
testMillis time.Duration
476+
expRenewalCalls int
477+
expRenewals int
478+
shardClosed bool
479+
expErr error
480+
}{
481+
{
482+
"renew once",
483+
fields{
484+
&mockCheckpointer{},
485+
&config.KinesisClientLibConfiguration{
486+
LeaseRefreshWaitTime: 200,
487+
Logger: log,
488+
InitialPositionInStream: config.LATEST,
489+
},
490+
&mockMetrics{},
491+
},
492+
250,
493+
1,
494+
1,
495+
false,
496+
nil,
497+
},
498+
{
499+
"renew some",
500+
fields{
501+
&mockCheckpointer{},
502+
&config.KinesisClientLibConfiguration{
503+
LeaseRefreshWaitTime: 50,
504+
Logger: log,
505+
InitialPositionInStream: config.LATEST,
506+
},
507+
&mockMetrics{},
508+
},
509+
50*5 + 10,
510+
5,
511+
5,
512+
false,
513+
nil,
514+
},
515+
{
516+
"renew twice every 2.5 seconds",
517+
fields{
518+
&mockCheckpointer{},
519+
&config.KinesisClientLibConfiguration{
520+
LeaseRefreshWaitTime: 2500,
521+
Logger: log,
522+
InitialPositionInStream: config.LATEST,
523+
},
524+
&mockMetrics{},
525+
},
526+
5100,
527+
2,
528+
2,
529+
false,
530+
nil,
531+
},
532+
{
533+
"lease error",
534+
fields{
535+
&mockCheckpointer{fail: true},
536+
&config.KinesisClientLibConfiguration{
537+
LeaseRefreshWaitTime: 500,
538+
Logger: log,
539+
InitialPositionInStream: config.LATEST,
540+
},
541+
&mockMetrics{},
542+
},
543+
1100,
544+
1,
545+
0,
546+
false,
547+
getLeaseTestFailure,
548+
},
549+
}
550+
iterator := "test-iterator"
551+
nextIt := "test-next-iterator"
552+
millisBehind := int64(0)
553+
stopChan := make(chan struct{})
554+
for _, tt := range tests {
555+
t.Run(tt.name, func(t *testing.T) {
556+
mk := MockKinesisSubscriberGetter{}
557+
gro := kinesis.GetRecordsOutput{
558+
Records: []types.Record{
559+
{
560+
Data: []byte{},
561+
PartitionKey: new(string),
562+
SequenceNumber: new(string),
563+
ApproximateArrivalTimestamp: &time.Time{},
564+
EncryptionType: "",
565+
},
566+
},
567+
MillisBehindLatest: &millisBehind,
568+
}
569+
if !tt.shardClosed {
570+
gro.NextShardIterator = &nextIt
571+
}
572+
mk.On("GetRecords", mock.Anything, mock.Anything, mock.Anything).Return(&gro, nil)
573+
mk.On("GetShardIterator", mock.Anything, mock.Anything, mock.Anything).Return(&kinesis.GetShardIteratorOutput{ShardIterator: &iterator}, nil)
574+
rp := mockRecordProcessor{
575+
processDurationMillis: tt.testMillis,
576+
}
577+
sc := &PollingShardConsumer{
578+
commonShardConsumer: commonShardConsumer{
579+
shard: &par.ShardStatus{
580+
ID: "test-shard-id",
581+
Mux: &sync.RWMutex{},
582+
},
583+
checkpointer: tt.fields.checkpointer,
584+
kclConfig: tt.fields.kclConfig,
585+
kc: &mk,
586+
recordProcessor: &rp,
587+
mService: tt.fields.mService,
588+
},
589+
stop: &stopChan,
590+
mService: tt.fields.mService,
591+
}
592+
593+
// Send the stop signal a little before the total time it should
594+
// take to get records and process them. This prevents test time
595+
// errors due to the threads running longer than the test case
596+
// expects.
597+
go func() {
598+
time.Sleep((tt.testMillis - 1) * time.Millisecond)
599+
stopChan <- struct{}{}
600+
}()
601+
602+
err := sc.getRecords()
603+
604+
assert.Equal(t, tt.expErr, err)
605+
assert.Equal(t, tt.expRenewalCalls, sc.checkpointer.(*mockCheckpointer).getLeaseCalledTimes)
606+
assert.Equal(t, tt.expRenewals, sc.mService.(*mockMetrics).leaseRenewedCalledTimes)
607+
mk.AssertExpectations(t)
608+
})
609+
}
610+
}
611+
458612
type mockCheckpointer struct {
459613
getLeaseCalledTimes int
460614
fail bool
@@ -478,6 +632,16 @@ func (m mockCheckpointer) ListActiveWorkers(map[string]*par.ShardStatus) (map[st
478632
}
479633
func (m mockCheckpointer) ClaimShard(*par.ShardStatus, string) error { return nil }
480634

635+
type mockRecordProcessor struct {
636+
processDurationMillis time.Duration
637+
}
638+
639+
func (m mockRecordProcessor) Initialize(initializationInput *kcl.InitializationInput) {}
640+
func (m mockRecordProcessor) ProcessRecords(processRecordsInput *kcl.ProcessRecordsInput) {
641+
time.Sleep(time.Millisecond * m.processDurationMillis)
642+
}
643+
func (m mockRecordProcessor) Shutdown(shutdownInput *kcl.ShutdownInput) {}
644+
481645
type mockMetrics struct {
482646
leaseRenewedCalledTimes int
483647
}

0 commit comments

Comments
 (0)