Skip to content

Commit 3ee9c2f

Browse files
authored
[SCAN-165] Use Err Reporting (#3862)
* nit * update rate limit handler to use reporter * update process repos to use rate limit handler with unit reporter * update getReposByOrgOrUser to report err * update dedupreporter to report err * add errReporter interface to handle both types of reporters * convert to error reporter types * update handleRateLimit signature * use reporters for all rate limit handlers in github * get repo url before err check * use iterator * remove err log * nit * pluralize * remove err log * update tests * make linter happy
1 parent 69b6d01 commit 3ee9c2f

File tree

5 files changed

+83
-40
lines changed

5 files changed

+83
-40
lines changed

pkg/sources/github/connector_token.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ type tokenConnector struct {
1717
apiClient *github.Client
1818
token string
1919
isGitHubEnterprise bool
20-
handleRateLimit func(context.Context, error) bool
20+
handleRateLimit func(context.Context, error, ...errorReporter) bool
2121
user string
2222
userMu sync.Mutex
2323
}
2424

2525
var _ connector = (*tokenConnector)(nil)
2626

27-
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error) bool) (*tokenConnector, error) {
27+
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error, ...errorReporter) bool) (*tokenConnector, error) {
2828
const httpTimeoutSeconds = 60
2929
httpClient := common.RetryableHTTPClientTimeout(int64(httpTimeoutSeconds))
3030
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})

pkg/sources/github/github.go

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
381381
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
382382
for _, name := range s.filteredRepoCache.Keys() {
383383
url, _ := s.filteredRepoCache.Get(name)
384-
url, err := s.ensureRepoInfoCache(ctx, url)
384+
url, err := s.ensureRepoInfoCache(ctx, url, &unitErrorReporter{reporter})
385385
if err != nil {
386386
if err := dedupeReporter.UnitErr(ctx, err); err != nil {
387387
return err
@@ -417,9 +417,10 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
417417
for _, repo := range s.filteredRepoCache.Values() {
418418
ctx := context.WithValue(ctx, "repo", repo)
419419

420-
repo, err := s.ensureRepoInfoCache(ctx, repo)
420+
repo, err := s.ensureRepoInfoCache(ctx, repo, &unitErrorReporter{reporter})
421421
if err != nil {
422422
ctx.Logger().Error(err, "error caching repo info")
423+
_ = dedupeReporter.UnitErr(ctx, fmt.Errorf("error caching repo info: %w", err))
423424
}
424425
s.repos = append(s.repos, repo)
425426
}
@@ -434,7 +435,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
434435
// provided repository URL. If not, it fetches and stores the metadata for the
435436
// repository. In some cases, the gist URL needs to be normalized, which is
436437
// returned by this function.
437-
func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string, error) {
438+
func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string, reporter errorReporter) (string, error) {
438439
if _, ok := s.repoInfoCache.get(repo); ok {
439440
return repo, nil
440441
}
@@ -453,20 +454,23 @@ func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string,
453454
// Normalize the URL to the Gist's pull URL.
454455
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
455456
repo = gist.GetGitPullURL()
456-
if s.handleRateLimit(ctx, err) {
457+
458+
if s.handleRateLimit(ctx, err, reporter) {
457459
continue
458460
}
461+
459462
if err != nil {
460463
return repo, fmt.Errorf("failed to fetch gist")
461464
}
465+
462466
s.cacheGistInfo(gist)
463467
break
464468
}
465469
} else {
466470
// Cache repository info.
467471
for {
468472
ghRepo, _, err := s.connector.APIClient().Repositories.Get(ctx, urlParts[1], urlParts[2])
469-
if s.handleRateLimit(ctx, err) {
473+
if s.handleRateLimit(ctx, err, reporter) {
470474
continue
471475
}
472476
if err != nil {
@@ -491,7 +495,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context, reporter sources.UnitRe
491495
// TODO: This modifies s.memberCache but it doesn't look like
492496
// we do anything with it.
493497
if userType == organization && s.conn.ScanUsers {
494-
if err := s.addMembersByOrg(ctx, org); err != nil {
498+
if err := s.addMembersByOrg(ctx, org, reporter); err != nil {
495499
orgCtx.Logger().Error(err, "Unable to add members by org")
496500
}
497501
}
@@ -526,7 +530,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
526530
var err error
527531
for {
528532
ghUser, _, err = s.connector.APIClient().Users.Get(ctx, "")
529-
if s.handleRateLimit(ctx, err) {
533+
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
530534
continue
531535
}
532536
if err != nil {
@@ -546,11 +550,11 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
546550
}
547551

548552
if isGithubEnterprise {
549-
s.addAllVisibleOrgs(ctx)
553+
s.addAllVisibleOrgs(ctx, reporter)
550554
} else {
551555
// Scan for orgs is default with a token.
552556
// GitHub App enumerates the repos that were assigned to it in GitHub App settings.
553-
s.addOrgsByUser(ctx, ghUser.GetLogin())
557+
s.addOrgsByUser(ctx, ghUser.GetLogin(), reporter)
554558
}
555559
}
556560

@@ -564,7 +568,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
564568
}
565569

566570
if userType == organization && s.conn.ScanUsers {
567-
if err := s.addMembersByOrg(ctx, org); err != nil {
571+
if err := s.addMembersByOrg(ctx, org, reporter); err != nil {
568572
orgCtx.Logger().Error(err, "Unable to add members for org")
569573
}
570574
}
@@ -588,7 +592,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu
588592

589593
// Check if we need to find user repos.
590594
if s.conn.ScanUsers {
591-
err := s.addMembersByApp(ctx, installationClient)
595+
err := s.addMembersByApp(ctx, installationClient, reporter)
592596
if err != nil {
593597
return err
594598
}
@@ -739,13 +743,37 @@ var (
739743
rateLimitResumeTime time.Time
740744
)
741745

742-
// handleRateLimit returns true if a rate limit was handled
746+
// errorReporter is an interface that captures just the error reporting functionality
747+
type errorReporter interface {
748+
Err(ctx context.Context, err error) error
749+
}
750+
751+
// wrapper to adapt UnitReporter to errorReporter
752+
type unitErrorReporter struct {
753+
reporter sources.UnitReporter
754+
}
755+
756+
func (u unitErrorReporter) Err(ctx context.Context, err error) error {
757+
return u.reporter.UnitErr(ctx, err)
758+
}
759+
760+
// wrapper to adapt ChunkReporter to errorReporter
761+
type chunkErrorReporter struct {
762+
reporter sources.ChunkReporter
763+
}
764+
765+
func (c chunkErrorReporter) Err(ctx context.Context, err error) error {
766+
return c.reporter.ChunkErr(ctx, err)
767+
}
768+
769+
// handleRateLimit handles GitHub API rate limiting with an optional error reporter.
770+
// Returns true if a rate limit was handled.
743771
//
744772
// Unauthenticated users have a rate limit of 60 requests per hour.
745773
// Authenticated users have a rate limit of 5,000 requests per hour,
746774
// however, certain actions are subject to a stricter "secondary" limit.
747775
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
748-
func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
776+
func (s *Source) handleRateLimit(ctx context.Context, errIn error, reporters ...errorReporter) bool {
749777
if errIn == nil {
750778
return false
751779
}
@@ -757,7 +785,6 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
757785
var retryAfter time.Duration
758786
if resumeTime.IsZero() || time.Now().After(resumeTime) {
759787
rateLimitMu.Lock()
760-
761788
var (
762789
now = time.Now()
763790

@@ -785,6 +812,10 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
785812
retryAfter = retryAfter + jitter
786813
rateLimitResumeTime = now.Add(retryAfter)
787814
ctx.Logger().Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
815+
// Only report the error if a reporter was provided
816+
for _, reporter := range reporters {
817+
_ = reporter.Err(ctx, fmt.Errorf("exceeded %s rate limit", limitType))
818+
}
788819
} else {
789820
retryAfter = (5 * time.Minute) + jitter
790821
rateLimitResumeTime = now.Add(retryAfter)
@@ -803,6 +834,16 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
803834
return true
804835
}
805836

837+
// handleRateLimitWithUnitReporter is a wrapper around handleRateLimit that includes unit reporting
838+
func (s *Source) handleRateLimitWithUnitReporter(ctx context.Context, reporter sources.UnitReporter, errIn error) bool {
839+
return s.handleRateLimit(ctx, errIn, &unitErrorReporter{reporter: reporter})
840+
}
841+
842+
// handleRateLimitWithChunkReporter is a wrapper around handleRateLimit that includes chunk reporting
843+
func (s *Source) handleRateLimitWithChunkReporter(ctx context.Context, reporter sources.ChunkReporter, errIn error) bool {
844+
return s.handleRateLimit(ctx, errIn, &chunkErrorReporter{reporter: reporter})
845+
}
846+
806847
func (s *Source) addReposForMembers(ctx context.Context, reporter sources.UnitReporter) {
807848
ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache))
808849
for member := range s.memberCache {
@@ -823,7 +864,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter
823864

824865
for {
825866
gists, res, err := s.connector.APIClient().Gists.List(ctx, user, gistOpts)
826-
if s.handleRateLimit(ctx, err) {
867+
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
827868
continue
828869
}
829870
if err != nil {
@@ -847,7 +888,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter
847888
return nil
848889
}
849890

850-
func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client) error {
891+
func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client, reporter sources.UnitReporter) error {
851892
opts := &github.ListOptions{
852893
PerPage: membersAppPagination,
853894
}
@@ -862,15 +903,15 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
862903
if org.Account.GetType() != "Organization" {
863904
continue
864905
}
865-
if err := s.addMembersByOrg(ctx, *org.Account.Login); err != nil {
906+
if err := s.addMembersByOrg(ctx, *org.Account.Login, reporter); err != nil {
866907
return err
867908
}
868909
}
869910

870911
return nil
871912
}
872913

873-
func (s *Source) addAllVisibleOrgs(ctx context.Context) {
914+
func (s *Source) addAllVisibleOrgs(ctx context.Context, reporter sources.UnitReporter) {
874915
ctx.Logger().V(2).Info("enumerating all visible organizations on GHE")
875916
// Enumeration on this endpoint does not use pages it uses a since ID.
876917
// The endpoint will return organizations with an ID greater than the given since ID.
@@ -883,7 +924,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
883924
}
884925
for {
885926
orgs, _, err := s.connector.APIClient().Organizations.ListAll(ctx, orgOpts)
886-
if s.handleRateLimit(ctx, err) {
927+
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
887928
continue
888929
}
889930
if err != nil {
@@ -915,14 +956,14 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
915956
}
916957
}
917958

918-
func (s *Source) addOrgsByUser(ctx context.Context, user string) {
959+
func (s *Source) addOrgsByUser(ctx context.Context, user string, reporter sources.UnitReporter) {
919960
orgOpts := &github.ListOptions{
920961
PerPage: defaultPagination,
921962
}
922963
logger := ctx.Logger().WithValues("user", user)
923964
for {
924965
orgs, resp, err := s.connector.APIClient().Organizations.List(ctx, "", orgOpts)
925-
if s.handleRateLimit(ctx, err) {
966+
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
926967
continue
927968
}
928969
if err != nil {
@@ -944,7 +985,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
944985
}
945986
}
946987

947-
func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
988+
func (s *Source) addMembersByOrg(ctx context.Context, org string, reporter sources.UnitReporter) error {
948989
opts := &github.ListMembersOptions{
949990
PublicOnly: false,
950991
ListOptions: github.ListOptions{
@@ -955,7 +996,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
955996
logger := ctx.Logger().WithValues("org", org)
956997
for {
957998
members, res, err := s.connector.APIClient().Organizations.ListMembers(ctx, org, opts)
958-
if s.handleRateLimit(ctx, err) {
999+
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
9591000
continue
9601001
}
9611002
if err != nil {
@@ -1087,7 +1128,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
10871128
}
10881129
for {
10891130
comments, _, err := s.connector.APIClient().Gists.ListComments(ctx, gistID, options)
1090-
if s.handleRateLimit(ctx, err) {
1131+
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
10911132
continue
10921133
}
10931134
if err != nil {
@@ -1187,7 +1228,6 @@ func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, rep
11871228
}
11881229

11891230
return nil
1190-
11911231
}
11921232

11931233
func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
@@ -1203,7 +1243,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter
12031243

12041244
for {
12051245
issues, _, err := s.connector.APIClient().Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts)
1206-
if s.handleRateLimit(ctx, err) {
1246+
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
12071247
continue
12081248
}
12091249

@@ -1272,7 +1312,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, re
12721312

12731313
for {
12741314
issueComments, _, err := s.connector.APIClient().Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts)
1275-
if s.handleRateLimit(ctx, err) {
1315+
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
12761316
continue
12771317
}
12781318
if err != nil {
@@ -1340,7 +1380,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, reporter sou
13401380

13411381
for {
13421382
prs, _, err := s.connector.APIClient().PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts)
1343-
if s.handleRateLimit(ctx, err) {
1383+
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
13441384
continue
13451385
}
13461386
if err != nil {
@@ -1372,7 +1412,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, repor
13721412

13731413
for {
13741414
prComments, _, err := s.connector.APIClient().PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts)
1375-
if s.handleRateLimit(ctx, err) {
1415+
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
13761416
continue
13771417
}
13781418
if err != nil {
@@ -1528,7 +1568,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
15281568
ctx = context.WithValue(ctx, "repo", repoURL)
15291569
// ChunkUnit is not guaranteed to be called from Enumerate, so we must
15301570
// check and fetch the repoInfoCache for this repo.
1531-
repoURL, err := s.ensureRepoInfoCache(ctx, repoURL)
1571+
repoURL, err := s.ensureRepoInfoCache(ctx, repoURL, &chunkErrorReporter{reporter: reporter})
15321572
if err != nil {
15331573
return err
15341574
}

pkg/sources/github/github_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ func TestAddMembersByOrg(t *testing.T) {
197197
})
198198

199199
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
200-
err := s.addMembersByOrg(context.Background(), "org1")
200+
err := s.addMembersByOrg(context.Background(), "org1", noopReporter())
201201
assert.Nil(t, err)
202202
assert.Equal(t, 2, len(s.memberCache))
203203
_, ok := s.memberCache["testman1"]
@@ -221,7 +221,7 @@ func TestAddMembersByOrg_AuthFailure(t *testing.T) {
221221
}})
222222

223223
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
224-
err := s.addMembersByOrg(context.Background(), "org1")
224+
err := s.addMembersByOrg(context.Background(), "org1", noopReporter())
225225
assert.True(t, strings.HasPrefix(err.Error(), "could not list organization"))
226226
assert.False(t, gock.HasUnmatchedRequest())
227227
assert.True(t, gock.IsDone())
@@ -236,7 +236,7 @@ func TestAddMembersByOrg_NoMembers(t *testing.T) {
236236
JSON([]map[string]string{})
237237

238238
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
239-
err := s.addMembersByOrg(context.Background(), "org1")
239+
err := s.addMembersByOrg(context.Background(), "org1", noopReporter())
240240

241241
assert.Equal(t, fmt.Sprintf("organization (%q) had 0 members: account may not have access to list organization members", "org1"), err.Error())
242242
assert.False(t, gock.HasUnmatchedRequest())
@@ -276,7 +276,7 @@ func TestAddMembersByApp(t *testing.T) {
276276
AppId: "4141",
277277
},
278278
}})
279-
err := s.addMembersByApp(context.Background(), s.connector.(*appConnector).InstallationClient())
279+
err := s.addMembersByApp(context.Background(), s.connector.(*appConnector).InstallationClient(), noopReporter())
280280
assert.Nil(t, err)
281281
assert.Equal(t, 3, len(s.memberCache))
282282
_, ok := s.memberCache["ssm1"]
@@ -327,7 +327,7 @@ func TestAddOrgsByUser(t *testing.T) {
327327
})
328328

329329
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
330-
s.addOrgsByUser(context.Background(), "super-secret-user")
330+
s.addOrgsByUser(context.Background(), "super-secret-user", noopReporter())
331331
assert.Equal(t, 1, s.orgsCache.Count())
332332
ok := s.orgsCache.Exists("sso2")
333333
assert.True(t, ok)

0 commit comments

Comments
 (0)