Skip to content

Commit cc796ce

Browse files
committed
Respect transport of custom http client
Signed-off-by: Jan-Otto Kröpke <[email protected]>
1 parent d8bffae commit cc796ce

File tree

9 files changed

+107
-73
lines changed

9 files changed

+107
-73
lines changed

core/auth/auth.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
4545
if cfg.CustomAuth != nil {
4646
return cfg.CustomAuth, nil
4747
} else if cfg.NoAuth {
48-
noAuthRoundTripper, err := NoAuth()
48+
noAuthRoundTripper, err := NoAuth(cfg)
4949
if err != nil {
5050
return nil, fmt.Errorf("configuring no auth client: %w", err)
5151
}
@@ -98,10 +98,17 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
9898

9999
// NoAuth configures a flow without authentication and returns an http.RoundTripper
100100
// that can be used to make unauthenticated requests
101-
func NoAuth() (rt http.RoundTripper, err error) {
101+
func NoAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
102102
noAuthConfig := clients.NoAuthFlowConfig{}
103103
noAuthRoundTripper := &clients.NoAuthFlow{}
104-
if err := noAuthRoundTripper.Init(noAuthConfig); err != nil {
104+
105+
if cfg.HTTPClient == nil {
106+
cfg.HTTPClient = &http.Client{
107+
Timeout: clients.DefaultClientTimeout,
108+
}
109+
}
110+
111+
if err := noAuthRoundTripper.Init(noAuthConfig, cfg.HTTPClient.Transport); err != nil {
105112
return nil, fmt.Errorf("initializing client: %w", err)
106113
}
107114
return noAuthRoundTripper, nil
@@ -130,8 +137,14 @@ func TokenAuth(cfg *config.Configuration) (http.RoundTripper, error) {
130137
ServiceAccountToken: cfg.Token,
131138
}
132139

140+
if cfg.HTTPClient == nil {
141+
cfg.HTTPClient = &http.Client{
142+
Timeout: clients.DefaultClientTimeout,
143+
}
144+
}
145+
133146
client := &clients.TokenFlow{}
134-
if err := client.Init(&tokenCfg); err != nil {
147+
if err := client.Init(&tokenCfg, cfg.HTTPClient.Transport); err != nil {
135148
return nil, fmt.Errorf("error initializing client: %w", err)
136149
}
137150

@@ -187,8 +200,14 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
187200
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
188201
}
189202

203+
if cfg.HTTPClient == nil {
204+
cfg.HTTPClient = &http.Client{
205+
Timeout: clients.DefaultClientTimeout,
206+
}
207+
}
208+
190209
client := &clients.KeyFlow{}
191-
if err := client.Init(&keyCfg); err != nil {
210+
if err := client.Init(&keyCfg, cfg.HTTPClient.Transport); err != nil {
192211
return nil, fmt.Errorf("error initializing client: %w", err)
193212
}
194213

core/auth/auth_test.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"crypto/x509"
77
"encoding/json"
88
"encoding/pem"
9+
"net/http"
910
"os"
1011
"reflect"
1112
"testing"
@@ -125,6 +126,7 @@ func TestSetupAuth(t *testing.T) {
125126
t.Fatalf("Creating temporary file: %s", err)
126127
}
127128
defer func() {
129+
_ = credentialsKeyFile.Close()
128130
err := os.Remove(credentialsKeyFile.Name())
129131
if err != nil {
130132
t.Fatalf("Removing temporary file: %s", err)
@@ -361,6 +363,7 @@ func TestDefaultAuth(t *testing.T) {
361363
t.Fatalf("Creating temporary file: %s", err)
362364
}
363365
defer func() {
366+
_ = saKeyFile.Close()
364367
err := os.Remove(saKeyFile.Name())
365368
if err != nil {
366369
t.Fatalf("Removing temporary file: %s", err)
@@ -377,19 +380,13 @@ func TestDefaultAuth(t *testing.T) {
377380
t.Fatalf("Writing private key to temporary file: %s", err)
378381
}
379382

380-
defer func() {
381-
err := saKeyFile.Close()
382-
if err != nil {
383-
t.Fatalf("Removing temporary file: %s", err)
384-
}
385-
}()
386-
387383
// create a credentials file with saKey and private key
388384
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
389385
if errs != nil {
390386
t.Fatalf("Creating temporary file: %s", err)
391387
}
392388
defer func() {
389+
_ = credentialsKeyFile.Close()
393390
err := os.Remove(credentialsKeyFile.Name())
394391
if err != nil {
395392
t.Fatalf("Removing temporary file: %s", err)
@@ -681,7 +678,7 @@ func TestNoAuth(t *testing.T) {
681678
} {
682679
t.Run(test.desc, func(t *testing.T) {
683680
setTemporaryHome(t) // Get the default authentication client and ensure that it's not nil
684-
authClient, err := NoAuth()
681+
authClient, err := NoAuth(&config.Configuration{HTTPClient: http.DefaultClient})
685682
if err != nil {
686683
t.Fatalf("Test returned error on valid test case: %v", err)
687684
}

core/clients/key_flow.go

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ const (
3434

3535
// KeyFlow handles auth with SA key
3636
type KeyFlow struct {
37-
client *http.Client
37+
rt http.RoundTripper
38+
authClient *http.Client
3839
config *KeyFlowConfig
39-
doer func(req *http.Request) (resp *http.Response, err error)
4040
key *ServiceAccountKeyResponse
4141
privateKey *rsa.PrivateKey
4242
privateKeyPEM []byte
@@ -116,15 +116,25 @@ func (c *KeyFlow) GetToken() TokenResponseBody {
116116
return *c.token
117117
}
118118

119-
func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
119+
func (c *KeyFlow) Init(cfg *KeyFlowConfig, rt http.RoundTripper) error {
120120
// No concurrency at this point, so no mutex check needed
121121
c.token = &TokenResponseBody{}
122122
c.config = cfg
123123

124124
if c.config.TokenUrl == "" {
125125
c.config.TokenUrl = tokenAPI
126126
}
127-
c.configureHTTPClient()
127+
128+
if rt == nil {
129+
rt = http.DefaultTransport
130+
}
131+
132+
c.rt = rt
133+
c.authClient = &http.Client{
134+
Transport: rt,
135+
Timeout: DefaultClientTimeout,
136+
}
137+
128138
err := c.validate()
129139
if err != nil {
130140
return err
@@ -163,7 +173,7 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {
163173

164174
// Roundtrip performs the request
165175
func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
166-
if c.client == nil {
176+
if c.rt == nil {
167177
return nil, fmt.Errorf("please run Init()")
168178
}
169179

@@ -172,13 +182,13 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
172182
return nil, err
173183
}
174184
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
175-
return c.doer(req)
185+
return c.rt.RoundTrip(req)
176186
}
177187

178188
// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
179189
func (c *KeyFlow) GetAccessToken() (string, error) {
180-
if c.client == nil {
181-
return "", fmt.Errorf("nil http client, please run Init()")
190+
if c.rt == nil {
191+
return "", fmt.Errorf("nil http round tripper, please run Init()")
182192
}
183193

184194
c.tokenMutex.RLock()
@@ -203,14 +213,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) {
203213
return accessToken, nil
204214
}
205215

206-
// configureHTTPClient configures the HTTP client
207-
func (c *KeyFlow) configureHTTPClient() {
208-
client := &http.Client{}
209-
client.Timeout = DefaultClientTimeout
210-
c.client = client
211-
c.doer = c.client.Do
212-
}
213-
214216
// validate the client is configured well
215217
func (c *KeyFlow) validate() error {
216218
if c.config.ServiceAccountKey == nil {
@@ -279,10 +281,6 @@ func (c *KeyFlow) createAccessToken() (err error) {
279281
// createAccessTokenWithRefreshToken creates an access token using
280282
// an existing pre-validated refresh token
281283
func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) {
282-
if c.client == nil {
283-
return fmt.Errorf("nil http client, please run Init()")
284-
}
285-
286284
c.tokenMutex.RLock()
287285
refreshToken := c.token.RefreshToken
288286
c.tokenMutex.RUnlock()
@@ -334,7 +332,8 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error)
334332
return nil, err
335333
}
336334
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
337-
return c.doer(req)
335+
336+
return c.authClient.Do(req)
338337
}
339338

340339
// parseTokenResponse parses the response from the server

core/clients/key_flow_continuous_refresh_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ func TestContinuousRefreshToken(t *testing.T) {
137137
config: &KeyFlowConfig{
138138
BackgroundTokenRefreshContext: ctx,
139139
},
140-
client: &http.Client{},
141-
doer: mockDo,
140+
authClient: &http.Client{
141+
Transport: mockTransportFn{mockDo},
142+
},
142143
token: &TokenResponseBody{
143144
AccessToken: accessToken,
144145
RefreshToken: refreshToken,
@@ -328,11 +329,13 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
328329
}
329330

330331
keyFlow := &KeyFlow{
331-
client: &http.Client{},
332332
config: &KeyFlowConfig{
333333
BackgroundTokenRefreshContext: ctx,
334334
},
335-
doer: mockDo,
335+
authClient: &http.Client{
336+
Transport: mockTransportFn{mockDo},
337+
},
338+
rt: mockTransportFn{mockDo},
336339
token: &TokenResponseBody{
337340
AccessToken: accessTokenFirst,
338341
RefreshToken: refreshToken,

core/clients/key_flow_test.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/rsa"
66
"crypto/x509"
77
"encoding/pem"
8+
"errors"
89
"fmt"
910
"io"
1011
"net/http"
@@ -113,7 +114,7 @@ func TestKeyFlowInit(t *testing.T) {
113114
}
114115

115116
cfg.ServiceAccountKey = tt.serviceAccountKey
116-
if err := c.Init(cfg); (err != nil) != tt.wantErr {
117+
if err := c.Init(cfg, http.DefaultTransport); (err != nil) != tt.wantErr {
117118
t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr)
118119
}
119120
if c.config == nil {
@@ -268,13 +269,14 @@ func TestRequestToken(t *testing.T) {
268269

269270
for _, tt := range testCases {
270271
t.Run(tt.name, func(t *testing.T) {
271-
mockDo := func(_ *http.Request) (resp *http.Response, err error) {
272-
return tt.mockResponse, tt.mockError
273-
}
274-
275272
c := &KeyFlow{
273+
authClient: &http.Client{
274+
Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) {
275+
return tt.mockResponse, tt.mockError
276+
}},
277+
},
276278
config: &KeyFlowConfig{},
277-
doer: mockDo,
279+
rt: http.DefaultTransport,
278280
}
279281

280282
res, err := c.requestToken(tt.grant, tt.assertion)
@@ -289,7 +291,7 @@ func TestRequestToken(t *testing.T) {
289291
if tt.expectedError != nil {
290292
if err == nil {
291293
t.Errorf("Expected error '%v' but no error was returned", tt.expectedError)
292-
} else if tt.expectedError.Error() != err.Error() {
294+
} else if errors.Is(err, tt.expectedError) {
293295
t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err)
294296
}
295297
} else {
@@ -303,3 +305,11 @@ func TestRequestToken(t *testing.T) {
303305
})
304306
}
305307
}
308+
309+
type mockTransportFn struct {
310+
fn func(req *http.Request) (*http.Response, error)
311+
}
312+
313+
func (m mockTransportFn) RoundTrip(req *http.Request) (*http.Response, error) {
314+
return m.fn(req)
315+
}

core/clients/no_auth_flow.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
)
77

88
type NoAuthFlow struct {
9-
client *http.Client
9+
rt http.RoundTripper
1010
config *NoAuthFlowConfig
1111
}
1212

@@ -24,18 +24,23 @@ func (c *NoAuthFlow) GetConfig() NoAuthFlowConfig {
2424
return *c.config
2525
}
2626

27-
func (c *NoAuthFlow) Init(_ NoAuthFlowConfig) error {
27+
func (c *NoAuthFlow) Init(_ NoAuthFlowConfig, rt http.RoundTripper) error {
2828
c.config = &NoAuthFlowConfig{}
29-
c.client = &http.Client{
30-
Timeout: DefaultClientTimeout,
29+
30+
if rt == nil {
31+
rt = http.DefaultTransport
3132
}
33+
34+
c.rt = rt
35+
3236
return nil
3337
}
3438

3539
// Roundtrip performs the request
3640
func (c *NoAuthFlow) RoundTrip(req *http.Request) (*http.Response, error) {
37-
if c.client == nil {
41+
if c.rt == nil {
3842
return nil, fmt.Errorf("please run Init()")
3943
}
40-
return c.client.Do(req)
44+
45+
return c.rt.RoundTrip(req)
4146
}

core/clients/no_auth_flow_test.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func TestNoAuthFlow_Init(t *testing.T) {
2525
for _, tt := range tests {
2626
t.Run(tt.name, func(t *testing.T) {
2727
c := &NoAuthFlow{}
28-
if err := c.Init(tt.args.cfg); (err != nil) != tt.wantErr {
28+
if err := c.Init(tt.args.cfg, http.DefaultTransport); (err != nil) != tt.wantErr {
2929
t.Errorf("NoAuthFlow.Init() error = %v, wantErr %v", err, tt.wantErr)
3030
}
3131
})
@@ -34,7 +34,7 @@ func TestNoAuthFlow_Init(t *testing.T) {
3434

3535
func TestNoAuthFlow_Do(t *testing.T) {
3636
type fields struct {
37-
client *http.Client
37+
rt http.RoundTripper
3838
}
3939
type args struct{}
4040
tests := []struct {
@@ -45,16 +45,18 @@ func TestNoAuthFlow_Do(t *testing.T) {
4545
wantErr bool
4646
}{
4747
{
48-
name: "fail",
49-
fields: fields{nil},
48+
name: "fail",
49+
fields: fields{
50+
nil,
51+
},
5052
args: args{},
5153
want: 0,
5254
wantErr: true,
5355
},
5456
{
5557
name: "success",
5658
fields: fields{
57-
&http.Client{},
59+
http.DefaultTransport,
5860
},
5961
args: args{},
6062
want: http.StatusOK,
@@ -64,7 +66,7 @@ func TestNoAuthFlow_Do(t *testing.T) {
6466
for _, tt := range tests {
6567
t.Run(tt.name, func(t *testing.T) {
6668
c := &NoAuthFlow{
67-
client: tt.fields.client,
69+
rt: tt.fields.rt,
6870
}
6971
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
7072
w.Header().Set("Content-Type", "application/json")

0 commit comments

Comments
 (0)