From 44d706ec35b4d1a1d1a9501831095a10eaa69abe Mon Sep 17 00:00:00 2001 From: jbeemster Date: Tue, 27 Jan 2026 17:19:22 +0100 Subject: [PATCH 1/2] fix unclaimed shards not being completed --- internal/impl/aws/input_kinesis.go | 12 +++- .../impl/aws/input_kinesis_checkpointer.go | 27 +++++++++ internal/impl/aws/input_kinesis_test.go | 58 +++++++++++++++++++ 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/internal/impl/aws/input_kinesis.go b/internal/impl/aws/input_kinesis.go index aaa21998de..df42d2a16b 100644 --- a/internal/impl/aws/input_kinesis.go +++ b/internal/impl/aws/input_kinesis.go @@ -658,9 +658,13 @@ func (k *kinesisReader) runBalancedShards() { for _, info := range k.streams { allShards, err := collectShards(k.ctx, info.arn, k.svc) var clientClaims map[string][]awsKinesisClientClaim + var shardsWithCheckpoints map[string]bool if err == nil { clientClaims, err = k.checkpointer.AllClaims(k.ctx, info.id) } + if err == nil { + shardsWithCheckpoints, err = k.checkpointer.AllCheckpoints(k.ctx, info.id) + } if err != nil { if k.ctx.Err() != nil { return @@ -672,8 +676,12 @@ func (k *kinesisReader) runBalancedShards() { totalShards := len(allShards) unclaimedShards := make(map[string]string, totalShards) for _, s := range allShards { - if !isShardFinished(s) { - unclaimedShards[*s.ShardId] = "" + // Include shard if: + // 1. It's not finished (still open), OR + // 2. It's finished but has a checkpoint (meaning it hasn't been fully consumed yet) + shardID := *s.ShardId + if !isShardFinished(s) || shardsWithCheckpoints[shardID] { + unclaimedShards[shardID] = "" } } for clientID, claims := range clientClaims { diff --git a/internal/impl/aws/input_kinesis_checkpointer.go b/internal/impl/aws/input_kinesis_checkpointer.go index 5dcf78cf2b..f836f37eef 100644 --- a/internal/impl/aws/input_kinesis_checkpointer.go +++ b/internal/impl/aws/input_kinesis_checkpointer.go @@ -180,6 +180,33 @@ type awsKinesisClientClaim struct { LeaseTimeout time.Time } +// AllCheckpoints returns a set of all shard IDs that have checkpoint records +// in DynamoDB for the given stream, regardless of whether they are claimed or not. +func (k *awsKinesisCheckpointer) AllCheckpoints(ctx context.Context, streamID string) (map[string]bool, error) { + checkpoints := make(map[string]bool) + + scanRes, err := k.svc.Scan(ctx, &dynamodb.ScanInput{ + TableName: aws.String(k.conf.Table), + FilterExpression: aws.String("StreamID = :stream_id"), + ExpressionAttributeValues: map[string]types.AttributeValue{ + ":stream_id": &types.AttributeValueMemberS{ + Value: streamID, + }, + }, + }) + if err != nil { + return nil, err + } + + for _, i := range scanRes.Items { + if s, ok := i["ShardID"].(*types.AttributeValueMemberS); ok { + checkpoints[s.Value] = true + } + } + + return checkpoints, nil +} + // AllClaims returns a map of client IDs to shards claimed by that client, // including the lease timeout of the claim. func (k *awsKinesisCheckpointer) AllClaims(ctx context.Context, streamID string) (map[string][]awsKinesisClientClaim, error) { diff --git a/internal/impl/aws/input_kinesis_test.go b/internal/impl/aws/input_kinesis_test.go index 4f80a7c9e3..a44c8cc283 100644 --- a/internal/impl/aws/input_kinesis_test.go +++ b/internal/impl/aws/input_kinesis_test.go @@ -3,6 +3,8 @@ package aws import ( "testing" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,3 +66,59 @@ func TestStreamIDParser(t *testing.T) { }) } } + +func TestIsShardFinished(t *testing.T) { + tests := []struct { + name string + shard types.Shard + expected bool + }{ + { + name: "open shard - no ending sequence", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + }, + }, + expected: false, + }, + { + name: "closed shard - has ending sequence", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + EndingSequenceNumber: aws.String("49671246667589458717803282320587893555896035326582658"), + }, + }, + expected: true, + }, + { + name: "closed shard - ending sequence is null string", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + EndingSequenceNumber: aws.String("null"), + }, + }, + expected: false, + }, + { + name: "shard with no sequence number range", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + }, + expected: false, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + result := isShardFinished(test.shard) + assert.Equal(t, test.expected, result) + }) + } +} From d5dda26ff494c96e54257c00247ba5ba1df8dbf7 Mon Sep 17 00:00:00 2001 From: Matus Tomlein Date: Wed, 28 Jan 2026 15:20:18 +0100 Subject: [PATCH 2/2] Use a single query for both retrieving claims and checkpoints from dynamodb and use pagination rather than the scan query --- internal/impl/aws/input_kinesis.go | 11 +- .../impl/aws/input_kinesis_checkpointer.go | 121 +++++++++--------- 2 files changed, 68 insertions(+), 64 deletions(-) diff --git a/internal/impl/aws/input_kinesis.go b/internal/impl/aws/input_kinesis.go index df42d2a16b..e5ef35bc80 100644 --- a/internal/impl/aws/input_kinesis.go +++ b/internal/impl/aws/input_kinesis.go @@ -657,13 +657,9 @@ func (k *kinesisReader) runBalancedShards() { for { for _, info := range k.streams { allShards, err := collectShards(k.ctx, info.arn, k.svc) - var clientClaims map[string][]awsKinesisClientClaim - var shardsWithCheckpoints map[string]bool + var checkpointData *awsKinesisCheckpointData if err == nil { - clientClaims, err = k.checkpointer.AllClaims(k.ctx, info.id) - } - if err == nil { - shardsWithCheckpoints, err = k.checkpointer.AllCheckpoints(k.ctx, info.id) + checkpointData, err = k.checkpointer.GetCheckpointsAndClaims(k.ctx, info.id) } if err != nil { if k.ctx.Err() != nil { @@ -673,6 +669,9 @@ func (k *kinesisReader) runBalancedShards() { continue } + clientClaims := checkpointData.ClientClaims + shardsWithCheckpoints := checkpointData.ShardsWithCheckpoints + totalShards := len(allShards) unclaimedShards := make(map[string]string, totalShards) for _, s := range allShards { diff --git a/internal/impl/aws/input_kinesis_checkpointer.go b/internal/impl/aws/input_kinesis_checkpointer.go index f836f37eef..626b634809 100644 --- a/internal/impl/aws/input_kinesis_checkpointer.go +++ b/internal/impl/aws/input_kinesis_checkpointer.go @@ -180,81 +180,86 @@ type awsKinesisClientClaim struct { LeaseTimeout time.Time } -// AllCheckpoints returns a set of all shard IDs that have checkpoint records -// in DynamoDB for the given stream, regardless of whether they are claimed or not. -func (k *awsKinesisCheckpointer) AllCheckpoints(ctx context.Context, streamID string) (map[string]bool, error) { - checkpoints := make(map[string]bool) - - scanRes, err := k.svc.Scan(ctx, &dynamodb.ScanInput{ - TableName: aws.String(k.conf.Table), - FilterExpression: aws.String("StreamID = :stream_id"), - ExpressionAttributeValues: map[string]types.AttributeValue{ - ":stream_id": &types.AttributeValueMemberS{ - Value: streamID, - }, - }, - }) - if err != nil { - return nil, err - } - - for _, i := range scanRes.Items { - if s, ok := i["ShardID"].(*types.AttributeValueMemberS); ok { - checkpoints[s.Value] = true - } - } - - return checkpoints, nil +// awsKinesisCheckpointData contains both the set of all shards with checkpoints +// and the map of client claims, retrieved in a single DynamoDB query. +type awsKinesisCheckpointData struct { + // ShardsWithCheckpoints is a set of all shard IDs that have checkpoint records + ShardsWithCheckpoints map[string]bool + // ClientClaims is a map of client IDs to shards claimed by that client + ClientClaims map[string][]awsKinesisClientClaim } -// AllClaims returns a map of client IDs to shards claimed by that client, -// including the lease timeout of the claim. -func (k *awsKinesisCheckpointer) AllClaims(ctx context.Context, streamID string) (map[string][]awsKinesisClientClaim, error) { - clientClaims := make(map[string][]awsKinesisClientClaim) - var scanErr error +// GetCheckpointsAndClaims retrieves all checkpoint data for a stream. +// +// Returns: +// - ShardsWithCheckpoints: set of all shard IDs that have checkpoint records +// - ClientClaims: map of client IDs to their claimed shards (excludes entries without ClientID) +func (k *awsKinesisCheckpointer) GetCheckpointsAndClaims(ctx context.Context, streamID string) (*awsKinesisCheckpointData, error) { + result := &awsKinesisCheckpointData{ + ShardsWithCheckpoints: make(map[string]bool), + ClientClaims: make(map[string][]awsKinesisClientClaim), + } - scanRes, err := k.svc.Scan(ctx, &dynamodb.ScanInput{ - TableName: aws.String(k.conf.Table), - FilterExpression: aws.String("StreamID = :stream_id"), + input := &dynamodb.QueryInput{ + TableName: aws.String(k.conf.Table), + KeyConditionExpression: aws.String("StreamID = :stream_id"), ExpressionAttributeValues: map[string]types.AttributeValue{ ":stream_id": &types.AttributeValueMemberS{ Value: streamID, }, }, - }) - if err != nil { - return nil, err } - for _, i := range scanRes.Items { - var clientID string - if s, ok := i["ClientID"].(*types.AttributeValueMemberS); ok { - clientID = s.Value - } else { - continue - } + paginator := dynamodb.NewQueryPaginator(k.svc, input) - var claim awsKinesisClientClaim - if s, ok := i["ShardID"].(*types.AttributeValueMemberS); ok { - claim.ShardID = s.Value - } - if claim.ShardID == "" { - return nil, errors.New("failed to extract shard id from claim") + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query checkpoints: %w", err) } - if s, ok := i["LeaseTimeout"].(*types.AttributeValueMemberS); ok { - if claim.LeaseTimeout, scanErr = time.Parse(time.RFC3339Nano, s.Value); scanErr != nil { - return nil, fmt.Errorf("failed to parse claim lease: %w", scanErr) + for _, item := range page.Items { + // Extract ShardID - required for all checkpoint entries + var shardID string + if s, ok := item["ShardID"].(*types.AttributeValueMemberS); ok { + shardID = s.Value } - } - if claim.LeaseTimeout.IsZero() { - return nil, errors.New("failed to extract lease timeout from claim") - } + if shardID == "" { + continue + } + + // Track all shards with checkpoints + result.ShardsWithCheckpoints[shardID] = true - clientClaims[clientID] = append(clientClaims[clientID], claim) + // Extract client claim if ClientID exists + var clientID string + if s, ok := item["ClientID"].(*types.AttributeValueMemberS); ok { + clientID = s.Value + } + if clientID == "" { + // No client ID means this is an orphaned checkpoint (from final=true) + continue + } + + // Extract lease timeout for claims + var claim awsKinesisClientClaim + claim.ShardID = shardID + + if s, ok := item["LeaseTimeout"].(*types.AttributeValueMemberS); ok { + var parseErr error + if claim.LeaseTimeout, parseErr = time.Parse(time.RFC3339Nano, s.Value); parseErr != nil { + return nil, fmt.Errorf("failed to parse claim lease for shard %s: %w", shardID, parseErr) + } + } + if claim.LeaseTimeout.IsZero() { + return nil, fmt.Errorf("failed to extract lease timeout from claim for shard %s", shardID) + } + + result.ClientClaims[clientID] = append(result.ClientClaims[clientID], claim) + } } - return clientClaims, scanErr + return result, nil } // Claim attempts to claim a shard for a particular stream ID. If fromClientID