Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions router-tests/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4016,6 +4016,123 @@ func TestIntrospectionAuthentication(t *testing.T) {
})
}

func TestUseCustomization(t *testing.T) {
t.Parallel()

authHeader := func(token string) http.Header {
return http.Header{
"Authorization": []string{"Bearer " + token},
}
}

testRequest := func(t *testing.T, xEnv *testenv.Environment, header http.Header, expectSuccess bool) string {
t.Helper()

res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer res.Body.Close()

if expectSuccess {
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader))
} else {
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
}

data, err := io.ReadAll(res.Body)
require.NoError(t, err)
return string(data)
}

testSetup := func(t *testing.T, crypto jwks.Crypto, allowedUse ...string) (string, []authentication.Authenticator) {
t.Helper()

authServer, err := jwks.NewServerWithOptions(t, jwks.WithProviders(crypto), jwks.WithUse(""))
require.NoError(t, err)
t.Cleanup(authServer.Close)

cfg := toJWKSConfig(authServer.JWKSURL(), time.Second*5)
cfg.AllowedUse = allowedUse

tokenDecoder, err := authentication.NewJwksTokenDecoder(
NewContextWithCancel(t),
zap.NewNop(),
[]authentication.JWKSConfig{cfg},
)
require.NoError(t, err)

authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: JwksName,
TokenDecoder: tokenDecoder,
}
authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions)
require.NoError(t, err)

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
}

t.Run("Use option", func(t *testing.T) {
t.Parallel()

t.Run("Test authentication with empty use should fail by default", func(t *testing.T) {
t.Parallel()

rsaCrypto, err := jwks.NewRSACrypto("test", jwkset.AlgRS256, 2048)
require.NoError(t, err)

token, authenticators := testSetup(t, rsaCrypto)

accessController, err := core.NewAccessController(core.AccessControllerOptions{
Authenticators: authenticators,
AuthenticationRequired: true,
SkipIntrospectionQueries: false,
IntrospectionSkipSecret: "",
})
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(accessController),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
body := testRequest(t, xEnv, authHeader(token), false)
require.Equal(t, unauthorizedExpectedData, string(body))
})
})

t.Run("Test authentication with empty use should succeed if allowed", func(t *testing.T) {
t.Parallel()

rsaCrypto, err := jwks.NewRSACrypto("test", jwkset.AlgRS256, 2048)
require.NoError(t, err)

token, authenticators := testSetup(t, rsaCrypto, "")

accessController, err := core.NewAccessController(core.AccessControllerOptions{
Authenticators: authenticators,
AuthenticationRequired: true,
SkipIntrospectionQueries: false,
IntrospectionSkipSecret: "",
})
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(accessController),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
body := testRequest(t, xEnv, authHeader(token), true)
require.Equal(t, employeesExpectedData, string(body))
})
})
})
}

func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig {
return authentication.JWKSConfig{
URL: url,
Expand Down
7 changes: 4 additions & 3 deletions router-tests/cmd/jwks-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ import (
"errors"
"flag"
"fmt"
"github.com/MicahParks/jwkset"
"github.com/golang-jwt/jwt/v5"
"github.com/wundergraph/cosmo/router-tests/jwks"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/MicahParks/jwkset"
"github.com/golang-jwt/jwt/v5"
"github.com/wundergraph/cosmo/router-tests/jwks"
)

var (
Expand Down
16 changes: 13 additions & 3 deletions router-tests/jwks/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Crypto interface {
SigningMethod() jwt.SigningMethod
PrivateKey() privateKey
MarshalJWK() (jwkset.JWK, error)
MarshalJWKWithUse(use jwkset.USE) (jwkset.JWK, error)
KID() string
}

Expand All @@ -37,15 +38,15 @@ func (b *baseCrypto) SigningMethod() jwt.SigningMethod {
return jwt.GetSigningMethod(b.alg.String())
}

func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) {
func (b *baseCrypto) MarshalJWKWithUse(use jwkset.USE) (jwkset.JWK, error) {
marshalOptions := jwkset.JWKMarshalOptions{
Private: false,
}

meta := jwkset.JWKMetadataOptions{
ALG: b.alg,
KID: b.kID,
USE: jwkset.UseSig,
USE: use,
}

options := jwkset.JWKOptions{
Expand All @@ -56,6 +57,11 @@ func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) {
return jwkset.NewJWKFromKey(b.pk, options)
}

func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) {
// Delegate to the new method with default signature use.
return b.MarshalJWKWithUse(jwkset.UseSig)
}

func (b *baseCrypto) KID() string {
return b.kID
}
Expand Down Expand Up @@ -112,14 +118,18 @@ func NewHMACCrypto(kID string, alg jwkset.ALG) (Crypto, error) {
}

func (b *hmacCrypto) MarshalJWK() (jwkset.JWK, error) {
return b.MarshalJWKWithUse(jwkset.UseSig)
}

func (b *hmacCrypto) MarshalJWKWithUse(use jwkset.USE) (jwkset.JWK, error) {
marshalOptions := jwkset.JWKMarshalOptions{
Private: true,
}

meta := jwkset.JWKMetadataOptions{
ALG: b.alg,
KID: b.kID,
USE: jwkset.UseSig,
USE: use,
}

options := jwkset.JWKOptions{
Expand Down
43 changes: 40 additions & 3 deletions router-tests/jwks/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,46 @@ func (s *Server) SetRespondTime(d time.Duration) {
s.respondTime = d
}

// ServerOption represents a configuration option for the test JWKS server.
type ServerOption func(*serverConfig)

// serverConfig holds configurable parameters for server initialization.
type serverConfig struct {
use jwkset.USE
providers []Crypto
}

// WithUse sets the JWK "use" metadata value for keys written to storage.
func WithUse(use jwkset.USE) ServerOption {
return func(cfg *serverConfig) {
cfg.use = use
}
}

func WithProviders(providers ...Crypto) ServerOption {
return func(cfg *serverConfig) {
cfg.providers = providers
}
}

func NewServerWithCrypto(t *testing.T, providers ...Crypto) (*Server, error) {
return NewServerWithOptions(t, WithProviders(providers...))
}

func NewServerWithOptions(t *testing.T, opts ...ServerOption) (*Server, error) {
t.Helper()
if len(providers) == 0 {

// Default configuration
cfg := &serverConfig{
use: jwkset.UseSig,
}

// Apply options
for _, opt := range opts {
opt(cfg)
}

if len(cfg.providers) == 0 {
t.Fatalf("At least one crypto provider is required.")
}

Expand All @@ -146,10 +183,10 @@ func NewServerWithCrypto(t *testing.T, providers ...Crypto) (*Server, error) {

ctx := context.Background()

for _, p := range providers {
for _, p := range cfg.providers {
kid := p.KID()

jwk, err := p.MarshalJWK()
jwk, err := p.MarshalJWKWithUse(cfg.use)
if err != nil {
t.Fatalf("Failed to marshal the JWK.\nError: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion router/core/access_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewAccessController(opts AccessControllerOptions) (*AccessController, error
func (a *AccessController) Access(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
auth, err := authentication.AuthenticateHTTPRequest(r.Context(), a.authenticators, r)
if err != nil {
return nil, ErrUnauthorized
return nil, errors.Join(err, ErrUnauthorized)
}
if auth != nil {
w.Header().Set("X-Authenticated-By", auth.Authenticator())
Expand Down
8 changes: 7 additions & 1 deletion router/core/graphql_prehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1146,8 +1146,14 @@ func (h *PreHandler) handleAuthenticationFailure(requestContext *requestContext,
rtrace.AttachErrToSpan(routerSpan, err)
rtrace.AttachErrToSpan(authenticateSpan, err)

graphqlErr := err
// If the error is an unauthorized error, we want to hide details from the graphql error
if errors.Is(err, ErrUnauthorized) {
graphqlErr = ErrUnauthorized
}

writeOperationError(r, w, requestLogger, &httpGraphqlError{
message: err.Error(),
message: graphqlErr.Error(),
statusCode: http.StatusUnauthorized,
})
}
Expand Down
1 change: 1 addition & 0 deletions router/core/supervisor_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co
URL: jwks.URL,
RefreshInterval: jwks.RefreshInterval,
AllowedAlgorithms: jwks.Algorithms,
AllowedUse: jwks.AllowedUse,

Secret: jwks.Secret,
Algorithm: jwks.Algorithm,
Expand Down
4 changes: 3 additions & 1 deletion router/core/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,13 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
handler.request, err = h.accessController.Access(w, r)
if err != nil {
statusCode := http.StatusForbidden
errorMessage := err
if errors.Is(err, ErrUnauthorized) {
statusCode = http.StatusUnauthorized
errorMessage = ErrUnauthorized
}
http.Error(handler.w, http.StatusText(statusCode), statusCode)
_ = handler.writeErrorMessage(requestID, err)
_ = handler.writeErrorMessage(requestID, errorMessage)
handler.Close(false)
return
}
Expand Down
28 changes: 22 additions & 6 deletions router/pkg/authentication/jwks_token_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type JWKSConfig struct {
URL string
RefreshInterval time.Duration
AllowedAlgorithms []string
AllowedUse []string

Secret string
Algorithm string
Expand Down Expand Up @@ -73,6 +74,7 @@ type keyFuncEntry struct {
jwks keyfunc.Keyfunc
aud audienceSet
allowedAlgorithms []string
allowedUse []string
}

func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) {
Expand Down Expand Up @@ -122,14 +124,15 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS
jwksetHTTPClientOptions.RateLimitWaitMax = c.RefreshUnknownKID.MaxWait
}

jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions)
jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions, toJwksetUseType(c.AllowedUse))
if err != nil {
return nil, err
}
entries = append(entries, keyFuncEntry{
jwks: jwks,
aud: audiencesMap[key],
allowedAlgorithms: c.AllowedAlgorithms,
allowedUse: c.AllowedUse,
})

} else if c.Secret != "" {
Expand Down Expand Up @@ -176,13 +179,14 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS
PrioritizeHTTP: false,
}

jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions)
jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions, toJwksetUseType(c.AllowedUse))
if err != nil {
return nil, err
}
entries = append(entries, keyFuncEntry{
jwks: jwks,
aud: audiencesMap[key],
jwks: jwks,
aud: audiencesMap[key],
allowedUse: c.AllowedUse,
})
}
}
Expand Down Expand Up @@ -248,7 +252,19 @@ func getAudienceSet(audiences []string) audienceSet {
return audSet
}

func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions) (keyfunc.Keyfunc, error) {
func toJwksetUseType(allowedUse []string) []jwkset.USE {
if len(allowedUse) == 0 {
return []jwkset.USE{jwkset.UseSig}
}

useWhitelist := make([]jwkset.USE, len(allowedUse))
for i, u := range allowedUse {
useWhitelist[i] = jwkset.USE(u)
}
return useWhitelist
}

func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions, useWhitelist []jwkset.USE) (keyfunc.Keyfunc, error) {
combined, err := jwkset.NewHTTPClient(options)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err)
Expand All @@ -257,7 +273,7 @@ func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions) (keyfu
keyfuncOptions := keyfunc.Options{
Ctx: ctx,
Storage: combined,
UseWhitelist: []jwkset.USE{jwkset.UseSig},
UseWhitelist: useWhitelist,
}

jwks, err := keyfunc.New(keyfuncOptions)
Expand Down
Loading
Loading