Skip to content

Commit e495661

Browse files
authored
[feat] - Support S3 Source Resumption (#3570)
* add config option for s3 resumption * updates * initial progress tracking logic * more testing * revert s3 source file * UpdateScanProgress tests * adjust * updates * invert * updates * updates * fix * update * adjust test * fix * remove progress tracking * cleanup * cleanup * remove dupe * remove context cancellation logic * fix comment format * make resumption logic more clear * rename * fixes * update * add edge case test * remove dupe mu * add comment * fix comment
1 parent 9a6cad9 commit e495661

File tree

3 files changed

+343
-37
lines changed

3 files changed

+343
-37
lines changed

pkg/sources/s3/s3.go

Lines changed: 180 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package s3
22

33
import (
44
"fmt"
5+
"slices"
56
"strings"
67
"sync"
78
"sync/atomic"
@@ -43,8 +44,10 @@ type Source struct {
4344
jobID sources.JobID
4445
verify bool
4546
concurrency int
47+
conn *sourcespb.S3
48+
49+
checkpointer *Checkpointer
4650
sources.Progress
47-
conn *sourcespb.S3
4851

4952
errorCount *sync.Map
5053
jobPool *errgroup.Group
@@ -67,7 +70,7 @@ func (s *Source) JobID() sources.JobID { return s.jobID }
6770

6871
// Init returns an initialized AWS source
6972
func (s *Source) Init(
70-
_ context.Context,
73+
ctx context.Context,
7174
name string,
7275
jobID sources.JobID,
7376
sourceID sources.SourceID,
@@ -90,6 +93,8 @@ func (s *Source) Init(
9093
}
9194
s.conn = &conn
9295

96+
s.checkpointer = NewCheckpointer(ctx, conn.GetEnableResumption(), &s.Progress)
97+
9398
s.setMaxObjectSize(conn.GetMaxObjectSize())
9499

95100
if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 {
@@ -173,9 +178,16 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {
173178
return s3.New(sess), nil
174179
}
175180

176-
// IAM identity needs s3:ListBuckets permission
181+
// getBucketsToScan returns a list of S3 buckets to scan.
182+
// If the connection has a list of buckets specified, those are returned.
183+
// Otherwise, it lists all buckets the client has access to and filters out the ignored ones.
184+
// The list of buckets is sorted lexicographically to ensure consistent ordering,
185+
// which allows resuming scanning from the same place if the scan is interrupted.
186+
//
187+
// Note: The IAM identity needs the s3:ListBuckets permission.
177188
func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
178189
if buckets := s.conn.GetBuckets(); len(buckets) > 0 {
190+
slices.Sort(buckets)
179191
return buckets, nil
180192
}
181193

@@ -196,32 +208,122 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
196208
bucketsToScan = append(bucketsToScan, name)
197209
}
198210
}
211+
slices.Sort(bucketsToScan)
212+
199213
return bucketsToScan, nil
200214
}
201215

216+
// pageMetadata contains metadata about a single page of S3 objects being scanned.
217+
type pageMetadata struct {
218+
bucket string // The name of the S3 bucket being scanned
219+
pageNumber int // Current page number in the pagination sequence
220+
client *s3.S3 // AWS S3 client configured for the appropriate region
221+
page *s3.ListObjectsV2Output // Contains the list of S3 objects in this page
222+
}
223+
224+
// processingState tracks the state of concurrent S3 object processing.
225+
type processingState struct {
226+
errorCount *sync.Map // Thread-safe map tracking errors per prefix
227+
objectCount *uint64 // Total number of objects processed
228+
}
229+
230+
// resumePosition tracks where to restart scanning S3 buckets and objects after an interruption.
231+
// It encapsulates all the information needed to resume a scan from its last known position.
232+
type resumePosition struct {
233+
bucket string // The bucket name we were processing
234+
index int // Index in the buckets slice where we should resume
235+
startAfter string // The last processed object key within the bucket
236+
isNewScan bool // True if we're starting a fresh scan
237+
exactMatch bool // True if we found the exact bucket we were previously processing
238+
}
239+
240+
// determineResumePosition calculates where to resume scanning from based on the last saved checkpoint
241+
// and the current list of available buckets to scan. It handles several scenarios:
242+
//
243+
// 1. If getting the resume point fails or there is no previous bucket saved (CurrentBucket is empty),
244+
// we start a new scan from the beginning, this is the safest option.
245+
//
246+
// 2. If the previous bucket exists in our current scan list (exactMatch=true),
247+
// we resume from that exact position and use the StartAfter value
248+
// to continue from the last processed object within that bucket.
249+
//
250+
// 3. If the previous bucket is not found in our current scan list (exactMatch=false), this typically means:
251+
// - The bucket was deleted since our last scan
252+
// - The bucket was explicitly excluded from this scan's configuration
253+
// - The IAM role no longer has access to the bucket
254+
// - The bucket name changed due to a configuration update
255+
// In this case, we use binary search to find the closest position where the bucket would have been,
256+
// allowing us to resume from the nearest available point in our sorted bucket list rather than
257+
// restarting the entire scan.
258+
func determineResumePosition(ctx context.Context, tracker *Checkpointer, buckets []string) resumePosition {
259+
resumePoint, err := tracker.ResumePoint(ctx)
260+
if err != nil {
261+
ctx.Logger().Error(err, "failed to get resume point; starting from the beginning")
262+
return resumePosition{isNewScan: true}
263+
}
264+
265+
if resumePoint.CurrentBucket == "" {
266+
return resumePosition{isNewScan: true}
267+
}
268+
269+
startIdx, found := slices.BinarySearch(buckets, resumePoint.CurrentBucket)
270+
return resumePosition{
271+
bucket: resumePoint.CurrentBucket,
272+
startAfter: resumePoint.StartAfter,
273+
index: startIdx,
274+
exactMatch: found,
275+
}
276+
}
277+
202278
func (s *Source) scanBuckets(
203279
ctx context.Context,
204280
client *s3.S3,
205281
role string,
206282
bucketsToScan []string,
207283
chunksChan chan *sources.Chunk,
208284
) {
209-
var objectCount uint64
210-
211285
if role != "" {
212286
ctx = context.WithValue(ctx, "role", role)
213287
}
288+
var objectCount uint64
214289

215-
for i, bucket := range bucketsToScan {
290+
pos := determineResumePosition(ctx, s.checkpointer, bucketsToScan)
291+
switch {
292+
case pos.isNewScan:
293+
ctx.Logger().Info("Starting new scan from beginning")
294+
case !pos.exactMatch:
295+
ctx.Logger().Info(
296+
"Resume bucket no longer available, starting from closest position",
297+
"original_bucket", pos.bucket,
298+
"position", pos.index,
299+
)
300+
default:
301+
ctx.Logger().Info(
302+
"Resuming scan from previous scan's bucket",
303+
"bucket", pos.bucket,
304+
"position", pos.index,
305+
)
306+
}
307+
308+
bucketsToScanCount := len(bucketsToScan)
309+
for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ {
310+
bucket := bucketsToScan[bucketIdx]
216311
ctx := context.WithValue(ctx, "bucket", bucket)
217312

218313
if common.IsDone(ctx) {
314+
ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket")
219315
return
220316
}
221317

222-
s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
223318
ctx.Logger().V(3).Info("Scanning bucket")
224319

320+
s.SetProgressComplete(
321+
bucketIdx,
322+
len(bucketsToScan),
323+
fmt.Sprintf("Bucket: %s", bucket),
324+
s.Progress.EncodedResumeInfo,
325+
)
326+
225327
regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
226328
if err != nil {
227329
ctx.Logger().Error(err, "could not get regional client for bucket")
@@ -230,10 +332,33 @@ func (s *Source) scanBuckets(
230332

231333
errorCount := sync.Map{}
232334

335+
input := &s3.ListObjectsV2Input{Bucket: &bucket}
336+
if bucket == pos.bucket && pos.startAfter != "" {
337+
input.StartAfter = &pos.startAfter
338+
ctx.Logger().V(3).Info(
339+
"Resuming bucket scan",
340+
"start_after", pos.startAfter,
341+
)
342+
}
343+
344+
pageNumber := 1
233345
err = regionalClient.ListObjectsV2PagesWithContext(
234-
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
346+
ctx,
347+
input,
235348
func(page *s3.ListObjectsV2Output, _ bool) bool {
236-
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
349+
pageMetadata := pageMetadata{
350+
bucket: bucket,
351+
pageNumber: pageNumber,
352+
client: regionalClient,
353+
page: page,
354+
}
355+
processingState := processingState{
356+
errorCount: &errorCount,
357+
objectCount: &objectCount,
358+
}
359+
s.pageChunker(ctx, pageMetadata, processingState, chunksChan)
360+
361+
pageNumber++
237362
return true
238363
})
239364

@@ -249,6 +374,7 @@ func (s *Source) scanBuckets(
249374
}
250375
}
251376
}
377+
252378
s.SetProgressComplete(
253379
len(bucketsToScan),
254380
len(bucketsToScan),
@@ -289,29 +415,25 @@ func (s *Source) getRegionalClientForBucket(
289415
return regionalClient, nil
290416
}
291417

292-
// pageChunker emits chunks onto the given channel from a page
418+
// pageChunker emits chunks onto the given channel from a page.
293419
func (s *Source) pageChunker(
294420
ctx context.Context,
295-
client *s3.S3,
421+
metadata pageMetadata,
422+
state processingState,
296423
chunksChan chan *sources.Chunk,
297-
bucket string,
298-
page *s3.ListObjectsV2Output,
299-
errorCount *sync.Map,
300-
pageNumber int,
301-
objectCount *uint64,
302424
) {
303-
for _, obj := range page.Contents {
425+
s.checkpointer.Reset() // Reset the checkpointer for each PAGE
426+
ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber)
427+
428+
for objIdx, obj := range metadata.page.Contents {
304429
if obj == nil {
430+
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
431+
ctx.Logger().Error(err, "could not update progress for nil object")
432+
}
305433
continue
306434
}
307435

308-
ctx = context.WithValues(
309-
ctx,
310-
"key", *obj.Key,
311-
"bucket", bucket,
312-
"page", pageNumber,
313-
"size", *obj.Size,
314-
)
436+
ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size)
315437

316438
if common.IsDone(ctx) {
317439
return
@@ -320,29 +442,44 @@ func (s *Source) pageChunker(
320442
// Skip GLACIER and GLACIER_IR objects.
321443
if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") {
322444
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass)
445+
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
446+
ctx.Logger().Error(err, "could not update progress for glacier object")
447+
}
323448
continue
324449
}
325450

326451
// Ignore large files.
327452
if *obj.Size > s.maxObjectSize {
328453
ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)")
454+
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
455+
ctx.Logger().Error(err, "could not update progress for large file")
456+
}
329457
continue
330458
}
331459

332460
// File empty file.
333461
if *obj.Size == 0 {
334462
ctx.Logger().V(5).Info("Skipping empty file")
463+
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
464+
ctx.Logger().Error(err, "could not update progress for empty file")
465+
}
335466
continue
336467
}
337468

338469
// Skip incompatible extensions.
339470
if common.SkipFile(*obj.Key) {
340471
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
472+
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
473+
ctx.Logger().Error(err, "could not update progress for incompatible file")
474+
}
341475
continue
342476
}
343477

344478
s.jobPool.Go(func() error {
345479
defer common.RecoverWithExit(ctx)
480+
if common.IsDone(ctx) {
481+
return ctx.Err()
482+
}
346483

347484
if strings.HasSuffix(*obj.Key, "/") {
348485
ctx.Logger().V(5).Info("Skipping directory")
@@ -352,7 +489,7 @@ func (s *Source) pageChunker(
352489
path := strings.Split(*obj.Key, "/")
353490
prefix := strings.Join(path[:len(path)-1], "/")
354491

355-
nErr, ok := errorCount.Load(prefix)
492+
nErr, ok := state.errorCount.Load(prefix)
356493
if !ok {
357494
nErr = 0
358495
}
@@ -366,8 +503,8 @@ func (s *Source) pageChunker(
366503
objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout)
367504
defer cancel()
368505

369-
res, err := client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
370-
Bucket: &bucket,
506+
res, err := metadata.client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
507+
Bucket: &metadata.bucket,
371508
Key: obj.Key,
372509
})
373510
if err != nil {
@@ -382,7 +519,7 @@ func (s *Source) pageChunker(
382519
res.Body.Close()
383520
}
384521

385-
nErr, ok := errorCount.Load(prefix)
522+
nErr, ok := state.errorCount.Load(prefix)
386523
if !ok {
387524
nErr = 0
388525
}
@@ -391,7 +528,7 @@ func (s *Source) pageChunker(
391528
return nil
392529
}
393530
nErr = nErr.(int) + 1
394-
errorCount.Store(prefix, nErr)
531+
state.errorCount.Store(prefix, nErr)
395532
// too many consecutive errors on this page
396533
if nErr.(int) > 3 {
397534
ctx.Logger().V(2).Info("Too many consecutive errors, excluding prefix", "prefix", prefix)
@@ -413,9 +550,9 @@ func (s *Source) pageChunker(
413550
SourceMetadata: &source_metadatapb.MetaData{
414551
Data: &source_metadatapb.MetaData_S3{
415552
S3: &source_metadatapb.S3{
416-
Bucket: bucket,
553+
Bucket: metadata.bucket,
417554
File: sanitizer.UTF8(*obj.Key),
418-
Link: sanitizer.UTF8(makeS3Link(bucket, *client.Config.Region, *obj.Key)),
555+
Link: sanitizer.UTF8(makeS3Link(metadata.bucket, *metadata.client.Config.Region, *obj.Key)),
419556
Email: sanitizer.UTF8(email),
420557
Timestamp: sanitizer.UTF8(modified),
421558
},
@@ -429,14 +566,19 @@ func (s *Source) pageChunker(
429566
return nil
430567
}
431568

432-
atomic.AddUint64(objectCount, 1)
433-
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", objectCount)
434-
nErr, ok = errorCount.Load(prefix)
569+
atomic.AddUint64(state.objectCount, 1)
570+
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount)
571+
nErr, ok = state.errorCount.Load(prefix)
435572
if !ok {
436573
nErr = 0
437574
}
438575
if nErr.(int) > 0 {
439-
errorCount.Store(prefix, 0)
576+
state.errorCount.Store(prefix, 0)
577+
}
578+
579+
// Update progress after successful processing.
580+
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
581+
ctx.Logger().Error(err, "could not update progress for scanned object")
440582
}
441583

442584
return nil
@@ -485,6 +627,9 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
485627
// for each role, passing in the default S3 client, the role ARN, and the list of
486628
// buckets to scan.
487629
//
630+
// The provided function parameter typically implements the core scanning logic
631+
// and must handle context cancellation appropriately.
632+
//
488633
// If no roles are configured, it will call the function with an empty role ARN.
489634
func (s *Source) visitRoles(
490635
ctx context.Context,

0 commit comments

Comments
 (0)