Skip to content

Commit 17e6791

Browse files
committed
correct use of keyflow.init() and refactoring
1 parent a406031 commit 17e6791

File tree

2 files changed

+102
-87
lines changed

2 files changed

+102
-87
lines changed

core/clients/key_flow_continuous_refresh_test.go

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,22 @@ func TestContinuousRefreshToken(t *testing.T) {
133133
ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn)
134134
defer cancel()
135135

136-
keyFlow := &KeyFlow{
137-
config: &KeyFlowConfig{
138-
BackgroundTokenRefreshContext: ctx,
139-
},
140-
authClient: &http.Client{
136+
keyFlow := &KeyFlow{}
137+
keyFlowConfig := &KeyFlowConfig{
138+
BackgroundTokenRefreshContext: ctx,
139+
AuthHTTPClient: &http.Client{
141140
Transport: mockTransportFn{mockDo},
142141
},
143-
token: &TokenResponseBody{
144-
AccessToken: accessToken,
145-
RefreshToken: refreshToken,
146-
},
142+
}
143+
err = keyFlow.Init(keyFlowConfig)
144+
if err != nil {
145+
t.Fatalf("failed to initialize key flow: %v", err)
146+
}
147+
148+
// Set the token after initialization
149+
err = keyFlow.SetToken(accessToken, refreshToken)
150+
if err != nil {
151+
t.Fatalf("failed to set token: %v", err)
147152
}
148153

149154
refresher := &continuousTokenRefresher{
@@ -328,18 +333,23 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
328333
}
329334
}
330335

331-
keyFlow := &KeyFlow{
332-
config: &KeyFlowConfig{
333-
BackgroundTokenRefreshContext: ctx,
334-
},
335-
authClient: &http.Client{
336+
keyFlow := &KeyFlow{}
337+
keyFlowConfig := &KeyFlowConfig{
338+
BackgroundTokenRefreshContext: ctx,
339+
AuthHTTPClient: &http.Client{
336340
Transport: mockTransportFn{mockDo},
337341
},
338-
rt: mockTransportFn{mockDo},
339-
token: &TokenResponseBody{
340-
AccessToken: accessTokenFirst,
341-
RefreshToken: refreshToken,
342-
},
342+
HTTPTransport: mockTransportFn{mockDo},
343+
}
344+
err = keyFlow.Init(keyFlowConfig)
345+
if err != nil {
346+
t.Fatalf("failed to initialize key flow: %v", err)
347+
}
348+
349+
// Set the token after initialization
350+
err = keyFlow.SetToken(accessTokenFirst, refreshToken)
351+
if err != nil {
352+
t.Fatalf("failed to set token: %v", err)
343353
}
344354

345355
// TEST START

core/clients/key_flow_test.go

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,25 @@ func TestKeyFlowInit(t *testing.T) {
105105
}
106106
for _, tt := range tests {
107107
t.Run(tt.name, func(t *testing.T) {
108-
c := &KeyFlow{}
109-
cfg := &KeyFlowConfig{}
108+
keyFlow := &KeyFlow{}
109+
keyFlowConfig := &KeyFlowConfig{}
110110
t.Setenv("STACKIT_SERVICE_ACCOUNT_KEY", "")
111111
if tt.genPrivateKey {
112112
privateKeyBytes, err := generatePrivateKey()
113113
if err != nil {
114114
t.Fatalf("Error generating private key: %s", err)
115115
}
116-
cfg.PrivateKey = string(privateKeyBytes)
116+
keyFlowConfig.PrivateKey = string(privateKeyBytes)
117117
}
118118
if tt.invalidPrivateKey {
119-
cfg.PrivateKey = "invalid_key"
119+
keyFlowConfig.PrivateKey = "invalid_key"
120120
}
121121

122-
cfg.ServiceAccountKey = tt.serviceAccountKey
123-
if err := c.Init(cfg); (err != nil) != tt.wantErr {
122+
keyFlowConfig.ServiceAccountKey = tt.serviceAccountKey
123+
if err := keyFlow.Init(keyFlowConfig); (err != nil) != tt.wantErr {
124124
t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr)
125125
}
126-
if c.config == nil {
126+
if keyFlow.config == nil {
127127
t.Error("config is nil")
128128
}
129129
})
@@ -167,8 +167,8 @@ func TestSetToken(t *testing.T) {
167167
}
168168
}
169169

170-
c := &KeyFlow{}
171-
err = c.SetToken(accessToken, tt.refreshToken)
170+
keyFlow := &KeyFlow{}
171+
err = keyFlow.SetToken(accessToken, tt.refreshToken)
172172

173173
if (err != nil) != tt.wantErr {
174174
t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr)
@@ -181,8 +181,8 @@ func TestSetToken(t *testing.T) {
181181
Scope: defaultScope,
182182
TokenType: defaultTokenType,
183183
}
184-
if !cmp.Equal(expectedKeyFlowToken, c.token) {
185-
t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, c.token)
184+
if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) {
185+
t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token)
186186
}
187187
}
188188
})
@@ -282,17 +282,21 @@ func TestRequestToken(t *testing.T) {
282282

283283
for _, tt := range testCases {
284284
t.Run(tt.name, func(t *testing.T) {
285-
c := &KeyFlow{
286-
authClient: &http.Client{
285+
keyFlow := &KeyFlow{}
286+
keyFlowConfig := &KeyFlowConfig{
287+
AuthHTTPClient: &http.Client{
287288
Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) {
288289
return tt.mockResponse, tt.mockError
289290
}},
290291
},
291-
config: &KeyFlowConfig{},
292-
rt: http.DefaultTransport,
292+
HTTPTransport: http.DefaultTransport,
293+
}
294+
err := keyFlow.Init(keyFlowConfig)
295+
if err != nil {
296+
t.Fatalf("failed to initialize key flow: %v", err)
293297
}
294298

295-
res, err := c.requestToken(tt.grant, tt.assertion)
299+
res, err := keyFlow.requestToken(tt.grant, tt.assertion)
296300
defer func() {
297301
if res != nil {
298302
tempErr := res.Body.Close()
@@ -324,14 +328,12 @@ func TestKeyFlow_Do(t *testing.T) {
324328

325329
tests := []struct {
326330
name string
327-
keyFlow *KeyFlow
328331
handlerFn func(tb testing.TB) http.HandlerFunc
329332
want int
330333
wantErr bool
331334
}{
332335
{
333-
name: "success",
334-
keyFlow: &KeyFlow{rt: http.DefaultTransport, config: &KeyFlowConfig{}},
336+
name: "success",
335337
handlerFn: func(tb testing.TB) http.HandlerFunc {
336338
tb.Helper()
337339

@@ -349,8 +351,7 @@ func TestKeyFlow_Do(t *testing.T) {
349351
wantErr: false,
350352
},
351353
{
352-
name: "success with code 500",
353-
keyFlow: &KeyFlow{rt: http.DefaultTransport, config: &KeyFlowConfig{}},
354+
name: "success with code 500",
354355
handlerFn: func(_ testing.TB) http.HandlerFunc {
355356
return func(w http.ResponseWriter, _ *http.Request) {
356357
w.Header().Set("Content-Type", "text/html")
@@ -363,16 +364,6 @@ func TestKeyFlow_Do(t *testing.T) {
363364
},
364365
{
365366
name: "success with custom transport",
366-
keyFlow: &KeyFlow{
367-
rt: mockTransportFn{
368-
fn: func(req *http.Request) (*http.Response, error) {
369-
req.Header.Set("User-Agent", "custom_transport")
370-
371-
return http.DefaultTransport.RoundTrip(req)
372-
},
373-
},
374-
config: &KeyFlowConfig{},
375-
},
376367
handlerFn: func(tb testing.TB) http.HandlerFunc {
377368
tb.Helper()
378369

@@ -391,14 +382,6 @@ func TestKeyFlow_Do(t *testing.T) {
391382
},
392383
{
393384
name: "fail with custom proxy",
394-
keyFlow: &KeyFlow{
395-
rt: &http.Transport{
396-
Proxy: func(_ *http.Request) (*url.URL, error) {
397-
return nil, fmt.Errorf("proxy error")
398-
},
399-
},
400-
config: &KeyFlowConfig{},
401-
},
402385
handlerFn: func(testing.TB) http.HandlerFunc {
403386
return func(w http.ResponseWriter, _ *http.Request) {
404387
w.Header().Set("Content-Type", "application/json")
@@ -421,37 +404,59 @@ func TestKeyFlow_Do(t *testing.T) {
421404
t.Errorf("no error is expected, but got %v", err)
422405
}
423406

424-
tt.keyFlow.config.ServiceAccountKey = fixtureServiceAccountKey()
425-
tt.keyFlow.config.PrivateKey = string(privateKeyBytes)
426-
tt.keyFlow.config.BackgroundTokenRefreshContext = ctx
427-
tt.keyFlow.authClient = &http.Client{
428-
Transport: mockTransportFn{
429-
fn: func(_ *http.Request) (*http.Response, error) {
430-
res := httptest.NewRecorder()
431-
res.WriteHeader(http.StatusOK)
432-
res.Header().Set("Content-Type", "application/json")
433-
434-
token := &TokenResponseBody{
435-
AccessToken: testBearerToken,
436-
ExpiresIn: 2147483647,
437-
RefreshToken: testBearerToken,
438-
TokenType: "Bearer",
407+
keyFlow := &KeyFlow{}
408+
keyFlowConfig := &KeyFlowConfig{
409+
ServiceAccountKey: fixtureServiceAccountKey(),
410+
PrivateKey: string(privateKeyBytes),
411+
BackgroundTokenRefreshContext: ctx,
412+
HTTPTransport: func() http.RoundTripper {
413+
switch tt.name {
414+
case "success with custom transport":
415+
return mockTransportFn{
416+
fn: func(req *http.Request) (*http.Response, error) {
417+
req.Header.Set("User-Agent", "custom_transport")
418+
return http.DefaultTransport.RoundTrip(req)
419+
},
439420
}
440-
441-
if err := json.NewEncoder(res.Body).Encode(token); err != nil {
442-
t.Logf("no error is expected, but got %v", err)
421+
case "fail with custom proxy":
422+
return &http.Transport{
423+
Proxy: func(_ *http.Request) (*url.URL, error) {
424+
return nil, fmt.Errorf("proxy error")
425+
},
443426
}
444-
445-
return res.Result(), nil
427+
default:
428+
return http.DefaultTransport
429+
}
430+
}(),
431+
AuthHTTPClient: &http.Client{
432+
Transport: mockTransportFn{
433+
fn: func(_ *http.Request) (*http.Response, error) {
434+
res := httptest.NewRecorder()
435+
res.WriteHeader(http.StatusOK)
436+
res.Header().Set("Content-Type", "application/json")
437+
438+
token := &TokenResponseBody{
439+
AccessToken: testBearerToken,
440+
ExpiresIn: 2147483647,
441+
RefreshToken: testBearerToken,
442+
TokenType: "Bearer",
443+
}
444+
445+
if err := json.NewEncoder(res.Body).Encode(token); err != nil {
446+
t.Logf("no error is expected, but got %v", err)
447+
}
448+
449+
return res.Result(), nil
450+
},
446451
},
447452
},
448453
}
449-
450-
if err := tt.keyFlow.validate(); err != nil {
451-
t.Errorf("no error is expected, but got %v", err)
454+
err = keyFlow.Init(keyFlowConfig)
455+
if err != nil {
456+
t.Fatalf("failed to initialize key flow: %v", err)
452457
}
453458

454-
go continuousRefreshToken(tt.keyFlow)
459+
go continuousRefreshToken(keyFlow)
455460

456461
tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second)
457462

@@ -461,14 +466,14 @@ func TestKeyFlow_Do(t *testing.T) {
461466
case <-tokenCtx.Done():
462467
t.Error(tokenCtx.Err())
463468
case <-time.After(50 * time.Millisecond):
464-
tt.keyFlow.tokenMutex.RLock()
465-
if tt.keyFlow.token != nil {
466-
tt.keyFlow.tokenMutex.RUnlock()
469+
keyFlow.tokenMutex.RLock()
470+
if keyFlow.token != nil {
471+
keyFlow.tokenMutex.RUnlock()
467472
tokenCancel()
468473
break token
469474
}
470475

471-
tt.keyFlow.tokenMutex.RUnlock()
476+
keyFlow.tokenMutex.RUnlock()
472477
}
473478
}
474479

@@ -486,7 +491,7 @@ func TestKeyFlow_Do(t *testing.T) {
486491
}
487492

488493
httpClient := &http.Client{
489-
Transport: tt.keyFlow,
494+
Transport: keyFlow,
490495
}
491496

492497
res, err := httpClient.Do(req)

0 commit comments

Comments
 (0)