Skip to content

Commit dea101d

Browse files
authored
Lazily register JWKS endpoint with a 5 second timeout instead of on s… (#1472)
1 parent 7c28fe8 commit dea101d

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

pkg/auth/token.go

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/http"
1111
"net/url"
1212
"strings"
13+
"sync"
1314
"time"
1415

1516
"github.com/golang-jwt/jwt/v5"
@@ -57,7 +58,10 @@ type TokenValidator struct {
5758
introspectURL string // Optional introspection endpoint
5859
client *http.Client // HTTP client for making requests
5960

60-
// No need for additional caching as jwk.Cache handles it
61+
// Lazy JWKS registration
62+
jwksRegistered bool
63+
jwksRegistrationMu sync.Mutex
64+
jwksRegistrationErr error
6165
}
6266

6367
// TokenValidatorConfig contains configuration for the token validator.
@@ -176,7 +180,6 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig) (*Token
176180
return nil, fmt.Errorf("%w: %v", ErrFailedToDiscoverOIDC, err)
177181
}
178182
jwksURL = doc.JWKSURI
179-
180183
}
181184

182185
// Ensure we have a JWKS URL either provided or discovered
@@ -203,11 +206,7 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig) (*Token
203206
return nil, fmt.Errorf("failed to create JWKS cache: %w", err)
204207
}
205208

206-
// Register the JWKS URL with the cache
207-
err = cache.Register(ctx, jwksURL)
208-
if err != nil {
209-
return nil, fmt.Errorf("failed to register JWKS URL: %w", err)
210-
}
209+
// Skip synchronous JWKS registration - will be done lazily on first use
211210

212211
return &TokenValidator{
213212
issuer: config.Issuer,
@@ -221,8 +220,40 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig) (*Token
221220
}, nil
222221
}
223222

223+
// ensureJWKSRegistered ensures that the JWKS URL is registered with the cache.
224+
// This is called lazily on first use to avoid blocking startup.
225+
func (v *TokenValidator) ensureJWKSRegistered(ctx context.Context) error {
226+
v.jwksRegistrationMu.Lock()
227+
defer v.jwksRegistrationMu.Unlock()
228+
229+
// Check if already registered or failed
230+
if v.jwksRegistered {
231+
return v.jwksRegistrationErr
232+
}
233+
234+
// Create context with 5-second timeout for JWKS registration
235+
registrationCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
236+
defer cancel()
237+
238+
// Attempt registration
239+
err := v.jwksClient.Register(registrationCtx, v.jwksURL)
240+
if err != nil {
241+
v.jwksRegistrationErr = fmt.Errorf("failed to register JWKS URL: %w", err)
242+
} else {
243+
v.jwksRegistrationErr = nil
244+
}
245+
246+
v.jwksRegistered = true
247+
return v.jwksRegistrationErr
248+
}
249+
224250
// getKeyFromJWKS gets the key from the JWKS.
225251
func (v *TokenValidator) getKeyFromJWKS(ctx context.Context, token *jwt.Token) (interface{}, error) {
252+
// Ensure JWKS is registered before attempting to use it
253+
if err := v.ensureJWKSRegistered(ctx); err != nil {
254+
return nil, fmt.Errorf("JWKS registration failed: %w", err)
255+
}
256+
226257
// Validate the signing method
227258
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
228259
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])

pkg/auth/token_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ func TestTokenValidator(t *testing.T) {
7575
t.Fatalf("Failed to create token validator: %v", err)
7676
}
7777

78+
// Ensure JWKS is registered before lookup
79+
err = validator.ensureJWKSRegistered(ctx)
80+
if err != nil {
81+
t.Fatalf("Failed to register JWKS: %v", err)
82+
}
83+
7884
// Force a refresh of the JWKS cache
7985
_, err = validator.jwksClient.Lookup(ctx, jwksServer.URL)
8086
if err != nil {
@@ -217,6 +223,12 @@ func TestTokenValidatorMiddleware(t *testing.T) {
217223
t.Fatalf("Failed to create token validator: %v", err)
218224
}
219225

226+
// Ensure JWKS is registered before lookup
227+
err = validator.ensureJWKSRegistered(ctx)
228+
if err != nil {
229+
t.Fatalf("Failed to register JWKS: %v", err)
230+
}
231+
220232
// Force a refresh of the JWKS cache
221233
_, err = validator.jwksClient.Lookup(ctx, jwksServer.URL)
222234
if err != nil {
@@ -607,6 +619,12 @@ func TestNewTokenValidatorWithOIDCDiscovery(t *testing.T) {
607619
t.Fatalf("Failed to sign token: %v", err)
608620
}
609621

622+
// Ensure JWKS is registered before lookup
623+
err = validator.ensureJWKSRegistered(ctx)
624+
if err != nil {
625+
t.Fatalf("Failed to register JWKS: %v", err)
626+
}
627+
610628
// Force a refresh of the JWKS cache
611629
_, err = validator.jwksClient.Lookup(ctx, validator.jwksURL)
612630
if err != nil {

0 commit comments

Comments
 (0)