diff --git a/internal/commands/ostest/depgraph_flow.go b/internal/commands/ostest/depgraph_flow.go index 2522df0..03c8e62 100644 --- a/internal/commands/ostest/depgraph_flow.go +++ b/internal/commands/ostest/depgraph_flow.go @@ -25,11 +25,10 @@ func RunUnifiedTestFlow( ctx context.Context, ictx workflow.InvocationContext, testClient testapi.TestClient, - riskScoreThreshold *uint16, - severityThreshold *testapi.Severity, orgID string, errFactory *errors.ErrorFactory, logger *zerolog.Logger, + localPolicy *testapi.LocalPolicy, ) ([]workflow.Data, error) { logger.Info().Msg("Starting open source test") @@ -39,8 +38,6 @@ func RunUnifiedTestFlow( return nil, err } - localPolicy := createLocalPolicy(riskScoreThreshold, severityThreshold) - allLegacyFindings, allOutputData, err := testAllDepGraphs( ctx, ictx, @@ -227,18 +224,6 @@ func prepareJSONOutput( return bytes.TrimRight(buffer.Bytes(), "\n"), nil } -// Create local policy only if risk score or severity threshold are specified. -func createLocalPolicy(riskScoreThreshold *uint16, severityThreshold *testapi.Severity) *testapi.LocalPolicy { - if riskScoreThreshold == nil && severityThreshold == nil { - return nil - } - - return &testapi.LocalPolicy{ - RiskScoreThreshold: riskScoreThreshold, - SeverityThreshold: severityThreshold, - } -} - // createDepGraphs creates depgraphs from the file parameter in the context. func createDepGraphs(ictx workflow.InvocationContext) ([]*testapi.IoSnykApiV1testdepgraphRequestDepGraph, []string, error) { depGraphResult, err := service.GetDepGraph(ictx) diff --git a/internal/commands/ostest/sbom_reachability_flow.go b/internal/commands/ostest/sbom_reachability_flow.go index 2a99dde..8f8587b 100644 --- a/internal/commands/ostest/sbom_reachability_flow.go +++ b/internal/commands/ostest/sbom_reachability_flow.go @@ -25,6 +25,7 @@ func RunSbomReachabilityFlow( sourceCodePath string, bsClient bundlestore.Client, orgID string, + localPolicy *testapi.LocalPolicy, ) ([]workflow.Data, error) { if sourceCodePath == "" { sourceCodePath = "." @@ -67,7 +68,7 @@ func RunSbomReachabilityFlow( return nil, fmt.Errorf("failed to create sbom test reachability subject: %w", err) } - findings, summary, err := RunTest(ctx, ictx, testClient, subject, "", "", int(0), sbomPath, orgID, errFactory, logger, nil) + findings, summary, err := RunTest(ctx, ictx, testClient, subject, "", "", int(0), sbomPath, orgID, errFactory, logger, localPolicy) if err != nil { return nil, err } diff --git a/internal/commands/ostest/sbom_reachability_flow_test.go b/internal/commands/ostest/sbom_reachability_flow_test.go index dbca452..8b47827 100644 --- a/internal/commands/ostest/sbom_reachability_flow_test.go +++ b/internal/commands/ostest/sbom_reachability_flow_test.go @@ -42,7 +42,7 @@ func Test_RunSbomReachabilityFlow_JSON(t *testing.T) { mockIctx, mockTestClient, mockBsClient, orgID, sbomPath, sourceCodePath := setupTest(ctx, t, ctrl, true) // This should now succeed with proper finding data - result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID) + result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, nil) require.NoError(t, err) require.NotNil(t, result) @@ -67,7 +67,7 @@ func Test_RunSbomReachabilityFlow_HumanReadable(t *testing.T) { mockIctx, mockTestClient, mockBsClient, orgID, sbomPath, sourceCodePath := setupTest(ctx, t, ctrl, false) // This should now succeed with proper finding data - result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID) + result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, nil) require.NoError(t, err) require.NotNil(t, result) diff --git a/internal/commands/ostest/workflow.go b/internal/commands/ostest/workflow.go index d56d2ac..9900282 100644 --- a/internal/commands/ostest/workflow.go +++ b/internal/commands/ostest/workflow.go @@ -81,6 +81,7 @@ func setupSBOMReachabilityFlow( errFactory *errors.ErrorFactory, logger *zerolog.Logger, sbom, sourceDir string, + localPolicy *testapi.LocalPolicy, ) ([]workflow.Data, error) { config := ictx.GetConfiguration() @@ -104,54 +105,38 @@ func setupSBOMReachabilityFlow( ) bsClient := bundlestore.NewClient(ictx.GetNetworkAccess().GetHttpClient(), codeScannerConfig, cScanner, logger) - return RunSbomReachabilityFlow(ctx, ictx, testClient, errFactory, logger, sbom, sourceDir, bsClient, orgID) + return RunSbomReachabilityFlow(ctx, ictx, testClient, errFactory, logger, sbom, sourceDir, bsClient, orgID, localPolicy) } -// setupDefaultTestFlow sets up and runs the default test flow with risk score and severity thresholds. -func setupDefaultTestFlow( - ctx context.Context, - ictx workflow.InvocationContext, - testClient testapi.TestClient, - orgID string, - errFactory *errors.ErrorFactory, - logger *zerolog.Logger, - riskScoreThreshold int, -) ([]workflow.Data, error) { - config := ictx.GetConfiguration() - - // Risk Score FFs - ffRiskScore := config.GetBool(FeatureFlagRiskScore) - ffRiskScoreInCLI := config.GetBool(FeatureFlagRiskScoreInCLI) - riskScoreFFsEnabled := ffRiskScore && ffRiskScoreInCLI - - if riskScoreThreshold != -1 && !riskScoreFFsEnabled { - // The user tried to use a risk score threshold without the required feature flags. - // Return a specific error for the first missing flag found. - if !ffRiskScore { - return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScore) - } - return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScoreInCLI) - } - - var riskScorePtr *uint16 - if riskScoreThreshold >= math.MaxUint16 { +// CreateLocalPolicy will create a local policy only if risk score or severity threshold are specified in the config. +func CreateLocalPolicy(config configuration.Configuration, logger *zerolog.Logger) *testapi.LocalPolicy { + var riskScoreThreshold *uint16 + riskScoreThresholdInt := config.GetInt(flags.FlagRiskScoreThreshold) + if riskScoreThresholdInt >= math.MaxUint16 { // the API will enforce a range from the test spec - logger.Warn().Msgf("Risk score threshold %d exceeds maximum uint16 value. Setting to maximum.", riskScoreThreshold) + logger.Warn().Msgf("Risk score threshold %d exceeds maximum uint16 value. Setting to maximum.", riskScoreThresholdInt) maxVal := uint16(math.MaxUint16) - riskScorePtr = &maxVal - } else if riskScoreThreshold >= 0 { - rs := uint16(riskScoreThreshold) - riskScorePtr = &rs + riskScoreThreshold = &maxVal + } else if riskScoreThresholdInt >= 0 { + rs := uint16(riskScoreThresholdInt) + riskScoreThreshold = &rs } - var severityThresholdPtr *testapi.Severity + var severityThreshold *testapi.Severity severityThresholdStr := config.GetString(flags.FlagSeverityThreshold) if severityThresholdStr != "" { st := testapi.Severity(severityThresholdStr) - severityThresholdPtr = &st + severityThreshold = &st } - return RunUnifiedTestFlow(ctx, ictx, testClient, riskScorePtr, severityThresholdPtr, orgID, errFactory, logger) + if riskScoreThreshold == nil && severityThreshold == nil { + return nil + } + + return &testapi.LocalPolicy{ + RiskScoreThreshold: riskScoreThreshold, + SeverityThreshold: severityThreshold, + } } // OSWorkflow is the entry point for the Open Source Test workflow. @@ -195,6 +180,17 @@ func OSWorkflow( return nil, errFactory.NewEmptyOrgError() } + if riskScoreThreshold != -1 && !riskScoreFFsEnabled { + // The user tried to use a risk score threshold without the required feature flags. + // Return a specific error for the first missing flag found. + if !ffRiskScore { + return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScore) + } + return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScoreInCLI) + } + + localPolicy := CreateLocalPolicy(config, logger) + // Create Snyk client httpClient := ictx.GetNetworkAccess().GetHttpClient() snykClient := snykclient.NewSnykClient(httpClient, ictx.GetConfiguration().GetString(configuration.API_URL), orgID) @@ -212,8 +208,8 @@ func OSWorkflow( // Route to the appropriate flow based on flags switch { case sbomReachabilityTest: - return setupSBOMReachabilityFlow(ctx, ictx, testClient, orgID, errFactory, logger, sbom, sourceDir) + return setupSBOMReachabilityFlow(ctx, ictx, testClient, orgID, errFactory, logger, sbom, sourceDir, localPolicy) default: - return setupDefaultTestFlow(ctx, ictx, testClient, orgID, errFactory, logger, riskScoreThreshold) + return RunUnifiedTestFlow(ctx, ictx, testClient, orgID, errFactory, logger, localPolicy) } } diff --git a/internal/commands/ostest/workflow_test.go b/internal/commands/ostest/workflow_test.go index 4bcf9c5..5dd6336 100644 --- a/internal/commands/ostest/workflow_test.go +++ b/internal/commands/ostest/workflow_test.go @@ -3,6 +3,7 @@ package ostest_test import ( "encoding/json" "fmt" + "math" "net/http" "net/http/httptest" "strings" @@ -30,6 +31,59 @@ import ( var legacyWorkflowID = workflow.NewWorkflowIdentifier("legacycli") +var logger = zerolog.Nop() + +func TestOSWorkflow_CreateLocalPolicy(t *testing.T) { + // Setup - No special flags set + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEngine := mocks.NewMockEngine(ctrl) + mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "") + mockConfig := mockInvocationCtx.GetConfiguration() + mockConfig.Set(flags.FlagRiskScoreThreshold, 100) + mockConfig.Set(flags.FlagSeverityThreshold, "high") + + localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger) + require.NotNil(t, localPolicy) + + require.NotNil(t, localPolicy.RiskScoreThreshold) + assert.Equal(t, uint16(100), *localPolicy.RiskScoreThreshold) + require.NotNil(t, localPolicy.SeverityThreshold) + assert.Equal(t, "high", string(*localPolicy.SeverityThreshold)) +} + +func TestOSWorkflow_CreateLocalPolicy_NoValues(t *testing.T) { + // Setup - No special flags set + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEngine := mocks.NewMockEngine(ctrl) + mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "") + mockConfig := mockInvocationCtx.GetConfiguration() + + localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger) + + assert.Nil(t, localPolicy) +} + +func TestOSWorkflow_CreateLocalPolicy_RiskScoreOverflow(t *testing.T) { + // Setup - No special flags set + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockEngine := mocks.NewMockEngine(ctrl) + mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "") + mockConfig := mockInvocationCtx.GetConfiguration() + mockConfig.Set(flags.FlagRiskScoreThreshold, math.MaxUint16+10) + + localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger) + require.NotNil(t, localPolicy) + + assert.NotNil(t, localPolicy.RiskScoreThreshold) + assert.Equal(t, uint16(math.MaxUint16), *localPolicy.RiskScoreThreshold) +} + func TestOSWorkflow_LegacyFlow(t *testing.T) { // Setup - No special flags set ctrl := gomock.NewController(t)