diff --git a/.gitignore b/.gitignore index e0b87691b8..fbeeb58672 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ release_notes.md .task .vscode .op +.gomodcache __pycache__ diff --git a/docs/modules/components/pages/inputs/aws_dynamodb_cdc.adoc b/docs/modules/components/pages/inputs/aws_dynamodb_cdc.adoc new file mode 100644 index 0000000000..77fa0383da --- /dev/null +++ b/docs/modules/components/pages/inputs/aws_dynamodb_cdc.adoc @@ -0,0 +1,329 @@ += aws_dynamodb_cdc +:type: input +:status: beta +:categories: ["Services","AWS"] + + + +//// + THIS FILE IS AUTOGENERATED! + + To make changes, edit the corresponding source file under: + + https://github.com/redpanda-data/connect/tree/main/internal/impl/. + + And: + + https://github.com/redpanda-data/connect/tree/main/cmd/tools/docs_gen/templates/plugin.adoc.tmpl +//// + +// © 2024 Redpanda Data Inc. + + +component_type_dropdown::[] + + +Reads change data capture (CDC) events from DynamoDB Streams + +Introduced in version 1.0.0. + + +[tabs] +====== +Common:: ++ +-- + +```yml +# Common config fields, showing default values +input: + label: "" + aws_dynamodb_cdc: + table: "" # No default (required) + checkpoint_table: redpanda_dynamodb_checkpoints + start_from: trim_horizon +``` + +-- +Advanced:: ++ +-- + +```yml +# All config fields, showing default values +input: + label: "" + aws_dynamodb_cdc: + table: "" # No default (required) + checkpoint_table: redpanda_dynamodb_checkpoints + batch_size: 1000 + poll_interval: 1s + start_from: trim_horizon + checkpoint_limit: 1000 + max_tracked_shards: 10000 + region: "" # No default (optional) + endpoint: "" # No default (optional) + tcp: + connect_timeout: 0s + keep_alive: + idle: 15s + interval: 15s + count: 9 + tcp_user_timeout: 0s + credentials: + profile: "" # No default (optional) + id: "" # No default (optional) + secret: "" # No default (optional) + token: "" # No default (optional) + from_ec2_role: false # No default (optional) + role: "" # No default (optional) + role_external_id: "" # No default (optional) +``` + +-- +====== + +Consumes records from DynamoDB Streams with automatic checkpointing and shard management. + +DynamoDB Streams capture item-level changes in DynamoDB tables. This input supports: +- Automatic shard discovery and management +- Checkpoint-based resumption after crashes +- Multiple shard processing + +For better performance and longer retention, consider using Kinesis Data Streams for DynamoDB +with the `aws_kinesis` input instead. + +== Metadata + +This input adds the following metadata fields to each message: + +- `dynamodb_shard_id` - The shard ID from which the record was read +- `dynamodb_sequence_number` - The sequence number of the record in the stream +- `dynamodb_event_name` - The type of change: INSERT, MODIFY, or REMOVE +- `dynamodb_table` - The name of the DynamoDB table + +== Metrics + +This input exposes the following metrics: + +- `dynamodb_cdc_shards_tracked` - Total number of shards being tracked (gauge) +- `dynamodb_cdc_shards_active` - Number of active shards currently being read from (gauge) + + +== Fields + +=== `table` + +The name of the DynamoDB table to read streams from. + + +*Type*: `string` + + +=== `checkpoint_table` + +DynamoDB table name for storing checkpoints. Will be created if it doesn't exist. + + +*Type*: `string` + +*Default*: `"redpanda_dynamodb_checkpoints"` + +=== `batch_size` + +Maximum number of records to read in a single batch. + + +*Type*: `int` + +*Default*: `1000` + +=== `poll_interval` + +Time to wait between polling attempts when no records are available. + + +*Type*: `string` + +*Default*: `"1s"` + +=== `start_from` + +Where to start reading when no checkpoint exists. `trim_horizon` starts from the oldest available record, `latest` starts from new records. + + +*Type*: `string` + +*Default*: `"trim_horizon"` + +Options: +`trim_horizon` +, `latest` +. + +=== `checkpoint_limit` + +Maximum number of messages to process before updating checkpoint. + + +*Type*: `int` + +*Default*: `1000` + +=== `max_tracked_shards` + +Maximum number of shards to track simultaneously. Prevents memory issues with extremely large tables. + + +*Type*: `int` + +*Default*: `10000` + +=== `region` + +The AWS region to target. + + +*Type*: `string` + + +=== `endpoint` + +Allows you to specify a custom endpoint for the AWS API. + + +*Type*: `string` + + +=== `tcp` + +TCP socket configuration. + + +*Type*: `object` + + +=== `tcp.connect_timeout` + +Maximum amount of time a dial will wait for a connect to complete. Zero disables. + + +*Type*: `string` + +*Default*: `"0s"` + +=== `tcp.keep_alive` + +TCP keep-alive probe configuration. + + +*Type*: `object` + + +=== `tcp.keep_alive.idle` + +Duration the connection must be idle before sending the first keep-alive probe. Zero defaults to 15s. Negative values disable keep-alive probes. + + +*Type*: `string` + +*Default*: `"15s"` + +=== `tcp.keep_alive.interval` + +Duration between keep-alive probes. Zero defaults to 15s. + + +*Type*: `string` + +*Default*: `"15s"` + +=== `tcp.keep_alive.count` + +Maximum unanswered keep-alive probes before dropping the connection. Zero defaults to 9. + + +*Type*: `int` + +*Default*: `9` + +=== `tcp.tcp_user_timeout` + +Maximum time to wait for acknowledgment of transmitted data before killing the connection. Linux-only (kernel 2.6.37+), ignored on other platforms. When enabled, keep_alive.idle must be greater than this value per RFC 5482. Zero disables. + + +*Type*: `string` + +*Default*: `"0s"` + +=== `credentials` + +Optional manual configuration of AWS credentials to use. More information can be found in xref:guides:cloud/aws.adoc[]. + + +*Type*: `object` + + +=== `credentials.profile` + +A profile from `~/.aws/credentials` to use. + + +*Type*: `string` + + +=== `credentials.id` + +The ID of credentials to use. + + +*Type*: `string` + + +=== `credentials.secret` + +The secret for the credentials being used. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + + +=== `credentials.token` + +The token for the credentials being used, required when using short term credentials. + + +*Type*: `string` + + +=== `credentials.from_ec2_role` + +Use the credentials of a host EC2 machine configured to assume https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_switch-role-ec2.html[an IAM role associated with the instance^]. + + +*Type*: `bool` + +Requires version 4.2.0 or newer + +=== `credentials.role` + +A role ARN to assume. + + +*Type*: `string` + + +=== `credentials.role_external_id` + +An external ID to provide when assuming a role. + + +*Type*: `string` + + + diff --git a/go.mod b/go.mod index bd6a76b999..79f41a0e57 100644 --- a/go.mod +++ b/go.mod @@ -338,7 +338,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.9 // indirect - github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.31.0 github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.0 // indirect github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.9 // indirect diff --git a/internal/impl/aws/input_dynamodb_cdc.go b/internal/impl/aws/input_dynamodb_cdc.go new file mode 100644 index 0000000000..75cc65e25b --- /dev/null +++ b/internal/impl/aws/input_dynamodb_cdc.go @@ -0,0 +1,751 @@ +// Copyright 2026 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package aws + +import ( + "context" + "errors" + "fmt" + "maps" + "sync" + "sync/atomic" + "time" + + "github.com/Jeffail/shutdown" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams" + "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" + "github.com/cenkalti/backoff/v4" + + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/aws/config" +) + +const ( + defaultDynamoDBBatchSize = 1000 // AWS max limit + defaultDynamoDBPollInterval = "1s" + defaultShutdownTimeout = 10 * time.Second + + // Metrics + metricShardsTracked = "dynamodb_cdc_shards_tracked" + metricShardsActive = "dynamodb_cdc_shards_active" +) + +func dynamoDBCDCInputConfig() *service.ConfigSpec { + return service.NewConfigSpec(). + Beta(). + Version("1.0.0"). + Categories("Services", "AWS"). + Summary("Reads change data capture (CDC) events from DynamoDB Streams"). + Description(` +Consumes records from DynamoDB Streams with automatic checkpointing and shard management. + +DynamoDB Streams capture item-level changes in DynamoDB tables. This input supports: +- Automatic shard discovery and management +- Checkpoint-based resumption after crashes +- Multiple shard processing + +For better performance and longer retention, consider using Kinesis Data Streams for DynamoDB +with the `+"`aws_kinesis`"+` input instead. + +== Metadata + +This input adds the following metadata fields to each message: + +- `+"`dynamodb_shard_id`"+` - The shard ID from which the record was read +- `+"`dynamodb_sequence_number`"+` - The sequence number of the record in the stream +- `+"`dynamodb_event_name`"+` - The type of change: INSERT, MODIFY, or REMOVE +- `+"`dynamodb_table`"+` - The name of the DynamoDB table + +== Metrics + +This input exposes the following metrics: + +- `+"`dynamodb_cdc_shards_tracked`"+` - Total number of shards being tracked (gauge) +- `+"`dynamodb_cdc_shards_active`"+` - Number of active shards currently being read from (gauge) +`). + Fields( + service.NewStringField("table"). + Description("The name of the DynamoDB table to read streams from."). + LintRule(`root = if this == "" { ["table name cannot be empty"] }`), + service.NewStringField("checkpoint_table"). + Description("DynamoDB table name for storing checkpoints. Will be created if it doesn't exist."). + Default("redpanda_dynamodb_checkpoints"), + service.NewIntField("batch_size"). + Description("Maximum number of records to read in a single batch."). + Default(defaultDynamoDBBatchSize). + Advanced(), + service.NewDurationField("poll_interval"). + Description("Time to wait between polling attempts when no records are available."). + Default(defaultDynamoDBPollInterval). + Advanced(), + service.NewStringEnumField("start_from", "trim_horizon", "latest"). + Description("Where to start reading when no checkpoint exists. `trim_horizon` starts from the oldest available record, `latest` starts from new records."). + Default("trim_horizon"), + service.NewIntField("checkpoint_limit"). + Description("Maximum number of messages to process before updating checkpoint."). + Default(1000). + Advanced(), + service.NewIntField("max_tracked_shards"). + Description("Maximum number of shards to track simultaneously. Prevents memory issues with extremely large tables."). + Default(10000). + Advanced(), + ). + Fields(config.SessionFields()...) +} + +func init() { + err := service.RegisterBatchInput( + "aws_dynamodb_cdc", dynamoDBCDCInputConfig(), + func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchInput, error) { + return newDynamoDBCDCInputFromConfig(conf, mgr) + }) + if err != nil { + panic(err) + } +} + +type dynamoDBCDCConfig struct { + table string + checkpointTable string + batchSize int + pollInterval time.Duration + startFrom string + checkpointLimit int + maxTrackedShards int +} + +type dynamoDBCDCInput struct { + conf dynamoDBCDCConfig + + awsConf aws.Config + dynamoClient *dynamodb.Client + streamsClient *dynamodbstreams.Client + streamArn *string + + checkpointer *dynamoDBCDCCheckpointer + recordBatcher *dynamoDBCDCRecordBatcher + + // Channel-based batch delivery + msgChan chan asyncMessage + shutSig *shutdown.Signaller + + // Shard management + shardReaders map[string]*dynamoDBShardReader + mu sync.RWMutex // Changed to RWMutex for better performance + + // Pending acknowledgments tracking + pendingAcks sync.WaitGroup + closed atomic.Bool + + log *service.Logger + + // Metrics + shardsTrackedMetric *service.MetricGauge + shardsActiveMetric *service.MetricGauge +} + +type dynamoDBShardReader struct { + shardID string + iterator *string + exhausted bool +} + +func dynamoCDCInputConfigFromParsed(pConf *service.ParsedConfig) (conf dynamoDBCDCConfig, err error) { + if conf.table, err = pConf.FieldString("table"); err != nil { + return + } + if conf.checkpointTable, err = pConf.FieldString("checkpoint_table"); err != nil { + return + } + if conf.batchSize, err = pConf.FieldInt("batch_size"); err != nil { + return + } + if conf.pollInterval, err = pConf.FieldDuration("poll_interval"); err != nil { + return + } + if conf.startFrom, err = pConf.FieldString("start_from"); err != nil { + return + } + if conf.checkpointLimit, err = pConf.FieldInt("checkpoint_limit"); err != nil { + return + } + if conf.maxTrackedShards, err = pConf.FieldInt("max_tracked_shards"); err != nil { + return + } + return +} + +func newDynamoDBCDCInputFromConfig(pConf *service.ParsedConfig, mgr *service.Resources) (*dynamoDBCDCInput, error) { + conf, err := dynamoCDCInputConfigFromParsed(pConf) + if err != nil { + return nil, err + } + + awsConf, err := GetSession(context.Background(), pConf) + if err != nil { + return nil, err + } + + return &dynamoDBCDCInput{ + conf: conf, + awsConf: awsConf, + shardReaders: make(map[string]*dynamoDBShardReader), + log: mgr.Logger(), + shardsTrackedMetric: mgr.Metrics().NewGauge(metricShardsTracked), + shardsActiveMetric: mgr.Metrics().NewGauge(metricShardsActive), + }, nil +} + +func (d *dynamoDBCDCInput) Connect(ctx context.Context) error { + d.dynamoClient = dynamodb.NewFromConfig(d.awsConf) + d.streamsClient = dynamodbstreams.NewFromConfig(d.awsConf) + + // Get stream ARN + descTable, err := d.dynamoClient.DescribeTable(ctx, &dynamodb.DescribeTableInput{ + TableName: &d.conf.table, + }) + if err != nil { + var aerr *types.ResourceNotFoundException + if errors.As(err, &aerr) { + return fmt.Errorf("table %s does not exist", d.conf.table) + } + return fmt.Errorf("failed to describe table %s: %w", d.conf.table, err) + } + + d.streamArn = descTable.Table.LatestStreamArn + if d.streamArn == nil { + return fmt.Errorf("no stream enabled on table %s", d.conf.table) + } + + // Initialize checkpointer + d.checkpointer, err = newDynamoDBCDCCheckpointer(ctx, d.dynamoClient, d.conf.checkpointTable, *d.streamArn, d.conf.checkpointLimit, d.log) + if err != nil { + return fmt.Errorf("failed to create checkpointer: %w", err) + } + + // Initialize record batcher + d.recordBatcher = newDynamoDBCDCRecordBatcher(d.conf.maxTrackedShards, d.conf.checkpointLimit, d.log) + + // Initialize channel and shutdown signaller + d.msgChan = make(chan asyncMessage) + d.shutSig = shutdown.NewSignaller() + + d.log.Infof("Connected to DynamoDB stream: %s", *d.streamArn) + + // Initialize shards + if err := d.refreshShards(ctx); err != nil { + return fmt.Errorf("failed to initialize shards: %w", err) + } + + // Verify at least one shard reader started successfully + d.mu.Lock() + activeCount := len(d.shardReaders) + d.mu.Unlock() + + if activeCount == 0 { + return errors.New("no active shard readers available - stream may have no shards or all failed to initialize") + } + + // Start background goroutine to coordinate shard readers + go d.runShardCoordinator() + + return nil +} + +func (d *dynamoDBCDCInput) refreshShards(ctx context.Context) error { + streamDesc, err := d.streamsClient.DescribeStream(ctx, &dynamodbstreams.DescribeStreamInput{ + StreamArn: d.streamArn, + }) + if err != nil { + return err + } + + // Collect new shards to add without holding locks during I/O operations + type shardToAdd struct { + shardID string + iterator *string + } + var newShards []shardToAdd + + for _, shard := range streamDesc.StreamDescription.Shards { + shardID := *shard.ShardId + + // Check if shard already exists (minimize lock hold time) + d.mu.RLock() + _, exists := d.shardReaders[shardID] + d.mu.RUnlock() + + if exists { + continue + } + + // Check checkpoint (I/O operation - do not hold lock) + checkpoint, err := d.checkpointer.Get(ctx, shardID) + if err != nil { + d.log.Warnf("Failed to get checkpoint for shard %s: %v", shardID, err) + } + + var ( + iteratorType types.ShardIteratorType + sequenceNumber *string + ) + + if checkpoint != "" { + iteratorType = types.ShardIteratorTypeAfterSequenceNumber + sequenceNumber = &checkpoint + d.log.Infof("Resuming shard %s from checkpoint: %s", shardID, checkpoint) + } else { + if d.conf.startFrom == "latest" { + iteratorType = types.ShardIteratorTypeLatest + } else { + iteratorType = types.ShardIteratorTypeTrimHorizon + } + d.log.Infof("Starting shard %s from %s", shardID, d.conf.startFrom) + } + + // Get shard iterator (I/O operation - do not hold lock) + iter, err := d.streamsClient.GetShardIterator(ctx, &dynamodbstreams.GetShardIteratorInput{ + StreamArn: d.streamArn, + ShardId: shard.ShardId, + ShardIteratorType: iteratorType, + SequenceNumber: sequenceNumber, + }) + if err != nil { + return fmt.Errorf("failed to get iterator for shard %s: %w", shardID, err) + } + + newShards = append(newShards, shardToAdd{ + shardID: shardID, + iterator: iter.ShardIterator, + }) + } + + // Add all new shard readers in a single critical section + if len(newShards) > 0 { + d.mu.Lock() + for _, s := range newShards { + // Double-check shard wasn't added by another goroutine + if _, exists := d.shardReaders[s.shardID]; !exists { + d.shardReaders[s.shardID] = &dynamoDBShardReader{ + shardID: s.shardID, + iterator: s.iterator, + exhausted: false, + } + } + } + totalShards := len(d.shardReaders) + d.mu.Unlock() + + d.log.Infof("Tracking %d shards", totalShards) + d.shardsTrackedMetric.Set(int64(totalShards)) + } + + return nil +} + +// runShardCoordinator spawns goroutines for each shard and manages shard refresh +func (d *dynamoDBCDCInput) runShardCoordinator() { + defer func() { + close(d.msgChan) + d.shutSig.TriggerHasStopped() + }() + + ctx, cancel := d.shutSig.SoftStopCtx(context.Background()) + defer cancel() + + // Track running shard readers + activeShards := make(map[string]context.CancelFunc) + defer func() { + // Cancel all active shard readers on shutdown + for _, cancelFn := range activeShards { + cancelFn() + } + }() + + refreshTicker := time.NewTicker(30 * time.Second) + defer refreshTicker.Stop() + + for { + // Get current shard readers + d.mu.RLock() + currentReaders := make(map[string]*dynamoDBShardReader) + maps.Copy(currentReaders, d.shardReaders) + d.mu.RUnlock() + + // Start new shard readers for any new shards + for shardID, reader := range currentReaders { + if _, exists := activeShards[shardID]; !exists && !reader.exhausted { + shardCtx, shardCancel := context.WithCancel(ctx) + activeShards[shardID] = shardCancel + go d.runShardReader(shardCtx, shardID) + } + } + + // Update active shards metric + activeCount := 0 + for shardID := range activeShards { + d.mu.RLock() + reader, exists := d.shardReaders[shardID] + d.mu.RUnlock() + if exists && !reader.exhausted { + activeCount++ + } + } + d.shardsActiveMetric.Set(int64(activeCount)) + + select { + case <-ctx.Done(): + return + case <-refreshTicker.C: + // Refresh shards periodically to discover new shards + // Use a timeout context to prevent blocking on shutdown + refreshCtx, refreshCancel := context.WithTimeout(ctx, 30*time.Second) + if err := d.refreshShards(refreshCtx); err != nil && !errors.Is(err, context.Canceled) { + d.log.Warnf("Failed to refresh shards: %v", err) + } + refreshCancel() + } + } +} + +// runShardReader continuously reads from a single shard and sends batches to the channel +func (d *dynamoDBCDCInput) runShardReader(ctx context.Context, shardID string) { + d.log.Debugf("Starting reader for shard %s", shardID) + defer d.log.Debugf("Stopped reader for shard %s", shardID) + + pollTicker := time.NewTicker(d.conf.pollInterval) + defer pollTicker.Stop() + + // Initialize backoff for throttling errors + boff := backoff.NewExponentialBackOff() + boff.InitialInterval = 200 * time.Millisecond + boff.MaxInterval = 2 * time.Second + boff.MaxElapsedTime = 0 // Never give up + + for { + select { + case <-ctx.Done(): + return + case <-pollTicker.C: + // Check for cancellation before expensive operations + select { + case <-ctx.Done(): + return + default: + } + + // Apply backpressure if too many messages are in flight + if d.recordBatcher != nil && d.recordBatcher.ShouldThrottle() { + d.log.Debugf("Throttling shard %s due to too many in-flight messages", shardID) + time.Sleep(100 * time.Millisecond) + continue + } + + // Get current reader state + d.mu.RLock() + reader, exists := d.shardReaders[shardID] + if !exists || reader.exhausted || reader.iterator == nil { + d.mu.RUnlock() + return + } + iterator := reader.iterator + d.mu.RUnlock() + + // Read records from the shard (I/O operation - no lock held) + getRecords, err := d.streamsClient.GetRecords(ctx, &dynamodbstreams.GetRecordsInput{ + ShardIterator: iterator, + Limit: aws.Int32(int32(d.conf.batchSize)), + }) + if err != nil { + if isThrottlingError(err) { + wait := boff.NextBackOff() + d.log.Debugf("Throttled on shard %s, backing off for %v", shardID, wait) + time.Sleep(wait) + continue + } + d.log.Errorf("Failed to get records from shard %s: %v", shardID, err) + // On error, wait and retry (don't mark as exhausted) + continue + } + + // Success - reset backoff + boff.Reset() + + // Update iterator + d.mu.Lock() + reader.iterator = getRecords.NextShardIterator + if reader.iterator == nil { + reader.exhausted = true + d.log.Infof("Shard %s exhausted", shardID) + d.mu.Unlock() + return + } + d.mu.Unlock() + + if len(getRecords.Records) == 0 { + continue + } + + // Convert records to messages + batch := d.convertRecordsToBatch(getRecords.Records, shardID) + if len(batch) == 0 { + continue + } + + // Track messages in batcher + batch = d.recordBatcher.AddMessages(batch, shardID) + + // Track pending ack + d.pendingAcks.Add(1) + + // Create ack function + checkpointer := d.checkpointer + recordBatcher := d.recordBatcher + ackFunc := func(ackCtx context.Context, err error) error { + defer d.pendingAcks.Done() + + // Check if already closed + if d.closed.Load() { + d.log.Warn("Received ack after close, dropping") + if err == nil && recordBatcher != nil { + recordBatcher.RemoveMessages(batch) + } + return nil + } + + if err != nil { + d.log.Warnf("Batch nacked from shard %s: %v", shardID, err) + if recordBatcher != nil { + recordBatcher.RemoveMessages(batch) + } + return err // Propagate nack error + } + + // Mark messages as acked and checkpoint if needed + if recordBatcher != nil && checkpointer != nil { + if ackErr := recordBatcher.AckMessages(ackCtx, checkpointer, batch); ackErr != nil { + d.log.Errorf("Failed to checkpoint shard %s after ack: %v", shardID, ackErr) + return ackErr // Propagate checkpoint failure + } + d.log.Debugf("Successfully checkpointed %d messages from shard %s", len(batch), shardID) + } + return nil + } + + // Send to channel + select { + case <-ctx.Done(): + return + case d.msgChan <- asyncMessage{msg: batch, ackFn: ackFunc}: + d.log.Debugf("Sent batch of %d records from shard %s", len(batch), shardID) + } + } + } +} + +// convertRecordsToBatch converts DynamoDB Stream records to Benthos messages +func (d *dynamoDBCDCInput) convertRecordsToBatch(records []types.Record, shardID string) service.MessageBatch { + batch := make(service.MessageBatch, 0, len(records)) + + for _, record := range records { + msg := service.NewMessage(nil) + + // Structure similar to Kinesis format for consistency + recordData := map[string]any{ + "tableName": d.conf.table, + "eventID": aws.ToString(record.EventID), + "eventName": string(record.EventName), + "eventVersion": aws.ToString(record.EventVersion), + "eventSource": aws.ToString(record.EventSource), + "awsRegion": aws.ToString(record.AwsRegion), + } + + var sequenceNumber string + if record.Dynamodb != nil { + dynamoData := map[string]any{ + "sequenceNumber": aws.ToString(record.Dynamodb.SequenceNumber), + "streamViewType": string(record.Dynamodb.StreamViewType), + } + + if record.Dynamodb.Keys != nil { + dynamoData["keys"] = convertAttributeMap(record.Dynamodb.Keys) + } + if record.Dynamodb.NewImage != nil { + dynamoData["newImage"] = convertAttributeMap(record.Dynamodb.NewImage) + } + if record.Dynamodb.OldImage != nil { + dynamoData["oldImage"] = convertAttributeMap(record.Dynamodb.OldImage) + } + if record.Dynamodb.SizeBytes != nil { + dynamoData["sizeBytes"] = *record.Dynamodb.SizeBytes + } + + recordData["dynamodb"] = dynamoData + sequenceNumber = aws.ToString(record.Dynamodb.SequenceNumber) + } + + msg.SetStructured(recordData) + + // Set metadata + msg.MetaSetMut("dynamodb_shard_id", shardID) + msg.MetaSetMut("dynamodb_sequence_number", sequenceNumber) + msg.MetaSetMut("dynamodb_event_name", string(record.EventName)) + msg.MetaSetMut("dynamodb_table", d.conf.table) + + batch = append(batch, msg) + } + + return batch +} + +func (d *dynamoDBCDCInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { + d.mu.RLock() + msgChan := d.msgChan + shutSig := d.shutSig + d.mu.RUnlock() + + if msgChan == nil || shutSig == nil { + return nil, nil, service.ErrNotConnected + } + + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + case <-shutSig.HasStoppedChan(): + return nil, nil, service.ErrNotConnected + case am, open := <-msgChan: + if !open { + return nil, nil, service.ErrNotConnected + } + return am.msg, am.ackFn, nil + } +} + +func (d *dynamoDBCDCInput) Close(ctx context.Context) error { + // Mark as closed to reject new acks + d.closed.Store(true) + + d.mu.RLock() + shutSig := d.shutSig + checkpointer := d.checkpointer + batcher := d.recordBatcher + d.mu.RUnlock() + + // Trigger graceful shutdown + if shutSig != nil { + d.log.Debug("Initiating graceful shutdown") + shutSig.TriggerSoftStop() + + // Wait for background goroutines to stop + select { + case <-shutSig.HasStoppedChan(): + d.log.Debug("Background goroutines stopped") + case <-time.After(defaultShutdownTimeout): + d.log.Warn("Timeout waiting for background goroutines to stop") + // Trigger hard stop if graceful shutdown times out + shutSig.TriggerHardStop() + } + } else { + d.log.Debug("Skipping shutdown signal - component not fully initialized") + } + + // Wait for pending acknowledgments with timeout + d.log.Debug("Waiting for pending acknowledgments") + acksDone := make(chan struct{}) + go func() { + d.pendingAcks.Wait() + close(acksDone) + }() + + select { + case <-acksDone: + d.log.Debug("All pending acks completed") + case <-time.After(defaultShutdownTimeout): + d.log.Warn("Timeout waiting for pending acks, proceeding with shutdown") + } + + // Flush any pending checkpoints + if checkpointer != nil && batcher != nil { + pendingCheckpoints := batcher.GetPendingCheckpoints() + if len(pendingCheckpoints) > 0 { + d.log.Infof("Flushing %d pending checkpoints on close", len(pendingCheckpoints)) + if err := checkpointer.FlushCheckpoints(ctx, pendingCheckpoints); err != nil { + d.log.Errorf("Failed to flush checkpoints: %v", err) + // Don't return error - continue cleanup to avoid resource leaks + } + } + } else { + d.log.Debug("Skipping checkpoint flush - components not initialized") + } + + // Clear references to help GC + d.mu.Lock() + d.dynamoClient = nil + d.streamsClient = nil + d.shardReaders = nil + d.checkpointer = nil + d.recordBatcher = nil + d.msgChan = nil + d.shutSig = nil + d.mu.Unlock() + + return nil +} + +// Helper to convert DynamoDB attribute values to Go types +func convertAttributeMap(attrs map[string]types.AttributeValue) map[string]any { + result := make(map[string]any) + for k, v := range attrs { + result[k] = convertAttributeValue(v) + } + return result +} + +// isThrottlingError checks if an error is due to AWS throttling. +func isThrottlingError(err error) bool { + if err == nil { + return false + } + var limitErr *types.LimitExceededException + var throttleErr *types.TrimmedDataAccessException + return errors.As(err, &limitErr) || errors.As(err, &throttleErr) +} + +func convertAttributeValue(attr types.AttributeValue) any { + switch v := attr.(type) { + case *types.AttributeValueMemberS: + return v.Value + case *types.AttributeValueMemberN: + return v.Value + case *types.AttributeValueMemberB: + return v.Value + case *types.AttributeValueMemberSS: + return v.Value + case *types.AttributeValueMemberNS: + return v.Value + case *types.AttributeValueMemberBS: + return v.Value + case *types.AttributeValueMemberM: + return convertAttributeMap(v.Value) + case *types.AttributeValueMemberL: + list := make([]any, len(v.Value)) + for i, item := range v.Value { + list[i] = convertAttributeValue(item) + } + return list + case *types.AttributeValueMemberNULL: + return nil + case *types.AttributeValueMemberBOOL: + return v.Value + default: + return nil + } +} diff --git a/internal/impl/aws/input_dynamodb_cdc_batcher.go b/internal/impl/aws/input_dynamodb_cdc_batcher.go new file mode 100644 index 0000000000..02bf9efd6e --- /dev/null +++ b/internal/impl/aws/input_dynamodb_cdc_batcher.go @@ -0,0 +1,177 @@ +// Copyright 2026 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package aws + +import ( + "context" + "fmt" + "maps" + "sync" + + "github.com/redpanda-data/benthos/v4/public/service" +) + +// dynamoDBCDCRecordBatcher tracks messages and their checkpoints for DynamoDB CDC. +// +// This batcher implements a batched checkpointing strategy to optimize performance by +// checkpointing only after a configurable threshold of messages have been acknowledged +// per shard, rather than after every message. +// +// Message tracking +type dynamoDBCDCRecordBatcher struct { + mu sync.Mutex + messageTracker map[*service.Message]*messageCheckpoint + + // Checkpoint state per shard + pendingCount map[string]int // Count of acked but not-yet-checkpointed messages + lastCheckpoints map[string]string // Most recent sequence number per shard + + // Configuration + maxTrackedShards int // Memory safety limit for number of unique shards + maxTrackedMessages int // Memory safety limit for in-flight messages + + log *service.Logger +} + +type messageCheckpoint struct { + shardID string + sequenceNumber string +} + +// newDynamoDBCDCRecordBatcher creates a new record batcher for DynamoDB CDC. +func newDynamoDBCDCRecordBatcher(maxTrackedShards, checkpointLimit int, log *service.Logger) *dynamoDBCDCRecordBatcher { + // Set max tracked messages to 10x the checkpoint limit to allow for some buffering + // This prevents unbounded growth while allowing parallel processing + maxTrackedMessages := checkpointLimit * 10 + if maxTrackedMessages < 1000 { + maxTrackedMessages = 1000 // Minimum reasonable size + } + + return &dynamoDBCDCRecordBatcher{ + messageTracker: make(map[*service.Message]*messageCheckpoint), + log: log, + pendingCount: make(map[string]int), + lastCheckpoints: make(map[string]string), + maxTrackedShards: maxTrackedShards, + maxTrackedMessages: maxTrackedMessages, + } +} + +// AddMessages tracks a batch of messages with their shard and sequence information. +// Each message should have its sequence number in metadata under "dynamodb_sequence_number". +func (b *dynamoDBCDCRecordBatcher) AddMessages(batch service.MessageBatch, shardID string) service.MessageBatch { + b.mu.Lock() + defer b.mu.Unlock() + + // Check if we're approaching memory limits + if len(b.messageTracker)+len(batch) > b.maxTrackedMessages { + b.log.Warnf("Message tracker near capacity: %d/%d tracked messages (adding %d from shard %s)", + len(b.messageTracker), b.maxTrackedMessages, len(batch), shardID) + // Still add messages but warn - this indicates downstream is slow + } + + for _, msg := range batch { + // Extract sequence number from message metadata + sequenceNumber, _ := msg.MetaGet("dynamodb_sequence_number") + b.messageTracker[msg] = &messageCheckpoint{ + shardID: shardID, + sequenceNumber: sequenceNumber, + } + } + + return batch +} + +// RemoveMessages removes messages from tracking (used when messages are nacked). +func (b *dynamoDBCDCRecordBatcher) RemoveMessages(batch service.MessageBatch) { + b.mu.Lock() + defer b.mu.Unlock() + + for _, msg := range batch { + delete(b.messageTracker, msg) + } +} + +// checkpointerInterface defines the interface for checkpointing operations. +type checkpointerInterface interface { + Set(ctx context.Context, shardID, sequenceNumber string) error + GetCheckpointLimit() int +} + +// GetCheckpointLimit returns the checkpoint limit for the checkpointer. +func (c *dynamoDBCDCCheckpointer) GetCheckpointLimit() int { + return c.checkpointLimit +} + +// AckMessages marks messages as acknowledged and checkpoints if threshold is reached. +func (b *dynamoDBCDCRecordBatcher) AckMessages(ctx context.Context, checkpointer checkpointerInterface, batch service.MessageBatch) error { + b.mu.Lock() + defer b.mu.Unlock() + + // Track sequence numbers and message counts per shard + shardSequences := make(map[string]string) + shardMessageCounts := make(map[string]int) + + // Collect the highest sequence number and count messages for each shard in this batch + for _, msg := range batch { + if cp, exists := b.messageTracker[msg]; exists { + // Only update if this sequence is higher (lexicographic comparison works for DynamoDB sequence numbers) + if current, ok := shardSequences[cp.shardID]; !ok || cp.sequenceNumber > current { + shardSequences[cp.shardID] = cp.sequenceNumber + } + shardMessageCounts[cp.shardID]++ + delete(b.messageTracker, msg) + } + } + + // Update pending counts and checkpoint if needed + for shardID, seq := range shardSequences { + b.lastCheckpoints[shardID] = seq + + // Enforce memory bounds on checkpoint map + if len(b.lastCheckpoints) > b.maxTrackedShards { + return fmt.Errorf("checkpoint map exceeded maximum size (%d shards) - possible memory leak", b.maxTrackedShards) + } + + // Increment pending count with the number of messages acked for this shard + b.pendingCount[shardID] += shardMessageCounts[shardID] + + // Check if we should checkpoint + if b.pendingCount[shardID] >= checkpointer.GetCheckpointLimit() { + if err := checkpointer.Set(ctx, shardID, seq); err != nil { + return err + } + + b.log.Debugf("Checkpointed shard %s at sequence %s", shardID, seq) + // Reset counter after successful checkpoint + b.pendingCount[shardID] = 0 + } + } + + return nil +} + +// GetPendingCheckpoints returns a copy of all pending checkpoints that haven't been persisted yet. +func (b *dynamoDBCDCRecordBatcher) GetPendingCheckpoints() map[string]string { + b.mu.Lock() + defer b.mu.Unlock() + + checkpoints := make(map[string]string, len(b.lastCheckpoints)) + maps.Copy(checkpoints, b.lastCheckpoints) + return checkpoints +} + +// ShouldThrottle returns true if the message tracker is near capacity and backpressure should be applied. +func (b *dynamoDBCDCRecordBatcher) ShouldThrottle() bool { + b.mu.Lock() + defer b.mu.Unlock() + + // Throttle at 90% capacity to leave some headroom + return len(b.messageTracker) >= (b.maxTrackedMessages * 9 / 10) +} diff --git a/internal/impl/aws/input_dynamodb_cdc_batcher_test.go b/internal/impl/aws/input_dynamodb_cdc_batcher_test.go new file mode 100644 index 0000000000..7ef055bcd3 --- /dev/null +++ b/internal/impl/aws/input_dynamodb_cdc_batcher_test.go @@ -0,0 +1,427 @@ +// Copyright 2026 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package aws + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/redpanda-data/benthos/v4/public/service" +) + +func createTestMessages(count int, shardID string, startSeq int) service.MessageBatch { + batch := make(service.MessageBatch, count) + for i := range count { + msg := service.NewMessage(nil) + msg.MetaSetMut("dynamodb_shard_id", shardID) + msg.MetaSetMut("dynamodb_sequence_number", string(rune('A'+startSeq+i))) + batch[i] = msg + } + return batch +} + +// Mock checkpointer for testing +type mockCheckpointer struct { + mu sync.Mutex + checkpoints map[string]string + checkpointLimit int + setCallCount int +} + +func (m *mockCheckpointer) Set(_ context.Context, shardID, sequenceNumber string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.checkpoints[shardID] = sequenceNumber + m.setCallCount++ + return nil +} + +func (m *mockCheckpointer) GetCheckpointLimit() int { + return m.checkpointLimit +} + +func TestBatcherAddMessages(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Add messages for shard-001 + batch1 := createTestMessages(5, "shard-001", 0) + result1 := batcher.AddMessages(batch1, "shard-001") + + assert.Len(t, result1, 5) + // pendingCount should be 0 until messages are acked + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Len(t, batcher.messageTracker, 5) + + // Add more messages for same shard + batch2 := createTestMessages(3, "shard-001", 5) + result2 := batcher.AddMessages(batch2, "shard-001") + + assert.Len(t, result2, 3) + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Len(t, batcher.messageTracker, 8) + + // Add messages for different shard + batch3 := createTestMessages(4, "shard-002", 0) + result3 := batcher.AddMessages(batch3, "shard-002") + + assert.Len(t, result3, 4) + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Equal(t, 0, batcher.pendingCount["shard-002"]) + assert.Len(t, batcher.messageTracker, 12) +} + +func TestBatcherRemoveMessages(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Add messages + batch := createTestMessages(10, "shard-001", 0) + batcher.AddMessages(batch, "shard-001") + + // pendingCount should be 0 until messages are acked + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Len(t, batcher.messageTracker, 10) + + // Remove some messages (simulating nack) + toRemove := batch[:5] + batcher.RemoveMessages(toRemove) + + // pendingCount is still 0 since we never acked these messages + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Len(t, batcher.messageTracker, 5) + + // Remove remaining messages + batcher.RemoveMessages(batch[5:]) + + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Empty(t, batcher.messageTracker) +} + +func TestBatcherAckMessagesWithCheckpointing(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + mockCheckpointer := &mockCheckpointer{ + checkpoints: make(map[string]string), + checkpointLimit: 5, // Low threshold for testing + } + + // Add 10 messages + batch := createTestMessages(10, "shard-001", 0) + batcher.AddMessages(batch, "shard-001") + + // Ack first 3 messages - pending count increments to 3, no checkpoint yet (< 5) + toAck1 := batch[:3] + err := batcher.AckMessages(context.Background(), mockCheckpointer, toAck1) + assert.NoError(t, err) + + assert.Equal(t, 3, batcher.pendingCount["shard-001"], "Should have 3 pending after acking 3") + assert.Len(t, batcher.messageTracker, 7) + assert.Equal(t, 0, mockCheckpointer.setCallCount, "Should not checkpoint yet (3 < 5)") + + // Ack 3 more messages - pending count reaches 6 (>= 5), should checkpoint + toAck2 := batch[3:6] + err = batcher.AckMessages(context.Background(), mockCheckpointer, toAck2) + assert.NoError(t, err) + + assert.Equal(t, 0, batcher.pendingCount["shard-001"], "Should reset to 0 after checkpoint") + assert.Len(t, batcher.messageTracker, 4) + assert.Equal(t, 1, mockCheckpointer.setCallCount, "Should checkpoint once (6 >= 5)") +} + +func TestBatcherAckMessagesMultipleShards(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Add messages for multiple shards + batch1 := createTestMessages(6, "shard-001", 0) + batch2 := createTestMessages(6, "shard-002", 0) + + batcher.AddMessages(batch1, "shard-001") + batcher.AddMessages(batch2, "shard-002") + + mockCheckpointer := &mockCheckpointer{ + checkpointLimit: 100, // High limit so we don't checkpoint + } + + checkpointer := &dynamoDBCDCCheckpointer{ + checkpointLimit: mockCheckpointer.checkpointLimit, + } + + // Ack messages from both shards + err := batcher.AckMessages(context.Background(), checkpointer, batch1) + assert.NoError(t, err) + err = batcher.AckMessages(context.Background(), checkpointer, batch2) + assert.NoError(t, err) + + assert.Equal(t, 6, batcher.pendingCount["shard-001"]) + assert.Equal(t, 6, batcher.pendingCount["shard-002"]) + + // Test that both shards are tracked independently + batcher.mu.Lock() + assert.Contains(t, batcher.pendingCount, "shard-001") + assert.Contains(t, batcher.pendingCount, "shard-002") + batcher.mu.Unlock() +} + +// Regression test: Ensure sequence numbers are tracked per message, not per batch +func TestBatcherSequenceNumberPerMessage(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Create messages with different sequence numbers + batch := make(service.MessageBatch, 3) + for i := range 3 { + msg := service.NewMessage(nil) + msg.MetaSetMut("dynamodb_shard_id", "shard-001") + msg.MetaSetMut("dynamodb_sequence_number", string(rune('A'+i))) // A, B, C + batch[i] = msg + } + + batcher.AddMessages(batch, "shard-001") + + // Verify each message has its own sequence number + batcher.mu.Lock() + assert.Equal(t, "A", batcher.messageTracker[batch[0]].sequenceNumber) + assert.Equal(t, "B", batcher.messageTracker[batch[1]].sequenceNumber) + assert.Equal(t, "C", batcher.messageTracker[batch[2]].sequenceNumber) + batcher.mu.Unlock() +} + +// Regression test: Verify pending count increments on ack +func TestBatcherPendingCountDoesNotIncrementOnAck(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + mockCheckpointer := &mockCheckpointer{ + checkpointLimit: 100, // High limit so we don't checkpoint + } + + checkpointer := &dynamoDBCDCCheckpointer{ + checkpointLimit: mockCheckpointer.checkpointLimit, + } + + // Add 10 messages + batch := createTestMessages(10, "shard-001", 0) + batcher.AddMessages(batch, "shard-001") + assert.Equal(t, 0, batcher.pendingCount["shard-001"], "Should be 0 before ack") + + // Ack messages - pending count should increment + err := batcher.AckMessages(context.Background(), checkpointer, batch) + assert.NoError(t, err) + + // Pending count should be 10 after acking 10 messages + assert.Equal(t, 10, batcher.pendingCount["shard-001"]) +} + +// Regression test: Verify latest sequence number is used for checkpointing +func TestBatcherUsesLatestSequenceForCheckpoint(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Create messages with sequence numbers in order + batch := make(service.MessageBatch, 5) + seqNumbers := []string{"00001", "00002", "00003", "00004", "00005"} + for i := range 5 { + msg := service.NewMessage(nil) + msg.MetaSetMut("dynamodb_shard_id", "shard-001") + msg.MetaSetMut("dynamodb_sequence_number", seqNumbers[i]) + batch[i] = msg + } + + batcher.AddMessages(batch, "shard-001") + + // Process messages out of order + outOfOrder := service.MessageBatch{batch[2], batch[0], batch[4], batch[1]} + + batcher.mu.Lock() + latestSeq := "" + for _, msg := range outOfOrder { + if cp, exists := batcher.messageTracker[msg]; exists { + // Track the latest (highest) sequence number + if cp.sequenceNumber > latestSeq { + latestSeq = cp.sequenceNumber + } + delete(batcher.messageTracker, msg) + } + } + batcher.mu.Unlock() + + // The latest sequence should be "00005" (from batch[4]) + assert.Equal(t, "00005", latestSeq) +} + +// Test concurrent access to batcher +func TestBatcherConcurrentAccess(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Add messages concurrently + done := make(chan bool, 2) + + go func() { + for i := range 10 { + batch := createTestMessages(5, "shard-001", i*5) + batcher.AddMessages(batch, "shard-001") + batcher.RemoveMessages(batch) + } + done <- true + }() + + go func() { + for i := range 10 { + batch := createTestMessages(5, "shard-002", i*5) + batcher.AddMessages(batch, "shard-002") + batcher.RemoveMessages(batch) + } + done <- true + }() + + <-done + <-done + + // Verify no race conditions - all messages should be processed + assert.Empty(t, batcher.messageTracker, "All messages should be removed") +} + +func TestBatcherNackAndReAdd(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Add messages + batch := createTestMessages(5, "shard-001", 0) + batcher.AddMessages(batch, "shard-001") + + // pendingCount should be 0 until ack + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + + // Simulate nack by removing messages + batcher.RemoveMessages(batch) + + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Empty(t, batcher.messageTracker) + + // Re-add the same logical messages (new message objects) + newBatch := createTestMessages(5, "shard-001", 0) + batcher.AddMessages(newBatch, "shard-001") + + // Still 0 until ack + assert.Equal(t, 0, batcher.pendingCount["shard-001"]) + assert.Len(t, batcher.messageTracker, 5) +} + +// Test that last checkpoints are updated correctly +func TestBatcherLastCheckpointsTracking(t *testing.T) { + logger := service.MockResources().Logger() + batcher := newDynamoDBCDCRecordBatcher(10000, 1000, logger) + + // Add messages for two shards + batch1 := createTestMessages(3, "shard-001", 0) + batch2 := createTestMessages(3, "shard-002", 0) + + batcher.AddMessages(batch1, "shard-001") + batcher.AddMessages(batch2, "shard-002") + + // Manually update last checkpoints + batcher.mu.Lock() + batcher.lastCheckpoints["shard-001"] = "C" // Last message in batch1 + batcher.lastCheckpoints["shard-002"] = "C" // Last message in batch2 + batcher.mu.Unlock() + + assert.Equal(t, "C", batcher.lastCheckpoints["shard-001"]) + assert.Equal(t, "C", batcher.lastCheckpoints["shard-002"]) +} + +// Test that max tracked shards limit is enforced +func TestBatcherMaxTrackedShardsLimit(t *testing.T) { + logger := service.MockResources().Logger() + // Create batcher with small limit for testing + batcher := newDynamoDBCDCRecordBatcher(5, 1, logger) + + mockCheckpointer := &mockCheckpointer{ + checkpointLimit: 1, + } + + // Wrap mockCheckpointer with the real checkpointer struct + checkpointer := &dynamoDBCDCCheckpointer{ + tableName: "test-checkpoints", + streamArn: "test-stream", + checkpointLimit: mockCheckpointer.checkpointLimit, + log: logger, + } + + // Add messages for 5 shards (at the limit) + for i := range 5 { + shardID := fmt.Sprintf("shard-%03d", i) + batch := createTestMessages(2, shardID, 0) + batcher.AddMessages(batch, shardID) + + // Manually set pending count high enough to trigger checkpoint + batcher.mu.Lock() + batcher.pendingCount[shardID] = 2 + for _, msg := range batch { + if cp, exists := batcher.messageTracker[msg]; exists { + batcher.lastCheckpoints[shardID] = cp.sequenceNumber + } + } + batcher.mu.Unlock() + } + + // Verify we're tracking exactly 5 shards + assert.Len(t, batcher.lastCheckpoints, 5) + + // Now try to add and ack a 6th shard (should exceed limit) + batch := createTestMessages(2, "shard-006", 0) + batcher.AddMessages(batch, "shard-006") + + batcher.mu.Lock() + batcher.pendingCount["shard-006"] = 2 + batcher.mu.Unlock() + + err := batcher.AckMessages(context.Background(), checkpointer, batch) + assert.Error(t, err, "Should fail when exceeding max tracked shards") + assert.Contains(t, err.Error(), "exceeded maximum size") + assert.Contains(t, err.Error(), "5 shards") +} + +// Test that ShouldThrottle works correctly +func TestBatcherShouldThrottle(t *testing.T) { + logger := service.MockResources().Logger() + // Create batcher with small limit for testing (checkpointLimit=10 -> maxTrackedMessages=1000) + batcher := newDynamoDBCDCRecordBatcher(100, 10, logger) + + // Initially should not throttle + assert.False(t, batcher.ShouldThrottle(), "Should not throttle when empty") + + // Add messages up to 80% capacity (should not throttle) + for i := 0; i < 800; i++ { + batch := createTestMessages(1, "shard-001", i) + batcher.AddMessages(batch, "shard-001") + } + assert.False(t, batcher.ShouldThrottle(), "Should not throttle at 80% capacity") + + // Add more to reach 90% capacity (should throttle) + for i := 800; i < 900; i++ { + batch := createTestMessages(1, "shard-001", i) + batcher.AddMessages(batch, "shard-001") + } + assert.True(t, batcher.ShouldThrottle(), "Should throttle at 90% capacity") + + // Add even more to exceed 90% + for i := 900; i < 950; i++ { + batch := createTestMessages(1, "shard-001", i) + batcher.AddMessages(batch, "shard-001") + } + assert.True(t, batcher.ShouldThrottle(), "Should still throttle above 90% capacity") +} diff --git a/internal/impl/aws/input_dynamodb_cdc_checkpoint.go b/internal/impl/aws/input_dynamodb_cdc_checkpoint.go new file mode 100644 index 0000000000..37d76b1152 --- /dev/null +++ b/internal/impl/aws/input_dynamodb_cdc_checkpoint.go @@ -0,0 +1,148 @@ +// Copyright 2026 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package aws + +import ( + "context" + "errors" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + + "github.com/redpanda-data/benthos/v4/public/service" +) + +// dynamoDBCDCCheckpointer manages checkpoints for DynamoDB CDC shards. +type dynamoDBCDCCheckpointer struct { + tableName string + streamArn string + checkpointLimit int + svc *dynamodb.Client + log *service.Logger +} + +// newDynamoDBCDCCheckpointer creates a new checkpointer for DynamoDB CDC. +func newDynamoDBCDCCheckpointer( + ctx context.Context, + svc *dynamodb.Client, + tableName, + streamArn string, + checkpointLimit int, + log *service.Logger, +) (*dynamoDBCDCCheckpointer, error) { + c := &dynamoDBCDCCheckpointer{ + tableName: tableName, + streamArn: streamArn, + checkpointLimit: checkpointLimit, + svc: svc, + log: log, + } + + if err := c.ensureTableExists(ctx); err != nil { + return nil, err + } + + return c, nil +} + +func (c *dynamoDBCDCCheckpointer) ensureTableExists(ctx context.Context) error { + _, err := c.svc.DescribeTable(ctx, &dynamodb.DescribeTableInput{ + TableName: aws.String(c.tableName), + }) + + var aerr *types.ResourceNotFoundException + if err == nil || !errors.As(err, &aerr) { + return err + } + + // Table doesn't exist, create it + input := &dynamodb.CreateTableInput{ + AttributeDefinitions: []types.AttributeDefinition{ + {AttributeName: aws.String("StreamArn"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("ShardID"), AttributeType: types.ScalarAttributeTypeS}, + }, + BillingMode: types.BillingModePayPerRequest, + KeySchema: []types.KeySchemaElement{ + {AttributeName: aws.String("StreamArn"), KeyType: types.KeyTypeHash}, + {AttributeName: aws.String("ShardID"), KeyType: types.KeyTypeRange}, + }, + TableName: aws.String(c.tableName), + } + + if _, err = c.svc.CreateTable(ctx, input); err != nil { + return fmt.Errorf("failed to create checkpoint table: %w", err) + } + + c.log.Infof("Created checkpoint table: %s", c.tableName) + return nil +} + +// Get retrieves the checkpoint for a shard. +func (c *dynamoDBCDCCheckpointer) Get(ctx context.Context, shardID string) (string, error) { + result, err := c.svc.GetItem(ctx, &dynamodb.GetItemInput{ + TableName: aws.String(c.tableName), + Key: map[string]types.AttributeValue{ + "StreamArn": &types.AttributeValueMemberS{Value: c.streamArn}, + "ShardID": &types.AttributeValueMemberS{Value: shardID}, + }, + }) + if err != nil { + var aerr *types.ResourceNotFoundException + if errors.As(err, &aerr) { + return "", nil + } + return "", fmt.Errorf("failed to get checkpoint for table=%s stream=%s shard=%s: %w", + c.tableName, c.streamArn, shardID, err) + } + + if result.Item == nil { + return "", nil + } + + if s, ok := result.Item["SequenceNumber"].(*types.AttributeValueMemberS); ok { + return s.Value, nil + } + + return "", nil +} + +// Set stores a checkpoint for a shard. +func (c *dynamoDBCDCCheckpointer) Set(ctx context.Context, shardID, sequenceNumber string) error { + _, err := c.svc.PutItem(ctx, &dynamodb.PutItemInput{ + TableName: aws.String(c.tableName), + Item: map[string]types.AttributeValue{ + "StreamArn": &types.AttributeValueMemberS{Value: c.streamArn}, + "ShardID": &types.AttributeValueMemberS{Value: shardID}, + "SequenceNumber": &types.AttributeValueMemberS{Value: sequenceNumber}, + }, + }) + if err != nil { + return fmt.Errorf("failed to set checkpoint for table=%s stream=%s shard=%s seq=%s: %w", + c.tableName, c.streamArn, shardID, sequenceNumber, err) + } + return nil +} + +// FlushCheckpoints writes all pending checkpoints to DynamoDB. +func (c *dynamoDBCDCCheckpointer) FlushCheckpoints(ctx context.Context, checkpoints map[string]string) error { + // Flush all pending checkpoints + for shardID, seq := range checkpoints { + if seq == "" { + continue + } + if err := c.Set(ctx, shardID, seq); err != nil { + c.log.Errorf("Failed to flush checkpoint for shard %s: %v", shardID, err) + return err + } + c.log.Infof("Flushed checkpoint for shard %s at sequence %s", shardID, seq) + } + return nil +} diff --git a/internal/impl/aws/input_dynamodb_cdc_integration_test.go b/internal/impl/aws/input_dynamodb_cdc_integration_test.go new file mode 100644 index 0000000000..27930f6d92 --- /dev/null +++ b/internal/impl/aws/input_dynamodb_cdc_integration_test.go @@ -0,0 +1,544 @@ +// Copyright 2026 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +//go:build integration + +package aws + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" +) + +// createTableWithStreams creates a DynamoDB table with streams enabled for testing +func createTableWithStreams(ctx context.Context, t testing.TB, dynamoPort, tableName string) (*dynamodb.Client, error) { + endpoint := fmt.Sprintf("http://localhost:%v", dynamoPort) + + conf, err := config.LoadDefaultConfig(ctx, + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("xxxxx", "xxxxx", "xxxxx")), + config.WithRegion("us-east-1"), + ) + require.NoError(t, err) + + conf.BaseEndpoint = &endpoint + client := dynamodb.NewFromConfig(conf) + + // Check if table already exists + ta, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{ + TableName: &tableName, + }) + if err != nil { + var derr *types.ResourceNotFoundException + if !errors.As(err, &derr) { + return nil, err + } + } + + if ta != nil && ta.Table != nil && ta.Table.TableStatus == types.TableStatusActive { + return client, nil + } + + intPtr := func(i int64) *int64 { + return &i + } + + t.Logf("Creating table with streams: %v\n", tableName) + _, err = client.CreateTable(ctx, &dynamodb.CreateTableInput{ + AttributeDefinitions: []types.AttributeDefinition{ + { + AttributeName: aws.String("id"), + AttributeType: types.ScalarAttributeTypeS, + }, + }, + KeySchema: []types.KeySchemaElement{ + { + AttributeName: aws.String("id"), + KeyType: types.KeyTypeHash, + }, + }, + ProvisionedThroughput: &types.ProvisionedThroughput{ + ReadCapacityUnits: intPtr(5), + WriteCapacityUnits: intPtr(5), + }, + TableName: &tableName, + StreamSpecification: &types.StreamSpecification{ + StreamEnabled: aws.Bool(true), + StreamViewType: types.StreamViewTypeNewAndOldImages, + }, + }) + if err != nil { + return nil, err + } + + // Wait for table to be active + waiter := dynamodb.NewTableExistsWaiter(client) + err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{ + TableName: &tableName, + }, time.Minute) + + return client, err +} + +// putTestItem inserts a test item into DynamoDB +func putTestItem(ctx context.Context, client *dynamodb.Client, tableName, id, value string) error { + _, err := client.PutItem(ctx, &dynamodb.PutItemInput{ + TableName: &tableName, + Item: map[string]types.AttributeValue{ + "id": &types.AttributeValueMemberS{Value: id}, + "value": &types.AttributeValueMemberS{Value: value}, + }, + }) + return err +} + +// updateTestItem updates a test item in DynamoDB +func updateTestItem(ctx context.Context, client *dynamodb.Client, tableName, id, newValue string) error { + _, err := client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ + TableName: &tableName, + Key: map[string]types.AttributeValue{ + "id": &types.AttributeValueMemberS{Value: id}, + }, + UpdateExpression: aws.String("SET #v = :val"), + ExpressionAttributeNames: map[string]string{ + "#v": "value", + }, + ExpressionAttributeValues: map[string]types.AttributeValue{ + ":val": &types.AttributeValueMemberS{Value: newValue}, + }, + }) + return err +} + +// deleteTestItem deletes a test item from DynamoDB +func deleteTestItem(ctx context.Context, client *dynamodb.Client, tableName, id string) error { + _, err := client.DeleteItem(ctx, &dynamodb.DeleteItemInput{ + TableName: &tableName, + Key: map[string]types.AttributeValue{ + "id": &types.AttributeValueMemberS{Value: id}, + }, + }) + return err +} + +func TestIntegrationDynamoDBStreams(t *testing.T) { + integration.CheckSkip(t) + t.Parallel() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pool.MaxWait = time.Second * 60 + + // Start DynamoDB Local container + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "amazon/dynamodb-local", + Tag: "latest", + ExposedPorts: []string{"8000/tcp"}, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + _ = resource.Expire(900) + + var client *dynamodb.Client + tableName := "test-streams-table" + + // Wait for DynamoDB to be ready and create table with streams + require.NoError(t, pool.Retry(func() error { + var err error + client, err = createTableWithStreams(context.Background(), t, resource.GetPort("8000/tcp"), tableName) + return err + })) + + port := resource.GetPort("8000/tcp") + + t.Run("ReadInsertEvents", func(t *testing.T) { + checkpointTable := "test-checkpoints-insert" + testReadInsertEvents(t, client, port, tableName, checkpointTable) + }) + + t.Run("ReadModifyEvents", func(t *testing.T) { + checkpointTable := "test-checkpoints-modify" + testReadModifyEvents(t, client, port, tableName, checkpointTable) + }) + + t.Run("ReadRemoveEvents", func(t *testing.T) { + checkpointTable := "test-checkpoints-remove" + testReadRemoveEvents(t, client, port, tableName, checkpointTable) + }) + + t.Run("CheckpointResumption", func(t *testing.T) { + checkpointTable := "test-checkpoints-resumption" + testCheckpointResumption(t, client, port, tableName, checkpointTable) + }) + + t.Run("VerifyRecordCount", func(t *testing.T) { + checkpointTable := "test-checkpoints-count" + testVerifyRecordCount(t, client, port, tableName, checkpointTable) + }) +} + +// testReadInsertEvents verifies that INSERT events are captured +func testReadInsertEvents(t *testing.T, client *dynamodb.Client, port, tableName, checkpointTable string) { + ctx := context.Background() + + // Create input configuration + confStr := fmt.Sprintf(` +table: %s +checkpoint_table: %s +endpoint: http://localhost:%s +region: us-east-1 +start_from: latest +credentials: + id: xxxxx + secret: xxxxx + token: xxxxx +`, tableName, checkpointTable, port) + + spec := dynamoDBCDCInputConfig() + parsed, err := spec.ParseYAML(confStr, nil) + require.NoError(t, err) + + input, err := newDynamoDBCDCInputFromConfig(parsed, service.MockResources()) + require.NoError(t, err) + + require.NoError(t, input.Connect(ctx)) + t.Cleanup(func() { + _ = input.Close(ctx) + }) + + // Insert test items + require.NoError(t, putTestItem(ctx, client, tableName, "test-1", "value-1")) + require.NoError(t, putTestItem(ctx, client, tableName, "test-2", "value-2")) + + // Read events + batch, _, err := input.ReadBatch(ctx) + require.NoError(t, err) + require.NotEmpty(t, batch) + + // Verify we got INSERT events + foundInsert := false + for _, msg := range batch { + eventName, _ := msg.MetaGet("dynamodb_event_name") + if eventName == "INSERT" { + foundInsert = true + break + } + } + assert.True(t, foundInsert, "Should receive INSERT events") +} + +// testReadModifyEvents verifies that MODIFY events are captured +func testReadModifyEvents(t *testing.T, client *dynamodb.Client, port, tableName, checkpointTable string) { + ctx := context.Background() + + // Create input configuration + confStr := fmt.Sprintf(` +table: %s +checkpoint_table: %s +endpoint: http://localhost:%s +region: us-east-1 +start_from: latest +credentials: + id: xxxxx + secret: xxxxx + token: xxxxx +`, tableName, checkpointTable, port) + + spec := dynamoDBCDCInputConfig() + parsed, err := spec.ParseYAML(confStr, nil) + require.NoError(t, err) + + input, err := newDynamoDBCDCInputFromConfig(parsed, service.MockResources()) + require.NoError(t, err) + + require.NoError(t, input.Connect(ctx)) + t.Cleanup(func() { + _ = input.Close(ctx) + }) + + // Insert an item + itemID := "modify-test" + require.NoError(t, putTestItem(ctx, client, tableName, itemID, "original")) + + // Wait briefly for stream propagation + time.Sleep(100 * time.Millisecond) + + // Update the item + require.NoError(t, updateTestItem(ctx, client, tableName, itemID, "updated")) + + // Read events (may need multiple batches) + foundModify := false + for i := 0; i < 5 && !foundModify; i++ { + batch, _, err := input.ReadBatch(ctx) + if err != nil { + time.Sleep(100 * time.Millisecond) + continue + } + + for _, msg := range batch { + eventName, _ := msg.MetaGet("dynamodb_event_name") + if eventName == "MODIFY" { + foundModify = true + break + } + } + + if !foundModify { + time.Sleep(100 * time.Millisecond) + } + } + + assert.True(t, foundModify, "Should receive MODIFY events") +} + +// testReadRemoveEvents verifies that REMOVE events are captured +func testReadRemoveEvents(t *testing.T, client *dynamodb.Client, port, tableName, checkpointTable string) { + ctx := context.Background() + + // Create input configuration + confStr := fmt.Sprintf(` +table: %s +checkpoint_table: %s +endpoint: http://localhost:%s +region: us-east-1 +start_from: latest +credentials: + id: xxxxx + secret: xxxxx + token: xxxxx +`, tableName, checkpointTable, port) + + spec := dynamoDBCDCInputConfig() + parsed, err := spec.ParseYAML(confStr, nil) + require.NoError(t, err) + + input, err := newDynamoDBCDCInputFromConfig(parsed, service.MockResources()) + require.NoError(t, err) + + require.NoError(t, input.Connect(ctx)) + t.Cleanup(func() { + _ = input.Close(ctx) + }) + + // Insert an item + itemID := "delete-test" + require.NoError(t, putTestItem(ctx, client, tableName, itemID, "to-delete")) + + // Wait briefly for stream propagation + time.Sleep(100 * time.Millisecond) + + // Delete the item + require.NoError(t, deleteTestItem(ctx, client, tableName, itemID)) + + // Read events (may need multiple batches) + foundRemove := false + for i := 0; i < 5 && !foundRemove; i++ { + batch, _, err := input.ReadBatch(ctx) + if err != nil { + time.Sleep(100 * time.Millisecond) + continue + } + + for _, msg := range batch { + eventName, _ := msg.MetaGet("dynamodb_event_name") + if eventName == "REMOVE" { + foundRemove = true + break + } + } + + if !foundRemove { + time.Sleep(100 * time.Millisecond) + } + } + + assert.True(t, foundRemove, "Should receive REMOVE events") +} + +// testVerifyRecordCount verifies that the number of CDC events matches the number of operations performed +func testVerifyRecordCount(t *testing.T, client *dynamodb.Client, port, tableName, checkpointTable string) { + ctx := context.Background() + + // Create input configuration + confStr := fmt.Sprintf(` +table: %s +checkpoint_table: %s +endpoint: http://localhost:%s +region: us-east-1 +start_from: latest +credentials: + id: xxxxx + secret: xxxxx + token: xxxxx +`, tableName, checkpointTable, port) + + spec := dynamoDBCDCInputConfig() + parsed, err := spec.ParseYAML(confStr, nil) + require.NoError(t, err) + + input, err := newDynamoDBCDCInputFromConfig(parsed, service.MockResources()) + require.NoError(t, err) + + require.NoError(t, input.Connect(ctx)) + t.Cleanup(func() { + _ = input.Close(ctx) + }) + + // Perform a known number of operations + numInserts := 100 + numUpdates := 5 + numDeletes := 3 + expectedTotalEvents := numInserts + numUpdates + numDeletes + + // Insert items + for i := 0; i < numInserts; i++ { + itemID := fmt.Sprintf("count-test-%d", i) + require.NoError(t, putTestItem(ctx, client, tableName, itemID, "initial")) + } + + // Update some items + for i := 0; i < numUpdates; i++ { + itemID := fmt.Sprintf("count-test-%d", i) + require.NoError(t, updateTestItem(ctx, client, tableName, itemID, "updated")) + } + + // Delete some items + for i := 0; i < numDeletes; i++ { + itemID := fmt.Sprintf("count-test-%d", i) + require.NoError(t, deleteTestItem(ctx, client, tableName, itemID)) + } + + // Read events until we get all expected events or timeout + receivedEvents := make([]string, 0, expectedTotalEvents) + eventCounts := map[string]int{ + "INSERT": 0, + "MODIFY": 0, + "REMOVE": 0, + } + + maxAttempts := 20 + for attempt := 0; attempt < maxAttempts; attempt++ { + batch, _, err := input.ReadBatch(ctx) + if err != nil { + time.Sleep(100 * time.Millisecond) + continue + } + + if len(batch) == 0 { + time.Sleep(100 * time.Millisecond) + continue + } + + for _, msg := range batch { + eventName, exists := msg.MetaGet("dynamodb_event_name") + if exists { + receivedEvents = append(receivedEvents, eventName) + eventCounts[eventName]++ + } + } + + // Check if we've received all expected events + if len(receivedEvents) >= expectedTotalEvents { + break + } + + time.Sleep(100 * time.Millisecond) + } + + // Verify counts + assert.Len(t, receivedEvents, expectedTotalEvents, + "Should receive exactly %d events", expectedTotalEvents) + assert.Equal(t, numInserts, eventCounts["INSERT"], + "Should receive %d INSERT events", numInserts) + assert.Equal(t, numUpdates, eventCounts["MODIFY"], + "Should receive %d MODIFY events", numUpdates) + assert.Equal(t, numDeletes, eventCounts["REMOVE"], + "Should receive %d REMOVE events", numDeletes) + + t.Logf("Received %d total events: %d INSERTs, %d MODIFYs, %d REMOVEs", + len(receivedEvents), eventCounts["INSERT"], eventCounts["MODIFY"], eventCounts["REMOVE"]) +} + +// testCheckpointResumption verifies that checkpoints work correctly +func testCheckpointResumption(t *testing.T, client *dynamodb.Client, port, tableName, checkpointTable string) { + ctx := context.Background() + + // Create input configuration + confStr := fmt.Sprintf(` +table: %s +checkpoint_table: %s +endpoint: http://localhost:%s +region: us-east-1 +start_from: trim_horizon +checkpoint_limit: 2 +credentials: + id: xxxxx + secret: xxxxx + token: xxxxx +`, tableName, checkpointTable, port) + + spec := dynamoDBCDCInputConfig() + parsed, err := spec.ParseYAML(confStr, nil) + require.NoError(t, err) + + // First input instance + input1, err := newDynamoDBCDCInputFromConfig(parsed, service.MockResources()) + require.NoError(t, err) + require.NoError(t, input1.Connect(ctx)) + + // Insert some items + require.NoError(t, putTestItem(ctx, client, tableName, "checkpoint-1", "value-1")) + require.NoError(t, putTestItem(ctx, client, tableName, "checkpoint-2", "value-2")) + + // Read and acknowledge messages + batch1, ackFn1, err := input1.ReadBatch(ctx) + require.NoError(t, err) + require.NotEmpty(t, batch1) + + // Acknowledge to trigger checkpoint + require.NoError(t, ackFn1(ctx, nil)) + + // Close first input + require.NoError(t, input1.Close(ctx)) + + // Create second input instance (should resume from checkpoint) + input2, err := newDynamoDBCDCInputFromConfig(parsed, service.MockResources()) + require.NoError(t, err) + require.NoError(t, input2.Connect(ctx)) + t.Cleanup(func() { + _ = input2.Close(ctx) + }) + + // Insert new item after checkpoint + require.NoError(t, putTestItem(ctx, client, tableName, "checkpoint-3", "value-3")) + + // Second input should read new events (not re-read old ones) + batch2, _, err := input2.ReadBatch(ctx) + require.NoError(t, err) + + // The batch may include checkpoint-3 but should not re-process already checkpointed items + assert.NotEmpty(t, batch2, "Should read new events after resumption") +} diff --git a/internal/impl/aws/input_dynamodb_cdc_test.go b/internal/impl/aws/input_dynamodb_cdc_test.go new file mode 100644 index 0000000000..0c2aec2915 --- /dev/null +++ b/internal/impl/aws/input_dynamodb_cdc_test.go @@ -0,0 +1,188 @@ +// Copyright 2026 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package aws + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + streamstypes "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" + "github.com/stretchr/testify/assert" + + "github.com/redpanda-data/benthos/v4/public/service" +) + +func TestConvertAttributeValue(t *testing.T) { + tests := []struct { + name string + input streamstypes.AttributeValue + expected any + }{ + { + name: "string value", + input: &streamstypes.AttributeValueMemberS{Value: "test"}, + expected: "test", + }, + { + name: "number value", + input: &streamstypes.AttributeValueMemberN{Value: "123"}, + expected: "123", + }, + { + name: "boolean true", + input: &streamstypes.AttributeValueMemberBOOL{Value: true}, + expected: true, + }, + { + name: "boolean false", + input: &streamstypes.AttributeValueMemberBOOL{Value: false}, + expected: false, + }, + { + name: "null value", + input: &streamstypes.AttributeValueMemberNULL{Value: true}, + expected: nil, + }, + { + name: "string set", + input: &streamstypes.AttributeValueMemberSS{Value: []string{"a", "b", "c"}}, + expected: []string{"a", "b", "c"}, + }, + { + name: "number set", + input: &streamstypes.AttributeValueMemberNS{Value: []string{"1", "2", "3"}}, + expected: []string{"1", "2", "3"}, + }, + { + name: "map value", + input: &streamstypes.AttributeValueMemberM{Value: map[string]streamstypes.AttributeValue{ + "key1": &streamstypes.AttributeValueMemberS{Value: "value1"}, + "key2": &streamstypes.AttributeValueMemberN{Value: "42"}, + }}, + expected: map[string]any{ + "key1": "value1", + "key2": "42", + }, + }, + { + name: "list value", + input: &streamstypes.AttributeValueMemberL{Value: []streamstypes.AttributeValue{ + &streamstypes.AttributeValueMemberS{Value: "item1"}, + &streamstypes.AttributeValueMemberN{Value: "100"}, + }}, + expected: []any{"item1", "100"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertAttributeValue(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestConvertAttributeMap(t *testing.T) { + input := map[string]streamstypes.AttributeValue{ + "id": &streamstypes.AttributeValueMemberS{Value: "123"}, + "count": &streamstypes.AttributeValueMemberN{Value: "42"}, + "active": &streamstypes.AttributeValueMemberBOOL{Value: true}, + "metadata": &streamstypes.AttributeValueMemberM{Value: map[string]streamstypes.AttributeValue{ + "created": &streamstypes.AttributeValueMemberS{Value: "2024-01-01"}, + }}, + } + + result := convertAttributeMap(input) + + assert.Equal(t, "123", result["id"]) + assert.Equal(t, "42", result["count"]) + assert.Equal(t, true, result["active"]) + assert.IsType(t, map[string]any{}, result["metadata"]) + metadata := result["metadata"].(map[string]any) + assert.Equal(t, "2024-01-01", metadata["created"]) +} + +func TestMinFunction(t *testing.T) { + tests := []struct { + a int + b int + expected int + }{ + {1, 2, 1}, + {5, 3, 3}, + {10, 10, 10}, + {-1, 5, -1}, + {0, 0, 0}, + } + + for _, tt := range tests { + result := min(tt.a, tt.b) + assert.Equal(t, tt.expected, result) + } +} + +// Regression test: Verify RWMutex allows concurrent reads +func TestConcurrentShardReaderAccess(t *testing.T) { + logger := service.MockResources().Logger() + + input := &dynamoDBCDCInput{ + shardReaders: map[string]*dynamoDBShardReader{ + "shard-001": {shardID: "shard-001", iterator: aws.String("iter-001"), exhausted: false}, + "shard-002": {shardID: "shard-002", iterator: aws.String("iter-002"), exhausted: false}, + }, + log: logger, + } + + // Multiple goroutines should be able to read concurrently + done := make(chan bool, 3) + + for range 3 { + go func() { + input.mu.RLock() + count := len(input.shardReaders) + input.mu.RUnlock() + assert.Equal(t, 2, count) + done <- true + }() + } + + for range 3 { + <-done + } +} + +// Test that exhausted shards are properly handled +func TestExhaustedShardHandling(t *testing.T) { + input := &dynamoDBCDCInput{ + shardReaders: map[string]*dynamoDBShardReader{ + "shard-001": { + shardID: "shard-001", + iterator: nil, // Exhausted - no iterator + exhausted: true, + }, + "shard-002": { + shardID: "shard-002", + iterator: aws.String("iter-002"), + exhausted: false, + }, + }, + } + + // Count active readers + input.mu.RLock() + activeCount := 0 + for _, reader := range input.shardReaders { + if !reader.exhausted && reader.iterator != nil { + activeCount++ + } + } + input.mu.RUnlock() + + assert.Equal(t, 1, activeCount, "Only one shard should be active") +} diff --git a/internal/plugins/info.csv b/internal/plugins/info.csv index dcf89a3fb1..47462c0194 100644 --- a/internal/plugins/info.csv +++ b/internal/plugins/info.csv @@ -13,6 +13,7 @@ aws_bedrock_embeddings ,processor ,aws_bedrock_embeddings ,4.37.0 ,certif aws_cloudwatch ,metric ,aws_cloudwatch ,3.36.0 ,community ,n ,n ,n aws_dynamodb ,cache ,AWS DynamoDB ,3.36.0 ,community ,n ,y ,y aws_dynamodb ,output ,AWS DynamoDB ,3.36.0 ,community ,n ,y ,y +aws_dynamodb_cdc ,input ,aws_dynamodb_cdc ,1.0.0 ,enterprise ,n ,y ,n aws_dynamodb_partiql ,processor ,aws_dynamodb_partiql ,3.48.0 ,certified ,n ,y ,y aws_kinesis ,input ,AWS Kinesis ,3.36.0 ,certified ,n ,y ,y aws_kinesis ,output ,AWS Kinesis ,3.36.0 ,certified ,n ,y ,y