@@ -19,22 +19,32 @@ import (
19
19
"github.com/supabase/auth/internal/utilities"
20
20
)
21
21
22
- // loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider
23
- // with that ID (or resource ID) and adds it to the context.
22
+ // loadSSOProvider looks for an idp_id and first checks it for a "resource_"
23
+ // prefix, if present the provider is loaded by resource_id. Otherwise the
24
+ // provider is loaded by id.
24
25
func (a * API ) loadSSOProvider (w http.ResponseWriter , r * http.Request ) (context.Context , error ) {
25
26
ctx := r .Context ()
26
27
db := a .db .WithContext (ctx )
27
28
28
- idpParam := chi .URLParam (r , "idp_id" )
29
+ var (
30
+ provider * models.SSOProvider
31
+ err error
32
+ )
29
33
30
- idpID , err := uuid .FromString (idpParam )
31
- if err != nil {
32
- // idpParam is not UUIDv4
33
- return nil , apierrors .NewNotFoundError (apierrors .ErrorCodeSSOProviderNotFound , "SSO Identity Provider not found" )
34
+ const resourcePrefix = "resource_"
35
+ idpParam := chi .URLParam (r , "idp_id" )
36
+ switch {
37
+ case strings .HasPrefix (idpParam , resourcePrefix ):
38
+ resourceID := strings .TrimPrefix (idpParam , resourcePrefix )
39
+ provider , err = models .FindSSOProviderByResourceID (db , resourceID )
40
+ default :
41
+ idpID , idpErr := uuid .FromString (idpParam )
42
+ if idpErr != nil {
43
+ return nil , apierrors .NewNotFoundError (apierrors .ErrorCodeSSOProviderNotFound , "SSO Identity Provider not found" )
44
+ }
45
+ provider , err = models .FindSSOProviderByID (db , idpID )
34
46
}
35
47
36
- // idpParam is a UUIDv4
37
- provider , err := models .FindSSOProviderByID (db , idpID )
38
48
if err != nil {
39
49
if models .IsNotFoundError (err ) {
40
50
return nil , apierrors .NewNotFoundError (apierrors .ErrorCodeSSOProviderNotFound , "SSO Identity Provider not found" )
@@ -44,17 +54,16 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C
44
54
}
45
55
46
56
observability .LogEntrySetField (r , "sso_provider_id" , provider .ID .String ())
47
-
48
57
return withSSOProvider (r .Context (), provider ), nil
49
58
}
50
59
51
- // adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does
60
+ // adminSSOProvidersList lists all SSO Identity Providers in the system. Does
52
61
// not deal with pagination at this time.
53
62
func (a * API ) adminSSOProvidersList (w http.ResponseWriter , r * http.Request ) error {
54
63
ctx := r .Context ()
55
64
db := a .db .WithContext (ctx )
56
65
57
- providers , err := models .FindAllSAMLProviders (db )
66
+ providers , err := models .FindAllSSOProvidersByFilter (db , r . URL . Query () )
58
67
if err != nil {
59
68
return err
60
69
}
@@ -77,6 +86,9 @@ type CreateSSOProviderParams struct {
77
86
Domains []string `json:"domains"`
78
87
AttributeMapping models.SAMLAttributeMapping `json:"attribute_mapping"`
79
88
NameIDFormat string `json:"name_id_format"`
89
+
90
+ ResourceID * string `json:"resource_id,omitempty"`
91
+ Disabled * bool `json:"disabled,omitempty"`
80
92
}
81
93
82
94
func (p * CreateSSOProviderParams ) validate (forUpdate bool ) error {
@@ -223,17 +235,23 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er
223
235
}
224
236
225
237
provider := & models.SSOProvider {
238
+
226
239
// TODO handle Name, Description, Attribute Mapping
227
240
SAMLProvider : models.SAMLProvider {
228
241
EntityID : metadata .EntityID ,
229
242
MetadataXML : string (rawMetadata ),
230
243
},
231
244
}
232
245
246
+ if params .ResourceID != nil {
247
+ provider .ResourceID = params .ResourceID
248
+ }
249
+ if params .Disabled != nil {
250
+ provider .Disabled = params .Disabled
251
+ }
233
252
if params .MetadataURL != "" {
234
253
provider .SAMLProvider .MetadataURL = & params .MetadataURL
235
254
}
236
-
237
255
if params .NameIDFormat != "" {
238
256
provider .SAMLProvider .NameIDFormat = & params .NameIDFormat
239
257
}
@@ -374,6 +392,28 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er
374
392
}
375
393
}
376
394
395
+ if params .ResourceID != nil {
396
+ resourceID := * params .ResourceID
397
+ switch {
398
+ case resourceID == "" && provider .ResourceID != nil :
399
+ provider .ResourceID = nil
400
+ modified = true
401
+ case resourceID != "" &&
402
+ (provider .ResourceID == nil ||
403
+ * provider .ResourceID != resourceID ):
404
+ provider .ResourceID = & resourceID
405
+ modified = true
406
+ }
407
+ }
408
+
409
+ if params .Disabled != nil {
410
+ disabled := * params .Disabled
411
+ if provider .Disabled == nil || * provider .Disabled != disabled {
412
+ provider .Disabled = & disabled
413
+ modified = true
414
+ }
415
+ }
416
+
377
417
if modified {
378
418
if err := db .Transaction (func (tx * storage.Connection ) error {
379
419
if terr := tx .Eager ().Update (provider ); terr != nil {
0 commit comments