Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
- **New:** API for application load balancer
- `cdn`: [v0.1.0](services/cdn/CHANGELOG.md#v010-2025-03-19)
- **New:** Introduce new API for content delivery
- `core`: [v0.16.2](core/CHANGELOG.md#v0162-2025-03-21)
- **New:** If a custom http.Client is provided, the http.Transport is respected. This allows customizing the http.Client with custom timeouts or instrumentation.
- `serverupdate`: [v1.0.0](services/serverupdate/CHANGELOG.md#v100-2025-03-19)
- **Breaking Change:** The region is no longer specified within the client configuration. Instead, the region must be passed as a parameter to any region-specific request.
- `serverbackup`: [v1.0.0](services/serverbackup/CHANGELOG.md#v100-2025-03-19)
Expand Down
3 changes: 3 additions & 0 deletions core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## v0.16.2 (2025-03-21)
- **New:** If a custom http.Client is provided, the http.Transport is respected. This allows customizing the http.Client with custom timeouts or instrumentation.

## v0.16.1 (2025-02-25)

- **Bugfix:** STACKIT_PRIVATE_KEY and STACKIT_SERVICE_ACCOUNT_KEY can be set via environment variable or via credentials file.
Expand Down
25 changes: 23 additions & 2 deletions core/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
if cfg.CustomAuth != nil {
return cfg.CustomAuth, nil
} else if cfg.NoAuth {
noAuthRoundTripper, err := NoAuth()
noAuthRoundTripper, err := NoAuth(cfg)
if err != nil {
return nil, fmt.Errorf("configuring no auth client: %w", err)
}
Expand Down Expand Up @@ -98,9 +98,22 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {

// NoAuth configures a flow without authentication and returns an http.RoundTripper
// that can be used to make unauthenticated requests
func NoAuth() (rt http.RoundTripper, err error) {
func NoAuth(cfgs ...*config.Configuration) (rt http.RoundTripper, err error) {
noAuthConfig := clients.NoAuthFlowConfig{}
noAuthRoundTripper := &clients.NoAuthFlow{}

var cfg *config.Configuration

if len(cfgs) > 0 {
cfg = cfgs[0]
} else {
cfg = &config.Configuration{}
}

if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
noAuthConfig.HTTPTransport = cfg.HTTPClient.Transport
}

if err := noAuthRoundTripper.Init(noAuthConfig); err != nil {
return nil, fmt.Errorf("initializing client: %w", err)
}
Expand Down Expand Up @@ -130,6 +143,10 @@ func TokenAuth(cfg *config.Configuration) (http.RoundTripper, error) {
ServiceAccountToken: cfg.Token,
}

if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
tokenCfg.HTTPTransport = cfg.HTTPClient.Transport
}

client := &clients.TokenFlow{}
if err := client.Init(&tokenCfg); err != nil {
return nil, fmt.Errorf("error initializing client: %w", err)
Expand Down Expand Up @@ -187,6 +204,10 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
}

if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
keyCfg.HTTPTransport = cfg.HTTPClient.Transport
}

client := &clients.KeyFlow{}
if err := client.Init(&keyCfg); err != nil {
return nil, fmt.Errorf("error initializing client: %w", err)
Expand Down
33 changes: 26 additions & 7 deletions core/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"net/http"
"os"
"reflect"
"testing"
Expand Down Expand Up @@ -125,6 +126,7 @@ func TestSetupAuth(t *testing.T) {
t.Fatalf("Creating temporary file: %s", err)
}
defer func() {
_ = credentialsKeyFile.Close()
err := os.Remove(credentialsKeyFile.Name())
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
Expand Down Expand Up @@ -361,6 +363,7 @@ func TestDefaultAuth(t *testing.T) {
t.Fatalf("Creating temporary file: %s", err)
}
defer func() {
_ = saKeyFile.Close()
err := os.Remove(saKeyFile.Name())
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
Expand All @@ -377,19 +380,13 @@ func TestDefaultAuth(t *testing.T) {
t.Fatalf("Writing private key to temporary file: %s", err)
}

defer func() {
err := saKeyFile.Close()
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}
}()

// create a credentials file with saKey and private key
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
t.Fatalf("Creating temporary file: %s", err)
}
defer func() {
_ = credentialsKeyFile.Close()
err := os.Remove(credentialsKeyFile.Name())
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
Expand Down Expand Up @@ -693,6 +690,28 @@ func TestNoAuth(t *testing.T) {
}
}

func TestNoAuthWithConfig(t *testing.T) {
for _, test := range []struct {
desc string
}{
{
desc: "valid_case",
},
} {
t.Run(test.desc, func(t *testing.T) {
setTemporaryHome(t) // Get the default authentication client and ensure that it's not nil
authClient, err := NoAuth(&config.Configuration{HTTPClient: http.DefaultClient})
if err != nil {
t.Fatalf("Test returned error on valid test case: %v", err)
}

if authClient == nil {
t.Fatalf("Client returned is nil for valid test case")
}
})
}
}

func TestGetServiceAccountEmail(t *testing.T) {
for _, test := range []struct {
description string
Expand Down
54 changes: 32 additions & 22 deletions core/clients/key_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ const (

// KeyFlow handles auth with SA key
type KeyFlow struct {
client *http.Client
rt http.RoundTripper
authClient *http.Client
config *KeyFlowConfig
doer func(req *http.Request) (resp *http.Response, err error)
key *ServiceAccountKeyResponse
privateKey *rsa.PrivateKey
privateKeyPEM []byte
Expand All @@ -53,6 +53,8 @@ type KeyFlowConfig struct {
ClientRetry *RetryConfig
TokenUrl string
BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil
HTTPTransport http.RoundTripper
AuthHTTPClient *http.Client
}

// TokenResponseBody is the API response
Expand Down Expand Up @@ -124,7 +126,18 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
if c.config.TokenUrl == "" {
c.config.TokenUrl = tokenAPI
}
c.configureHTTPClient()

if c.rt = cfg.HTTPTransport; c.rt == nil {
c.rt = http.DefaultTransport
}

if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil {
c.authClient = &http.Client{
Transport: c.rt,
Timeout: DefaultClientTimeout,
}
}

err := c.validate()
if err != nil {
return err
Expand Down Expand Up @@ -163,7 +176,7 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {

// Roundtrip performs the request
func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
if c.client == nil {
if c.rt == nil {
return nil, fmt.Errorf("please run Init()")
}

Expand All @@ -172,17 +185,21 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return c.doer(req)
return c.rt.RoundTrip(req)
}

// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
func (c *KeyFlow) GetAccessToken() (string, error) {
if c.client == nil {
return "", fmt.Errorf("nil http client, please run Init()")
if c.rt == nil {
return "", fmt.Errorf("nil http round tripper, please run Init()")
}

var accessToken string

c.tokenMutex.RLock()
accessToken := c.token.AccessToken
if c.token != nil {
accessToken = c.token.AccessToken
}
c.tokenMutex.RUnlock()

accessTokenExpired, err := tokenExpired(accessToken)
Expand All @@ -203,14 +220,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) {
return accessToken, nil
}

// configureHTTPClient configures the HTTP client
func (c *KeyFlow) configureHTTPClient() {
client := &http.Client{}
client.Timeout = DefaultClientTimeout
c.client = client
c.doer = c.client.Do
}

// validate the client is configured well
func (c *KeyFlow) validate() error {
if c.config.ServiceAccountKey == nil {
Expand Down Expand Up @@ -242,8 +251,12 @@ func (c *KeyFlow) validate() error {
// recreateAccessToken is used to create a new access token
// when the existing one isn't valid anymore
func (c *KeyFlow) recreateAccessToken() error {
var refreshToken string

c.tokenMutex.RLock()
refreshToken := c.token.RefreshToken
if c.token != nil {
refreshToken = c.token.RefreshToken
}
c.tokenMutex.RUnlock()

refreshTokenExpired, err := tokenExpired(refreshToken)
Expand Down Expand Up @@ -279,10 +292,6 @@ func (c *KeyFlow) createAccessToken() (err error) {
// createAccessTokenWithRefreshToken creates an access token using
// an existing pre-validated refresh token
func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) {
if c.client == nil {
return fmt.Errorf("nil http client, please run Init()")
}

c.tokenMutex.RLock()
refreshToken := c.token.RefreshToken
c.tokenMutex.RUnlock()
Expand Down Expand Up @@ -334,7 +343,8 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error)
return nil, err
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
return c.doer(req)

return c.authClient.Do(req)
}

// parseTokenResponse parses the response from the server
Expand Down
5 changes: 4 additions & 1 deletion core/clients/key_flow_continuous_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error {
// Compute timestamp where we'll refresh token
// Access token may be empty at this point, we have to check it
var startRefreshTimestamp time.Time
var accessToken string

refresher.keyFlow.tokenMutex.RLock()
accessToken := refresher.keyFlow.token.AccessToken
if refresher.keyFlow.token != nil {
accessToken = refresher.keyFlow.token.AccessToken
}
refresher.keyFlow.tokenMutex.RUnlock()
if accessToken == "" {
startRefreshTimestamp = time.Now()
Expand Down
11 changes: 7 additions & 4 deletions core/clients/key_flow_continuous_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ func TestContinuousRefreshToken(t *testing.T) {
config: &KeyFlowConfig{
BackgroundTokenRefreshContext: ctx,
},
client: &http.Client{},
doer: mockDo,
authClient: &http.Client{
Transport: mockTransportFn{mockDo},
},
token: &TokenResponseBody{
AccessToken: accessToken,
RefreshToken: refreshToken,
Expand Down Expand Up @@ -328,11 +329,13 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
}

keyFlow := &KeyFlow{
client: &http.Client{},
config: &KeyFlowConfig{
BackgroundTokenRefreshContext: ctx,
},
doer: mockDo,
authClient: &http.Client{
Transport: mockTransportFn{mockDo},
},
rt: mockTransportFn{mockDo},
token: &TokenResponseBody{
AccessToken: accessTokenFirst,
RefreshToken: refreshToken,
Expand Down
Loading
Loading