Skip to content

Commit c4e2ce8

Browse files
committed
chore: re-use same test logic to avoid duplicated code
1 parent 28a8916 commit c4e2ce8

File tree

2 files changed

+57
-90
lines changed

2 files changed

+57
-90
lines changed

internal/analysis/analysis.go

Lines changed: 55 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -485,24 +485,63 @@ func (a *analysisOrchestrator) host(isHidden bool) string {
485485
return fmt.Sprintf("%s/%s", apiUrl, path)
486486
}
487487

488-
func (a *analysisOrchestrator) RunTest(ctx context.Context, orgId string, b bundle.Bundle, target scan.Target, reportingConfig AnalysisConfig) (*sarif.SarifResponse, error) {
488+
func (a *analysisOrchestrator) createTestAndGetResults(ctx context.Context, orgId string, body *testApi.CreateTestApplicationVndAPIPlusJSONRequestBody, progressString string) (*sarif.SarifResponse, error) {
489489
tracker := a.trackerFactory.GenerateTracker()
490-
tracker.Begin("Snyk Code analysis for "+target.GetPath(), "Retrieving results...")
490+
tracker.Begin(progressString, "Retrieving results...")
491491

492-
orgUuid := uuid.MustParse(orgId)
493-
host := a.host(true)
492+
innerFunction := func() (*sarif.SarifResponse, error) {
493+
params := testApi.CreateTestParams{Version: testApi.ApiVersion}
494+
orgUuid := uuid.MustParse(orgId)
495+
host := a.host(true)
496+
497+
client, err := testApi.NewClient(host, testApi.WithHTTPClient(a.httpClient))
498+
if err != nil {
499+
return nil, err
500+
}
501+
502+
// create test
503+
resp, err := client.CreateTestWithApplicationVndAPIPlusJSONBody(ctx, orgUuid, &params, *body)
504+
if err != nil {
505+
return nil, err
506+
}
507+
508+
parsedResponse, err := testApi.ParseCreateTestResponse(resp)
509+
defer func() {
510+
closeErr := resp.Body.Close()
511+
if closeErr != nil {
512+
a.logger.Err(closeErr).Msg("failed to close response body")
513+
}
514+
}()
515+
if err != nil {
516+
a.logger.Debug().Msg(err.Error())
517+
return nil, err
518+
}
519+
520+
switch parsedResponse.StatusCode() {
521+
case http.StatusCreated:
522+
// poll results
523+
return a.pollTestForFindings(ctx, client, orgUuid, parsedResponse.ApplicationvndApiJSON201.Data.Id)
524+
}
525+
return nil, nil
526+
}
527+
528+
result, err := innerFunction()
529+
if err != nil {
530+
tracker.End("Analysis failed.")
531+
} else {
532+
tracker.End("Analysis completed.")
533+
}
534+
535+
return result, err
536+
}
537+
538+
func (a *analysisOrchestrator) RunTest(ctx context.Context, orgId string, b bundle.Bundle, target scan.Target, reportingConfig AnalysisConfig) (*sarif.SarifResponse, error) {
494539
var repoUrl *string = nil
495540
if repoTarget, ok := target.(*scan.RepositoryTarget); ok {
496541
tmp := repoTarget.GetRepositoryUrl()
497542
repoUrl = &tmp
498543
}
499544

500-
client, err := testApi.NewClient(host, testApi.WithHTTPClient(a.httpClient))
501-
if err != nil {
502-
return nil, err
503-
}
504-
505-
params := testApi.CreateTestParams{Version: testApi.ApiVersion}
506545
body := testApi.NewCreateTestApplicationBody(
507546
testApi.WithInputBundle(b.GetBundleHash(), target.GetPath(), repoUrl, b.GetLimitToFiles()),
508547
testApi.WithScanType(a.testType),
@@ -511,92 +550,23 @@ func (a *analysisOrchestrator) RunTest(ctx context.Context, orgId string, b bund
511550
testApi.WithReporting(&reportingConfig.Report),
512551
)
513552

514-
// create test
515-
resp, err := client.CreateTestWithApplicationVndAPIPlusJSONBody(ctx, orgUuid, &params, *body)
516-
if err != nil {
517-
return nil, err
518-
}
519-
520-
parsedResponse, err := testApi.ParseCreateTestResponse(resp)
521-
defer func() {
522-
closeErr := resp.Body.Close()
523-
if closeErr != nil {
524-
a.logger.Err(closeErr).Msg("failed to close response body")
525-
}
526-
}()
527-
if err != nil {
528-
a.logger.Debug().Msg(err.Error())
529-
return nil, err
530-
}
531-
532-
switch parsedResponse.StatusCode() {
533-
case http.StatusCreated:
534-
// poll results
535-
result, pollErr := a.pollTestForFindings(ctx, client, orgUuid, parsedResponse.ApplicationvndApiJSON201.Data.Id)
536-
tracker.End("Analysis complete.")
537-
return result, pollErr
538-
default:
539-
return nil, fmt.Errorf("failed to create test: %s", parsedResponse.Status())
540-
}
553+
return a.createTestAndGetResults(ctx, orgId, body, "Snyk Code analysis for "+target.GetPath())
541554
}
542555

543556
func (a *analysisOrchestrator) RunTestRemote(ctx context.Context, orgId string, interactionId string, cfg AnalysisConfig) (*sarif.SarifResponse, error) {
544-
tracker := a.trackerFactory.GenerateTracker()
545-
tracker.Begin("Snyk Code analysis for remote project", "Retrieving results...")
546-
547-
orgUuid := uuid.MustParse(orgId)
548-
host := a.host(true)
549-
550-
client, err := testApi.NewClient(host, testApi.WithHTTPClient(a.httpClient))
551-
if err != nil {
552-
return nil, err
553-
}
554-
555-
params := testApi.CreateTestParams{Version: testApi.ApiVersion}
556-
projectId := cfg.ProjectId
557-
commitId := cfg.CommitId
558-
559-
if projectId == nil || commitId == nil {
557+
if cfg.ProjectId == nil || cfg.CommitId == nil {
560558
return nil, errors.New("projectId and commitId are required")
561559
}
562-
legacyScmProject := testApi.NewTestInputLegacyScmProject(*projectId, *commitId)
560+
561+
legacyScmProject := testApi.NewTestInputLegacyScmProject(*cfg.ProjectId, *cfg.CommitId)
563562
body := testApi.NewCreateTestApplicationBody(
564563
testApi.WithInputLegacyScmProject(legacyScmProject),
565564
testApi.WithReporting(&cfg.Report),
566565
testApi.WithScanType(a.testType),
567-
testApi.WithProjectId(*projectId),
566+
testApi.WithProjectId(*cfg.ProjectId),
568567
)
569-
// create test
570-
bodyBytes, err := json.Marshal(body)
571-
if err != nil {
572-
return nil, err
573-
}
574-
resp, err := client.CreateTestWithBody(ctx, orgUuid, &params, "application/json", strings.NewReader(string(bodyBytes)))
575-
if err != nil {
576-
return nil, err
577-
}
578-
579-
parsedResponse, err := testApi.ParseGetTestResultResponse(resp)
580-
defer func() {
581-
closeErr := resp.Body.Close()
582-
if closeErr != nil {
583-
a.logger.Err(closeErr).Msg("failed to close response body")
584-
}
585-
}()
586-
if err != nil {
587-
a.logger.Debug().Msg(err.Error())
588-
return nil, err
589-
}
590568

591-
switch parsedResponse.StatusCode() {
592-
case http.StatusCreated:
593-
// poll results
594-
result, pollErr := a.pollTestForFindings(ctx, client, orgUuid, parsedResponse.ApplicationvndApiJSON200.Data.Id)
595-
tracker.End("Analysis complete.")
596-
return result, pollErr
597-
default:
598-
return nil, fmt.Errorf("failed to analyze project: %s", parsedResponse.Status())
599-
}
569+
return a.createTestAndGetResults(ctx, orgId, body, "Snyk Code analysis for remote project")
600570
}
601571

602572
func (a *analysisOrchestrator) pollTestForFindings(ctx context.Context, client *testApi.Client, org uuid.UUID, testId openapi_types.UUID) (*sarif.SarifResponse, error) {

internal/analysis/analysis_test.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func setup(t *testing.T, timeout *time.Duration) (*confMocks.MockConfig, *httpmo
6868
mockErrorReporter := mocks.NewMockErrorReporter(ctrl)
6969
mockTracker := trackerMocks.NewMockTracker(ctrl)
7070
mockTrackerFactory := trackerMocks.NewMockTrackerFactory(ctrl)
71-
mockTrackerFactory.EXPECT().GenerateTracker().Return(mockTracker)
71+
mockTrackerFactory.EXPECT().GenerateTracker().Return(mockTracker).AnyTimes()
7272

7373
logger := zerolog.Nop()
7474
return mockConfig, mockHTTPClient, mockInstrumentor, mockErrorReporter, mockTracker, mockTrackerFactory, logger
@@ -796,10 +796,7 @@ func TestAnalysis_RunTestRemote(t *testing.T) {
796796
}
797797

798798
func TestAnalysis_RunTestRemote_MissingRequiredParams(t *testing.T) {
799-
mockConfig, mockHTTPClient, mockInstrumentor, mockErrorReporter, mockTracker, mockTrackerFactory, logger := setup(t, nil)
800-
mockTrackerFactory.EXPECT().GenerateTracker().Return(mockTracker).AnyTimes()
801-
802-
mockTracker.EXPECT().Begin(gomock.Eq("Snyk Code analysis for remote project"), gomock.Eq("Retrieving results...")).Return().AnyTimes()
799+
mockConfig, mockHTTPClient, mockInstrumentor, mockErrorReporter, _, mockTrackerFactory, logger := setup(t, nil)
803800
mockHTTPClient.EXPECT().Do(gomock.Any()).Times(0)
804801

805802
analysisOrchestrator := analysis.NewAnalysisOrchestrator(

0 commit comments

Comments
 (0)