@@ -381,7 +381,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
381
381
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
382
382
for _ , name := range s .filteredRepoCache .Keys () {
383
383
url , _ := s .filteredRepoCache .Get (name )
384
- url , err := s .ensureRepoInfoCache (ctx , url )
384
+ url , err := s .ensureRepoInfoCache (ctx , url , & unitErrorReporter { reporter } )
385
385
if err != nil {
386
386
if err := dedupeReporter .UnitErr (ctx , err ); err != nil {
387
387
return err
@@ -417,9 +417,10 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
417
417
for _ , repo := range s .filteredRepoCache .Values () {
418
418
ctx := context .WithValue (ctx , "repo" , repo )
419
419
420
- repo , err := s .ensureRepoInfoCache (ctx , repo )
420
+ repo , err := s .ensureRepoInfoCache (ctx , repo , & unitErrorReporter { reporter } )
421
421
if err != nil {
422
422
ctx .Logger ().Error (err , "error caching repo info" )
423
+ _ = dedupeReporter .UnitErr (ctx , fmt .Errorf ("error caching repo info: %w" , err ))
423
424
}
424
425
s .repos = append (s .repos , repo )
425
426
}
@@ -434,7 +435,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
434
435
// provided repository URL. If not, it fetches and stores the metadata for the
435
436
// repository. In some cases, the gist URL needs to be normalized, which is
436
437
// returned by this function.
437
- func (s * Source ) ensureRepoInfoCache (ctx context.Context , repo string ) (string , error ) {
438
+ func (s * Source ) ensureRepoInfoCache (ctx context.Context , repo string , reporter errorReporter ) (string , error ) {
438
439
if _ , ok := s .repoInfoCache .get (repo ); ok {
439
440
return repo , nil
440
441
}
@@ -453,20 +454,23 @@ func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string,
453
454
// Normalize the URL to the Gist's pull URL.
454
455
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
455
456
repo = gist .GetGitPullURL ()
456
- if s .handleRateLimit (ctx , err ) {
457
+
458
+ if s .handleRateLimit (ctx , err , reporter ) {
457
459
continue
458
460
}
461
+
459
462
if err != nil {
460
463
return repo , fmt .Errorf ("failed to fetch gist" )
461
464
}
465
+
462
466
s .cacheGistInfo (gist )
463
467
break
464
468
}
465
469
} else {
466
470
// Cache repository info.
467
471
for {
468
472
ghRepo , _ , err := s .connector .APIClient ().Repositories .Get (ctx , urlParts [1 ], urlParts [2 ])
469
- if s .handleRateLimit (ctx , err ) {
473
+ if s .handleRateLimit (ctx , err , reporter ) {
470
474
continue
471
475
}
472
476
if err != nil {
@@ -491,7 +495,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context, reporter sources.UnitRe
491
495
// TODO: This modifies s.memberCache but it doesn't look like
492
496
// we do anything with it.
493
497
if userType == organization && s .conn .ScanUsers {
494
- if err := s .addMembersByOrg (ctx , org ); err != nil {
498
+ if err := s .addMembersByOrg (ctx , org , reporter ); err != nil {
495
499
orgCtx .Logger ().Error (err , "Unable to add members by org" )
496
500
}
497
501
}
@@ -526,7 +530,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
526
530
var err error
527
531
for {
528
532
ghUser , _ , err = s .connector .APIClient ().Users .Get (ctx , "" )
529
- if s .handleRateLimit (ctx , err ) {
533
+ if s .handleRateLimitWithUnitReporter (ctx , reporter , err ) {
530
534
continue
531
535
}
532
536
if err != nil {
@@ -546,11 +550,11 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
546
550
}
547
551
548
552
if isGithubEnterprise {
549
- s .addAllVisibleOrgs (ctx )
553
+ s .addAllVisibleOrgs (ctx , reporter )
550
554
} else {
551
555
// Scan for orgs is default with a token.
552
556
// GitHub App enumerates the repos that were assigned to it in GitHub App settings.
553
- s .addOrgsByUser (ctx , ghUser .GetLogin ())
557
+ s .addOrgsByUser (ctx , ghUser .GetLogin (), reporter )
554
558
}
555
559
}
556
560
@@ -564,7 +568,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
564
568
}
565
569
566
570
if userType == organization && s .conn .ScanUsers {
567
- if err := s .addMembersByOrg (ctx , org ); err != nil {
571
+ if err := s .addMembersByOrg (ctx , org , reporter ); err != nil {
568
572
orgCtx .Logger ().Error (err , "Unable to add members for org" )
569
573
}
570
574
}
@@ -588,7 +592,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu
588
592
589
593
// Check if we need to find user repos.
590
594
if s .conn .ScanUsers {
591
- err := s .addMembersByApp (ctx , installationClient )
595
+ err := s .addMembersByApp (ctx , installationClient , reporter )
592
596
if err != nil {
593
597
return err
594
598
}
@@ -739,13 +743,37 @@ var (
739
743
rateLimitResumeTime time.Time
740
744
)
741
745
742
- // handleRateLimit returns true if a rate limit was handled
746
+ // errorReporter is an interface that captures just the error reporting functionality
747
+ type errorReporter interface {
748
+ Err (ctx context.Context , err error ) error
749
+ }
750
+
751
+ // wrapper to adapt UnitReporter to errorReporter
752
+ type unitErrorReporter struct {
753
+ reporter sources.UnitReporter
754
+ }
755
+
756
+ func (u unitErrorReporter ) Err (ctx context.Context , err error ) error {
757
+ return u .reporter .UnitErr (ctx , err )
758
+ }
759
+
760
+ // wrapper to adapt ChunkReporter to errorReporter
761
+ type chunkErrorReporter struct {
762
+ reporter sources.ChunkReporter
763
+ }
764
+
765
+ func (c chunkErrorReporter ) Err (ctx context.Context , err error ) error {
766
+ return c .reporter .ChunkErr (ctx , err )
767
+ }
768
+
769
+ // handleRateLimit handles GitHub API rate limiting with an optional error reporter.
770
+ // Returns true if a rate limit was handled.
743
771
//
744
772
// Unauthenticated users have a rate limit of 60 requests per hour.
745
773
// Authenticated users have a rate limit of 5,000 requests per hour,
746
774
// however, certain actions are subject to a stricter "secondary" limit.
747
775
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
748
- func (s * Source ) handleRateLimit (ctx context.Context , errIn error ) bool {
776
+ func (s * Source ) handleRateLimit (ctx context.Context , errIn error , reporters ... errorReporter ) bool {
749
777
if errIn == nil {
750
778
return false
751
779
}
@@ -757,7 +785,6 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
757
785
var retryAfter time.Duration
758
786
if resumeTime .IsZero () || time .Now ().After (resumeTime ) {
759
787
rateLimitMu .Lock ()
760
-
761
788
var (
762
789
now = time .Now ()
763
790
@@ -785,6 +812,10 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
785
812
retryAfter = retryAfter + jitter
786
813
rateLimitResumeTime = now .Add (retryAfter )
787
814
ctx .Logger ().Info (fmt .Sprintf ("exceeded %s rate limit" , limitType ), "retry_after" , retryAfter .String (), "resume_time" , rateLimitResumeTime .Format (time .RFC3339 ))
815
+ // Only report the error if a reporter was provided
816
+ for _ , reporter := range reporters {
817
+ _ = reporter .Err (ctx , fmt .Errorf ("exceeded %s rate limit" , limitType ))
818
+ }
788
819
} else {
789
820
retryAfter = (5 * time .Minute ) + jitter
790
821
rateLimitResumeTime = now .Add (retryAfter )
@@ -803,6 +834,16 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
803
834
return true
804
835
}
805
836
837
+ // handleRateLimitWithUnitReporter is a wrapper around handleRateLimit that includes unit reporting
838
+ func (s * Source ) handleRateLimitWithUnitReporter (ctx context.Context , reporter sources.UnitReporter , errIn error ) bool {
839
+ return s .handleRateLimit (ctx , errIn , & unitErrorReporter {reporter : reporter })
840
+ }
841
+
842
+ // handleRateLimitWithChunkReporter is a wrapper around handleRateLimit that includes chunk reporting
843
+ func (s * Source ) handleRateLimitWithChunkReporter (ctx context.Context , reporter sources.ChunkReporter , errIn error ) bool {
844
+ return s .handleRateLimit (ctx , errIn , & chunkErrorReporter {reporter : reporter })
845
+ }
846
+
806
847
func (s * Source ) addReposForMembers (ctx context.Context , reporter sources.UnitReporter ) {
807
848
ctx .Logger ().Info ("Fetching repos from members" , "members" , len (s .memberCache ))
808
849
for member := range s .memberCache {
@@ -823,7 +864,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter
823
864
824
865
for {
825
866
gists , res , err := s .connector .APIClient ().Gists .List (ctx , user , gistOpts )
826
- if s .handleRateLimit (ctx , err ) {
867
+ if s .handleRateLimitWithUnitReporter (ctx , reporter , err ) {
827
868
continue
828
869
}
829
870
if err != nil {
@@ -847,7 +888,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter
847
888
return nil
848
889
}
849
890
850
- func (s * Source ) addMembersByApp (ctx context.Context , installationClient * github.Client ) error {
891
+ func (s * Source ) addMembersByApp (ctx context.Context , installationClient * github.Client , reporter sources. UnitReporter ) error {
851
892
opts := & github.ListOptions {
852
893
PerPage : membersAppPagination ,
853
894
}
@@ -862,15 +903,15 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
862
903
if org .Account .GetType () != "Organization" {
863
904
continue
864
905
}
865
- if err := s .addMembersByOrg (ctx , * org .Account .Login ); err != nil {
906
+ if err := s .addMembersByOrg (ctx , * org .Account .Login , reporter ); err != nil {
866
907
return err
867
908
}
868
909
}
869
910
870
911
return nil
871
912
}
872
913
873
- func (s * Source ) addAllVisibleOrgs (ctx context.Context ) {
914
+ func (s * Source ) addAllVisibleOrgs (ctx context.Context , reporter sources. UnitReporter ) {
874
915
ctx .Logger ().V (2 ).Info ("enumerating all visible organizations on GHE" )
875
916
// Enumeration on this endpoint does not use pages it uses a since ID.
876
917
// The endpoint will return organizations with an ID greater than the given since ID.
@@ -883,7 +924,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
883
924
}
884
925
for {
885
926
orgs , _ , err := s .connector .APIClient ().Organizations .ListAll (ctx , orgOpts )
886
- if s .handleRateLimit (ctx , err ) {
927
+ if s .handleRateLimitWithUnitReporter (ctx , reporter , err ) {
887
928
continue
888
929
}
889
930
if err != nil {
@@ -915,14 +956,14 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
915
956
}
916
957
}
917
958
918
- func (s * Source ) addOrgsByUser (ctx context.Context , user string ) {
959
+ func (s * Source ) addOrgsByUser (ctx context.Context , user string , reporter sources. UnitReporter ) {
919
960
orgOpts := & github.ListOptions {
920
961
PerPage : defaultPagination ,
921
962
}
922
963
logger := ctx .Logger ().WithValues ("user" , user )
923
964
for {
924
965
orgs , resp , err := s .connector .APIClient ().Organizations .List (ctx , "" , orgOpts )
925
- if s .handleRateLimit (ctx , err ) {
966
+ if s .handleRateLimitWithUnitReporter (ctx , reporter , err ) {
926
967
continue
927
968
}
928
969
if err != nil {
@@ -944,7 +985,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
944
985
}
945
986
}
946
987
947
- func (s * Source ) addMembersByOrg (ctx context.Context , org string ) error {
988
+ func (s * Source ) addMembersByOrg (ctx context.Context , org string , reporter sources. UnitReporter ) error {
948
989
opts := & github.ListMembersOptions {
949
990
PublicOnly : false ,
950
991
ListOptions : github.ListOptions {
@@ -955,7 +996,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
955
996
logger := ctx .Logger ().WithValues ("org" , org )
956
997
for {
957
998
members , res , err := s .connector .APIClient ().Organizations .ListMembers (ctx , org , opts )
958
- if s .handleRateLimit (ctx , err ) {
999
+ if s .handleRateLimitWithUnitReporter (ctx , reporter , err ) {
959
1000
continue
960
1001
}
961
1002
if err != nil {
@@ -1087,7 +1128,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
1087
1128
}
1088
1129
for {
1089
1130
comments , _ , err := s .connector .APIClient ().Gists .ListComments (ctx , gistID , options )
1090
- if s .handleRateLimit (ctx , err ) {
1131
+ if s .handleRateLimitWithChunkReporter (ctx , reporter , err ) {
1091
1132
continue
1092
1133
}
1093
1134
if err != nil {
@@ -1187,7 +1228,6 @@ func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, rep
1187
1228
}
1188
1229
1189
1230
return nil
1190
-
1191
1231
}
1192
1232
1193
1233
func (s * Source ) processIssues (ctx context.Context , repoInfo repoInfo , reporter sources.ChunkReporter ) error {
@@ -1203,7 +1243,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter
1203
1243
1204
1244
for {
1205
1245
issues , _ , err := s .connector .APIClient ().Issues .ListByRepo (ctx , repoInfo .owner , repoInfo .name , bodyTextsOpts )
1206
- if s .handleRateLimit (ctx , err ) {
1246
+ if s .handleRateLimitWithChunkReporter (ctx , reporter , err ) {
1207
1247
continue
1208
1248
}
1209
1249
@@ -1272,7 +1312,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, re
1272
1312
1273
1313
for {
1274
1314
issueComments , _ , err := s .connector .APIClient ().Issues .ListComments (ctx , repoInfo .owner , repoInfo .name , allComments , issueOpts )
1275
- if s .handleRateLimit (ctx , err ) {
1315
+ if s .handleRateLimitWithChunkReporter (ctx , reporter , err ) {
1276
1316
continue
1277
1317
}
1278
1318
if err != nil {
@@ -1340,7 +1380,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, reporter sou
1340
1380
1341
1381
for {
1342
1382
prs , _ , err := s .connector .APIClient ().PullRequests .List (ctx , repoInfo .owner , repoInfo .name , prOpts )
1343
- if s .handleRateLimit (ctx , err ) {
1383
+ if s .handleRateLimitWithChunkReporter (ctx , reporter , err ) {
1344
1384
continue
1345
1385
}
1346
1386
if err != nil {
@@ -1372,7 +1412,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, repor
1372
1412
1373
1413
for {
1374
1414
prComments , _ , err := s .connector .APIClient ().PullRequests .ListComments (ctx , repoInfo .owner , repoInfo .name , allComments , prOpts )
1375
- if s .handleRateLimit (ctx , err ) {
1415
+ if s .handleRateLimitWithChunkReporter (ctx , reporter , err ) {
1376
1416
continue
1377
1417
}
1378
1418
if err != nil {
@@ -1528,7 +1568,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
1528
1568
ctx = context .WithValue (ctx , "repo" , repoURL )
1529
1569
// ChunkUnit is not guaranteed to be called from Enumerate, so we must
1530
1570
// check and fetch the repoInfoCache for this repo.
1531
- repoURL , err := s .ensureRepoInfoCache (ctx , repoURL )
1571
+ repoURL , err := s .ensureRepoInfoCache (ctx , repoURL , & chunkErrorReporter { reporter : reporter } )
1532
1572
if err != nil {
1533
1573
return err
1534
1574
}
0 commit comments