Skip to content

Commit 5ca4489

Browse files
cstocktonChris Stockton
andauthored
feat: add support for managing SSO providers by resource_id (#2081)
Some time ago a `resource_id` was added to the `sso_providers` table to support infrastructure as code use cases down the road. This change adds basic support for utilizing this field to manage SSO providers. Key changes: - Updated API for SSO providers to allow get, put, delete by `resource_id` - Extended `loadSSOProvider` to accept `resource_`-prefixed `idp_id` values - Added optional `resource_id` field to `SSOProvider` model - Implemented `FindSSOProviderByResourceID` in model layer - Renamed `FindAllSAMLProviders` to `FindAllSSOProviders` - Added filtering to the `/admin/sso/providers` via `?resource_id{,_prefix}=` - Included full E2E test coverage for SSO provider api --------- Co-authored-by: Chris Stockton <[email protected]>
1 parent f1b15ff commit 5ca4489

File tree

9 files changed

+961
-36
lines changed

9 files changed

+961
-36
lines changed

internal/api/api.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
293293
})
294294
})
295295
})
296-
297296
})
298297
})
299298

internal/api/apierrors/errorcode.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ const (
5757
ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email"
5858
ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists"
5959
ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found"
60+
ErrorCodeSSOProviderDisabled ErrorCode = "sso_provider_disabled"
6061
ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed"
6162
ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists"
6263
ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists"

internal/api/samlacs.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
103103
return apierrors.NewInternalServerError("Unable to find SSO Provider from SAML RelayState")
104104
}
105105

106+
if !ssoProvider.IsEnabled() {
107+
return apierrors.NewNotFoundError(
108+
apierrors.ErrorCodeSSOProviderDisabled,
109+
"SSO Provider assigned for this domain is currently disabled")
110+
}
111+
106112
initiatedBy = "sp"
107113
entityId = ssoProvider.SAMLProvider.EntityID
108114
redirectTo = relayState.RedirectTo
@@ -156,6 +162,12 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
156162
return err
157163
}
158164

165+
if !ssoProvider.IsEnabled() {
166+
return apierrors.NewNotFoundError(
167+
apierrors.ErrorCodeSSOProviderDisabled,
168+
"SSO Provider assigned for this domain is currently disabled")
169+
}
170+
159171
idpMetadata, err := ssoProvider.SAMLProvider.EntityDescriptor()
160172
if err != nil {
161173
return err

internal/api/sso.go

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,8 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
5252
if hasProviderID, err = params.validate(); err != nil {
5353
return err
5454
}
55-
codeChallengeMethod := params.CodeChallengeMethod
56-
codeChallenge := params.CodeChallenge
57-
58-
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
59-
return err
60-
}
61-
flowType := getFlowFromChallenge(params.CodeChallenge)
62-
var flowStateID *uuid.UUID
63-
flowStateID = nil
64-
if isPKCEFlow(flowType) {
65-
flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
66-
if err != nil {
67-
return err
68-
}
69-
flowStateID = &flowState.ID
70-
}
7155

7256
var ssoProvider *models.SSOProvider
73-
7457
if hasProviderID {
7558
ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID)
7659
if models.IsNotFoundError(err) {
@@ -87,6 +70,29 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
8770
}
8871
}
8972

73+
if !ssoProvider.IsEnabled() {
74+
return apierrors.NewNotFoundError(
75+
apierrors.ErrorCodeSSOProviderDisabled,
76+
"SSO Provider is currently disabled")
77+
}
78+
79+
codeChallengeMethod := params.CodeChallengeMethod
80+
codeChallenge := params.CodeChallenge
81+
82+
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
83+
return err
84+
}
85+
flowType := getFlowFromChallenge(params.CodeChallenge)
86+
var flowStateID *uuid.UUID
87+
flowStateID = nil
88+
if isPKCEFlow(flowType) {
89+
flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
90+
if err != nil {
91+
return err
92+
}
93+
flowStateID = &flowState.ID
94+
}
95+
9096
entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor()
9197
if err != nil {
9298
return apierrors.NewInternalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err)

internal/api/ssoadmin.go

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,32 @@ import (
1919
"github.com/supabase/auth/internal/utilities"
2020
)
2121

22-
// loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider
23-
// with that ID (or resource ID) and adds it to the context.
22+
// loadSSOProvider looks for an idp_id and first checks it for a "resource_"
23+
// prefix, if present the provider is loaded by resource_id. Otherwise the
24+
// provider is loaded by id.
2425
func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.Context, error) {
2526
ctx := r.Context()
2627
db := a.db.WithContext(ctx)
2728

28-
idpParam := chi.URLParam(r, "idp_id")
29+
var (
30+
provider *models.SSOProvider
31+
err error
32+
)
2933

30-
idpID, err := uuid.FromString(idpParam)
31-
if err != nil {
32-
// idpParam is not UUIDv4
33-
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found")
34+
const resourcePrefix = "resource_"
35+
idpParam := chi.URLParam(r, "idp_id")
36+
switch {
37+
case strings.HasPrefix(idpParam, resourcePrefix):
38+
resourceID := strings.TrimPrefix(idpParam, resourcePrefix)
39+
provider, err = models.FindSSOProviderByResourceID(db, resourceID)
40+
default:
41+
idpID, idpErr := uuid.FromString(idpParam)
42+
if idpErr != nil {
43+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found")
44+
}
45+
provider, err = models.FindSSOProviderByID(db, idpID)
3446
}
3547

36-
// idpParam is a UUIDv4
37-
provider, err := models.FindSSOProviderByID(db, idpID)
3848
if err != nil {
3949
if models.IsNotFoundError(err) {
4050
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found")
@@ -44,17 +54,16 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C
4454
}
4555

4656
observability.LogEntrySetField(r, "sso_provider_id", provider.ID.String())
47-
4857
return withSSOProvider(r.Context(), provider), nil
4958
}
5059

51-
// adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does
60+
// adminSSOProvidersList lists all SSO Identity Providers in the system. Does
5261
// not deal with pagination at this time.
5362
func (a *API) adminSSOProvidersList(w http.ResponseWriter, r *http.Request) error {
5463
ctx := r.Context()
5564
db := a.db.WithContext(ctx)
5665

57-
providers, err := models.FindAllSAMLProviders(db)
66+
providers, err := models.FindAllSSOProvidersByFilter(db, r.URL.Query())
5867
if err != nil {
5968
return err
6069
}
@@ -77,6 +86,9 @@ type CreateSSOProviderParams struct {
7786
Domains []string `json:"domains"`
7887
AttributeMapping models.SAMLAttributeMapping `json:"attribute_mapping"`
7988
NameIDFormat string `json:"name_id_format"`
89+
90+
ResourceID *string `json:"resource_id,omitempty"`
91+
Disabled *bool `json:"disabled,omitempty"`
8092
}
8193

8294
func (p *CreateSSOProviderParams) validate(forUpdate bool) error {
@@ -223,17 +235,23 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er
223235
}
224236

225237
provider := &models.SSOProvider{
238+
226239
// TODO handle Name, Description, Attribute Mapping
227240
SAMLProvider: models.SAMLProvider{
228241
EntityID: metadata.EntityID,
229242
MetadataXML: string(rawMetadata),
230243
},
231244
}
232245

246+
if params.ResourceID != nil {
247+
provider.ResourceID = params.ResourceID
248+
}
249+
if params.Disabled != nil {
250+
provider.Disabled = params.Disabled
251+
}
233252
if params.MetadataURL != "" {
234253
provider.SAMLProvider.MetadataURL = &params.MetadataURL
235254
}
236-
237255
if params.NameIDFormat != "" {
238256
provider.SAMLProvider.NameIDFormat = &params.NameIDFormat
239257
}
@@ -374,6 +392,28 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er
374392
}
375393
}
376394

395+
if params.ResourceID != nil {
396+
resourceID := *params.ResourceID
397+
switch {
398+
case resourceID == "" && provider.ResourceID != nil:
399+
provider.ResourceID = nil
400+
modified = true
401+
case resourceID != "" &&
402+
(provider.ResourceID == nil ||
403+
*provider.ResourceID != resourceID):
404+
provider.ResourceID = &resourceID
405+
modified = true
406+
}
407+
}
408+
409+
if params.Disabled != nil {
410+
disabled := *params.Disabled
411+
if provider.Disabled == nil || *provider.Disabled != disabled {
412+
provider.Disabled = &disabled
413+
modified = true
414+
}
415+
}
416+
377417
if modified {
378418
if err := db.Transaction(func(tx *storage.Connection) error {
379419
if terr := tx.Eager().Update(provider); terr != nil {

0 commit comments

Comments
 (0)