Skip to content

Commit 0e099bd

Browse files
committed
Separate authentication from authorization in IDP providers
Providers now only fetch identity — access control (domain/org checks) is centralized in a single validateAccess method on the auth handler. This means IDP errors (unreachable provider) produce ErrServerError while policy rejections produce ErrAccessDenied, and adding new access rules no longer requires touching every provider. GitHub always fetches orgs now (scope was already requested unconditionally). UserInfo struct renamed to Identity to avoid collision with the method name. ParseClientRequest relocated to oauth package as ParseClientRegistration. Deleted deprecated ProtectedResourceMetadata and inlined the workaround logic into the handler that used it.
1 parent 9ccd1ff commit 0e099bd

20 files changed

+279
-438
lines changed

internal/idp/azure.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import "fmt"
44

55
// NewAzureProvider creates an Azure AD provider using OIDC discovery.
66
// Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL.
7-
func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string, allowedDomains []string) (*OIDCProvider, error) {
7+
func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OIDCProvider, error) {
88
if tenantID == "" {
99
return nil, fmt.Errorf("tenantId is required for Azure AD")
1010
}
@@ -15,12 +15,11 @@ func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string, allo
1515
)
1616

1717
return NewOIDCProvider(OIDCConfig{
18-
ProviderType: "azure",
19-
DiscoveryURL: discoveryURL,
20-
ClientID: clientID,
21-
ClientSecret: clientSecret,
22-
RedirectURI: redirectURI,
23-
Scopes: []string{"openid", "email", "profile"},
24-
AllowedDomains: allowedDomains,
18+
ProviderType: "azure",
19+
DiscoveryURL: discoveryURL,
20+
ClientID: clientID,
21+
ClientSecret: clientSecret,
22+
RedirectURI: redirectURI,
23+
Scopes: []string{"openid", "email", "profile"},
2524
})
2625
}

internal/idp/azure_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
)
99

1010
func TestNewAzureProvider_MissingTenantID(t *testing.T) {
11-
_, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback", nil)
11+
_, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback")
1212

1313
require.Error(t, err)
1414
assert.Contains(t, err.Error(), "tenantId is required")

internal/idp/factory.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@ import (
77
)
88

99
// NewProvider creates a Provider based on the IDPConfig.
10-
// allowedDomains configures domain-based access control for all provider types.
11-
func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error) {
10+
func NewProvider(cfg config.IDPConfig) (Provider, error) {
1211
switch cfg.Provider {
1312
case "google":
1413
return NewGoogleProvider(
1514
cfg.ClientID,
1615
string(cfg.ClientSecret),
1716
cfg.RedirectURI,
18-
allowedDomains,
1917
), nil
2018

2119
case "azure":
@@ -24,16 +22,13 @@ func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error
2422
cfg.ClientID,
2523
string(cfg.ClientSecret),
2624
cfg.RedirectURI,
27-
allowedDomains,
2825
)
2926

3027
case "github":
3128
return NewGitHubProvider(
3229
cfg.ClientID,
3330
string(cfg.ClientSecret),
3431
cfg.RedirectURI,
35-
allowedDomains,
36-
cfg.AllowedOrgs,
3732
), nil
3833

3934
case "oidc":
@@ -47,7 +42,6 @@ func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error
4742
ClientSecret: string(cfg.ClientSecret),
4843
RedirectURI: cfg.RedirectURI,
4944
Scopes: cfg.Scopes,
50-
AllowedDomains: allowedDomains,
5145
})
5246

5347
default:

internal/idp/factory_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func TestNewProvider(t *testing.T) {
8787

8888
for _, tt := range tests {
8989
t.Run(tt.name, func(t *testing.T) {
90-
provider, err := NewProvider(tt.cfg, nil)
90+
provider, err := NewProvider(tt.cfg)
9191

9292
if tt.wantErr {
9393
require.Error(t, err)

internal/idp/github.go

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@ import (
1414
// GitHubProvider implements the Provider interface for GitHub OAuth.
1515
// GitHub uses OAuth 2.0 (not OIDC) and has its own API for user info and org membership.
1616
type GitHubProvider struct {
17-
config oauth2.Config
18-
apiBaseURL string // defaults to https://api.github.com, can be overridden for testing
19-
allowedDomains []string // email domains users must belong to (empty = no restriction)
20-
allowedOrgs []string // organizations users must be members of (empty = no restriction)
17+
config oauth2.Config
18+
apiBaseURL string // defaults to https://api.github.com, can be overridden for testing
2119
}
2220

2321
// githubUserResponse represents GitHub's user API response.
@@ -42,7 +40,7 @@ type githubOrgResponse struct {
4240
}
4341

4442
// NewGitHubProvider creates a new GitHub OAuth provider.
45-
func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedDomains, allowedOrgs []string) *GitHubProvider {
43+
func NewGitHubProvider(clientID, clientSecret, redirectURI string) *GitHubProvider {
4644
return &GitHubProvider{
4745
config: oauth2.Config{
4846
ClientID: clientID,
@@ -51,9 +49,7 @@ func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedDomain
5149
Scopes: []string{"user:email", "read:org"},
5250
Endpoint: github.Endpoint,
5351
},
54-
apiBaseURL: "https://api.github.com",
55-
allowedDomains: allowedDomains,
56-
allowedOrgs: allowedOrgs,
52+
apiBaseURL: "https://api.github.com",
5753
}
5854
}
5955

@@ -72,13 +68,11 @@ func (p *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*oauth2
7268
return p.config.Exchange(ctx, code)
7369
}
7470

75-
// UserInfo fetches user information from GitHub's API.
76-
// Validates domain and organization membership based on construction-time config.
77-
// TODO: Consider caching org membership to reduce API calls.
78-
func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) {
71+
// UserInfo fetches user identity from GitHub's API.
72+
// Always fetches organizations so the authorization layer can check membership.
73+
func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) {
7974
client := p.config.Client(ctx, token)
8075

81-
// Fetch user profile
8276
user, err := p.fetchUser(client)
8377
if err != nil {
8478
return nil, err
@@ -99,38 +93,12 @@ func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Us
9993

10094
domain := emailutil.ExtractDomain(email)
10195

102-
// Validate domain if configured
103-
if err := ValidateDomain(domain, p.allowedDomains); err != nil {
104-
return nil, err
105-
}
106-
107-
// Fetch organizations only if org validation is configured
108-
var orgs []string
109-
if len(p.allowedOrgs) > 0 {
110-
orgs, err = p.fetchOrganizations(client)
111-
if err != nil {
112-
return nil, fmt.Errorf("failed to get user organizations: %w", err)
113-
}
114-
115-
// Validate org membership
116-
hasAllowedOrg := false
117-
for _, org := range orgs {
118-
for _, allowed := range p.allowedOrgs {
119-
if org == allowed {
120-
hasAllowedOrg = true
121-
break
122-
}
123-
}
124-
if hasAllowedOrg {
125-
break
126-
}
127-
}
128-
if !hasAllowedOrg {
129-
return nil, fmt.Errorf("user is not a member of any allowed organization. Contact your administrator")
130-
}
96+
orgs, err := p.fetchOrganizations(client)
97+
if err != nil {
98+
return nil, fmt.Errorf("failed to get user organizations: %w", err)
13199
}
132100

133-
return &UserInfo{
101+
return &Identity{
134102
ProviderType: "github",
135103
Subject: fmt.Sprintf("%d", user.ID),
136104
Email: email,

internal/idp/github_test.go

Lines changed: 20 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ import (
1313
)
1414

1515
func TestGitHubProvider_Type(t *testing.T) {
16-
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil)
16+
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback")
1717
assert.Equal(t, "github", provider.Type())
1818
}
1919

2020
func TestGitHubProvider_AuthURL(t *testing.T) {
21-
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil)
21+
provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback")
2222

2323
authURL := provider.AuthURL("test-state")
2424

@@ -33,10 +33,6 @@ func TestGitHubProvider_UserInfo(t *testing.T) {
3333
userResp githubUserResponse
3434
emailsResp []githubEmailResponse
3535
orgsResp []githubOrgResponse
36-
allowedDomains []string
37-
allowedOrgs []string
38-
wantErr bool
39-
errContains string
4036
expectedEmail string
4137
expectedEmailVerified bool
4238
expectedDomain string
@@ -51,10 +47,11 @@ func TestGitHubProvider_UserInfo(t *testing.T) {
5147
Name: "Test User",
5248
AvatarURL: "https://github.com/avatar.jpg",
5349
},
50+
orgsResp: []githubOrgResponse{{Login: "my-org"}},
5451
expectedEmail: "user@company.com",
55-
expectedEmailVerified: true, // Public emails in GitHub profile are verified
52+
expectedEmailVerified: true,
5653
expectedDomain: "company.com",
57-
expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty
54+
expectedOrgs: []string{"my-org"},
5855
},
5956
{
6057
name: "user_without_public_email_fetches_from_api",
@@ -67,10 +64,11 @@ func TestGitHubProvider_UserInfo(t *testing.T) {
6764
{Email: "secondary@other.com", Primary: false, Verified: true},
6865
{Email: "primary@company.com", Primary: true, Verified: true},
6966
},
67+
orgsResp: []githubOrgResponse{},
7068
expectedEmail: "primary@company.com",
7169
expectedEmailVerified: true,
7270
expectedDomain: "company.com",
73-
expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty
71+
expectedOrgs: []string{},
7472
},
7573
{
7674
name: "user_with_unverified_primary_falls_back_to_verified",
@@ -82,72 +80,24 @@ func TestGitHubProvider_UserInfo(t *testing.T) {
8280
{Email: "primary@company.com", Primary: true, Verified: false},
8381
{Email: "verified@company.com", Primary: false, Verified: true},
8482
},
83+
orgsResp: []githubOrgResponse{},
8584
expectedEmail: "verified@company.com",
8685
expectedEmailVerified: true,
8786
expectedDomain: "company.com",
88-
expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty
89-
},
90-
{
91-
name: "domain_validation_success",
92-
userResp: githubUserResponse{
93-
ID: 12345,
94-
Login: "testuser",
95-
Email: "user@company.com",
96-
},
97-
allowedDomains: []string{"company.com"},
98-
expectedEmail: "user@company.com",
99-
expectedEmailVerified: true,
100-
expectedDomain: "company.com",
101-
expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty
87+
expectedOrgs: []string{},
10288
},
10389
{
104-
name: "domain_validation_failure",
105-
userResp: githubUserResponse{
106-
ID: 12345,
107-
Login: "testuser",
108-
Email: "user@other.com",
109-
},
110-
allowedDomains: []string{"company.com"},
111-
wantErr: true,
112-
errContains: "domain 'other.com' is not allowed",
113-
},
114-
{
115-
name: "org_validation_success",
90+
name: "orgs_always_populated",
11691
userResp: githubUserResponse{
11792
ID: 12345,
11893
Login: "testuser",
11994
Email: "user@gmail.com",
12095
},
121-
orgsResp: []githubOrgResponse{{Login: "allowed-org"}, {Login: "other-org"}},
122-
allowedOrgs: []string{"allowed-org"},
96+
orgsResp: []githubOrgResponse{{Login: "org-a"}, {Login: "org-b"}},
12397
expectedEmail: "user@gmail.com",
12498
expectedEmailVerified: true,
12599
expectedDomain: "gmail.com",
126-
expectedOrgs: []string{"allowed-org", "other-org"},
127-
},
128-
{
129-
name: "org_validation_failure",
130-
userResp: githubUserResponse{
131-
ID: 12345,
132-
Login: "testuser",
133-
Email: "user@gmail.com",
134-
},
135-
orgsResp: []githubOrgResponse{{Login: "other-org"}},
136-
allowedOrgs: []string{"required-org"},
137-
wantErr: true,
138-
errContains: "not a member of any allowed organization",
139-
},
140-
{
141-
name: "user_with_no_orgs_restriction",
142-
userResp: githubUserResponse{
143-
ID: 12345,
144-
Login: "testuser",
145-
Email: "user@gmail.com",
146-
},
147-
expectedEmail: "user@gmail.com",
148-
expectedEmailVerified: true,
149-
expectedDomain: "gmail.com",
150-
expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty
100+
expectedOrgs: []string{"org-a", "org-b"},
151101
},
152102
}
153103

@@ -172,7 +122,6 @@ func TestGitHubProvider_UserInfo(t *testing.T) {
172122
}))
173123
defer server.Close()
174124

175-
// Create provider with test server endpoints and allowedOrgs
176125
provider := &GitHubProvider{
177126
config: oauth2.Config{
178127
ClientID: "test-client",
@@ -184,29 +133,19 @@ func TestGitHubProvider_UserInfo(t *testing.T) {
184133
TokenURL: server.URL + "/token",
185134
},
186135
},
187-
apiBaseURL: server.URL,
188-
allowedDomains: tt.allowedDomains,
189-
allowedOrgs: tt.allowedOrgs,
136+
apiBaseURL: server.URL,
190137
}
191138

192139
token := &oauth2.Token{AccessToken: "test-token"}
193-
userInfo, err := provider.UserInfo(context.Background(), token)
194-
195-
if tt.wantErr {
196-
require.Error(t, err)
197-
if tt.errContains != "" {
198-
assert.Contains(t, err.Error(), tt.errContains)
199-
}
200-
return
201-
}
140+
identity, err := provider.UserInfo(context.Background(), token)
202141

203142
require.NoError(t, err)
204-
require.NotNil(t, userInfo)
205-
assert.Equal(t, "github", userInfo.ProviderType)
206-
assert.Equal(t, tt.expectedEmail, userInfo.Email)
207-
assert.Equal(t, tt.expectedEmailVerified, userInfo.EmailVerified)
208-
assert.Equal(t, tt.expectedDomain, userInfo.Domain)
209-
assert.Equal(t, tt.expectedOrgs, userInfo.Organizations)
143+
require.NotNil(t, identity)
144+
assert.Equal(t, "github", identity.ProviderType)
145+
assert.Equal(t, tt.expectedEmail, identity.Email)
146+
assert.Equal(t, tt.expectedEmailVerified, identity.EmailVerified)
147+
assert.Equal(t, tt.expectedDomain, identity.Domain)
148+
assert.Equal(t, tt.expectedOrgs, identity.Organizations)
210149
})
211150
}
212151
}

0 commit comments

Comments
 (0)