Skip to content

Commit 15a2669

Browse files
authored
feat: validate OIDC discovery document on custom provider create/update (#2390)
## Summary - Fetch and validate OIDC discovery documents when creating or updating custom OIDC providers, catching misconfigurations at admin time - Validate required fields (issuer, authorization_endpoint, token_endpoint, jwks_uri) and verify the discovery issuer matches the configured issuer per OpenID Connect Discovery 1.0, Section 4.3 - Store validated discovery in the existing `cached_discovery` DB column - Invalidate in-memory OIDC cache on provider update (issuer change) and delete - Keep network calls outside DB transactions to avoid holding connections open
1 parent 045001b commit 15a2669

File tree

4 files changed

+275
-63
lines changed

4 files changed

+275
-63
lines changed

internal/api/custom_oauth_admin.go

Lines changed: 145 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package api
22

33
import (
4+
"context"
5+
"encoding/json"
6+
"io"
47
"net/http"
58
"slices"
69
"strings"
10+
"time"
711

812
"github.com/go-chi/chi/v5"
913
popslices "github.com/gobuffalo/pop/v6/slices"
@@ -208,6 +212,17 @@ func (a *API) adminCustomOAuthProviderCreate(w http.ResponseWriter, r *http.Requ
208212
// Create provider model
209213
provider := buildProviderFromParams(params, providerType)
210214

215+
// For OIDC providers, fetch and validate the discovery document before persisting.
216+
// This catches misconfigurations (bad issuer URL, missing endpoints) at admin time
217+
// rather than failing silently at user login time.
218+
if providerType == models.ProviderTypeOIDC {
219+
discovery, err := fetchAndValidateDiscovery(ctx, provider.GetDiscoveryURL(), params.Issuer)
220+
if err != nil {
221+
return err
222+
}
223+
provider.SetDiscoveryCache(discovery)
224+
}
225+
211226
// Encrypt and store client secret
212227
if err := provider.SetClientSecret(params.ClientSecret, config.Security.DBEncryption); err != nil {
213228
return apierrors.NewInternalServerError("Error encrypting custom OAuth provider client secret").WithInternalError(err)
@@ -271,29 +286,46 @@ func (a *API) adminCustomOAuthProviderUpdate(w http.ResponseWriter, r *http.Requ
271286
}
272287
}
273288

274-
var provider *models.CustomOAuthProvider
275-
err := db.Transaction(func(tx *storage.Connection) error {
276-
var terr error
277-
provider, terr = models.FindCustomOAuthProviderByIdentifier(tx, identifier)
278-
if terr != nil {
279-
if models.IsNotFoundError(terr) {
280-
return apierrors.NewNotFoundError(apierrors.ErrorCodeCustomProviderNotFound, "Custom OAuth provider not found")
281-
}
282-
return apierrors.NewInternalServerError("Error retrieving custom OAuth provider").WithInternalError(terr)
289+
// Read the existing provider outside the write transaction so the
290+
// network call (discovery fetch) doesn't hold a transaction open.
291+
provider, err := models.FindCustomOAuthProviderByIdentifier(db, identifier)
292+
if err != nil {
293+
if models.IsNotFoundError(err) {
294+
return apierrors.NewNotFoundError(apierrors.ErrorCodeCustomProviderNotFound, "Custom OAuth provider not found")
283295
}
296+
return apierrors.NewInternalServerError("Error retrieving custom OAuth provider").WithInternalError(err)
297+
}
284298

285-
// Update provider with new non-secret values
286-
if terr := updateProviderFromParams(provider, params); terr != nil {
287-
return terr
299+
// Capture the current issuer before applying updates so we can
300+
// invalidate the in-memory cache if it changes.
301+
var oldIssuer string
302+
if provider.IsOIDC() && provider.Issuer != nil {
303+
oldIssuer = *provider.Issuer
304+
}
305+
306+
// Update provider with new non-secret values
307+
if err := updateProviderFromParams(provider, params); err != nil {
308+
return err
309+
}
310+
311+
// For OIDC providers, re-validate discovery when the issuer or discovery URL changes.
312+
// This network call happens outside the transaction to avoid holding it open.
313+
if provider.IsOIDC() && (params.Issuer != "" || params.DiscoveryURL != nil) {
314+
discovery, err := fetchAndValidateDiscovery(ctx, provider.GetDiscoveryURL(), *provider.Issuer)
315+
if err != nil {
316+
return err
288317
}
318+
provider.SetDiscoveryCache(discovery)
319+
}
289320

290-
// If a new client secret is provided, encrypt and store it (likely move to out of the transaction)
291-
if params.ClientSecret != "" {
292-
if terr := provider.SetClientSecret(params.ClientSecret, config.Security.DBEncryption); terr != nil {
293-
return apierrors.NewInternalServerError("Error encrypting custom OAuth provider client secret").WithInternalError(terr)
294-
}
321+
// If a new client secret is provided, encrypt and store it
322+
if params.ClientSecret != "" {
323+
if err := provider.SetClientSecret(params.ClientSecret, config.Security.DBEncryption); err != nil {
324+
return apierrors.NewInternalServerError("Error encrypting custom OAuth provider client secret").WithInternalError(err)
295325
}
326+
}
296327

328+
err = db.Transaction(func(tx *storage.Connection) error {
297329
if terr := models.UpdateCustomOAuthProvider(tx, provider); terr != nil {
298330
return apierrors.NewInternalServerError("Error updating custom OAuth provider").WithInternalError(terr)
299331
}
@@ -307,6 +339,17 @@ func (a *API) adminCustomOAuthProviderUpdate(w http.ResponseWriter, r *http.Requ
307339
return err
308340
}
309341

342+
// Invalidate in-memory OIDC cache if the issuer changed or discovery was refreshed,
343+
// so the next auth request picks up the new configuration.
344+
if provider.IsOIDC() && provider.Issuer != nil {
345+
if oldIssuer != "" && oldIssuer != *provider.Issuer {
346+
a.oidcCache.Invalidate(oldIssuer)
347+
}
348+
if params.Issuer != "" || params.DiscoveryURL != nil {
349+
a.oidcCache.Invalidate(*provider.Issuer)
350+
}
351+
}
352+
310353
return sendJSON(w, http.StatusOK, provider)
311354
}
312355

@@ -326,6 +369,7 @@ func (a *API) adminCustomOAuthProviderDelete(w http.ResponseWriter, r *http.Requ
326369

327370
observability.LogEntrySetField(r, "identifier", identifier)
328371

372+
var issuerToInvalidate string
329373
err := db.Transaction(func(tx *storage.Connection) error {
330374
provider, terr := models.FindCustomOAuthProviderByIdentifier(tx, identifier)
331375
if terr != nil {
@@ -335,6 +379,10 @@ func (a *API) adminCustomOAuthProviderDelete(w http.ResponseWriter, r *http.Requ
335379
return apierrors.NewInternalServerError("Error retrieving custom OAuth provider").WithInternalError(terr)
336380
}
337381

382+
if provider.IsOIDC() && provider.Issuer != nil {
383+
issuerToInvalidate = *provider.Issuer
384+
}
385+
338386
// TODO: Add admin audit logging here (see create endpoint for details)
339387

340388
if terr := models.DeleteCustomOAuthProvider(tx, provider.ID); terr != nil {
@@ -348,6 +396,10 @@ func (a *API) adminCustomOAuthProviderDelete(w http.ResponseWriter, r *http.Requ
348396
return err
349397
}
350398

399+
if issuerToInvalidate != "" {
400+
a.oidcCache.Invalidate(issuerToInvalidate)
401+
}
402+
351403
w.WriteHeader(http.StatusNoContent)
352404
return nil
353405
}
@@ -619,6 +671,82 @@ func validateAuthorizationParams(params map[string]interface{}) error {
619671
return nil
620672
}
621673

674+
// maxDiscoveryResponseSize is the maximum size of an OIDC discovery response body (1 MB).
675+
const maxDiscoveryResponseSize = 1 << 20
676+
677+
// discoveryFetchTimeout is the timeout for fetching an OIDC discovery document.
678+
const discoveryFetchTimeout = 10 * time.Second
679+
680+
// fetchAndValidateDiscovery fetches the OIDC discovery document from the
681+
// provider's discovery URL and validates that it contains the required fields
682+
// per the OpenID Connect Discovery 1.0 specification. It also verifies that
683+
// the issuer in the discovery document matches the expected issuer.
684+
func fetchAndValidateDiscovery(ctx context.Context, discoveryURL, expectedIssuer string) (*models.OIDCDiscovery, error) {
685+
resp, err := utilities.FetchURLWithTimeout(ctx, discoveryURL, discoveryFetchTimeout)
686+
if err != nil {
687+
return nil, apierrors.NewBadRequestError(
688+
apierrors.ErrorCodeValidationFailed,
689+
"Failed to fetch OIDC discovery document from %q: %v", discoveryURL, err,
690+
)
691+
}
692+
defer resp.Body.Close()
693+
694+
if resp.StatusCode != http.StatusOK {
695+
return nil, apierrors.NewBadRequestError(
696+
apierrors.ErrorCodeValidationFailed,
697+
"OIDC discovery endpoint %q returned HTTP %d, expected 200", discoveryURL, resp.StatusCode,
698+
)
699+
}
700+
701+
body, err := io.ReadAll(io.LimitReader(resp.Body, maxDiscoveryResponseSize))
702+
if err != nil {
703+
return nil, apierrors.NewBadRequestError(
704+
apierrors.ErrorCodeValidationFailed,
705+
"Failed to read OIDC discovery response from %q", discoveryURL,
706+
)
707+
}
708+
709+
var discovery models.OIDCDiscovery
710+
if err := json.Unmarshal(body, &discovery); err != nil {
711+
return nil, apierrors.NewBadRequestError(
712+
apierrors.ErrorCodeValidationFailed,
713+
"OIDC discovery document from %q is not valid JSON", discoveryURL,
714+
)
715+
}
716+
717+
// Validate required fields per OpenID Connect Discovery 1.0 spec
718+
var missing []string
719+
if discovery.Issuer == "" {
720+
missing = append(missing, "issuer")
721+
}
722+
if discovery.AuthorizationEndpoint == "" {
723+
missing = append(missing, "authorization_endpoint")
724+
}
725+
if discovery.TokenEndpoint == "" {
726+
missing = append(missing, "token_endpoint")
727+
}
728+
if discovery.JwksURI == "" {
729+
missing = append(missing, "jwks_uri")
730+
}
731+
if len(missing) > 0 {
732+
return nil, apierrors.NewBadRequestError(
733+
apierrors.ErrorCodeValidationFailed,
734+
"OIDC discovery document is missing required fields: %s", strings.Join(missing, ", "),
735+
)
736+
}
737+
738+
// The issuer in the discovery document MUST exactly match the expected issuer
739+
// per OpenID Connect Discovery 1.0, Section 4.3.
740+
if discovery.Issuer != expectedIssuer {
741+
return nil, apierrors.NewBadRequestError(
742+
apierrors.ErrorCodeValidationFailed,
743+
"OIDC discovery issuer mismatch: discovery document reports %q but expected %q", discovery.Issuer, expectedIssuer,
744+
)
745+
}
746+
747+
return &discovery, nil
748+
}
749+
622750
// validateAttributeMapping ensures no sensitive system fields are targeted
623751
func validateAttributeMapping(mapping map[string]interface{}) error {
624752
if mapping == nil {

internal/api/custom_oauth_admin_test.go

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
popslices "github.com/gobuffalo/pop/v6/slices"
1313
jwt "github.com/golang-jwt/jwt/v5"
14+
"github.com/gofrs/uuid"
1415
"github.com/stretchr/testify/assert"
1516
"github.com/stretchr/testify/require"
1617
"github.com/stretchr/testify/suite"
@@ -107,40 +108,68 @@ func (ts *CustomOAuthAdminTestSuite) TestCreateOAuth2Provider() {
107108
assert.Empty(ts.T(), provider.ClientSecret)
108109
}
109110

110-
func (ts *CustomOAuthAdminTestSuite) TestCreateOIDCProvider() {
111-
payload := map[string]interface{}{
112-
"provider_type": "oidc",
113-
"identifier": "custom:self-keycloak",
114-
"name": "Keycloak",
115-
"client_id": "test-client-id",
116-
"client_secret": "test-client-secret",
117-
"issuer": "https://example.com/realms/myrealm",
118-
"scopes": []string{"profile", "email"},
119-
"pkce_enabled": true,
120-
"enabled": true,
121-
}
111+
func (ts *CustomOAuthAdminTestSuite) TestCreateOIDCProviderValidatesDiscovery() {
112+
ts.Run("Unreachable issuer rejected", func() {
113+
// An OIDC provider with an unresolvable issuer is caught by URL validation
114+
payload := map[string]interface{}{
115+
"provider_type": "oidc",
116+
"identifier": "custom:bad-issuer",
117+
"name": "Bad Issuer",
118+
"client_id": "test-client-id",
119+
"client_secret": "test-client-secret",
120+
"issuer": "https://unreachable.example.com/realms/myrealm",
121+
"scopes": []string{"profile", "email"},
122+
}
122123

123-
var body bytes.Buffer
124-
require.NoError(ts.T(), json.NewEncoder(&body).Encode(payload))
124+
var body bytes.Buffer
125+
require.NoError(ts.T(), json.NewEncoder(&body).Encode(payload))
125126

126-
req := httptest.NewRequest(http.MethodPost, "/admin/custom-providers", &body)
127-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
127+
req := httptest.NewRequest(http.MethodPost, "/admin/custom-providers", &body)
128+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
128129

129-
w := httptest.NewRecorder()
130-
ts.API.handler.ServeHTTP(w, req)
130+
w := httptest.NewRecorder()
131+
ts.API.handler.ServeHTTP(w, req)
131132

132-
require.Equal(ts.T(), http.StatusCreated, w.Code)
133+
require.Equal(ts.T(), http.StatusBadRequest, w.Code)
133134

134-
var provider models.CustomOAuthProvider
135-
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&provider))
135+
var apiErr apierrors.HTTPError
136+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&apiErr))
137+
assert.Equal(ts.T(), apierrors.ErrorCodeValidationFailed, apiErr.ErrorCode)
138+
})
136139

137-
assert.Equal(ts.T(), models.ProviderTypeOIDC, provider.ProviderType)
138-
assert.Equal(ts.T(), "custom:self-keycloak", provider.Identifier)
139-
assert.Contains(ts.T(), provider.Scopes, "openid") // Auto-added for OIDC
140-
assert.Contains(ts.T(), provider.Scopes, "profile")
140+
ts.Run("Invalid discovery document rejected", func() {
141+
// Use a real resolvable domain whose discovery endpoint returns non-JSON.
142+
// This passes URL validation but fails at discovery fetch/parse.
143+
payload := map[string]interface{}{
144+
"provider_type": "oidc",
145+
"identifier": "custom:bad-discovery",
146+
"name": "Bad Discovery",
147+
"client_id": "test-client-id",
148+
"client_secret": "test-client-secret",
149+
"issuer": "https://example.com",
150+
"scopes": []string{"profile", "email"},
151+
}
141152

142-
// Ensure client secret is not exposed in JSON
143-
assert.Empty(ts.T(), provider.ClientSecret)
153+
var body bytes.Buffer
154+
require.NoError(ts.T(), json.NewEncoder(&body).Encode(payload))
155+
156+
req := httptest.NewRequest(http.MethodPost, "/admin/custom-providers", &body)
157+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
158+
159+
w := httptest.NewRecorder()
160+
ts.API.handler.ServeHTTP(w, req)
161+
162+
require.Equal(ts.T(), http.StatusBadRequest, w.Code)
163+
164+
var apiErr apierrors.HTTPError
165+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&apiErr))
166+
assert.Equal(ts.T(), apierrors.ErrorCodeValidationFailed, apiErr.ErrorCode)
167+
// Should fail at discovery fetch or parse stage
168+
assert.True(ts.T(),
169+
strings.Contains(apiErr.Message, "OIDC discovery") ||
170+
strings.Contains(apiErr.Message, "Failed to fetch"),
171+
"Expected discovery-related error, got: %s", apiErr.Message)
172+
})
144173
}
145174

146175
func (ts *CustomOAuthAdminTestSuite) TestCreateProviderValidation() {
@@ -429,7 +458,7 @@ func (ts *CustomOAuthAdminTestSuite) TestListProviders() {
429458
// Create some providers
430459
ts.createProvider(ts.createTestOAuth2Payload("oauth2-1"), http.StatusCreated)
431460
ts.createProvider(ts.createTestOAuth2Payload("oauth2-2"), http.StatusCreated)
432-
ts.createProvider(ts.createTestOIDCPayload("oidc-1", "https://oidc1.example.com"), http.StatusCreated)
461+
ts.createOIDCProviderInDB("oidc-1", "https://oidc1.example.com")
433462

434463
req := httptest.NewRequest(http.MethodGet, "/admin/custom-providers", nil)
435464
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
@@ -463,8 +492,8 @@ func (ts *CustomOAuthAdminTestSuite) TestListProvidersEmptyReturnsArray() {
463492
func (ts *CustomOAuthAdminTestSuite) TestListProvidersWithTypeFilter() {
464493
// Create mixed providers
465494
ts.createProvider(ts.createTestOAuth2Payload("oauth2-1"), http.StatusCreated)
466-
ts.createProvider(ts.createTestOIDCPayload("oidc-1", "https://oidc1.example.com"), http.StatusCreated)
467-
ts.createProvider(ts.createTestOIDCPayload("oidc-2", "https://oidc2.example.com"), http.StatusCreated)
495+
ts.createOIDCProviderInDB("oidc-1", "https://oidc1.example.com")
496+
ts.createOIDCProviderInDB("oidc-2", "https://oidc2.example.com")
468497

469498
// Filter by OAuth2
470499
req := httptest.NewRequest(http.MethodGet, "/admin/custom-providers?type=oauth2", nil)
@@ -617,25 +646,31 @@ func (ts *CustomOAuthAdminTestSuite) createTestOAuth2Payload(identifier string)
617646
}
618647
}
619648

620-
func (ts *CustomOAuthAdminTestSuite) createTestOIDCPayload(identifier, issuer string) map[string]interface{} {
649+
// createOIDCProviderInDB inserts an OIDC provider directly into the database,
650+
// bypassing the admin handler (and its discovery validation). Use this for tests
651+
// that need an OIDC provider to exist but aren't testing the create flow.
652+
func (ts *CustomOAuthAdminTestSuite) createOIDCProviderInDB(identifier, issuer string) *models.CustomOAuthProvider {
621653
if !strings.HasPrefix(identifier, "custom:") {
622654
identifier = "custom:" + identifier
623655
}
624-
// If issuer is not provided or uses non-resolvable domain, use example.com
625-
if issuer == "" || strings.Contains(issuer, "oidc1.example.com") || strings.Contains(issuer, "oidc2.example.com") {
626-
issuer = "https://example.com/realms/" + identifier
627-
}
628-
return map[string]interface{}{
629-
"provider_type": "oidc",
630-
"identifier": identifier,
631-
"name": "Test OIDC Provider",
632-
"client_id": "test-client-id",
633-
"client_secret": "test-client-secret",
634-
"issuer": issuer,
635-
"scopes": []string{"profile", "email"},
636-
"pkce_enabled": true,
637-
"enabled": true,
656+
id, err := uuid.NewV4()
657+
require.NoError(ts.T(), err)
658+
659+
provider := &models.CustomOAuthProvider{
660+
ID: id,
661+
ProviderType: models.ProviderTypeOIDC,
662+
Identifier: identifier,
663+
Name: "Test OIDC Provider",
664+
ClientID: "test-client-id",
665+
ClientSecret: "test-client-secret",
666+
Issuer: &issuer,
667+
Scopes: []string{"openid", "profile", "email"},
668+
PKCEEnabled: true,
669+
Enabled: true,
638670
}
671+
672+
require.NoError(ts.T(), models.CreateCustomOAuthProvider(ts.API.db, provider))
673+
return provider
639674
}
640675

641676
func (ts *CustomOAuthAdminTestSuite) createProvider(payload map[string]interface{}, expectedStatus int) *httptest.ResponseRecorder {

0 commit comments

Comments
 (0)