Skip to content

Commit f092c4d

Browse files
committed
Move allowedDomains to construction time for consistency
Previously allowedDomains was passed at call time to UserInfo(), while allowedOrgs was configured at construction. This inconsistency made the interface harder to understand and didn't follow the principle that access control should be configured when the provider is created. Changes: - Provider.UserInfo() now takes only context and token - All providers store allowedDomains internally - Factory accepts allowedDomains parameter - All tests updated to reflect new interface
1 parent c580b4e commit f092c4d

File tree

14 files changed

+69
-52
lines changed

14 files changed

+69
-52
lines changed

internal/idp/azure.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import "fmt"
44

55
// NewAzureProvider creates an Azure AD provider using OIDC discovery.
66
// Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL.
7-
func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OIDCProvider, error) {
7+
func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string, allowedDomains []string) (*OIDCProvider, error) {
88
if tenantID == "" {
99
return nil, fmt.Errorf("tenantId is required for Azure AD")
1010
}
@@ -15,11 +15,12 @@ func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OI
1515
)
1616

1717
return NewOIDCProvider(OIDCConfig{
18-
ProviderType: "azure",
19-
DiscoveryURL: discoveryURL,
20-
ClientID: clientID,
21-
ClientSecret: clientSecret,
22-
RedirectURI: redirectURI,
23-
Scopes: []string{"openid", "email", "profile"},
18+
ProviderType: "azure",
19+
DiscoveryURL: discoveryURL,
20+
ClientID: clientID,
21+
ClientSecret: clientSecret,
22+
RedirectURI: redirectURI,
23+
Scopes: []string{"openid", "email", "profile"},
24+
AllowedDomains: allowedDomains,
2425
})
2526
}

internal/idp/azure_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
)
99

1010
func TestNewAzureProvider_MissingTenantID(t *testing.T) {
11-
_, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback")
11+
_, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback", nil)
1212

1313
require.Error(t, err)
1414
assert.Contains(t, err.Error(), "tenantId is required")

internal/idp/factory.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ import (
77
)
88

99
// NewProvider creates a Provider based on the IDPConfig.
10-
func NewProvider(cfg config.IDPConfig) (Provider, error) {
10+
// allowedDomains configures domain-based access control for all provider types.
11+
func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error) {
1112
switch cfg.Provider {
1213
case "google":
1314
return NewGoogleProvider(
1415
cfg.ClientID,
1516
string(cfg.ClientSecret),
1617
cfg.RedirectURI,
18+
allowedDomains,
1719
), nil
1820

1921
case "azure":
@@ -22,13 +24,15 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) {
2224
cfg.ClientID,
2325
string(cfg.ClientSecret),
2426
cfg.RedirectURI,
27+
allowedDomains,
2528
)
2629

2730
case "github":
2831
return NewGitHubProvider(
2932
cfg.ClientID,
3033
string(cfg.ClientSecret),
3134
cfg.RedirectURI,
35+
allowedDomains,
3236
cfg.AllowedOrgs,
3337
), nil
3438

@@ -43,6 +47,7 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) {
4347
ClientSecret: string(cfg.ClientSecret),
4448
RedirectURI: cfg.RedirectURI,
4549
Scopes: cfg.Scopes,
50+
AllowedDomains: allowedDomains,
4651
})
4752

4853
default:

internal/idp/factory_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func TestNewProvider(t *testing.T) {
8787

8888
for _, tt := range tests {
8989
t.Run(tt.name, func(t *testing.T) {
90-
provider, err := NewProvider(tt.cfg)
90+
provider, err := NewProvider(tt.cfg, nil)
9191

9292
if tt.wantErr {
9393
require.Error(t, err)

internal/idp/github.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ import (
1414
// GitHubProvider implements the Provider interface for GitHub OAuth.
1515
// GitHub uses OAuth 2.0 (not OIDC) and has its own API for user info and org membership.
1616
type GitHubProvider struct {
17-
config oauth2.Config
18-
apiBaseURL string // defaults to https://api.github.com, can be overridden for testing
19-
allowedOrgs []string // organizations users must be members of (empty = no restriction)
17+
config oauth2.Config
18+
apiBaseURL string // defaults to https://api.github.com, can be overridden for testing
19+
allowedDomains []string // email domains users must belong to (empty = no restriction)
20+
allowedOrgs []string // organizations users must be members of (empty = no restriction)
2021
}
2122

2223
// githubUserResponse represents GitHub's user API response.
@@ -41,8 +42,7 @@ type githubOrgResponse struct {
4142
}
4243

4344
// NewGitHubProvider creates a new GitHub OAuth provider.
44-
// allowedOrgs specifies organizations users must be members of (empty = no restriction).
45-
func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedOrgs []string) *GitHubProvider {
45+
func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedDomains, allowedOrgs []string) *GitHubProvider {
4646
return &GitHubProvider{
4747
config: oauth2.Config{
4848
ClientID: clientID,
@@ -51,8 +51,9 @@ func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedOrgs [
5151
Scopes: []string{"user:email", "read:org"},
5252
Endpoint: github.Endpoint,
5353
},
54-
apiBaseURL: "https://api.github.com",
55-
allowedOrgs: allowedOrgs,
54+
apiBaseURL: "https://api.github.com",
55+
allowedDomains: allowedDomains,
56+
allowedOrgs: allowedOrgs,
5657
}
5758
}
5859

@@ -72,9 +73,9 @@ func (p *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*oauth2
7273
}
7374

7475
// UserInfo fetches user information from GitHub's API.
75-
// Validates organization membership if allowedOrgs was configured at construction.
76+
// Validates domain and organization membership based on construction-time config.
7677
// TODO: Consider caching org membership to reduce API calls.
77-
func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) {
78+
func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) {
7879
client := p.config.Client(ctx, token)
7980

8081
// Fetch user profile
@@ -99,7 +100,7 @@ func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token, allo
99100
domain := emailutil.ExtractDomain(email)
100101

101102
// Validate domain if configured
102-
if err := ValidateDomain(domain, allowedDomains); err != nil {
103+
if err := ValidateDomain(domain, p.allowedDomains); err != nil {
103104
return nil, err
104105
}
105106

internal/idp/github_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ import (
1313
)
1414

1515
func TestGitHubProvider_Type(t *testing.T) {
16-
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil)
16+
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil)
1717
assert.Equal(t, "github", provider.Type())
1818
}
1919

2020
func TestGitHubProvider_AuthURL(t *testing.T) {
21-
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil)
21+
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil)
2222

2323
authURL := provider.AuthURL("test-state")
2424

@@ -184,12 +184,13 @@ func TestGitHubProvider_UserInfo(t *testing.T) {
184184
TokenURL: server.URL + "/token",
185185
},
186186
},
187-
apiBaseURL: server.URL,
188-
allowedOrgs: tt.allowedOrgs,
187+
apiBaseURL: server.URL,
188+
allowedDomains: tt.allowedDomains,
189+
allowedOrgs: tt.allowedOrgs,
189190
}
190191

191192
token := &oauth2.Token{AccessToken: "test-token"}
192-
userInfo, err := provider.UserInfo(context.Background(), token, tt.allowedDomains)
193+
userInfo, err := provider.UserInfo(context.Background(), token)
193194

194195
if tt.wantErr {
195196
require.Error(t, err)
@@ -244,7 +245,7 @@ func TestGitHubProvider_UserInfo_APIErrors(t *testing.T) {
244245
}
245246

246247
token := &oauth2.Token{AccessToken: "test-token"}
247-
_, err := provider.UserInfo(context.Background(), token, nil)
248+
_, err := provider.UserInfo(context.Background(), token)
248249

249250
require.Error(t, err)
250251
assert.Contains(t, err.Error(), tt.errContains)
@@ -277,7 +278,7 @@ func TestGitHubProvider_UserInfo_NoVerifiedEmail(t *testing.T) {
277278
}
278279

279280
token := &oauth2.Token{AccessToken: "test-token"}
280-
_, err := provider.UserInfo(context.Background(), token, nil)
281+
_, err := provider.UserInfo(context.Background(), token)
281282

282283
require.Error(t, err)
283284
assert.Contains(t, err.Error(), "no verified email")

internal/idp/google.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ import (
1414
// GoogleProvider implements the Provider interface for Google OAuth.
1515
// Google has specific quirks like `hd` for hosted domain and `verified_email` field.
1616
type GoogleProvider struct {
17-
config oauth2.Config
18-
userInfoURL string
17+
config oauth2.Config
18+
userInfoURL string
19+
allowedDomains []string
1920
}
2021

2122
// googleUserInfoResponse represents Google's userinfo response.
@@ -30,7 +31,7 @@ type googleUserInfoResponse struct {
3031
}
3132

3233
// NewGoogleProvider creates a new Google OAuth provider.
33-
func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvider {
34+
func NewGoogleProvider(clientID, clientSecret, redirectURI string, allowedDomains []string) *GoogleProvider {
3435
return &GoogleProvider{
3536
config: oauth2.Config{
3637
ClientID: clientID,
@@ -39,7 +40,8 @@ func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvid
3940
Scopes: []string{"openid", "profile", "email"},
4041
Endpoint: google.Endpoint,
4142
},
42-
userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
43+
userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
44+
allowedDomains: allowedDomains,
4345
}
4446
}
4547

@@ -62,7 +64,7 @@ func (p *GoogleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2
6264
}
6365

6466
// UserInfo fetches user information from Google's userinfo endpoint.
65-
func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) {
67+
func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) {
6668
client := p.config.Client(ctx, token)
6769

6870
resp, err := client.Get(p.userInfoURL)
@@ -87,7 +89,7 @@ func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token, allo
8789
}
8890

8991
// Validate domain if configured
90-
if err := ValidateDomain(domain, allowedDomains); err != nil {
92+
if err := ValidateDomain(domain, p.allowedDomains); err != nil {
9193
return nil, err
9294
}
9395

internal/idp/google_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ import (
1313
)
1414

1515
func TestGoogleProvider_Type(t *testing.T) {
16-
provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback")
16+
provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", nil)
1717
assert.Equal(t, "google", provider.Type())
1818
}
1919

2020
func TestGoogleProvider_AuthURL(t *testing.T) {
21-
provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback")
21+
provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", nil)
2222

2323
authURL := provider.AuthURL("test-state")
2424

@@ -102,11 +102,12 @@ func TestGoogleProvider_UserInfo(t *testing.T) {
102102
TokenURL: server.URL + "/token",
103103
},
104104
},
105-
userInfoURL: server.URL,
105+
userInfoURL: server.URL,
106+
allowedDomains: tt.allowedDomains,
106107
}
107108
token := &oauth2.Token{AccessToken: "test-token"}
108109

109-
userInfo, err := provider.UserInfo(context.Background(), token, tt.allowedDomains)
110+
userInfo, err := provider.UserInfo(context.Background(), token)
110111

111112
if tt.wantErr {
112113
require.Error(t, err)
@@ -142,7 +143,7 @@ func TestGoogleProvider_UserInfo_ServerError(t *testing.T) {
142143
}
143144
token := &oauth2.Token{AccessToken: "test-token"}
144145

145-
_, err := provider.UserInfo(context.Background(), token, nil)
146+
_, err := provider.UserInfo(context.Background(), token)
146147

147148
require.Error(t, err)
148149
assert.Contains(t, err.Error(), "status 500")

internal/idp/oidc.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@ type OIDCConfig struct {
2929
ClientSecret string
3030
RedirectURI string
3131
Scopes []string
32+
33+
// Access control.
34+
AllowedDomains []string
3235
}
3336

3437
// OIDCProvider implements the Provider interface for OIDC-compliant identity providers.
3538
type OIDCProvider struct {
36-
providerType string
37-
config oauth2.Config
38-
userInfoURL string
39+
providerType string
40+
config oauth2.Config
41+
userInfoURL string
42+
allowedDomains []string
3943
}
4044

4145
// oidcDiscoveryDocument represents the OIDC discovery document.
@@ -99,7 +103,8 @@ func NewOIDCProvider(cfg OIDCConfig) (*OIDCProvider, error) {
99103
TokenURL: tokenURL,
100104
},
101105
},
102-
userInfoURL: userInfoURL,
106+
userInfoURL: userInfoURL,
107+
allowedDomains: cfg.AllowedDomains,
103108
}, nil
104109
}
105110

@@ -147,7 +152,7 @@ func (p *OIDCProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.T
147152

148153
// UserInfo fetches user information from the OIDC userinfo endpoint.
149154
// TODO: Add ID token validation as optimization (avoids network call).
150-
func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) {
155+
func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) {
151156
client := p.config.Client(ctx, token)
152157
resp, err := client.Get(p.userInfoURL)
153158
if err != nil {
@@ -167,7 +172,7 @@ func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowe
167172
domain := emailutil.ExtractDomain(userInfoResp.Email)
168173

169174
// Validate domain if configured
170-
if err := ValidateDomain(domain, allowedDomains); err != nil {
175+
if err := ValidateDomain(domain, p.allowedDomains); err != nil {
171176
return nil, err
172177
}
173178

internal/idp/oidc_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func TestOIDCProvider_UserInfo(t *testing.T) {
145145
require.NoError(t, err)
146146

147147
token := &oauth2.Token{AccessToken: "test-token"}
148-
userInfo, err := provider.UserInfo(context.Background(), token, nil)
148+
userInfo, err := provider.UserInfo(context.Background(), token)
149149

150150
require.NoError(t, err)
151151
require.NotNil(t, userInfo)
@@ -175,11 +175,12 @@ func TestOIDCProvider_UserInfo_DomainValidation(t *testing.T) {
175175
ClientID: "client-id",
176176
ClientSecret: "client-secret",
177177
RedirectURI: "https://example.com/callback",
178+
AllowedDomains: []string{"example.com"},
178179
})
179180
require.NoError(t, err)
180181

181182
token := &oauth2.Token{AccessToken: "test-token"}
182-
_, err = provider.UserInfo(context.Background(), token, []string{"example.com"})
183+
_, err = provider.UserInfo(context.Background(), token)
183184

184185
require.Error(t, err)
185186
assert.Contains(t, err.Error(), "domain 'other.com' is not allowed")

0 commit comments

Comments
 (0)