Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions router-tests/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4016,6 +4016,123 @@ func TestIntrospectionAuthentication(t *testing.T) {
})
}

func TestUseCustomization(t *testing.T) {
t.Parallel()

authHeader := func(token string) http.Header {
return http.Header{
"Authorization": []string{"Bearer " + token},
}
}

testRequest := func(t *testing.T, xEnv *testenv.Environment, header http.Header, expectSuccess bool) string {
t.Helper()

res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer res.Body.Close()

if expectSuccess {
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader))
} else {
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
}

data, err := io.ReadAll(res.Body)
require.NoError(t, err)
return string(data)
}

testSetup := func(t *testing.T, crypto jwks.Crypto, allowedUse ...string) (string, []authentication.Authenticator) {
t.Helper()

authServer, err := jwks.NewServerWithOptions(t, jwks.WithProviders(crypto), jwks.WithUse(""))
require.NoError(t, err)
t.Cleanup(authServer.Close)

cfg := toJWKSConfig(authServer.JWKSURL(), time.Second*5)
cfg.AllowedUse = allowedUse

tokenDecoder, err := authentication.NewJwksTokenDecoder(
NewContextWithCancel(t),
zap.NewNop(),
[]authentication.JWKSConfig{cfg},
)
require.NoError(t, err)

authOptions := authentication.HttpHeaderAuthenticatorOptions{
Name: JwksName,
TokenDecoder: tokenDecoder,
}
authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions)
require.NoError(t, err)

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
}

t.Run("Use option", func(t *testing.T) {
t.Parallel()

t.Run("Test authentication with empty use should fail by default", func(t *testing.T) {
t.Parallel()

rsaCrypto, err := jwks.NewRSACrypto("test", jwkset.AlgRS256, 2048)
require.NoError(t, err)

token, authenticators := testSetup(t, rsaCrypto)

accessController, err := core.NewAccessController(core.AccessControllerOptions{
Authenticators: authenticators,
AuthenticationRequired: true,
SkipIntrospectionQueries: false,
IntrospectionSkipSecret: "",
})
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(accessController),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
body := testRequest(t, xEnv, authHeader(token), false)
require.Equal(t, unauthorizedExpectedData, string(body))
})
})

t.Run("Test authentication with empty use should succeed if allowed", func(t *testing.T) {
t.Parallel()

rsaCrypto, err := jwks.NewRSACrypto("test", jwkset.AlgRS256, 2048)
require.NoError(t, err)

token, authenticators := testSetup(t, rsaCrypto, "")

accessController, err := core.NewAccessController(core.AccessControllerOptions{
Authenticators: authenticators,
AuthenticationRequired: true,
SkipIntrospectionQueries: false,
IntrospectionSkipSecret: "",
})
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(accessController),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
body := testRequest(t, xEnv, authHeader(token), true)
require.Equal(t, employeesExpectedData, string(body))
})
})
})
}

func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig {
return authentication.JWKSConfig{
URL: url,
Expand Down
7 changes: 4 additions & 3 deletions router-tests/cmd/jwks-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ import (
"errors"
"flag"
"fmt"
"github.com/MicahParks/jwkset"
"github.com/golang-jwt/jwt/v5"
"github.com/wundergraph/cosmo/router-tests/jwks"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/MicahParks/jwkset"
"github.com/golang-jwt/jwt/v5"
"github.com/wundergraph/cosmo/router-tests/jwks"
)

var (
Expand Down
16 changes: 13 additions & 3 deletions router-tests/jwks/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Crypto interface {
SigningMethod() jwt.SigningMethod
PrivateKey() privateKey
MarshalJWK() (jwkset.JWK, error)
MarshalJWKWithUse(use jwkset.USE) (jwkset.JWK, error)
KID() string
}

Expand All @@ -37,15 +38,15 @@ func (b *baseCrypto) SigningMethod() jwt.SigningMethod {
return jwt.GetSigningMethod(b.alg.String())
}

func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) {
func (b *baseCrypto) MarshalJWKWithUse(use jwkset.USE) (jwkset.JWK, error) {
marshalOptions := jwkset.JWKMarshalOptions{
Private: false,
}

meta := jwkset.JWKMetadataOptions{
ALG: b.alg,
KID: b.kID,
USE: jwkset.UseSig,
USE: use,
}

options := jwkset.JWKOptions{
Expand All @@ -56,6 +57,11 @@ func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) {
return jwkset.NewJWKFromKey(b.pk, options)
}

func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) {
// Delegate to the new method with default signature use.
return b.MarshalJWKWithUse(jwkset.UseSig)
}

func (b *baseCrypto) KID() string {
return b.kID
}
Expand Down Expand Up @@ -112,14 +118,18 @@ func NewHMACCrypto(kID string, alg jwkset.ALG) (Crypto, error) {
}

func (b *hmacCrypto) MarshalJWK() (jwkset.JWK, error) {
return b.MarshalJWKWithUse(jwkset.UseSig)
}

func (b *hmacCrypto) MarshalJWKWithUse(use jwkset.USE) (jwkset.JWK, error) {
marshalOptions := jwkset.JWKMarshalOptions{
Private: true,
}

meta := jwkset.JWKMetadataOptions{
ALG: b.alg,
KID: b.kID,
USE: jwkset.UseSig,
USE: use,
}

options := jwkset.JWKOptions{
Expand Down
43 changes: 40 additions & 3 deletions router-tests/jwks/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,46 @@ func (s *Server) SetRespondTime(d time.Duration) {
s.respondTime = d
}

// ServerOption represents a configuration option for the test JWKS server.
type ServerOption func(*serverConfig)

// serverConfig holds configurable parameters for server initialization.
type serverConfig struct {
use jwkset.USE
providers []Crypto
}

// WithUse sets the JWK "use" metadata value for keys written to storage.
func WithUse(use jwkset.USE) ServerOption {
return func(cfg *serverConfig) {
cfg.use = use
}
}

func WithProviders(providers ...Crypto) ServerOption {
return func(cfg *serverConfig) {
cfg.providers = providers
}
}

func NewServerWithCrypto(t *testing.T, providers ...Crypto) (*Server, error) {
return NewServerWithOptions(t, WithProviders(providers...))
}

func NewServerWithOptions(t *testing.T, opts ...ServerOption) (*Server, error) {
t.Helper()
if len(providers) == 0 {

// Default configuration
cfg := &serverConfig{
use: jwkset.UseSig,
}

// Apply options
for _, opt := range opts {
opt(cfg)
}

if len(cfg.providers) == 0 {
t.Fatalf("At least one crypto provider is required.")
}

Expand All @@ -146,10 +183,10 @@ func NewServerWithCrypto(t *testing.T, providers ...Crypto) (*Server, error) {

ctx := context.Background()

for _, p := range providers {
for _, p := range cfg.providers {
kid := p.KID()

jwk, err := p.MarshalJWK()
jwk, err := p.MarshalJWKWithUse(cfg.use)
if err != nil {
t.Fatalf("Failed to marshal the JWK.\nError: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion router/core/access_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewAccessController(opts AccessControllerOptions) (*AccessController, error
func (a *AccessController) Access(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
auth, err := authentication.AuthenticateHTTPRequest(r.Context(), a.authenticators, r)
if err != nil {
return nil, ErrUnauthorized
return nil, errors.Join(err, ErrUnauthorized)
}
if auth != nil {
w.Header().Set("X-Authenticated-By", auth.Authenticator())
Expand Down
8 changes: 7 additions & 1 deletion router/core/graphql_prehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1146,8 +1146,14 @@ func (h *PreHandler) handleAuthenticationFailure(requestContext *requestContext,
rtrace.AttachErrToSpan(routerSpan, err)
rtrace.AttachErrToSpan(authenticateSpan, err)

graphqlErr := err
// If the error is an unauthorized error, we want to hide details from the graphql error
if errors.Is(err, ErrUnauthorized) {
graphqlErr = ErrUnauthorized
}

writeOperationError(r, w, requestLogger, &httpGraphqlError{
message: err.Error(),
message: graphqlErr.Error(),
statusCode: http.StatusUnauthorized,
})
}
Expand Down
1 change: 1 addition & 0 deletions router/core/supervisor_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co
URL: jwks.URL,
RefreshInterval: jwks.RefreshInterval,
AllowedAlgorithms: jwks.Algorithms,
AllowedUse: jwks.AllowedUse,

Secret: jwks.Secret,
Algorithm: jwks.Algorithm,
Expand Down
4 changes: 3 additions & 1 deletion router/core/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,13 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
handler.request, err = h.accessController.Access(w, r)
if err != nil {
statusCode := http.StatusForbidden
errorMessage := err
if errors.Is(err, ErrUnauthorized) {
statusCode = http.StatusUnauthorized
errorMessage = ErrUnauthorized
}
http.Error(handler.w, http.StatusText(statusCode), statusCode)
_ = handler.writeErrorMessage(requestID, err)
_ = handler.writeErrorMessage(requestID, errorMessage)
handler.Close(false)
return
}
Expand Down
Loading
Loading