diff --git a/internal/impl/aws/input_kinesis.go b/internal/impl/aws/input_kinesis.go index aaa21998de..e5ef35bc80 100644 --- a/internal/impl/aws/input_kinesis.go +++ b/internal/impl/aws/input_kinesis.go @@ -657,9 +657,9 @@ func (k *kinesisReader) runBalancedShards() { for { for _, info := range k.streams { allShards, err := collectShards(k.ctx, info.arn, k.svc) - var clientClaims map[string][]awsKinesisClientClaim + var checkpointData *awsKinesisCheckpointData if err == nil { - clientClaims, err = k.checkpointer.AllClaims(k.ctx, info.id) + checkpointData, err = k.checkpointer.GetCheckpointsAndClaims(k.ctx, info.id) } if err != nil { if k.ctx.Err() != nil { @@ -669,11 +669,18 @@ func (k *kinesisReader) runBalancedShards() { continue } + clientClaims := checkpointData.ClientClaims + shardsWithCheckpoints := checkpointData.ShardsWithCheckpoints + totalShards := len(allShards) unclaimedShards := make(map[string]string, totalShards) for _, s := range allShards { - if !isShardFinished(s) { - unclaimedShards[*s.ShardId] = "" + // Include shard if: + // 1. It's not finished (still open), OR + // 2. It's finished but has a checkpoint (meaning it hasn't been fully consumed yet) + shardID := *s.ShardId + if !isShardFinished(s) || shardsWithCheckpoints[shardID] { + unclaimedShards[shardID] = "" } } for clientID, claims := range clientClaims { diff --git a/internal/impl/aws/input_kinesis_checkpointer.go b/internal/impl/aws/input_kinesis_checkpointer.go index 5dcf78cf2b..626b634809 100644 --- a/internal/impl/aws/input_kinesis_checkpointer.go +++ b/internal/impl/aws/input_kinesis_checkpointer.go @@ -180,54 +180,86 @@ type awsKinesisClientClaim struct { LeaseTimeout time.Time } -// AllClaims returns a map of client IDs to shards claimed by that client, -// including the lease timeout of the claim. -func (k *awsKinesisCheckpointer) AllClaims(ctx context.Context, streamID string) (map[string][]awsKinesisClientClaim, error) { - clientClaims := make(map[string][]awsKinesisClientClaim) - var scanErr error - - scanRes, err := k.svc.Scan(ctx, &dynamodb.ScanInput{ - TableName: aws.String(k.conf.Table), - FilterExpression: aws.String("StreamID = :stream_id"), +// awsKinesisCheckpointData contains both the set of all shards with checkpoints +// and the map of client claims, retrieved in a single DynamoDB query. +type awsKinesisCheckpointData struct { + // ShardsWithCheckpoints is a set of all shard IDs that have checkpoint records + ShardsWithCheckpoints map[string]bool + // ClientClaims is a map of client IDs to shards claimed by that client + ClientClaims map[string][]awsKinesisClientClaim +} + +// GetCheckpointsAndClaims retrieves all checkpoint data for a stream. +// +// Returns: +// - ShardsWithCheckpoints: set of all shard IDs that have checkpoint records +// - ClientClaims: map of client IDs to their claimed shards (excludes entries without ClientID) +func (k *awsKinesisCheckpointer) GetCheckpointsAndClaims(ctx context.Context, streamID string) (*awsKinesisCheckpointData, error) { + result := &awsKinesisCheckpointData{ + ShardsWithCheckpoints: make(map[string]bool), + ClientClaims: make(map[string][]awsKinesisClientClaim), + } + + input := &dynamodb.QueryInput{ + TableName: aws.String(k.conf.Table), + KeyConditionExpression: aws.String("StreamID = :stream_id"), ExpressionAttributeValues: map[string]types.AttributeValue{ ":stream_id": &types.AttributeValueMemberS{ Value: streamID, }, }, - }) - if err != nil { - return nil, err } - for _, i := range scanRes.Items { - var clientID string - if s, ok := i["ClientID"].(*types.AttributeValueMemberS); ok { - clientID = s.Value - } else { - continue - } + paginator := dynamodb.NewQueryPaginator(k.svc, input) - var claim awsKinesisClientClaim - if s, ok := i["ShardID"].(*types.AttributeValueMemberS); ok { - claim.ShardID = s.Value - } - if claim.ShardID == "" { - return nil, errors.New("failed to extract shard id from claim") + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query checkpoints: %w", err) } - if s, ok := i["LeaseTimeout"].(*types.AttributeValueMemberS); ok { - if claim.LeaseTimeout, scanErr = time.Parse(time.RFC3339Nano, s.Value); scanErr != nil { - return nil, fmt.Errorf("failed to parse claim lease: %w", scanErr) + for _, item := range page.Items { + // Extract ShardID - required for all checkpoint entries + var shardID string + if s, ok := item["ShardID"].(*types.AttributeValueMemberS); ok { + shardID = s.Value + } + if shardID == "" { + continue + } + + // Track all shards with checkpoints + result.ShardsWithCheckpoints[shardID] = true + + // Extract client claim if ClientID exists + var clientID string + if s, ok := item["ClientID"].(*types.AttributeValueMemberS); ok { + clientID = s.Value + } + if clientID == "" { + // No client ID means this is an orphaned checkpoint (from final=true) + continue } - } - if claim.LeaseTimeout.IsZero() { - return nil, errors.New("failed to extract lease timeout from claim") - } - clientClaims[clientID] = append(clientClaims[clientID], claim) + // Extract lease timeout for claims + var claim awsKinesisClientClaim + claim.ShardID = shardID + + if s, ok := item["LeaseTimeout"].(*types.AttributeValueMemberS); ok { + var parseErr error + if claim.LeaseTimeout, parseErr = time.Parse(time.RFC3339Nano, s.Value); parseErr != nil { + return nil, fmt.Errorf("failed to parse claim lease for shard %s: %w", shardID, parseErr) + } + } + if claim.LeaseTimeout.IsZero() { + return nil, fmt.Errorf("failed to extract lease timeout from claim for shard %s", shardID) + } + + result.ClientClaims[clientID] = append(result.ClientClaims[clientID], claim) + } } - return clientClaims, scanErr + return result, nil } // Claim attempts to claim a shard for a particular stream ID. If fromClientID diff --git a/internal/impl/aws/input_kinesis_test.go b/internal/impl/aws/input_kinesis_test.go index 4f80a7c9e3..a44c8cc283 100644 --- a/internal/impl/aws/input_kinesis_test.go +++ b/internal/impl/aws/input_kinesis_test.go @@ -3,6 +3,8 @@ package aws import ( "testing" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,3 +66,59 @@ func TestStreamIDParser(t *testing.T) { }) } } + +func TestIsShardFinished(t *testing.T) { + tests := []struct { + name string + shard types.Shard + expected bool + }{ + { + name: "open shard - no ending sequence", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + }, + }, + expected: false, + }, + { + name: "closed shard - has ending sequence", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + EndingSequenceNumber: aws.String("49671246667589458717803282320587893555896035326582658"), + }, + }, + expected: true, + }, + { + name: "closed shard - ending sequence is null string", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + EndingSequenceNumber: aws.String("null"), + }, + }, + expected: false, + }, + { + name: "shard with no sequence number range", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + }, + expected: false, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + result := isShardFinished(test.shard) + assert.Equal(t, test.expected, result) + }) + } +}