Skip to content

Commit 5e41dc3

Browse files
committed
replace wif assertion options with a func
Signed-off-by: Jorge Turrado <[email protected]>
1 parent 675a8b5 commit 5e41dc3

File tree

6 files changed

+150
-74
lines changed

6 files changed

+150
-74
lines changed

core/auth/auth.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,8 @@ func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTrippe
238238
TokenUrl: cfg.TokenCustomUrl,
239239
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
240240
ClientID: cfg.ServiceAccountEmail,
241-
FederatedTokenFilePath: cfg.ServiceAccountFederatedTokenPath,
242241
TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration,
243-
FederatedToken: cfg.ServiceAccountFederatedToken,
242+
FederatedTokenFunction: cfg.ServiceAccountFederatedTokenFunc,
244243
}
245244

246245
if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {

core/clients/workload_identity_flow.go

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"sync"
1111
"time"
1212

13-
"github.com/golang-jwt/jwt/v5"
13+
"github.com/stackitcloud/stackit-sdk-go/core/utils"
1414
)
1515

1616
const (
@@ -48,8 +48,6 @@ type WorkloadIdentityFederationFlow struct {
4848
tokenMutex sync.RWMutex
4949
token *TokenResponseBody
5050

51-
parser *jwt.Parser
52-
5351
// If the current access token would expire in less than TokenExpirationLeeway,
5452
// the client will refresh it early to prevent clock skew or other timing issues.
5553
tokenExpirationLeeway time.Duration
@@ -59,12 +57,11 @@ type WorkloadIdentityFederationFlow struct {
5957
type WorkloadIdentityFederationFlowConfig struct {
6058
TokenUrl string
6159
ClientID string
62-
FederatedToken string // Static token string. This is optional, if not set the token will be read from file.
63-
FederatedTokenFilePath string
6460
TokenExpiration string // Not supported yet
6561
BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil
6662
HTTPTransport http.RoundTripper
6763
AuthHTTPClient *http.Client
64+
FederatedTokenFunction func() (string, error) // Function to get the federated token
6865
}
6966

7067
// GetConfig returns the flow configuration
@@ -130,7 +127,6 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo
130127
// No concurrency at this point, so no mutex check needed
131128
c.token = &TokenResponseBody{}
132129
c.config = cfg
133-
c.parser = jwt.NewParser()
134130

135131
if c.config.TokenUrl == "" {
136132
c.config.TokenUrl = getEnvOrDefault(wifTokenEndpointEnv, defaultWifTokenEndpoint)
@@ -140,8 +136,10 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo
140136
c.config.ClientID = getEnvOrDefault(clientIDEnv, "")
141137
}
142138

143-
if c.config.FederatedToken == "" && c.config.FederatedTokenFilePath == "" {
144-
c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath)
139+
if c.config.FederatedTokenFunction == nil {
140+
c.config.FederatedTokenFunction = func() (string, error) {
141+
return utils.ReadJWTFromFileSystem(getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath))
142+
}
145143
}
146144

147145
c.tokenExpirationLeeway = defaultTokenExpirationLeeway
@@ -176,10 +174,8 @@ func (c *WorkloadIdentityFederationFlow) validate() error {
176174
if c.config.TokenUrl == "" {
177175
return fmt.Errorf("token URL cannot be empty")
178176
}
179-
if c.config.FederatedToken == "" {
180-
if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil {
181-
return fmt.Errorf("error reading federated token file - %w", err)
182-
}
177+
if _, err := c.config.FederatedTokenFunction(); err != nil {
178+
return fmt.Errorf("error reading federated token file - %w", err)
183179
}
184180
if c.tokenExpirationLeeway < 0 {
185181
return fmt.Errorf("token expiration leeway cannot be negative")
@@ -190,15 +186,10 @@ func (c *WorkloadIdentityFederationFlow) validate() error {
190186

191187
// createAccessToken creates an access token using self signed JWT
192188
func (c *WorkloadIdentityFederationFlow) createAccessToken() error {
193-
clientAssertion := c.config.FederatedToken
194-
if clientAssertion == "" {
195-
var err error
196-
clientAssertion, err = c.readJWTFromFileSystem(c.config.FederatedTokenFilePath)
197-
if err != nil {
198-
return fmt.Errorf("error reading service account assertion - %w", err)
199-
}
189+
clientAssertion, err := c.config.FederatedTokenFunction()
190+
if err != nil {
191+
return err
200192
}
201-
202193
res, err := c.requestToken(c.config.ClientID, clientAssertion)
203194
if err != nil {
204195
return err
@@ -235,16 +226,3 @@ func (c *WorkloadIdentityFederationFlow) requestToken(clientID, assertion string
235226

236227
return c.authClient.Do(req)
237228
}
238-
239-
func (c *WorkloadIdentityFederationFlow) readJWTFromFileSystem(tokenFilePath string) (string, error) {
240-
token, err := os.ReadFile(tokenFilePath)
241-
if err != nil {
242-
return "", err
243-
}
244-
tokenStr := string(token)
245-
_, _, err = c.parser.ParseUnverified(tokenStr, jwt.MapClaims{})
246-
if err != nil {
247-
return "", err
248-
}
249-
return tokenStr, nil
250-
}

core/clients/workload_identity_flow_test.go

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/golang-jwt/jwt/v5"
14+
"github.com/stackitcloud/stackit-sdk-go/core/utils"
1415
)
1516

1617
func TestWorkloadIdentityFlowInit(t *testing.T) {
@@ -55,12 +56,6 @@ func TestWorkloadIdentityFlowInit(t *testing.T) {
5556
validAssertion: true,
5657
wantErr: true,
5758
},
58-
{
59-
name: "missing assertion",
60-
clientID: "[email protected]",
61-
missingTokenFilePath: true,
62-
wantErr: true,
63-
},
6459
{
6560
name: "invalid assertion",
6661
clientID: "[email protected]",
@@ -114,7 +109,9 @@ func TestWorkloadIdentityFlowInit(t *testing.T) {
114109
if tt.tokenFilePathAsEnv {
115110
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name())
116111
} else {
117-
flowConfig.FederatedTokenFilePath = file.Name()
112+
flowConfig.FederatedTokenFunction = func() (string, error) {
113+
return utils.ReadJWTFromFileSystem(file.Name())
114+
}
118115
}
119116
}
120117

@@ -137,14 +134,6 @@ func TestWorkloadIdentityFlowInit(t *testing.T) {
137134
t.Errorf("tokenUrl mismatch, want %s, got %s", "https://accounts.stackit.cloud/oauth/v2/token", flow.config.TokenUrl)
138135
}
139136

140-
if tt.missingTokenFilePath && flow.config.FederatedTokenFilePath != "/var/run/secrets/stackit.cloud/serviceaccount/token" {
141-
t.Errorf("clientID mismatch, want %s, got %s", "/var/run/secrets/stackit.cloud/serviceaccount/token", flow.config.FederatedTokenFilePath)
142-
}
143-
144-
if !tt.missingTokenFilePath && flow.config.FederatedTokenFilePath == "/var/run/secrets/stackit.cloud/serviceaccount/token" {
145-
t.Errorf("clientID mismatch, want different from %s", flow.config.FederatedTokenFilePath)
146-
}
147-
148137
if tt.tokenExpiration != "" && flow.config.TokenExpiration != tt.tokenExpiration {
149138
t.Errorf("tokenExpiration mismatch, want %s, got %s", tt.tokenExpiration, flow.config.TokenExpiration)
150139
}
@@ -276,7 +265,9 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) {
276265
}
277266

278267
if tt.injectToken {
279-
flowConfig.FederatedToken = token
268+
flowConfig.FederatedTokenFunction = func() (string, error) {
269+
return token, nil
270+
}
280271
} else {
281272
file, err := os.CreateTemp("", "*.token")
282273
if err != nil {
@@ -288,7 +279,9 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) {
288279
t.Fatalf("Removing temporary file: %s", err)
289280
}
290281
}()
291-
flowConfig.FederatedTokenFilePath = file.Name()
282+
flowConfig.FederatedTokenFunction = func() (string, error) {
283+
return utils.ReadJWTFromFileSystem(file.Name())
284+
}
292285
err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend)
293286
if err != nil {
294287
t.Fatalf("writing temporary file: %s", err)

core/config/config.go

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/stackitcloud/stackit-sdk-go/core/clients"
14+
"github.com/stackitcloud/stackit-sdk-go/core/utils"
1415
)
1516

1617
const (
@@ -75,25 +76,24 @@ type Middleware func(http.RoundTripper) http.RoundTripper
7576

7677
// Configuration stores the configuration of the API client
7778
type Configuration struct {
78-
Host string `json:"host,omitempty"`
79-
Scheme string `json:"scheme,omitempty"`
80-
DefaultHeader map[string]string `json:"defaultHeader,omitempty"`
81-
UserAgent string `json:"userAgent,omitempty"`
82-
Debug bool `json:"debug,omitempty"`
83-
NoAuth bool `json:"noAuth,omitempty"`
84-
WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"`
85-
ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"`
86-
ServiceAccountFederatedToken string `json:"serviceAccountFederatedToken,omitempty"`
87-
ServiceAccountFederatedTokenPath string `json:"serviceAccountFederatedTokenPath,omitempty"`
88-
ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"`
89-
Token string `json:"token,omitempty"`
90-
ServiceAccountKey string `json:"serviceAccountKey,omitempty"`
91-
PrivateKey string `json:"privateKey,omitempty"`
92-
ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"`
93-
PrivateKeyPath string `json:"privateKeyPath,omitempty"`
94-
CredentialsFilePath string `json:"credentialsFilePath,omitempty"`
95-
TokenCustomUrl string `json:"tokenCustomUrl,omitempty"`
96-
Region string `json:"region,omitempty"`
79+
Host string `json:"host,omitempty"`
80+
Scheme string `json:"scheme,omitempty"`
81+
DefaultHeader map[string]string `json:"defaultHeader,omitempty"`
82+
UserAgent string `json:"userAgent,omitempty"`
83+
Debug bool `json:"debug,omitempty"`
84+
NoAuth bool `json:"noAuth,omitempty"`
85+
WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"`
86+
ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"`
87+
ServiceAccountFederatedTokenFunc func() (string, error) `json:"serviceAccountFederatedTokenFunc,omitempty"`
88+
ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"`
89+
Token string `json:"token,omitempty"`
90+
ServiceAccountKey string `json:"serviceAccountKey,omitempty"`
91+
PrivateKey string `json:"privateKey,omitempty"`
92+
ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"`
93+
PrivateKeyPath string `json:"privateKeyPath,omitempty"`
94+
CredentialsFilePath string `json:"credentialsFilePath,omitempty"`
95+
TokenCustomUrl string `json:"tokenCustomUrl,omitempty"`
96+
Region string `json:"region,omitempty"`
9797
CustomAuth http.RoundTripper
9898
Servers ServerConfigurations
9999
OperationServers map[string]ServerConfigurations
@@ -247,10 +247,30 @@ func WithWorkloadIdentityFederationAuth() ConfigurationOption {
247247
}
248248
}
249249

250-
// WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls
251-
func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption {
250+
// WithWorkloadIdentityFederationFunc returns a ConfigurationOption that sets the function to get the federated token for workload identity federation flow
251+
func WithWorkloadIdentityFederationFunc(function func() (string, error)) ConfigurationOption {
252252
return func(config *Configuration) error {
253-
config.ServiceAccountFederatedTokenPath = path
253+
config.ServiceAccountFederatedTokenFunc = function
254+
return nil
255+
}
256+
}
257+
258+
// WithWorkloadIdentityFederationPath returns a ConfigurationOption that sets the custom path to the federated token file for workload identity federation flow
259+
func WithWorkloadIdentityFederationPath(path string) ConfigurationOption {
260+
return func(config *Configuration) error {
261+
config.ServiceAccountFederatedTokenFunc = func() (string, error) {
262+
return utils.ReadJWTFromFileSystem(path)
263+
}
264+
return nil
265+
}
266+
}
267+
268+
// WithWorkloadIdentityFederationFunc returns a ConfigurationOption that sets the id token for workload identity federation flow
269+
func WithWorkloadIdentityFederationToken(token string) ConfigurationOption {
270+
return func(config *Configuration) error {
271+
config.ServiceAccountFederatedTokenFunc = func() (string, error) {
272+
return token, nil
273+
}
254274
return nil
255275
}
256276
}

core/utils/filesystem.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package utils
2+
3+
import (
4+
"os"
5+
6+
"github.com/golang-jwt/jwt/v5"
7+
)
8+
9+
var (
10+
parser *jwt.Parser = jwt.NewParser()
11+
)
12+
13+
func ReadJWTFromFileSystem(tokenFilePath string) (string, error) {
14+
token, err := os.ReadFile(tokenFilePath)
15+
if err != nil {
16+
return "", err
17+
}
18+
tokenStr := string(token)
19+
_, _, err = parser.ParseUnverified(tokenStr, jwt.MapClaims{})
20+
if err != nil {
21+
return "", err
22+
}
23+
return tokenStr, nil
24+
}

core/utils/filesystem_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package utils
2+
3+
import (
4+
"log"
5+
"os"
6+
"testing"
7+
)
8+
9+
func TestReadJWTFromFileSystem(t *testing.T) {
10+
file, err := os.CreateTemp("", "*.token")
11+
if err != nil {
12+
log.Fatal(err)
13+
}
14+
defer func() {
15+
err := os.Remove(file.Name())
16+
if err != nil {
17+
t.Fatalf("Removing temporary file: %s", err)
18+
}
19+
}()
20+
21+
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30" // nolint:gosec // This is a fake token for testing purposes only
22+
err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend)
23+
if err != nil {
24+
t.Fatalf("Writing temporary file: %s", err)
25+
}
26+
27+
_, err = ReadJWTFromFileSystem(file.Name())
28+
if err != nil {
29+
t.Fatalf("Reading JWT from file system: %s", err)
30+
}
31+
}
32+
33+
func TestReadRandomContentFromFileSystem(t *testing.T) {
34+
file, err := os.CreateTemp("", "*.token")
35+
if err != nil {
36+
log.Fatal(err)
37+
}
38+
defer func() {
39+
err := os.Remove(file.Name())
40+
if err != nil {
41+
t.Fatalf("Removing temporary file: %s", err)
42+
}
43+
}()
44+
45+
token := "invalid random content"
46+
err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend)
47+
if err != nil {
48+
t.Fatalf("Writing temporary file: %s", err)
49+
}
50+
51+
_, err = ReadJWTFromFileSystem(file.Name())
52+
if err == nil {
53+
t.Fatalf("Reading JWT from file system must fail")
54+
}
55+
}
56+
57+
func TestReadMissingFileFromFileSystem(t *testing.T) {
58+
_, err := ReadJWTFromFileSystem("/path/to/nonexistent/file.token")
59+
if err == nil {
60+
t.Fatalf("Reading JWT from file system must fail")
61+
}
62+
}

0 commit comments

Comments
 (0)