Skip to content

Commit fde4e5d

Browse files
Reapply "Merge pull request #40 from snyk/feat(reachability)/pass-local-policy-to-test"
This reverts commit d96f1f4.
1 parent d4c1727 commit fde4e5d

File tree

5 files changed

+94
-58
lines changed

5 files changed

+94
-58
lines changed

internal/commands/ostest/depgraph_flow.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@ func RunUnifiedTestFlow(
2525
ctx context.Context,
2626
ictx workflow.InvocationContext,
2727
testClient testapi.TestClient,
28-
riskScoreThreshold *uint16,
29-
severityThreshold *testapi.Severity,
3028
orgID string,
3129
errFactory *errors.ErrorFactory,
3230
logger *zerolog.Logger,
31+
localPolicy *testapi.LocalPolicy,
3332
) ([]workflow.Data, error) {
3433
logger.Info().Msg("Starting open source test")
3534

@@ -39,8 +38,6 @@ func RunUnifiedTestFlow(
3938
return nil, err
4039
}
4140

42-
localPolicy := createLocalPolicy(riskScoreThreshold, severityThreshold)
43-
4441
allLegacyFindings, allOutputData, err := testAllDepGraphs(
4542
ctx,
4643
ictx,
@@ -227,18 +224,6 @@ func prepareJSONOutput(
227224
return bytes.TrimRight(buffer.Bytes(), "\n"), nil
228225
}
229226

230-
// Create local policy only if risk score or severity threshold are specified.
231-
func createLocalPolicy(riskScoreThreshold *uint16, severityThreshold *testapi.Severity) *testapi.LocalPolicy {
232-
if riskScoreThreshold == nil && severityThreshold == nil {
233-
return nil
234-
}
235-
236-
return &testapi.LocalPolicy{
237-
RiskScoreThreshold: riskScoreThreshold,
238-
SeverityThreshold: severityThreshold,
239-
}
240-
}
241-
242227
// createDepGraphs creates depgraphs from the file parameter in the context.
243228
func createDepGraphs(ictx workflow.InvocationContext) ([]*testapi.IoSnykApiV1testdepgraphRequestDepGraph, []string, error) {
244229
depGraphResult, err := service.GetDepGraph(ictx)

internal/commands/ostest/sbom_reachability_flow.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func RunSbomReachabilityFlow(
2525
sourceCodePath string,
2626
bsClient bundlestore.Client,
2727
orgID string,
28+
localPolicy *testapi.LocalPolicy,
2829
) ([]workflow.Data, error) {
2930
if sourceCodePath == "" {
3031
sourceCodePath = "."
@@ -67,7 +68,7 @@ func RunSbomReachabilityFlow(
6768
return nil, fmt.Errorf("failed to create sbom test reachability subject: %w", err)
6869
}
6970

70-
findings, summary, err := RunTest(ctx, ictx, testClient, subject, "", "", int(0), sbomPath, orgID, errFactory, logger, nil)
71+
findings, summary, err := RunTest(ctx, ictx, testClient, subject, "", "", int(0), sbomPath, orgID, errFactory, logger, localPolicy)
7172
if err != nil {
7273
return nil, err
7374
}

internal/commands/ostest/sbom_reachability_flow_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func Test_RunSbomReachabilityFlow_JSON(t *testing.T) {
4242
mockIctx, mockTestClient, mockBsClient, orgID, sbomPath, sourceCodePath := setupTest(ctx, t, ctrl, true)
4343

4444
// This should now succeed with proper finding data
45-
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID)
45+
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, nil)
4646

4747
require.NoError(t, err)
4848
require.NotNil(t, result)
@@ -67,7 +67,7 @@ func Test_RunSbomReachabilityFlow_HumanReadable(t *testing.T) {
6767
mockIctx, mockTestClient, mockBsClient, orgID, sbomPath, sourceCodePath := setupTest(ctx, t, ctrl, false)
6868

6969
// This should now succeed with proper finding data
70-
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID)
70+
result, err := ostest.RunSbomReachabilityFlow(ctx, mockIctx, mockTestClient, ef, &nopLogger, sbomPath, sourceCodePath, mockBsClient, orgID, nil)
7171

7272
require.NoError(t, err)
7373
require.NotNil(t, result)

internal/commands/ostest/workflow.go

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ func setupSBOMReachabilityFlow(
8181
errFactory *errors.ErrorFactory,
8282
logger *zerolog.Logger,
8383
sbom, sourceDir string,
84+
localPolicy *testapi.LocalPolicy,
8485
) ([]workflow.Data, error) {
8586
config := ictx.GetConfiguration()
8687

@@ -104,54 +105,38 @@ func setupSBOMReachabilityFlow(
104105
)
105106

106107
bsClient := bundlestore.NewClient(ictx.GetNetworkAccess().GetHttpClient(), codeScannerConfig, cScanner, logger)
107-
return RunSbomReachabilityFlow(ctx, ictx, testClient, errFactory, logger, sbom, sourceDir, bsClient, orgID)
108+
return RunSbomReachabilityFlow(ctx, ictx, testClient, errFactory, logger, sbom, sourceDir, bsClient, orgID, localPolicy)
108109
}
109110

110-
// setupDefaultTestFlow sets up and runs the default test flow with risk score and severity thresholds.
111-
func setupDefaultTestFlow(
112-
ctx context.Context,
113-
ictx workflow.InvocationContext,
114-
testClient testapi.TestClient,
115-
orgID string,
116-
errFactory *errors.ErrorFactory,
117-
logger *zerolog.Logger,
118-
riskScoreThreshold int,
119-
) ([]workflow.Data, error) {
120-
config := ictx.GetConfiguration()
121-
122-
// Risk Score FFs
123-
ffRiskScore := config.GetBool(FeatureFlagRiskScore)
124-
ffRiskScoreInCLI := config.GetBool(FeatureFlagRiskScoreInCLI)
125-
riskScoreFFsEnabled := ffRiskScore && ffRiskScoreInCLI
126-
127-
if riskScoreThreshold != -1 && !riskScoreFFsEnabled {
128-
// The user tried to use a risk score threshold without the required feature flags.
129-
// Return a specific error for the first missing flag found.
130-
if !ffRiskScore {
131-
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScore)
132-
}
133-
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScoreInCLI)
134-
}
135-
136-
var riskScorePtr *uint16
137-
if riskScoreThreshold >= math.MaxUint16 {
111+
// CreateLocalPolicy will create a local policy only if risk score or severity threshold are specified in the config.
112+
func CreateLocalPolicy(config configuration.Configuration, logger *zerolog.Logger) *testapi.LocalPolicy {
113+
var riskScoreThreshold *uint16
114+
riskScoreThresholdInt := config.GetInt(flags.FlagRiskScoreThreshold)
115+
if riskScoreThresholdInt >= math.MaxUint16 {
138116
// the API will enforce a range from the test spec
139-
logger.Warn().Msgf("Risk score threshold %d exceeds maximum uint16 value. Setting to maximum.", riskScoreThreshold)
117+
logger.Warn().Msgf("Risk score threshold %d exceeds maximum uint16 value. Setting to maximum.", riskScoreThresholdInt)
140118
maxVal := uint16(math.MaxUint16)
141-
riskScorePtr = &maxVal
142-
} else if riskScoreThreshold >= 0 {
143-
rs := uint16(riskScoreThreshold)
144-
riskScorePtr = &rs
119+
riskScoreThreshold = &maxVal
120+
} else if riskScoreThresholdInt >= 0 {
121+
rs := uint16(riskScoreThresholdInt)
122+
riskScoreThreshold = &rs
145123
}
146124

147-
var severityThresholdPtr *testapi.Severity
125+
var severityThreshold *testapi.Severity
148126
severityThresholdStr := config.GetString(flags.FlagSeverityThreshold)
149127
if severityThresholdStr != "" {
150128
st := testapi.Severity(severityThresholdStr)
151-
severityThresholdPtr = &st
129+
severityThreshold = &st
152130
}
153131

154-
return RunUnifiedTestFlow(ctx, ictx, testClient, riskScorePtr, severityThresholdPtr, orgID, errFactory, logger)
132+
if riskScoreThreshold == nil && severityThreshold == nil {
133+
return nil
134+
}
135+
136+
return &testapi.LocalPolicy{
137+
RiskScoreThreshold: riskScoreThreshold,
138+
SeverityThreshold: severityThreshold,
139+
}
155140
}
156141

157142
// OSWorkflow is the entry point for the Open Source Test workflow.
@@ -195,6 +180,17 @@ func OSWorkflow(
195180
return nil, errFactory.NewEmptyOrgError()
196181
}
197182

183+
if riskScoreThreshold != -1 && !riskScoreFFsEnabled {
184+
// The user tried to use a risk score threshold without the required feature flags.
185+
// Return a specific error for the first missing flag found.
186+
if !ffRiskScore {
187+
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScore)
188+
}
189+
return nil, errFactory.NewFeatureNotPermittedError(FeatureFlagRiskScoreInCLI)
190+
}
191+
192+
localPolicy := CreateLocalPolicy(config, logger)
193+
198194
// Create Snyk client
199195
httpClient := ictx.GetNetworkAccess().GetHttpClient()
200196
snykClient := snykclient.NewSnykClient(httpClient, ictx.GetConfiguration().GetString(configuration.API_URL), orgID)
@@ -212,8 +208,8 @@ func OSWorkflow(
212208
// Route to the appropriate flow based on flags
213209
switch {
214210
case sbomReachabilityTest:
215-
return setupSBOMReachabilityFlow(ctx, ictx, testClient, orgID, errFactory, logger, sbom, sourceDir)
211+
return setupSBOMReachabilityFlow(ctx, ictx, testClient, orgID, errFactory, logger, sbom, sourceDir, localPolicy)
216212
default:
217-
return setupDefaultTestFlow(ctx, ictx, testClient, orgID, errFactory, logger, riskScoreThreshold)
213+
return RunUnifiedTestFlow(ctx, ictx, testClient, orgID, errFactory, logger, localPolicy)
218214
}
219215
}

internal/commands/ostest/workflow_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ostest_test
33
import (
44
"encoding/json"
55
"fmt"
6+
"math"
67
"net/http"
78
"net/http/httptest"
89
"strings"
@@ -30,6 +31,59 @@ import (
3031

3132
var legacyWorkflowID = workflow.NewWorkflowIdentifier("legacycli")
3233

34+
var logger = zerolog.Nop()
35+
36+
func TestOSWorkflow_CreateLocalPolicy(t *testing.T) {
37+
// Setup - No special flags set
38+
ctrl := gomock.NewController(t)
39+
defer ctrl.Finish()
40+
41+
mockEngine := mocks.NewMockEngine(ctrl)
42+
mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "")
43+
mockConfig := mockInvocationCtx.GetConfiguration()
44+
mockConfig.Set(flags.FlagRiskScoreThreshold, 100)
45+
mockConfig.Set(flags.FlagSeverityThreshold, "high")
46+
47+
localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger)
48+
require.NotNil(t, localPolicy)
49+
50+
require.NotNil(t, localPolicy.RiskScoreThreshold)
51+
assert.Equal(t, uint16(100), *localPolicy.RiskScoreThreshold)
52+
require.NotNil(t, localPolicy.SeverityThreshold)
53+
assert.Equal(t, "high", string(*localPolicy.SeverityThreshold))
54+
}
55+
56+
func TestOSWorkflow_CreateLocalPolicy_NoValues(t *testing.T) {
57+
// Setup - No special flags set
58+
ctrl := gomock.NewController(t)
59+
defer ctrl.Finish()
60+
61+
mockEngine := mocks.NewMockEngine(ctrl)
62+
mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "")
63+
mockConfig := mockInvocationCtx.GetConfiguration()
64+
65+
localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger)
66+
67+
assert.Nil(t, localPolicy)
68+
}
69+
70+
func TestOSWorkflow_CreateLocalPolicy_RiskScoreOverflow(t *testing.T) {
71+
// Setup - No special flags set
72+
ctrl := gomock.NewController(t)
73+
defer ctrl.Finish()
74+
75+
mockEngine := mocks.NewMockEngine(ctrl)
76+
mockInvocationCtx := createMockInvocationCtxWithURL(t, ctrl, mockEngine, "")
77+
mockConfig := mockInvocationCtx.GetConfiguration()
78+
mockConfig.Set(flags.FlagRiskScoreThreshold, math.MaxUint16+10)
79+
80+
localPolicy := ostest.CreateLocalPolicy(mockConfig, &logger)
81+
require.NotNil(t, localPolicy)
82+
83+
assert.NotNil(t, localPolicy.RiskScoreThreshold)
84+
assert.Equal(t, uint16(math.MaxUint16), *localPolicy.RiskScoreThreshold)
85+
}
86+
3387
func TestOSWorkflow_LegacyFlow(t *testing.T) {
3488
// Setup - No special flags set
3589
ctrl := gomock.NewController(t)

0 commit comments

Comments
 (0)