Skip to content

Commit 145bddd

Browse files
yroblataskbot
andauthored
Add resource metadata url to www-authenticate header (#1565)
Co-authored-by: taskbot <[email protected]>
1 parent b5bea23 commit 145bddd

File tree

3 files changed

+342
-6
lines changed

3 files changed

+342
-6
lines changed

pkg/auth/token.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type TokenValidator struct {
5757
jwksClient *jwk.Cache
5858
introspectURL string // Optional introspection endpoint
5959
client *http.Client // HTTP client for making requests
60+
resourceURL string // (RFC 9728)
6061

6162
// Lazy JWKS registration
6263
jwksRegistered bool
@@ -217,6 +218,7 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig) (*Token
217218
clientSecret: config.ClientSecret,
218219
jwksClient: cache,
219220
client: config.httpClient,
221+
resourceURL: config.ResourceURL,
220222
}, nil
221223
}
222224

@@ -457,20 +459,46 @@ func (v *TokenValidator) ValidateToken(ctx context.Context, tokenString string)
457459
// ClaimsContextKey is the key used to store claims in the request context.
458460
type ClaimsContextKey struct{}
459461

462+
// buildWWWAuthenticate builds a RFC 6750 / RFC 9728 compliant value for the
463+
// WWW-Authenticate header. It always includes realm and, if set, resource_metadata.
464+
// If includeError is true, it appends error="invalid_token" and an optional description.
465+
func (v *TokenValidator) buildWWWAuthenticate(includeError bool, errDescription string) string {
466+
var parts []string
467+
468+
// realm (RFC 6750)
469+
if v.issuer != "" {
470+
parts = append(parts, fmt.Sprintf(`realm="%s"`, EscapeQuotes(v.issuer)))
471+
}
472+
473+
// resource_metadata (RFC 9728)
474+
if v.resourceURL != "" {
475+
parts = append(parts, fmt.Sprintf(`resource_metadata="%s"`, EscapeQuotes(v.resourceURL)))
476+
}
477+
478+
// error fields (RFC 6750 §3)
479+
if includeError {
480+
parts = append(parts, `error="invalid_token"`)
481+
if errDescription != "" {
482+
parts = append(parts, fmt.Sprintf(`error_description="%s"`, EscapeQuotes(errDescription)))
483+
}
484+
}
485+
return "Bearer " + strings.Join(parts, ", ")
486+
}
487+
460488
// Middleware creates an HTTP middleware that validates JWT tokens.
461489
func (v *TokenValidator) Middleware(next http.Handler) http.Handler {
462490
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
463491
// Get the token from the Authorization header
464492
authHeader := r.Header.Get("Authorization")
465493
if authHeader == "" {
466-
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"%s\"", v.issuer))
494+
w.Header().Set("WWW-Authenticate", v.buildWWWAuthenticate(false, ""))
467495
http.Error(w, "Authorization header required", http.StatusUnauthorized)
468496
return
469497
}
470498

471499
// Check if the Authorization header has the Bearer prefix
472500
if !strings.HasPrefix(authHeader, "Bearer ") {
473-
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"%s\"", v.issuer))
501+
w.Header().Set("WWW-Authenticate", v.buildWWWAuthenticate(false, ""))
474502
http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized)
475503
return
476504
}
@@ -481,10 +509,7 @@ func (v *TokenValidator) Middleware(next http.Handler) http.Handler {
481509
// Validate the token
482510
claims, err := v.ValidateToken(r.Context(), tokenString)
483511
if err != nil {
484-
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
485-
"Bearer realm=\"%s\", error=\"invalid_token\", error_description=\"%v\"",
486-
v.issuer, err,
487-
))
512+
w.Header().Set("WWW-Authenticate", v.buildWWWAuthenticate(true, err.Error()))
488513
http.Error(w, fmt.Sprintf("Invalid token: %v", err), http.StatusUnauthorized)
489514
return
490515
}

pkg/auth/token_test.go

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net/http"
1212
"net/http/httptest"
1313
"os"
14+
"strings"
1415
"testing"
1516
"time"
1617

@@ -19,6 +20,7 @@ import (
1920
)
2021

2122
const testKeyID = "test-key-1"
23+
const issuer = "https://issuer.example.com"
2224

2325
//nolint:gocyclo // This test function is complex but manageable
2426
func TestTokenValidator(t *testing.T) {
@@ -964,3 +966,304 @@ func TestNewAuthInfoHandler(t *testing.T) {
964966
})
965967
}
966968
}
969+
970+
func parseAuthParams(ch string) map[string]string {
971+
out := map[string]string{}
972+
ch = strings.TrimSpace(ch)
973+
if i := strings.IndexByte(ch, ' '); i >= 0 {
974+
ch = strings.TrimSpace(ch[i+1:])
975+
}
976+
var parts []string
977+
var b strings.Builder
978+
inQ := false
979+
for i := 0; i < len(ch); i++ {
980+
c := ch[i]
981+
switch c {
982+
case '"':
983+
inQ = !inQ
984+
b.WriteByte(c)
985+
case ',':
986+
if inQ {
987+
b.WriteByte(c)
988+
} else {
989+
parts = append(parts, strings.TrimSpace(b.String()))
990+
b.Reset()
991+
}
992+
default:
993+
b.WriteByte(c)
994+
}
995+
}
996+
if b.Len() > 0 {
997+
parts = append(parts, strings.TrimSpace(b.String()))
998+
}
999+
for _, p := range parts {
1000+
if p == "" {
1001+
continue
1002+
}
1003+
kv := strings.SplitN(p, "=", 2)
1004+
if len(kv) != 2 {
1005+
continue
1006+
}
1007+
k := strings.ToLower(strings.TrimSpace(kv[0]))
1008+
v := strings.TrimSpace(kv[1])
1009+
if len(v) >= 2 && v[0] == '"' && v[len(v)-1] == '"' {
1010+
v = strings.ReplaceAll(v[1:len(v)-1], `\"`, `"`)
1011+
v = strings.ReplaceAll(v, `\\`, `\`)
1012+
}
1013+
out[k] = v
1014+
}
1015+
return out
1016+
}
1017+
func TestMiddleware_WWWAuthenticate_NoHeader_And_WrongScheme(t *testing.T) {
1018+
t.Parallel()
1019+
1020+
resourceMeta := "https://resource.example.com/.well-known/oauth-protected-resource"
1021+
1022+
tests := []struct {
1023+
name string
1024+
setHeader func(req *http.Request)
1025+
}{
1026+
{
1027+
name: "missing Authorization",
1028+
setHeader: func(_ *http.Request) {},
1029+
},
1030+
{
1031+
name: "wrong scheme Basic",
1032+
setHeader: func(r *http.Request) {
1033+
r.Header.Set("Authorization", "Basic Zm9vOmJhcg==")
1034+
},
1035+
},
1036+
}
1037+
1038+
for _, tt := range tests {
1039+
tt := tt
1040+
t.Run(tt.name, func(t *testing.T) {
1041+
t.Parallel()
1042+
1043+
tv := &TokenValidator{
1044+
issuer: issuer,
1045+
resourceURL: resourceMeta,
1046+
}
1047+
1048+
hitDownstream := false
1049+
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1050+
hitDownstream = true
1051+
w.WriteHeader(http.StatusOK)
1052+
})
1053+
1054+
// Create a NEW server per subtest (so no cross-parallel sharing)
1055+
srv := httptest.NewServer(tv.Middleware(next))
1056+
t.Cleanup(srv.Close)
1057+
1058+
req, _ := http.NewRequest("GET", srv.URL+"/", nil)
1059+
tt.setHeader(req)
1060+
1061+
res, err := http.DefaultClient.Do(req)
1062+
if err != nil {
1063+
t.Fatalf("request failed: %v", err)
1064+
}
1065+
defer res.Body.Close()
1066+
1067+
if res.StatusCode != http.StatusUnauthorized {
1068+
t.Fatalf("expected 401, got %d", res.StatusCode)
1069+
}
1070+
if hitDownstream {
1071+
t.Fatalf("downstream should not have been reached on 401")
1072+
}
1073+
1074+
h := res.Header.Get("WWW-Authenticate")
1075+
if h == "" {
1076+
t.Fatalf("WWW-Authenticate header missing")
1077+
}
1078+
1079+
params := parseAuthParams(h)
1080+
if got := params["realm"]; got != issuer {
1081+
t.Fatalf("realm mismatch: want %q, got %q", issuer, got)
1082+
}
1083+
if v, ok := params["resource_metadata"]; ok && v == "" {
1084+
t.Fatalf("resource_metadata present but empty")
1085+
}
1086+
if _, ok := params["error"]; ok {
1087+
t.Fatalf("unexpected error param for %s", tt.name)
1088+
}
1089+
if _, ok := params["error_description"]; ok {
1090+
t.Fatalf("unexpected error_description for %s", tt.name)
1091+
}
1092+
})
1093+
}
1094+
}
1095+
1096+
func TestMiddleware_WWWAuthenticate_InvalidOpaqueToken_NoIntrospectionConfigured(t *testing.T) {
1097+
t.Parallel()
1098+
1099+
tv := &TokenValidator{
1100+
issuer: issuer,
1101+
// introspectURL intentionally empty to force the error path
1102+
}
1103+
1104+
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1105+
w.WriteHeader(http.StatusOK)
1106+
})
1107+
1108+
srv := httptest.NewServer(tv.Middleware(next))
1109+
t.Cleanup(srv.Close)
1110+
1111+
req, _ := http.NewRequest("GET", srv.URL+"/", nil)
1112+
req.Header.Set("Authorization", "Bearer not-a-jwt") // triggers opaque → introspection path
1113+
1114+
res, err := http.DefaultClient.Do(req)
1115+
if err != nil {
1116+
t.Fatalf("request failed: %v", err)
1117+
}
1118+
defer res.Body.Close()
1119+
1120+
if res.StatusCode != http.StatusUnauthorized {
1121+
t.Fatalf("expected 401, got %d", res.StatusCode)
1122+
}
1123+
h := res.Header.Get("WWW-Authenticate")
1124+
if h == "" {
1125+
t.Fatalf("WWW-Authenticate header missing")
1126+
}
1127+
p := parseAuthParams(h)
1128+
if p["realm"] != issuer {
1129+
t.Fatalf("realm mismatch: want %q got %q", issuer, p["realm"])
1130+
}
1131+
if p["error"] != "invalid_token" {
1132+
t.Fatalf("expected error=invalid_token, got %q", p["error"])
1133+
}
1134+
if p["error_description"] == "" {
1135+
t.Fatalf("expected non-empty error_description")
1136+
}
1137+
}
1138+
1139+
func TestMiddleware_WWWAuthenticate_WithMockIntrospection(t *testing.T) {
1140+
t.Parallel()
1141+
1142+
// Introspection mock that varies by token value
1143+
mux := http.NewServeMux()
1144+
mux.HandleFunc("/introspect", func(w http.ResponseWriter, r *http.Request) {
1145+
_ = r.ParseForm()
1146+
switch r.Form.Get("token") {
1147+
case "good":
1148+
_ = json.NewEncoder(w).Encode(map[string]any{
1149+
"active": true,
1150+
"exp": float64(time.Now().Add(60 * time.Second).Unix()),
1151+
"iss": issuer,
1152+
})
1153+
case "inactive":
1154+
_ = json.NewEncoder(w).Encode(map[string]any{"active": false})
1155+
case "unauth":
1156+
w.WriteHeader(http.StatusUnauthorized)
1157+
_, _ = w.Write([]byte(`{"error":"nope"}`))
1158+
default:
1159+
_ = json.NewEncoder(w).Encode(map[string]any{"active": false})
1160+
}
1161+
})
1162+
introspectTS := httptest.NewServer(mux)
1163+
t.Cleanup(introspectTS.Close)
1164+
1165+
type tc struct {
1166+
name string
1167+
auth string
1168+
wantStatus int
1169+
wantError bool
1170+
errSubstr string
1171+
hitNext bool
1172+
}
1173+
cases := []tc{
1174+
{
1175+
name: "inactive => 401",
1176+
auth: "Bearer inactive",
1177+
wantStatus: http.StatusUnauthorized,
1178+
wantError: true,
1179+
hitNext: false,
1180+
},
1181+
{
1182+
name: "unauth introspection => 401",
1183+
auth: "Bearer unauth",
1184+
wantStatus: http.StatusUnauthorized,
1185+
wantError: true,
1186+
errSubstr: "introspection unauthorized",
1187+
hitNext: false,
1188+
},
1189+
{
1190+
name: "good => passes",
1191+
auth: "Bearer good",
1192+
wantStatus: http.StatusOK,
1193+
wantError: false,
1194+
hitNext: true,
1195+
},
1196+
}
1197+
1198+
for _, c := range cases {
1199+
c := c
1200+
t.Run(c.name, func(t *testing.T) {
1201+
t.Parallel()
1202+
1203+
tv := &TokenValidator{
1204+
issuer: issuer,
1205+
introspectURL: introspectTS.URL + "/introspect",
1206+
clientID: "cid",
1207+
clientSecret: "csecret",
1208+
client: http.DefaultClient,
1209+
}
1210+
1211+
hit := false
1212+
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1213+
hit = true
1214+
w.WriteHeader(http.StatusOK)
1215+
})
1216+
1217+
// NEW: server per subtest
1218+
srv := httptest.NewServer(tv.Middleware(next))
1219+
t.Cleanup(srv.Close)
1220+
1221+
req, _ := http.NewRequest("GET", srv.URL+"/", nil)
1222+
req.Header.Set("Authorization", c.auth)
1223+
res, err := http.DefaultClient.Do(req)
1224+
if err != nil {
1225+
t.Fatalf("request failed: %v", err)
1226+
}
1227+
defer res.Body.Close()
1228+
1229+
if res.StatusCode != c.wantStatus {
1230+
t.Fatalf("status mismatch: want %d got %d", c.wantStatus, res.StatusCode)
1231+
}
1232+
if hit != c.hitNext {
1233+
t.Fatalf("downstream hit mismatch: want %v got %v", c.hitNext, hit)
1234+
}
1235+
1236+
h := res.Header.Get("WWW-Authenticate")
1237+
if c.wantStatus == http.StatusUnauthorized {
1238+
if h == "" {
1239+
t.Fatalf("missing WWW-Authenticate header")
1240+
}
1241+
p := parseAuthParams(h)
1242+
if p["realm"] != issuer {
1243+
t.Fatalf("realm mismatch: %q", p["realm"])
1244+
}
1245+
if c.wantError && p["error"] != "invalid_token" {
1246+
t.Fatalf("expected error=invalid_token, got %q", p["error"])
1247+
}
1248+
if c.errSubstr != "" && !strings.Contains(p["error_description"], c.errSubstr) {
1249+
t.Fatalf("error_description %q missing %q", p["error_description"], c.errSubstr)
1250+
}
1251+
} else if h != "" {
1252+
t.Fatalf("did not expect WWW-Authenticate header on success")
1253+
}
1254+
})
1255+
}
1256+
}
1257+
1258+
func TestBuildWWWAuthenticate_Format(t *testing.T) {
1259+
t.Parallel()
1260+
tv := &TokenValidator{
1261+
issuer: "https://issuer.example.com",
1262+
resourceURL: "https://resource.example.com/.well-known/oauth-protected-resource",
1263+
}
1264+
got := tv.buildWWWAuthenticate(true, `failed to parse "token", reason`)
1265+
want := `Bearer realm="https://issuer.example.com", resource_metadata="https://resource.example.com/.well-known/oauth-protected-resource", error="invalid_token", error_description="failed to parse \"token\", reason"`
1266+
if got != want {
1267+
t.Fatalf("format mismatch:\nwant: %s\n got: %s", want, got)
1268+
}
1269+
}

0 commit comments

Comments
 (0)