@@ -11,6 +11,7 @@ import (
11
11
"net/http"
12
12
"net/http/httptest"
13
13
"os"
14
+ "strings"
14
15
"testing"
15
16
"time"
16
17
@@ -19,6 +20,7 @@ import (
19
20
)
20
21
21
22
const testKeyID = "test-key-1"
23
+ const issuer = "https://issuer.example.com"
22
24
23
25
//nolint:gocyclo // This test function is complex but manageable
24
26
func TestTokenValidator (t * testing.T ) {
@@ -964,3 +966,304 @@ func TestNewAuthInfoHandler(t *testing.T) {
964
966
})
965
967
}
966
968
}
969
+
970
+ func parseAuthParams (ch string ) map [string ]string {
971
+ out := map [string ]string {}
972
+ ch = strings .TrimSpace (ch )
973
+ if i := strings .IndexByte (ch , ' ' ); i >= 0 {
974
+ ch = strings .TrimSpace (ch [i + 1 :])
975
+ }
976
+ var parts []string
977
+ var b strings.Builder
978
+ inQ := false
979
+ for i := 0 ; i < len (ch ); i ++ {
980
+ c := ch [i ]
981
+ switch c {
982
+ case '"' :
983
+ inQ = ! inQ
984
+ b .WriteByte (c )
985
+ case ',' :
986
+ if inQ {
987
+ b .WriteByte (c )
988
+ } else {
989
+ parts = append (parts , strings .TrimSpace (b .String ()))
990
+ b .Reset ()
991
+ }
992
+ default :
993
+ b .WriteByte (c )
994
+ }
995
+ }
996
+ if b .Len () > 0 {
997
+ parts = append (parts , strings .TrimSpace (b .String ()))
998
+ }
999
+ for _ , p := range parts {
1000
+ if p == "" {
1001
+ continue
1002
+ }
1003
+ kv := strings .SplitN (p , "=" , 2 )
1004
+ if len (kv ) != 2 {
1005
+ continue
1006
+ }
1007
+ k := strings .ToLower (strings .TrimSpace (kv [0 ]))
1008
+ v := strings .TrimSpace (kv [1 ])
1009
+ if len (v ) >= 2 && v [0 ] == '"' && v [len (v )- 1 ] == '"' {
1010
+ v = strings .ReplaceAll (v [1 :len (v )- 1 ], `\"` , `"` )
1011
+ v = strings .ReplaceAll (v , `\\` , `\` )
1012
+ }
1013
+ out [k ] = v
1014
+ }
1015
+ return out
1016
+ }
1017
+ func TestMiddleware_WWWAuthenticate_NoHeader_And_WrongScheme (t * testing.T ) {
1018
+ t .Parallel ()
1019
+
1020
+ resourceMeta := "https://resource.example.com/.well-known/oauth-protected-resource"
1021
+
1022
+ tests := []struct {
1023
+ name string
1024
+ setHeader func (req * http.Request )
1025
+ }{
1026
+ {
1027
+ name : "missing Authorization" ,
1028
+ setHeader : func (_ * http.Request ) {},
1029
+ },
1030
+ {
1031
+ name : "wrong scheme Basic" ,
1032
+ setHeader : func (r * http.Request ) {
1033
+ r .Header .Set ("Authorization" , "Basic Zm9vOmJhcg==" )
1034
+ },
1035
+ },
1036
+ }
1037
+
1038
+ for _ , tt := range tests {
1039
+ tt := tt
1040
+ t .Run (tt .name , func (t * testing.T ) {
1041
+ t .Parallel ()
1042
+
1043
+ tv := & TokenValidator {
1044
+ issuer : issuer ,
1045
+ resourceURL : resourceMeta ,
1046
+ }
1047
+
1048
+ hitDownstream := false
1049
+ next := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
1050
+ hitDownstream = true
1051
+ w .WriteHeader (http .StatusOK )
1052
+ })
1053
+
1054
+ // Create a NEW server per subtest (so no cross-parallel sharing)
1055
+ srv := httptest .NewServer (tv .Middleware (next ))
1056
+ t .Cleanup (srv .Close )
1057
+
1058
+ req , _ := http .NewRequest ("GET" , srv .URL + "/" , nil )
1059
+ tt .setHeader (req )
1060
+
1061
+ res , err := http .DefaultClient .Do (req )
1062
+ if err != nil {
1063
+ t .Fatalf ("request failed: %v" , err )
1064
+ }
1065
+ defer res .Body .Close ()
1066
+
1067
+ if res .StatusCode != http .StatusUnauthorized {
1068
+ t .Fatalf ("expected 401, got %d" , res .StatusCode )
1069
+ }
1070
+ if hitDownstream {
1071
+ t .Fatalf ("downstream should not have been reached on 401" )
1072
+ }
1073
+
1074
+ h := res .Header .Get ("WWW-Authenticate" )
1075
+ if h == "" {
1076
+ t .Fatalf ("WWW-Authenticate header missing" )
1077
+ }
1078
+
1079
+ params := parseAuthParams (h )
1080
+ if got := params ["realm" ]; got != issuer {
1081
+ t .Fatalf ("realm mismatch: want %q, got %q" , issuer , got )
1082
+ }
1083
+ if v , ok := params ["resource_metadata" ]; ok && v == "" {
1084
+ t .Fatalf ("resource_metadata present but empty" )
1085
+ }
1086
+ if _ , ok := params ["error" ]; ok {
1087
+ t .Fatalf ("unexpected error param for %s" , tt .name )
1088
+ }
1089
+ if _ , ok := params ["error_description" ]; ok {
1090
+ t .Fatalf ("unexpected error_description for %s" , tt .name )
1091
+ }
1092
+ })
1093
+ }
1094
+ }
1095
+
1096
+ func TestMiddleware_WWWAuthenticate_InvalidOpaqueToken_NoIntrospectionConfigured (t * testing.T ) {
1097
+ t .Parallel ()
1098
+
1099
+ tv := & TokenValidator {
1100
+ issuer : issuer ,
1101
+ // introspectURL intentionally empty to force the error path
1102
+ }
1103
+
1104
+ next := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
1105
+ w .WriteHeader (http .StatusOK )
1106
+ })
1107
+
1108
+ srv := httptest .NewServer (tv .Middleware (next ))
1109
+ t .Cleanup (srv .Close )
1110
+
1111
+ req , _ := http .NewRequest ("GET" , srv .URL + "/" , nil )
1112
+ req .Header .Set ("Authorization" , "Bearer not-a-jwt" ) // triggers opaque → introspection path
1113
+
1114
+ res , err := http .DefaultClient .Do (req )
1115
+ if err != nil {
1116
+ t .Fatalf ("request failed: %v" , err )
1117
+ }
1118
+ defer res .Body .Close ()
1119
+
1120
+ if res .StatusCode != http .StatusUnauthorized {
1121
+ t .Fatalf ("expected 401, got %d" , res .StatusCode )
1122
+ }
1123
+ h := res .Header .Get ("WWW-Authenticate" )
1124
+ if h == "" {
1125
+ t .Fatalf ("WWW-Authenticate header missing" )
1126
+ }
1127
+ p := parseAuthParams (h )
1128
+ if p ["realm" ] != issuer {
1129
+ t .Fatalf ("realm mismatch: want %q got %q" , issuer , p ["realm" ])
1130
+ }
1131
+ if p ["error" ] != "invalid_token" {
1132
+ t .Fatalf ("expected error=invalid_token, got %q" , p ["error" ])
1133
+ }
1134
+ if p ["error_description" ] == "" {
1135
+ t .Fatalf ("expected non-empty error_description" )
1136
+ }
1137
+ }
1138
+
1139
+ func TestMiddleware_WWWAuthenticate_WithMockIntrospection (t * testing.T ) {
1140
+ t .Parallel ()
1141
+
1142
+ // Introspection mock that varies by token value
1143
+ mux := http .NewServeMux ()
1144
+ mux .HandleFunc ("/introspect" , func (w http.ResponseWriter , r * http.Request ) {
1145
+ _ = r .ParseForm ()
1146
+ switch r .Form .Get ("token" ) {
1147
+ case "good" :
1148
+ _ = json .NewEncoder (w ).Encode (map [string ]any {
1149
+ "active" : true ,
1150
+ "exp" : float64 (time .Now ().Add (60 * time .Second ).Unix ()),
1151
+ "iss" : issuer ,
1152
+ })
1153
+ case "inactive" :
1154
+ _ = json .NewEncoder (w ).Encode (map [string ]any {"active" : false })
1155
+ case "unauth" :
1156
+ w .WriteHeader (http .StatusUnauthorized )
1157
+ _ , _ = w .Write ([]byte (`{"error":"nope"}` ))
1158
+ default :
1159
+ _ = json .NewEncoder (w ).Encode (map [string ]any {"active" : false })
1160
+ }
1161
+ })
1162
+ introspectTS := httptest .NewServer (mux )
1163
+ t .Cleanup (introspectTS .Close )
1164
+
1165
+ type tc struct {
1166
+ name string
1167
+ auth string
1168
+ wantStatus int
1169
+ wantError bool
1170
+ errSubstr string
1171
+ hitNext bool
1172
+ }
1173
+ cases := []tc {
1174
+ {
1175
+ name : "inactive => 401" ,
1176
+ auth : "Bearer inactive" ,
1177
+ wantStatus : http .StatusUnauthorized ,
1178
+ wantError : true ,
1179
+ hitNext : false ,
1180
+ },
1181
+ {
1182
+ name : "unauth introspection => 401" ,
1183
+ auth : "Bearer unauth" ,
1184
+ wantStatus : http .StatusUnauthorized ,
1185
+ wantError : true ,
1186
+ errSubstr : "introspection unauthorized" ,
1187
+ hitNext : false ,
1188
+ },
1189
+ {
1190
+ name : "good => passes" ,
1191
+ auth : "Bearer good" ,
1192
+ wantStatus : http .StatusOK ,
1193
+ wantError : false ,
1194
+ hitNext : true ,
1195
+ },
1196
+ }
1197
+
1198
+ for _ , c := range cases {
1199
+ c := c
1200
+ t .Run (c .name , func (t * testing.T ) {
1201
+ t .Parallel ()
1202
+
1203
+ tv := & TokenValidator {
1204
+ issuer : issuer ,
1205
+ introspectURL : introspectTS .URL + "/introspect" ,
1206
+ clientID : "cid" ,
1207
+ clientSecret : "csecret" ,
1208
+ client : http .DefaultClient ,
1209
+ }
1210
+
1211
+ hit := false
1212
+ next := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
1213
+ hit = true
1214
+ w .WriteHeader (http .StatusOK )
1215
+ })
1216
+
1217
+ // NEW: server per subtest
1218
+ srv := httptest .NewServer (tv .Middleware (next ))
1219
+ t .Cleanup (srv .Close )
1220
+
1221
+ req , _ := http .NewRequest ("GET" , srv .URL + "/" , nil )
1222
+ req .Header .Set ("Authorization" , c .auth )
1223
+ res , err := http .DefaultClient .Do (req )
1224
+ if err != nil {
1225
+ t .Fatalf ("request failed: %v" , err )
1226
+ }
1227
+ defer res .Body .Close ()
1228
+
1229
+ if res .StatusCode != c .wantStatus {
1230
+ t .Fatalf ("status mismatch: want %d got %d" , c .wantStatus , res .StatusCode )
1231
+ }
1232
+ if hit != c .hitNext {
1233
+ t .Fatalf ("downstream hit mismatch: want %v got %v" , c .hitNext , hit )
1234
+ }
1235
+
1236
+ h := res .Header .Get ("WWW-Authenticate" )
1237
+ if c .wantStatus == http .StatusUnauthorized {
1238
+ if h == "" {
1239
+ t .Fatalf ("missing WWW-Authenticate header" )
1240
+ }
1241
+ p := parseAuthParams (h )
1242
+ if p ["realm" ] != issuer {
1243
+ t .Fatalf ("realm mismatch: %q" , p ["realm" ])
1244
+ }
1245
+ if c .wantError && p ["error" ] != "invalid_token" {
1246
+ t .Fatalf ("expected error=invalid_token, got %q" , p ["error" ])
1247
+ }
1248
+ if c .errSubstr != "" && ! strings .Contains (p ["error_description" ], c .errSubstr ) {
1249
+ t .Fatalf ("error_description %q missing %q" , p ["error_description" ], c .errSubstr )
1250
+ }
1251
+ } else if h != "" {
1252
+ t .Fatalf ("did not expect WWW-Authenticate header on success" )
1253
+ }
1254
+ })
1255
+ }
1256
+ }
1257
+
1258
+ func TestBuildWWWAuthenticate_Format (t * testing.T ) {
1259
+ t .Parallel ()
1260
+ tv := & TokenValidator {
1261
+ issuer : "https://issuer.example.com" ,
1262
+ resourceURL : "https://resource.example.com/.well-known/oauth-protected-resource" ,
1263
+ }
1264
+ got := tv .buildWWWAuthenticate (true , `failed to parse "token", reason` )
1265
+ want := `Bearer realm="https://issuer.example.com", resource_metadata="https://resource.example.com/.well-known/oauth-protected-resource", error="invalid_token", error_description="failed to parse \"token\", reason"`
1266
+ if got != want {
1267
+ t .Fatalf ("format mismatch:\n want: %s\n got: %s" , want , got )
1268
+ }
1269
+ }
0 commit comments