Skip to content

Commit df179f8

Browse files
authored
Merge pull request #40 from snyk/feat(reachability)/pass-local-policy-to-test
feat(reachability): pass local policy to test
2 parents 9aa2ea7 + ddbf9f5 commit df179f8

File tree

5 files changed

+94
-59
lines changed

5 files changed

+94
-59
lines changed

internal/commands/ostest/depgraph_flow.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ func RunUnifiedTestFlow(
2525
ctx context.Context,
2626
ictx workflow.InvocationContext,
2727
testClient testapi.TestClient,
28-
riskScoreThreshold *uint16,
29-
severityThreshold *testapi.Severity,
3028
orgID string,
3129
orgSlugOrID string,
3230
errFactory *errors.ErrorFactory,
3331
logger *zerolog.Logger,
32+
localPolicy *testapi.LocalPolicy,
3433
) ([]workflow.Data, error) {
3534
logger.Info().Msg("Starting open source test")
3635

@@ -40,8 +39,6 @@ func RunUnifiedTestFlow(
4039
return nil, err
4140
}
4241

43-
localPolicy := createLocalPolicy(riskScoreThreshold, severityThreshold)
44-
4542
allLegacyFindings, allOutputData, err := testAllDepGraphs(
4643
ctx,
4744
ictx,
@@ -230,18 +227,6 @@ func prepareJSONOutput(
230227
return bytes.TrimRight(buffer.Bytes(), "\n"), nil
231228
}
232229

233-
// Create local policy only if risk score or severity threshold are specified.
234-
func createLocalPolicy(riskScoreThreshold *uint16, severityThreshold *testapi.Severity) *testapi.LocalPolicy {
235-
if riskScoreThreshold == nil && severityThreshold == nil {
236-
return nil
237-
}
238-
239-
return &testapi.LocalPolicy{
240-
RiskScoreThreshold: riskScoreThreshold,
241-
SeverityThreshold: severityThreshold,
242-
}
243-
}
244-
245230
// createDepGraphs creates depgraphs from the file parameter in the context.
246231
func createDepGraphs(ictx workflow.InvocationContext) ([]*testapi.IoSnykApiV1testdepgraphRequestDepGraph, []string, error) {
247232
depGraphResult, err := service.GetDepGraph(ictx)

internal/commands/ostest/sbom_reachability_flow.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func RunSbomReachabilityFlow(
2626
bsClient bundlestore.Client,
2727
orgID string,
2828
orgSlugOrID string,
29+
localPolicy *testapi.LocalPolicy,
2930
) ([]workflow.Data, error) {
3031
if sourceCodePath == "" {
3132
sourceCodePath = "."
@@ -68,7 +69,7 @@ func RunSbomReachabilityFlow(
6869
return nil, fmt.Errorf("failed to create sbom test reachability subject: %w", err)
6970
}
7071

71-
findings, summary, err := RunTest(ctx, ictx, testClient, subject, "", "", int(0), sbomPath, orgID, orgSlugOrID, errFactory, logger, nil)
72+
findings, summary, err := RunTest(ctx, ictx, testClient, subject, "", "", int(0), sbomPath, orgID, orgSlugOrID, errFactory, logger, localPolicy)
7273
if err != nil {
7374
return nil, err
7475
}

internal/commands/ostest/sbom_reachability_flow_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func Test_RunSbomReachabilityFlow_JSON(t *testing.T) {
4242
mockIctx, mockTestClient, mockBsClient, orgID, orgSlug, sbomPath, sourceCodePath := setupTest(ctx, t, ctrl, true)
4343

4444
// This should now succeed with proper finding data
45-
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, orgSlug)
45+
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, orgSlug, nil)
4646

4747
require.NoError(t, err)
4848
require.NotNil(t, result)
@@ -67,7 +67,7 @@ func Test_RunSbomReachabilityFlow_HumanReadable(t *testing.T) {
6767
mockIctx, mockTestClient, mockBsClient, orgID, orgSlug, sbomPath, sourceCodePath := setupTest(ctx, t, ctrl, false)
6868

6969
// This should now succeed with proper finding data
70-
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, orgSlug)
70+
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, orgSlug, nil)
7171

7272
require.NoError(t, err)
7373
require.NotNil(t, result)

internal/commands/ostest/workflow.go

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ func setupSBOMReachabilityFlow(
8282
errFactory *errors.ErrorFactory,
8383
logger *zerolog.Logger,
8484
sbom, sourceDir string,
85+
localPolicy *testapi.LocalPolicy,
8586
) ([]workflow.Data, error) {
8687
config := ictx.GetConfiguration()
8788

@@ -105,55 +106,38 @@ func setupSBOMReachabilityFlow(
105106
)
106107

107108
bsClient := bundlestore.NewClient(ictx.GetNetworkAccess().GetHttpClient(), codeScannerConfig, cScanner, logger)
108-
return RunSbomReachabilityFlow(ctx, ictx, testClient, errFactory, logger, sbom, sourceDir, bsClient, orgID, orgSlugOrID)
109+
return RunSbomReachabilityFlow(ctx, ictx, testClient, errFactory, logger, sbom, sourceDir, bsClient, orgID, orgSlugOrID, localPolicy)
109110
}
110111

111-
// setupDefaultTestFlow sets up and runs the default test flow with risk score and severity thresholds.
112-
func setupDefaultTestFlow(
113-
ctx context.Context,
114-
ictx workflow.InvocationContext,
115-
testClient testapi.TestClient,
116-
orgID string,
117-
orgSlugOrID string,
118-
errFactory *errors.ErrorFactory,
119-
logger *zerolog.Logger,
120-
riskScoreThreshold int,
121-
) ([]workflow.Data, error) {
122-
config := ictx.GetConfiguration()
123-
124-
// Risk Score FFs
125-
ffRiskScore := config.GetBool(FeatureFlagRiskScore)
126-
ffRiskScoreInCLI := config.GetBool(FeatureFlagRiskScoreInCLI)
127-
riskScoreFFsEnabled := ffRiskScore && ffRiskScoreInCLI
128-
129-
if riskScoreThreshold != -1 && !riskScoreFFsEnabled {
130-
// The user tried to use a risk score threshold without the required feature flags.
131-
// Return a specific error for the first missing flag found.
132-
if !ffRiskScore {
133-
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScore)
134-
}
135-
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScoreInCLI)
136-
}
137-
138-
var riskScorePtr *uint16
139-
if riskScoreThreshold >= math.MaxUint16 {
112+
// CreateLocalPolicy will create a local policy only if risk score or severity threshold are specified in the config.
113+
func CreateLocalPolicy(config configuration.Configuration, logger *zerolog.Logger) *testapi.LocalPolicy {
114+
var riskScoreThreshold *uint16
115+
riskScoreThresholdInt := config.GetInt(flags.FlagRiskScoreThreshold)
116+
if riskScoreThresholdInt >= math.MaxUint16 {
140117
// the API will enforce a range from the test spec
141-
logger.Warn().Msgf("Risk score threshold %d exceeds maximum uint16 value. Setting to maximum.", riskScoreThreshold)
118+
logger.Warn().Msgf("Risk score threshold %d exceeds maximum uint16 value. Setting to maximum.", riskScoreThresholdInt)
142119
maxVal := uint16(math.MaxUint16)
143-
riskScorePtr = &maxVal
144-
} else if riskScoreThreshold >= 0 {
145-
rs := uint16(riskScoreThreshold)
146-
riskScorePtr = &rs
120+
riskScoreThreshold = &maxVal
121+
} else if riskScoreThresholdInt >= 0 {
122+
rs := uint16(riskScoreThresholdInt)
123+
riskScoreThreshold = &rs
147124
}
148125

149-
var severityThresholdPtr *testapi.Severity
126+
var severityThreshold *testapi.Severity
150127
severityThresholdStr := config.GetString(flags.FlagSeverityThreshold)
151128
if severityThresholdStr != "" {
152129
st := testapi.Severity(severityThresholdStr)
153-
severityThresholdPtr = &st
130+
severityThreshold = &st
154131
}
155132

156-
return RunUnifiedTestFlow(ctx, ictx, testClient, riskScorePtr, severityThresholdPtr, orgID, orgSlugOrID, errFactory, logger)
133+
if riskScoreThreshold == nil && severityThreshold == nil {
134+
return nil
135+
}
136+
137+
return &testapi.LocalPolicy{
138+
RiskScoreThreshold: riskScoreThreshold,
139+
SeverityThreshold: severityThreshold,
140+
}
157141
}
158142

159143
// OSWorkflow is the entry point for the Open Source Test workflow.
@@ -203,6 +187,17 @@ func OSWorkflow(
203187
orgSlugOrID = orgID
204188
}
205189

190+
if riskScoreThreshold != -1 && !riskScoreFFsEnabled {
191+
// The user tried to use a risk score threshold without the required feature flags.
192+
// Return a specific error for the first missing flag found.
193+
if !ffRiskScore {
194+
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScore)
195+
}
196+
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScoreInCLI)
197+
}
198+
199+
localPolicy := CreateLocalPolicy(config, logger)
200+
206201
// Create Snyk client
207202
httpClient := ictx.GetNetworkAccess().GetHttpClient()
208203
snykClient := snykclient.NewSnykClient(httpClient, ictx.GetConfiguration().GetString(configuration.API_URL), orgID)
@@ -220,8 +215,8 @@ func OSWorkflow(
220215
// Route to the appropriate flow based on flags
221216
switch {
222217
case sbomReachabilityTest:
223-
return setupSBOMReachabilityFlow(ctx, ictx, testClient, orgID, orgSlugOrID, errFactory, logger, sbom, sourceDir)
218+
return setupSBOMReachabilityFlow(ctx, ictx, testClient, orgID, orgSlugOrID, errFactory, logger, sbom, sourceDir, localPolicy)
224219
default:
225-
return setupDefaultTestFlow(ctx, ictx, testClient, orgID, orgSlugOrID, errFactory, logger, riskScoreThreshold)
220+
return RunUnifiedTestFlow(ctx, ictx, testClient, orgID, orgSlugOrID, errFactory, logger, localPolicy)
226221
}
227222
}

internal/commands/ostest/workflow_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ostest_test
33
import (
44
"encoding/json"
55
"fmt"
6+
"math"
67
"net/http"
78
"net/http/httptest"
89
"strings"
@@ -30,6 +31,59 @@ import (
3031

3132
var legacyWorkflowID = workflow.NewWorkflowIdentifier("legacycli")
3233

34+
var logger = zerolog.Nop()
35+
36+
func TestOSWorkflow_CreateLocalPolicy(t *testing.T) {
37+
// Setup - No special flags set
38+
ctrl := gomock.NewController(t)
39+
defer ctrl.Finish()
40+
41+
mockEngine := mocks.NewMockEngine(ctrl)
42+
mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "")
43+
mockConfig := mockInvocationCtx.GetConfiguration()
44+
mockConfig.Set(flags.FlagRiskScoreThreshold, 100)
45+
mockConfig.Set(flags.FlagSeverityThreshold, "high")
46+
47+
localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger)
48+
require.NotNil(t, localPolicy)
49+
50+
require.NotNil(t, localPolicy.RiskScoreThreshold)
51+
assert.Equal(t, uint16(100), *localPolicy.RiskScoreThreshold)
52+
require.NotNil(t, localPolicy.SeverityThreshold)
53+
assert.Equal(t, "high", string(*localPolicy.SeverityThreshold))
54+
}
55+
56+
func TestOSWorkflow_CreateLocalPolicy_NoValues(t *testing.T) {
57+
// Setup - No special flags set
58+
ctrl := gomock.NewController(t)
59+
defer ctrl.Finish()
60+
61+
mockEngine := mocks.NewMockEngine(ctrl)
62+
mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "")
63+
mockConfig := mockInvocationCtx.GetConfiguration()
64+
65+
localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger)
66+
67+
assert.Nil(t, localPolicy)
68+
}
69+
70+
func TestOSWorkflow_CreateLocalPolicy_RiskScoreOverflow(t *testing.T) {
71+
// Setup - No special flags set
72+
ctrl := gomock.NewController(t)
73+
defer ctrl.Finish()
74+
75+
mockEngine := mocks.NewMockEngine(ctrl)
76+
mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "")
77+
mockConfig := mockInvocationCtx.GetConfiguration()
78+
mockConfig.Set(flags.FlagRiskScoreThreshold, math.MaxUint16+10)
79+
80+
localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger)
81+
require.NotNil(t, localPolicy)
82+
83+
assert.NotNil(t, localPolicy.RiskScoreThreshold)
84+
assert.Equal(t, uint16(math.MaxUint16), *localPolicy.RiskScoreThreshold)
85+
}
86+
3387
func TestOSWorkflow_LegacyFlow(t *testing.T) {
3488
// Setup - No special flags set
3589
ctrl := gomock.NewController(t)

0 commit comments

Comments
 (0)