Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions internal/impl/aws/input_kinesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,9 @@ func (k *kinesisReader) runBalancedShards() {
for {
for _, info := range k.streams {
allShards, err := collectShards(k.ctx, info.arn, k.svc)
var clientClaims map[string][]awsKinesisClientClaim
var checkpointData *awsKinesisCheckpointData
if err == nil {
clientClaims, err = k.checkpointer.AllClaims(k.ctx, info.id)
checkpointData, err = k.checkpointer.GetCheckpointsAndClaims(k.ctx, info.id)
}
if err != nil {
if k.ctx.Err() != nil {
Expand All @@ -669,11 +669,18 @@ func (k *kinesisReader) runBalancedShards() {
continue
}

clientClaims := checkpointData.ClientClaims
shardsWithCheckpoints := checkpointData.ShardsWithCheckpoints

totalShards := len(allShards)
unclaimedShards := make(map[string]string, totalShards)
for _, s := range allShards {
if !isShardFinished(s) {
unclaimedShards[*s.ShardId] = ""
// Include shard if:
// 1. It's not finished (still open), OR
// 2. It's finished but has a checkpoint (meaning it hasn't been fully consumed yet)
shardID := *s.ShardId
if !isShardFinished(s) || shardsWithCheckpoints[shardID] {
unclaimedShards[shardID] = ""
}
}
for clientID, claims := range clientClaims {
Expand Down
100 changes: 66 additions & 34 deletions internal/impl/aws/input_kinesis_checkpointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,54 +180,86 @@ type awsKinesisClientClaim struct {
LeaseTimeout time.Time
}

// AllClaims returns a map of client IDs to shards claimed by that client,
// including the lease timeout of the claim.
func (k *awsKinesisCheckpointer) AllClaims(ctx context.Context, streamID string) (map[string][]awsKinesisClientClaim, error) {
clientClaims := make(map[string][]awsKinesisClientClaim)
var scanErr error

scanRes, err := k.svc.Scan(ctx, &dynamodb.ScanInput{
TableName: aws.String(k.conf.Table),
FilterExpression: aws.String("StreamID = :stream_id"),
// awsKinesisCheckpointData contains both the set of all shards with checkpoints
// and the map of client claims, retrieved in a single DynamoDB query.
type awsKinesisCheckpointData struct {
// ShardsWithCheckpoints is a set of all shard IDs that have checkpoint records
ShardsWithCheckpoints map[string]bool
// ClientClaims is a map of client IDs to shards claimed by that client
ClientClaims map[string][]awsKinesisClientClaim
}

// GetCheckpointsAndClaims retrieves all checkpoint data for a stream.
//
// Returns:
// - ShardsWithCheckpoints: set of all shard IDs that have checkpoint records
// - ClientClaims: map of client IDs to their claimed shards (excludes entries without ClientID)
func (k *awsKinesisCheckpointer) GetCheckpointsAndClaims(ctx context.Context, streamID string) (*awsKinesisCheckpointData, error) {
result := &awsKinesisCheckpointData{
ShardsWithCheckpoints: make(map[string]bool),
ClientClaims: make(map[string][]awsKinesisClientClaim),
}

input := &dynamodb.QueryInput{
TableName: aws.String(k.conf.Table),
KeyConditionExpression: aws.String("StreamID = :stream_id"),
ExpressionAttributeValues: map[string]types.AttributeValue{
":stream_id": &types.AttributeValueMemberS{
Value: streamID,
},
},
})
if err != nil {
return nil, err
}

for _, i := range scanRes.Items {
var clientID string
if s, ok := i["ClientID"].(*types.AttributeValueMemberS); ok {
clientID = s.Value
} else {
continue
}
paginator := dynamodb.NewQueryPaginator(k.svc, input)

var claim awsKinesisClientClaim
if s, ok := i["ShardID"].(*types.AttributeValueMemberS); ok {
claim.ShardID = s.Value
}
if claim.ShardID == "" {
return nil, errors.New("failed to extract shard id from claim")
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query checkpoints: %w", err)
}

if s, ok := i["LeaseTimeout"].(*types.AttributeValueMemberS); ok {
if claim.LeaseTimeout, scanErr = time.Parse(time.RFC3339Nano, s.Value); scanErr != nil {
return nil, fmt.Errorf("failed to parse claim lease: %w", scanErr)
for _, item := range page.Items {
// Extract ShardID - required for all checkpoint entries
var shardID string
if s, ok := item["ShardID"].(*types.AttributeValueMemberS); ok {
shardID = s.Value
}
if shardID == "" {
continue
}

// Track all shards with checkpoints
result.ShardsWithCheckpoints[shardID] = true

// Extract client claim if ClientID exists
var clientID string
if s, ok := item["ClientID"].(*types.AttributeValueMemberS); ok {
clientID = s.Value
}
if clientID == "" {
// No client ID means this is an orphaned checkpoint (from final=true)
continue
}
}
if claim.LeaseTimeout.IsZero() {
return nil, errors.New("failed to extract lease timeout from claim")
}

clientClaims[clientID] = append(clientClaims[clientID], claim)
// Extract lease timeout for claims
var claim awsKinesisClientClaim
claim.ShardID = shardID

if s, ok := item["LeaseTimeout"].(*types.AttributeValueMemberS); ok {
var parseErr error
if claim.LeaseTimeout, parseErr = time.Parse(time.RFC3339Nano, s.Value); parseErr != nil {
return nil, fmt.Errorf("failed to parse claim lease for shard %s: %w", shardID, parseErr)
}
}
if claim.LeaseTimeout.IsZero() {
return nil, fmt.Errorf("failed to extract lease timeout from claim for shard %s", shardID)
}

result.ClientClaims[clientID] = append(result.ClientClaims[clientID], claim)
}
}

return clientClaims, scanErr
return result, nil
}

// Claim attempts to claim a shard for a particular stream ID. If fromClientID
Expand Down
58 changes: 58 additions & 0 deletions internal/impl/aws/input_kinesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package aws
import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -64,3 +66,59 @@ func TestStreamIDParser(t *testing.T) {
})
}
}

func TestIsShardFinished(t *testing.T) {
tests := []struct {
name string
shard types.Shard
expected bool
}{
{
name: "open shard - no ending sequence",
shard: types.Shard{
ShardId: aws.String("shardId-000000000001"),
SequenceNumberRange: &types.SequenceNumberRange{
StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"),
},
},
expected: false,
},
{
name: "closed shard - has ending sequence",
shard: types.Shard{
ShardId: aws.String("shardId-000000000001"),
SequenceNumberRange: &types.SequenceNumberRange{
StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"),
EndingSequenceNumber: aws.String("49671246667589458717803282320587893555896035326582658"),
},
},
expected: true,
},
{
name: "closed shard - ending sequence is null string",
shard: types.Shard{
ShardId: aws.String("shardId-000000000001"),
SequenceNumberRange: &types.SequenceNumberRange{
StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"),
EndingSequenceNumber: aws.String("null"),
},
},
expected: false,
},
{
name: "shard with no sequence number range",
shard: types.Shard{
ShardId: aws.String("shardId-000000000001"),
},
expected: false,
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
result := isShardFinished(test.shard)
assert.Equal(t, test.expected, result)
})
}
}
Loading