Skip to content

Commit 7b3d98d

Browse files
authored
[feat] - S3 metrics (#3577)
* add config option for s3 resumption * updates * initial progress tracking logic * more testing * revert s3 source file * UpdateScanProgress tests * adjust * updates * invert * updates * updates * fix * update * adjust test * fix * remove progress tracking * cleanup * cleanup * remove dupe * add metrics to s3 scan * make collector a singleton * address comments * fix * remove
1 parent 33879e4 commit 7b3d98d

File tree

3 files changed

+156
-5
lines changed

3 files changed

+156
-5
lines changed

pkg/sources/s3/metrics.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package s3
2+
3+
import (
4+
"github.com/prometheus/client_golang/prometheus"
5+
"github.com/prometheus/client_golang/prometheus/promauto"
6+
7+
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
8+
)
9+
10+
// metricsCollector defines the interface for recording S3 scan metrics.
11+
type metricsCollector interface {
12+
// Object metrics.
13+
14+
RecordObjectScanned(bucket string)
15+
RecordObjectSkipped(bucket, reason string)
16+
RecordObjectError(bucket string)
17+
18+
// Role metrics.
19+
20+
RecordRoleScanned(roleArn string)
21+
RecordBucketForRole(roleArn string)
22+
}
23+
24+
type collector struct {
25+
objectsScanned *prometheus.CounterVec
26+
objectsSkipped *prometheus.CounterVec
27+
objectsErrors *prometheus.CounterVec
28+
rolesScanned *prometheus.GaugeVec
29+
bucketsPerRole *prometheus.GaugeVec
30+
}
31+
32+
var metricsInstance metricsCollector
33+
34+
func init() {
35+
metricsInstance = &collector{
36+
objectsScanned: promauto.NewCounterVec(prometheus.CounterOpts{
37+
Namespace: common.MetricsNamespace,
38+
Subsystem: common.MetricsSubsystem,
39+
Name: "objects_scanned_total",
40+
Help: "Total number of S3 objects successfully scanned",
41+
}, []string{"bucket"}),
42+
43+
objectsSkipped: promauto.NewCounterVec(prometheus.CounterOpts{
44+
Namespace: common.MetricsNamespace,
45+
Subsystem: common.MetricsSubsystem,
46+
Name: "objects_skipped_total",
47+
Help: "Total number of S3 objects skipped during scan",
48+
}, []string{"bucket", "reason"}),
49+
50+
objectsErrors: promauto.NewCounterVec(prometheus.CounterOpts{
51+
Namespace: common.MetricsNamespace,
52+
Subsystem: common.MetricsSubsystem,
53+
Name: "objects_errors_total",
54+
Help: "Total number of errors encountered during S3 scan",
55+
}, []string{"bucket"}),
56+
57+
rolesScanned: promauto.NewGaugeVec(prometheus.GaugeOpts{
58+
Namespace: common.MetricsNamespace,
59+
Subsystem: common.MetricsSubsystem,
60+
Name: "roles_scanned",
61+
Help: "Number of AWS roles being scanned",
62+
}, []string{"role_arn"}),
63+
64+
bucketsPerRole: promauto.NewGaugeVec(prometheus.GaugeOpts{
65+
Namespace: common.MetricsNamespace,
66+
Subsystem: common.MetricsSubsystem,
67+
Name: "buckets_per_role",
68+
Help: "Number of buckets accessible per AWS role",
69+
}, []string{"role_arn"}),
70+
}
71+
}
72+
73+
func (c *collector) RecordObjectScanned(bucket string) {
74+
c.objectsScanned.WithLabelValues(bucket).Inc()
75+
}
76+
77+
func (c *collector) RecordObjectSkipped(bucket, reason string) {
78+
c.objectsSkipped.WithLabelValues(bucket, reason).Inc()
79+
}
80+
81+
func (c *collector) RecordObjectError(bucket string) {
82+
c.objectsErrors.WithLabelValues(bucket).Inc()
83+
}
84+
85+
const defaultRoleARN = "default"
86+
87+
func (c *collector) RecordRoleScanned(roleArn string) {
88+
if roleArn == "" {
89+
roleArn = defaultRoleARN
90+
}
91+
c.rolesScanned.WithLabelValues(roleArn).Set(1)
92+
}
93+
94+
func (c *collector) RecordBucketForRole(roleArn string) {
95+
if roleArn == "" {
96+
roleArn = defaultRoleARN
97+
}
98+
c.bucketsPerRole.WithLabelValues(roleArn).Inc()
99+
}

pkg/sources/s3/s3.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type Source struct {
4848

4949
checkpointer *Checkpointer
5050
sources.Progress
51+
metricsCollector metricsCollector
5152

5253
errorCount *sync.Map
5354
jobPool *errgroup.Group
@@ -94,6 +95,7 @@ func (s *Source) Init(
9495
s.conn = &conn
9596

9697
s.checkpointer = NewCheckpointer(ctx, conn.GetEnableResumption(), &s.Progress)
98+
s.metricsCollector = metricsInstance
9799

98100
s.setMaxObjectSize(conn.GetMaxObjectSize())
99101

@@ -106,11 +108,12 @@ func (s *Source) Init(
106108

107109
func (s *Source) Validate(ctx context.Context) []error {
108110
var errs []error
109-
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
111+
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
110112
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
111113
if len(roleErrs) > 0 {
112114
errs = append(errs, roleErrs...)
113115
}
116+
return nil
114117
}
115118

116119
if err := s.visitRoles(ctx, visitor); err != nil {
@@ -307,6 +310,7 @@ func (s *Source) scanBuckets(
307310

308311
bucketsToScanCount := len(bucketsToScan)
309312
for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ {
313+
s.metricsCollector.RecordBucketForRole(role)
310314
bucket := bucketsToScan[bucketIdx]
311315
ctx := context.WithValue(ctx, "bucket", bucket)
312316

@@ -385,8 +389,9 @@ func (s *Source) scanBuckets(
385389

386390
// Chunks emits chunks of bytes over a channel.
387391
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
388-
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
392+
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
389393
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
394+
return nil
390395
}
391396

392397
return s.visitRoles(ctx, visitor)
@@ -427,6 +432,7 @@ func (s *Source) pageChunker(
427432

428433
for objIdx, obj := range metadata.page.Contents {
429434
if obj == nil {
435+
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "nil_object")
430436
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
431437
ctx.Logger().Error(err, "could not update progress for nil object")
432438
}
@@ -442,6 +448,7 @@ func (s *Source) pageChunker(
442448
// Skip GLACIER and GLACIER_IR objects.
443449
if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") {
444450
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass)
451+
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "storage_class")
445452
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
446453
ctx.Logger().Error(err, "could not update progress for glacier object")
447454
}
@@ -451,6 +458,7 @@ func (s *Source) pageChunker(
451458
// Ignore large files.
452459
if *obj.Size > s.maxObjectSize {
453460
ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)")
461+
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "size_limit")
454462
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
455463
ctx.Logger().Error(err, "could not update progress for large file")
456464
}
@@ -460,6 +468,7 @@ func (s *Source) pageChunker(
460468
// File empty file.
461469
if *obj.Size == 0 {
462470
ctx.Logger().V(5).Info("Skipping empty file")
471+
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "empty_file")
463472
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
464473
ctx.Logger().Error(err, "could not update progress for empty file")
465474
}
@@ -469,6 +478,7 @@ func (s *Source) pageChunker(
469478
// Skip incompatible extensions.
470479
if common.SkipFile(*obj.Key) {
471480
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
481+
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "incompatible_extension")
472482
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
473483
ctx.Logger().Error(err, "could not update progress for incompatible file")
474484
}
@@ -483,6 +493,7 @@ func (s *Source) pageChunker(
483493

484494
if strings.HasSuffix(*obj.Key, "/") {
485495
ctx.Logger().V(5).Info("Skipping directory")
496+
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "directory")
486497
return nil
487498
}
488499

@@ -508,8 +519,12 @@ func (s *Source) pageChunker(
508519
Key: obj.Key,
509520
})
510521
if err != nil {
511-
if !strings.Contains(err.Error(), "AccessDenied") {
522+
if strings.Contains(err.Error(), "AccessDenied") {
523+
ctx.Logger().Error(err, "could not get S3 object; access denied")
524+
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "access_denied")
525+
} else {
512526
ctx.Logger().Error(err, "could not get S3 object")
527+
s.metricsCollector.RecordObjectError(metadata.bucket)
513528
}
514529
// According to the documentation for GetObjectWithContext,
515530
// the response can be non-nil even if there was an error.
@@ -563,6 +578,7 @@ func (s *Source) pageChunker(
563578

564579
if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil {
565580
ctx.Logger().Error(err, "error handling file")
581+
s.metricsCollector.RecordObjectError(metadata.bucket)
566582
return nil
567583
}
568584

@@ -580,6 +596,7 @@ func (s *Source) pageChunker(
580596
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
581597
ctx.Logger().Error(err, "could not update progress for scanned object")
582598
}
599+
s.metricsCollector.RecordObjectScanned(metadata.bucket)
583600

584601
return nil
585602
})
@@ -633,14 +650,16 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
633650
// If no roles are configured, it will call the function with an empty role ARN.
634651
func (s *Source) visitRoles(
635652
ctx context.Context,
636-
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
653+
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error,
637654
) error {
638655
roles := s.conn.GetRoles()
639656
if len(roles) == 0 {
640657
roles = []string{""}
641658
}
642659

643660
for _, role := range roles {
661+
s.metricsCollector.RecordRoleScanned(role)
662+
644663
client, err := s.newClient(defaultAWSRegion, role)
645664
if err != nil {
646665
return fmt.Errorf("could not create s3 client: %w", err)
@@ -651,7 +670,9 @@ func (s *Source) visitRoles(
651670
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
652671
}
653672

654-
f(ctx, client, role, bucketsToScan)
673+
if err := f(ctx, client, role, bucketsToScan); err != nil {
674+
return err
675+
}
655676
}
656677

657678
return nil

pkg/sources/s3/s3_integration_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,37 @@ func TestSource_ChunksLarge(t *testing.T) {
8282
assert.Equal(t, got, wantChunkCount)
8383
}
8484

85+
func TestSourceChunksNoResumption(t *testing.T) {
86+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
87+
defer cancel()
88+
89+
s := Source{}
90+
connection := &sourcespb.S3{
91+
Credential: &sourcespb.S3_Unauthenticated{},
92+
Buckets: []string{"trufflesec-ahrav-test-2"},
93+
}
94+
conn, err := anypb.New(connection)
95+
if err != nil {
96+
t.Fatal(err)
97+
}
98+
99+
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
100+
chunksCh := make(chan *sources.Chunk)
101+
go func() {
102+
defer close(chunksCh)
103+
err = s.Chunks(ctx, chunksCh)
104+
assert.Nil(t, err)
105+
}()
106+
107+
wantChunkCount := 19787
108+
got := 0
109+
110+
for range chunksCh {
111+
got++
112+
}
113+
assert.Equal(t, got, wantChunkCount)
114+
}
115+
85116
func TestSource_Validate(t *testing.T) {
86117
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
87118
defer cancel()

0 commit comments

Comments
 (0)