Skip to content

Commit c9e8892

Browse files
committed
Refactor auth flows to support storage contexts
Update authentication flows to support multiple storage contexts, enabling context-aware token management and refresh. Key changes: - Add *WithContext() variants for auth functions - Update user login flow to accept storage context parameter - Store access token expiry (JWT exp claim) instead of session expiry - Update token refresh to write tokens back to correct context - Add getAccessTokenExpiresAtUnix() to parse JWT exp claim - Update tests to use new context-aware functions This enables proper token refresh and bidirectional sync for both CLI and API authentication contexts.
1 parent 7b4f43a commit c9e8892

File tree

6 files changed

+173
-30
lines changed

6 files changed

+173
-30
lines changed

internal/pkg/auth/auth.go

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package auth
22

33
import (
4+
"bytes"
45
"fmt"
6+
"io"
57
"net/http"
68
"os"
79
"strconv"
@@ -25,7 +27,10 @@ type tokenClaims struct {
2527
//
2628
// If the user was logged in and the user session expired, reauthorizeUserRoutine is called to reauthenticate the user again.
2729
// If the environment variable STACKIT_ACCESS_TOKEN is set this token is used instead.
28-
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, _ bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
30+
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, context StorageContext, _ bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
31+
// Set the storage printer so debug messages use the correct verbosity
32+
SetStoragePrinter(p)
33+
2934
// Get access token from env and use this if present
3035
accessToken := os.Getenv(envAccessTokenName)
3136
if accessToken != "" {
@@ -70,7 +75,7 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print
7075
case AUTH_FLOW_USER_TOKEN:
7176
p.Debug(print.DebugLevel, "authenticating using user token")
7277
if userSessionExpired {
73-
err = reauthorizeUserRoutine(p, true)
78+
err = reauthorizeUserRoutine(p, StorageContextCLI, true)
7479
if err != nil {
7580
return nil, fmt.Errorf("user login: %w", err)
7681
}
@@ -84,7 +89,11 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print
8489
}
8590

8691
func UserSessionExpired() (bool, error) {
87-
sessionExpiresAtString, err := GetAuthField(SESSION_EXPIRES_AT_UNIX)
92+
return UserSessionExpiredWithContext(StorageContextCLI)
93+
}
94+
95+
func UserSessionExpiredWithContext(context StorageContext) (bool, error) {
96+
sessionExpiresAtString, err := GetAuthFieldWithContext(context, SESSION_EXPIRES_AT_UNIX)
8897
if err != nil {
8998
return false, fmt.Errorf("get %s: %w", SESSION_EXPIRES_AT_UNIX, err)
9099
}
@@ -98,7 +107,11 @@ func UserSessionExpired() (bool, error) {
98107
}
99108

100109
func GetAccessToken() (string, error) {
101-
accessToken, err := GetAuthField(ACCESS_TOKEN)
110+
return GetAccessTokenWithContext(StorageContextCLI)
111+
}
112+
113+
func GetAccessTokenWithContext(context StorageContext) (string, error) {
114+
accessToken, err := GetAuthFieldWithContext(context, ACCESS_TOKEN)
102115
if err != nil {
103116
return "", fmt.Errorf("get %s: %w", ACCESS_TOKEN, err)
104117
}
@@ -142,18 +155,47 @@ func getEmailFromToken(token string) (string, error) {
142155
return claims.Email, nil
143156
}
144157

158+
func getAccessTokenExpiresAtUnix(accessToken string) (string, error) {
159+
// Parse the access token to get its expiration time
160+
parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{})
161+
if err != nil {
162+
return "", fmt.Errorf("parse access token: %w", err)
163+
}
164+
165+
claims, ok := parsedAccessToken.Claims.(*jwt.RegisteredClaims)
166+
if !ok {
167+
return "", fmt.Errorf("get claims from parsed token: unknown claims type")
168+
}
169+
170+
if claims.ExpiresAt == nil {
171+
return "", fmt.Errorf("access token has no expiration claim")
172+
}
173+
174+
return strconv.FormatInt(claims.ExpiresAt.Unix(), 10), nil
175+
}
176+
145177
// GetValidAccessToken returns a valid access token for the current authentication flow.
146178
// For user token flows, it refreshes the token if necessary.
147179
// For service account flows, it returns the current access token.
148180
func GetValidAccessToken(p *print.Printer) (string, error) {
149-
flow, err := GetAuthFlow()
181+
return GetValidAccessTokenWithContext(p, StorageContextCLI)
182+
}
183+
184+
// GetValidAccessTokenWithContext returns a valid access token for the specified storage context.
185+
// For user token flows, it refreshes the token if necessary.
186+
// For service account flows, it returns the current access token.
187+
func GetValidAccessTokenWithContext(p *print.Printer, context StorageContext) (string, error) {
188+
// Set the storage printer so debug messages use the correct verbosity
189+
SetStoragePrinter(p)
190+
191+
flow, err := GetAuthFlowWithContext(context)
150192
if err != nil {
151193
return "", fmt.Errorf("get authentication flow: %w", err)
152194
}
153195

154196
// For service account flows, just return the current token
155197
if flow == AUTH_FLOW_SERVICE_ACCOUNT_TOKEN || flow == AUTH_FLOW_SERVICE_ACCOUNT_KEY {
156-
return GetAccessToken()
198+
return GetAccessTokenWithContext(context)
157199
}
158200

159201
if flow != AUTH_FLOW_USER_TOKEN {
@@ -166,7 +208,7 @@ func GetValidAccessToken(p *print.Printer) (string, error) {
166208
REFRESH_TOKEN: "",
167209
IDP_TOKEN_ENDPOINT: "",
168210
}
169-
err = GetAuthFieldMap(authFields)
211+
err = GetAuthFieldMapWithContext(context, authFields)
170212
if err != nil {
171213
return "", fmt.Errorf("get tokens from auth storage: %w", err)
172214
}
@@ -201,6 +243,7 @@ func GetValidAccessToken(p *print.Printer) (string, error) {
201243
utf := &userTokenFlow{
202244
printer: p,
203245
client: &http.Client{},
246+
context: context,
204247
authFlow: flow,
205248
accessToken: accessToken,
206249
refreshToken: refreshToken,
@@ -216,3 +259,53 @@ func GetValidAccessToken(p *print.Printer) (string, error) {
216259
// Return the new access token
217260
return utf.accessToken, nil
218261
}
262+
263+
// debugHTTPRequest logs the raw HTTP request details for debugging purposes
264+
func debugHTTPRequest(p *print.Printer, req *http.Request) {
265+
if p == nil || req == nil {
266+
return
267+
}
268+
269+
p.Debug(print.DebugLevel, "=== HTTP REQUEST ===")
270+
p.Debug(print.DebugLevel, "Method: %s", req.Method)
271+
p.Debug(print.DebugLevel, "URL: %s", req.URL.String())
272+
p.Debug(print.DebugLevel, "Headers:")
273+
for name, values := range req.Header {
274+
for _, value := range values {
275+
p.Debug(print.DebugLevel, " %s: %s", name, value)
276+
}
277+
}
278+
p.Debug(print.DebugLevel, "===================")
279+
}
280+
281+
// debugHTTPResponse logs the raw HTTP response details for debugging purposes
282+
func debugHTTPResponse(p *print.Printer, resp *http.Response) {
283+
if p == nil || resp == nil {
284+
return
285+
}
286+
287+
p.Debug(print.DebugLevel, "=== HTTP RESPONSE ===")
288+
p.Debug(print.DebugLevel, "Status: %s", resp.Status)
289+
p.Debug(print.DebugLevel, "Status Code: %d", resp.StatusCode)
290+
p.Debug(print.DebugLevel, "Headers:")
291+
for name, values := range resp.Header {
292+
for _, value := range values {
293+
p.Debug(print.DebugLevel, " %s: %s", name, value)
294+
}
295+
}
296+
297+
// Read and log body (need to restore it for later use)
298+
if resp.Body != nil {
299+
bodyBytes, err := io.ReadAll(resp.Body)
300+
if err != nil {
301+
p.Debug(print.ErrorLevel, "Error reading response body: %v", err)
302+
} else {
303+
// Restore the body for later use
304+
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
305+
306+
// Show raw body without sanitization
307+
p.Debug(print.DebugLevel, "Body: %s", string(bodyBytes))
308+
}
309+
}
310+
p.Debug(print.DebugLevel, "====================")
311+
}

internal/pkg/auth/auth_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func TestAuthenticationConfig(t *testing.T) {
188188
}
189189

190190
reauthorizeUserCalled := false
191-
reauthenticateUser := func(_ *print.Printer, _ bool) error {
191+
reauthenticateUser := func(_ *print.Printer, _ StorageContext, _ bool) error {
192192
if reauthorizeUserCalled {
193193
t.Errorf("user reauthorized more than once")
194194
}

internal/pkg/auth/user_login.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ type apiClient interface {
5050
}
5151

5252
// AuthorizeUser implements the PKCE OAuth2 flow.
53-
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
53+
func AuthorizeUser(p *print.Printer, context StorageContext, isReauthentication bool) error {
54+
// Set the storage printer so debug messages use the correct verbosity
55+
SetStoragePrinter(p)
56+
5457
idpWellKnownConfigURL, err := getIDPWellKnownConfigURL()
5558
if err != nil {
5659
return fmt.Errorf("get IDP well-known configuration: %w", err)
@@ -65,7 +68,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
6568

6669
p.Debug(print.DebugLevel, "get IDP well-known configuration from %s", idpWellKnownConfigURL)
6770
httpClient := &http.Client{}
68-
idpWellKnownConfig, err := parseWellKnownConfiguration(httpClient, idpWellKnownConfigURL)
71+
idpWellKnownConfig, err := parseWellKnownConfiguration(p, httpClient, idpWellKnownConfigURL, context)
6972
if err != nil {
7073
return fmt.Errorf("parse IDP well-known configuration: %w", err)
7174
}
@@ -164,29 +167,30 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
164167
p.Debug(print.DebugLevel, "trading authorization code for access and refresh tokens")
165168

166169
// Trade the authorization code and the code verifier for access and refresh tokens
167-
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpWellKnownConfig, idpClientID, codeVerifier, code, redirectURL)
170+
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(p, idpWellKnownConfig, idpClientID, codeVerifier, code, redirectURL)
168171
if err != nil {
169172
errServer = fmt.Errorf("retrieve tokens: %w", err)
170173
return
171174
}
172175

173176
p.Debug(print.DebugLevel, "received response from the authentication server")
174177

175-
sessionExpiresAtUnix, err := getStartingSessionExpiresAtUnix()
178+
// Get access token expiration from the token itself (not session time limit)
179+
sessionExpiresAtUnix, err := getAccessTokenExpiresAtUnix(accessToken)
176180
if err != nil {
177-
errServer = fmt.Errorf("compute session expiration timestamp: %w", err)
181+
errServer = fmt.Errorf("get access token expiration: %w", err)
178182
return
179183
}
180184

181185
sessionExpiresAtUnixInt, err := strconv.Atoi(sessionExpiresAtUnix)
182186
if err != nil {
183-
p.Debug(print.ErrorLevel, "parse session expiration value \"%s\": %s", sessionExpiresAtUnix, err)
187+
p.Debug(print.ErrorLevel, "parse access token expiration value \"%s\": %s", sessionExpiresAtUnix, err)
184188
} else {
185189
sessionExpiresAt := time.Unix(int64(sessionExpiresAtUnixInt), 0)
186-
p.Debug(print.DebugLevel, "session expires at %s", sessionExpiresAt)
190+
p.Debug(print.DebugLevel, "access token expires at %s", sessionExpiresAt)
187191
}
188192

189-
err = SetAuthFlow(AUTH_FLOW_USER_TOKEN)
193+
err = SetAuthFlowWithContext(context, AUTH_FLOW_USER_TOKEN)
190194
if err != nil {
191195
errServer = fmt.Errorf("set auth flow type: %w", err)
192196
return
@@ -200,7 +204,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
200204

201205
p.Debug(print.DebugLevel, "user %s logged in successfully", email)
202206

203-
err = LoginUser(email, accessToken, refreshToken, sessionExpiresAtUnix)
207+
err = LoginUserWithContext(context, email, accessToken, refreshToken, sessionExpiresAtUnix)
204208
if err != nil {
205209
errServer = fmt.Errorf("set in auth storage: %w", err)
206210
return
@@ -216,7 +220,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
216220
mux.HandleFunc(loginSuccessPath, func(w http.ResponseWriter, _ *http.Request) {
217221
defer cleanup(server)
218222

219-
email, err := GetAuthField(USER_EMAIL)
223+
email, err := GetAuthFieldWithContext(context, USER_EMAIL)
220224
if err != nil {
221225
errServer = fmt.Errorf("read user email: %w", err)
222226
}
@@ -270,7 +274,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
270274
}
271275

272276
// getUserAccessAndRefreshTokens trades the authorization code retrieved from the first OAuth2 leg for an access token and a refresh token
273-
func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
277+
func getUserAccessAndRefreshTokens(p *print.Printer, idpWellKnownConfig *wellKnownConfig, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
274278
// Set form-encoded data for the POST to the access token endpoint
275279
data := fmt.Sprintf(
276280
"grant_type=authorization_code&client_id=%s"+
@@ -283,6 +287,10 @@ func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID
283287
// Create the request and execute it
284288
req, _ := http.NewRequest("POST", idpWellKnownConfig.TokenEndpoint, payload)
285289
req.Header.Add("content-type", "application/x-www-form-urlencoded")
290+
291+
// Debug log the request
292+
debugHTTPRequest(p, req)
293+
286294
httpClient := &http.Client{}
287295
res, err := httpClient.Do(req)
288296
if err != nil {
@@ -296,6 +304,10 @@ func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID
296304
err = fmt.Errorf("close response body: %w", closeErr)
297305
}
298306
}()
307+
308+
// Debug log the response
309+
debugHTTPResponse(p, res)
310+
299311
body, err := io.ReadAll(res.Body)
300312
if err != nil {
301313
return "", "", fmt.Errorf("read response body: %w", err)
@@ -355,8 +367,12 @@ func openBrowser(pageUrl string) error {
355367

356368
// parseWellKnownConfiguration gets the well-known OpenID configuration from the provided URL and returns it as a JSON
357369
// the method also stores the IDP token endpoint in the authentication storage
358-
func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string) (wellKnownConfig *wellKnownConfig, err error) {
370+
func parseWellKnownConfiguration(p *print.Printer, httpClient apiClient, wellKnownConfigURL string, context StorageContext) (wellKnownConfig *wellKnownConfig, err error) {
359371
req, _ := http.NewRequest("GET", wellKnownConfigURL, http.NoBody)
372+
373+
// Debug log the request
374+
debugHTTPRequest(p, req)
375+
360376
res, err := httpClient.Do(req)
361377
if err != nil {
362378
return nil, fmt.Errorf("make the request: %w", err)
@@ -369,6 +385,10 @@ func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string
369385
err = fmt.Errorf("close response body: %w", closeErr)
370386
}
371387
}()
388+
389+
// Debug log the response
390+
debugHTTPResponse(p, res)
391+
372392
body, err := io.ReadAll(res.Body)
373393
if err != nil {
374394
return nil, fmt.Errorf("read response body: %w", err)
@@ -391,7 +411,7 @@ func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string
391411
return nil, fmt.Errorf("found no token endpoint")
392412
}
393413

394-
err = SetAuthField(IDP_TOKEN_ENDPOINT, wellKnownConfig.TokenEndpoint)
414+
err = SetAuthFieldWithContext(context, IDP_TOKEN_ENDPOINT, wellKnownConfig.TokenEndpoint)
395415
if err != nil {
396416
return nil, fmt.Errorf("set token endpoint in the authentication storage: %w", err)
397417
}

internal/pkg/auth/user_login_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99

1010
"github.com/google/go-cmp/cmp"
11+
"github.com/stackitcloud/stackit-cli/internal/pkg/print"
1112
"github.com/zalando/go-keyring"
1213
)
1314

@@ -93,7 +94,9 @@ func TestParseWellKnownConfig(t *testing.T) {
9394
tt.getResponse,
9495
}
9596

96-
got, err := parseWellKnownConfiguration(&testClient, "")
97+
p := print.NewPrinter()
98+
99+
got, err := parseWellKnownConfiguration(p, &testClient, "", StorageContextCLI)
97100

98101
if tt.isValid && err != nil {
99102
t.Fatalf("expected no error, got %v", err)

0 commit comments

Comments
 (0)