Skip to content

Commit a265818

Browse files
authored
feat: add PAT auto-region configuration (#395)
1 parent 03077df commit a265818

File tree

8 files changed

+269
-22
lines changed

8 files changed

+269
-22
lines changed

pkg/app/app.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ func defaultFuncOrganization(engine workflow.Engine, config configuration.Config
8181
func defaultFuncApiUrl(_ configuration.Configuration, logger *zerolog.Logger) configuration.DefaultValueFunction {
8282
callback := func(config configuration.Configuration, existingValue interface{}) (interface{}, error) {
8383
urlString := constants.SNYK_DEFAULT_API_URL
84+
authToken := config.GetString(configuration.AUTHENTICATION_TOKEN)
8485

8586
urlFromOauthToken, err := auth.GetAudienceClaimFromOauthToken(config.GetString(auth.CONFIG_KEY_OAUTH_TOKEN))
8687
if err != nil {
@@ -89,6 +90,14 @@ func defaultFuncApiUrl(_ configuration.Configuration, logger *zerolog.Logger) co
8990

9091
if len(urlFromOauthToken) > 0 && len(urlFromOauthToken[0]) > 0 {
9192
urlString = urlFromOauthToken[0]
93+
} else if auth.IsAuthTypePAT(authToken) {
94+
apiUrl, claimsErr := auth.GetApiUrlFromPAT(authToken)
95+
if claimsErr != nil {
96+
logger.Warn().Err(claimsErr).Msg("failed to get api url from pat")
97+
}
98+
if len(apiUrl) > 0 {
99+
urlString = apiUrl
100+
}
92101
} else if existingValue != nil { // try the configured value as last resort
93102
if temp, ok := existingValue.(string); ok {
94103
urlString = temp

pkg/app/app_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package app
22

33
import (
4+
"encoding/base64"
45
"errors"
56
"fmt"
67
"log"
@@ -89,6 +90,40 @@ func Test_CreateAppEngine_config_replaceV1inApi(t *testing.T) {
8990
assert.Equal(t, expectApiUrl, actualApiUrl)
9091
}
9192

93+
func Test_CreateAppEngine_config_PAT_autoRegionDetection(t *testing.T) {
94+
t.Run("default", func(t *testing.T) {
95+
apiUrl := "api.snyk.io"
96+
euPAT := createMockPAT(t, fmt.Sprintf(`{"h":"%s"}`, apiUrl))
97+
engine := CreateAppEngine()
98+
config := engine.GetConfiguration()
99+
config.Set(configuration.AUTHENTICATION_TOKEN, euPAT)
100+
101+
actualApiUrl := config.GetString(configuration.API_URL)
102+
assert.Equal(t, fmt.Sprintf("https://%s", apiUrl), actualApiUrl)
103+
})
104+
105+
t.Run("eu", func(t *testing.T) {
106+
apiUrl := "api.eu.snyk.io"
107+
euPAT := createMockPAT(t, fmt.Sprintf(`{"h":"%s"}`, apiUrl))
108+
engine := CreateAppEngine()
109+
config := engine.GetConfiguration()
110+
config.Set(configuration.AUTHENTICATION_TOKEN, euPAT)
111+
112+
actualApiUrl := config.GetString(configuration.API_URL)
113+
assert.Equal(t, fmt.Sprintf("https://%s", apiUrl), actualApiUrl)
114+
})
115+
116+
t.Run("invalid PAT reverts to default API URL", func(t *testing.T) {
117+
patWithExtraSegments := "snyk_uat.12345678.payload.signature.extra"
118+
engine := CreateAppEngine()
119+
config := engine.GetConfiguration()
120+
config.Set(configuration.AUTHENTICATION_TOKEN, patWithExtraSegments)
121+
122+
actualApiUrl := config.GetString(configuration.API_URL)
123+
assert.Equal(t, constants.SNYK_DEFAULT_API_URL, actualApiUrl)
124+
})
125+
}
126+
92127
func Test_CreateAppEngine_config_OauthAudHasPrecedence(t *testing.T) {
93128
config := configuration.New()
94129
config.Set(auth.CONFIG_KEY_OAUTH_TOKEN,
@@ -368,3 +403,11 @@ func Test_initConfiguration_DEFAULT_TEMP_DIRECTORY(t *testing.T) {
368403
assert.Equal(t, expected, actual)
369404
})
370405
}
406+
407+
func createMockPAT(t *testing.T, payload string) string {
408+
t.Helper()
409+
410+
encodedPayload := base64.RawURLEncoding.EncodeToString([]byte(payload))
411+
signature := "signature"
412+
return fmt.Sprintf("snyk_uat.12345678.%s.%s", encodedPayload, signature)
413+
}

pkg/auth/tokenauthenticator.go

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package auth
22

33
import (
4+
"encoding/base64"
5+
"encoding/json"
46
"fmt"
57
"net/http"
68
"regexp"
9+
"strings"
710

811
"github.com/google/uuid"
912
)
@@ -15,8 +18,15 @@ const (
1518
CACHED_PAT_IS_VALID_KEY_PREFIX = "cached_pat_is_valid"
1619
CONFIG_KEY_TOKEN = "api" // the snyk config key for api token
1720
CONFIG_KEY_ENDPOINT = "endpoint" // the snyk config key for api endpoint
21+
PAT_REGEX = `(snyk_(?:uat|sat))\.([a-z0-9]{8}\.[a-zA-Z0-9-_]+\.[a-zA-Z0-9-_]+)`
1822
)
1923

24+
// Claims represents the structure of the PATs claims, it does not represent all the claims; only the ones we need
25+
type Claims struct {
26+
// Hostname PAT is valid for
27+
Hostname string `json:"h,omitempty"`
28+
}
29+
2030
var _ Authenticator = (*tokenAuthenticator)(nil)
2131

2232
type tokenAuthenticator struct {
@@ -60,9 +70,40 @@ func IsAuthTypeToken(token string) bool {
6070

6171
func IsAuthTypePAT(token string) bool {
6272
// e.g. snyk_uat.1a2b3c4d.mySuperSecret_Token-Value.aChecksum_123-Value
63-
patRegex := `^snyk_(?:uat|sat)\.[a-z0-9]{8}\.[a-zA-Z0-9-_]+\.[a-zA-Z0-9-_]+$`
64-
if matched, err := regexp.MatchString(patRegex, token); err == nil && matched {
65-
return matched
73+
return regexp.MustCompile(fmt.Sprintf("^%s$", PAT_REGEX)).MatchString(token)
74+
}
75+
76+
// extractClaimsFromPAT accepts a raw PAT string and returns the PAT claims
77+
// differs from the implementation in oauth.go as Snyk PATs do not strictly follow the JWT spec
78+
func extractClaimsFromPAT(raw string) (*Claims, error) {
79+
parts := strings.Split(raw, ".")
80+
if len(parts) != 4 {
81+
return nil, fmt.Errorf("invalid number of segments: %d", len(parts))
6682
}
67-
return false
83+
84+
payload, err := base64.RawURLEncoding.DecodeString(parts[2])
85+
if err != nil {
86+
return nil, fmt.Errorf("failed to decode payload: %w", err)
87+
}
88+
89+
var c Claims
90+
if err = json.Unmarshal(payload, &c); err != nil {
91+
return nil, fmt.Errorf("failed to unmarshal payload: %w", err)
92+
}
93+
94+
return &c, nil
95+
}
96+
97+
func GetApiUrlFromPAT(pat string) (string, error) {
98+
claims, err := extractClaimsFromPAT(pat)
99+
if err != nil {
100+
return "", err
101+
}
102+
103+
hostname := claims.Hostname
104+
if !strings.HasPrefix(hostname, "http") {
105+
hostname = fmt.Sprintf("https://%s", hostname)
106+
}
107+
108+
return hostname, nil
68109
}

pkg/auth/tokenauthenticator_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package auth
22

33
import (
4+
"encoding/base64"
5+
"fmt"
46
"testing"
57

68
"github.com/stretchr/testify/assert"
@@ -17,3 +19,95 @@ func TestIsAuthTypePAT(t *testing.T) {
1719
// legacy token format
1820
assert.False(t, IsAuthTypePAT("f47ac10b-58cc-4372-a567-0e02b2c3d479"))
1921
}
22+
23+
func TestExtractClaimsFromPAT(t *testing.T) {
24+
t.Run("Valid PAT", func(t *testing.T) {
25+
pat := createMockPAT(t, `{"h":"api.snyk.io"}`)
26+
27+
claims, err := extractClaimsFromPAT(pat)
28+
assert.NoError(t, err)
29+
assert.NotNil(t, claims)
30+
assert.Equal(t, "api.snyk.io", claims.Hostname)
31+
})
32+
33+
t.Run("Valid EU PAT", func(t *testing.T) {
34+
pat := createMockPAT(t, `{"h":"api.eu.snyk.io"}`)
35+
36+
claims, err := extractClaimsFromPAT(pat)
37+
assert.NoError(t, err)
38+
assert.NotNil(t, claims)
39+
assert.Equal(t, "api.eu.snyk.io", claims.Hostname)
40+
})
41+
42+
t.Run("PAT with fewer than 4 segments", func(t *testing.T) {
43+
pat := "snyk_test.12345678.payload"
44+
claims, err := extractClaimsFromPAT(pat)
45+
assert.Error(t, err)
46+
assert.Contains(t, err.Error(), "invalid number of segments: 3")
47+
assert.Nil(t, claims)
48+
})
49+
50+
t.Run("PAT with more than 4 segments", func(t *testing.T) {
51+
pat := "snyk_test.12345678.payload.signature.extra"
52+
claims, err := extractClaimsFromPAT(pat)
53+
assert.Error(t, err)
54+
assert.Contains(t, err.Error(), "invalid number of segments: 5")
55+
assert.Nil(t, claims)
56+
})
57+
58+
t.Run("PAT with invalid base64 payload", func(t *testing.T) {
59+
pat := "snyk_test.12345678.invalid-base64!@#$.signature"
60+
claims, err := extractClaimsFromPAT(pat)
61+
assert.Error(t, err)
62+
assert.Contains(t, err.Error(), "failed to decode payload")
63+
assert.Nil(t, claims)
64+
})
65+
66+
t.Run("PAT with invalid JSON payload", func(t *testing.T) {
67+
pat := createMockPAT(t, `{"j":"pat-id-123", "h":"api.snyk.io", "e":1678886400, "s":"sub-id-456`)
68+
69+
claims, err := extractClaimsFromPAT(pat)
70+
assert.Error(t, err)
71+
assert.Contains(t, err.Error(), "failed to unmarshal payload")
72+
assert.Nil(t, claims)
73+
})
74+
}
75+
76+
func TestGetApiUrlFromPAT(t *testing.T) {
77+
t.Run("Valid PAT", func(t *testing.T) {
78+
pat := createMockPAT(t, `{"h":"api.snyk.io"}`)
79+
apiUrl, err := GetApiUrlFromPAT(pat)
80+
assert.NoError(t, err)
81+
assert.Equal(t, "https://api.snyk.io", apiUrl)
82+
})
83+
84+
t.Run("Valid EU PAT", func(t *testing.T) {
85+
pat := createMockPAT(t, `{"h":"api.eu.snyk.io"}`)
86+
apiUrl, err := GetApiUrlFromPAT(pat)
87+
assert.NoError(t, err)
88+
assert.Equal(t, "https://api.eu.snyk.io", apiUrl)
89+
})
90+
91+
t.Run("PAT with scheme", func(t *testing.T) {
92+
pat := createMockPAT(t, `{"h":"http://api.snyk.io"}`)
93+
fmt.Println("pat", pat)
94+
apiUrl, err := GetApiUrlFromPAT(pat)
95+
assert.NoError(t, err)
96+
assert.Equal(t, "http://api.snyk.io", apiUrl)
97+
})
98+
99+
t.Run("Invalid PAT", func(t *testing.T) {
100+
patTooManySegments := "snyk_test.12345678.payload.signature.extra"
101+
apiUrl, err := GetApiUrlFromPAT(patTooManySegments)
102+
assert.Error(t, err)
103+
assert.Equal(t, "", apiUrl)
104+
})
105+
}
106+
107+
func createMockPAT(t *testing.T, payload string) string {
108+
t.Helper()
109+
110+
encodedPayload := base64.RawURLEncoding.EncodeToString([]byte(payload))
111+
signature := "signature"
112+
return fmt.Sprintf("snyk_uat.12345678.%s.%s", encodedPayload, signature)
113+
}

pkg/local_workflows/auth_workflow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ func entryPointDI(invocationCtx workflow.InvocationContext, logger *zerolog.Logg
135135

136136
logger.Print("Validating pat")
137137
whoamiConfig := config.Clone()
138+
// we don't want to use the cache here, so this is a workaround
139+
whoamiConfig.ClearCache()
138140
whoamiConfig.Set(configuration.FLAG_EXPERIMENTAL, true)
139141
whoamiConfig.Set(configuration.AUTHENTICATION_TOKEN, pat)
140142
_, whoamiErr := engine.InvokeWithConfig(workflow.NewWorkflowIdentifier("whoami"), whoamiConfig)
@@ -147,6 +149,8 @@ func entryPointDI(invocationCtx workflow.InvocationContext, logger *zerolog.Logg
147149
}
148150

149151
logger.Print("Validation successful; set pat credentials in config")
152+
// we don't want to use the cache here, so this is a workaround
153+
engine.GetConfiguration().ClearCache()
150154
engine.GetConfiguration().Set(auth.CONFIG_KEY_TOKEN, pat)
151155

152156
err = ui.DefaultUi().Output(auth.AUTHENTICATED_MESSAGE)

pkg/local_workflows/auth_workflow_test.go

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,6 @@ func Test_auth_token(t *testing.T) {
103103
}
104104

105105
func Test_pat(t *testing.T) {
106-
const (
107-
testPAT = "snyk_pat.12345678.abcdefghijklmnopqrstuvwxyz123456"
108-
mockedPatEndpoint = "https://api.snyk.io"
109-
expectedAPIKeyStorage = auth.CONFIG_KEY_TOKEN
110-
)
111-
112106
mockCtl := gomock.NewController(t)
113107
defer mockCtl.Finish()
114108

@@ -118,23 +112,22 @@ func Test_pat(t *testing.T) {
118112

119113
engine := mocks.NewMockEngine(mockCtl)
120114
authenticator := mocks.NewMockAuthenticator(mockCtl)
115+
pat := "myPAT"
121116

122117
t.Run("happy", func(t *testing.T) {
123-
config := configuration.New()
118+
config := configuration.NewWithOpts()
124119
config.Set(authTypeParameter, auth.AUTH_TYPE_PAT)
125-
config.Set(ConfigurationNewAuthenticationToken, testPAT)
120+
config.Set(ConfigurationNewAuthenticationToken, pat)
121+
126122
config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, "some-oauth-token")
127123
config.Set(configuration.AUTHENTICATION_TOKEN, "some-legacy-api-token")
128124

129-
config.Set(configuration.API_URL, []string{"https://api.snyk.io"})
130-
131125
mockInvocationContext := mocks.NewMockInvocationContext(mockCtl)
132126
mockInvocationContext.EXPECT().GetConfiguration().Return(config).AnyTimes()
133127
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
134128
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).Times(1)
135129

136-
engineConfig := configuration.New()
137-
engine.EXPECT().GetConfiguration().Return(engineConfig).AnyTimes()
130+
engine.EXPECT().GetConfiguration().Return(config).AnyTimes()
138131
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any())
139132

140133
err := entryPointDI(mockInvocationContext, &logger, engine, authenticator)
@@ -145,21 +138,20 @@ func Test_pat(t *testing.T) {
145138
})
146139

147140
t.Run("invalid pat should fail", func(t *testing.T) {
148-
config := configuration.New()
141+
config := configuration.NewWithOpts()
149142
config.Set(authTypeParameter, auth.AUTH_TYPE_PAT)
150-
config.Set(ConfigurationNewAuthenticationToken, testPAT)
143+
config.Set(ConfigurationNewAuthenticationToken, pat)
144+
151145
config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, "some-oauth-token")
152146
config.Set(configuration.AUTHENTICATION_TOKEN, "some-legacy-api-token")
153147

154-
config.Set(configuration.API_URL, []string{"https://api.snyk.io"})
155-
156148
mockInvocationContext := mocks.NewMockInvocationContext(mockCtl)
157149
mockInvocationContext.EXPECT().GetConfiguration().Return(config).AnyTimes()
158150
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
159151
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).Times(1)
160152

161-
engineConfig := configuration.New()
162-
engine.EXPECT().GetConfiguration().Return(engineConfig).AnyTimes()
153+
engine.EXPECT().GetConfiguration().Return(config).AnyTimes()
154+
163155
mockWhoAmIError := fmt.Errorf("mock whoami failure")
164156
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any()).Return(nil, mockWhoAmIError)
165157

pkg/logging/scrubbingLogWriter.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,13 @@ func addMandatoryMasking(dict ScrubbingDict) ScrubbingDict {
194194
regex: regexp.MustCompile(s),
195195
}
196196

197+
// Snyk PATs
198+
s = auth.PAT_REGEX
199+
dict[s] = scrubStruct{
200+
groupToRedact: 2,
201+
regex: regexp.MustCompile(s),
202+
}
203+
197204
// github
198205
s = fmt.Sprintf(`(access_token[\\="\s:]+)(%s)&?`, charGroup)
199206
dict[s] = scrubStruct{

0 commit comments

Comments
 (0)