Skip to content

Commit 42bc9ab

Browse files
authored
fix(oauth2): switch to Origin header for request validation (#2174)
## Summary Replace `Referer` header validation with `Origin` header in `validateRequestOrigin` function. Referer was unintended at the beginning. Using `Origin` header as it's automatically set by browsers for cross-origin requests. This is a very basic attempt to prevent cross domain malicious calls. Key changes: - Check `Origin` header instead of `Referer` header - Allow empty `Origin` header (for backend/mobile app initiated requests) - Add comprehensive test coverage for new validation behavior
1 parent 060a992 commit 42bc9ab

File tree

2 files changed

+194
-5
lines changed

2 files changed

+194
-5
lines changed

internal/api/oauthserver/authorize.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,15 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro
369369

370370
// validateRequestOrigin checks if the request is coming from an authorized origin
371371
func (s *Server) validateRequestOrigin(r *http.Request) error {
372-
// Check referer header
373-
referer := r.Referer()
374-
if referer == "" {
375-
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "request must originate from authorized domain")
372+
// Check Origin header
373+
// browsers add this header by default, we can at least prevent some basic cross-origin attacks
374+
origin := r.Header.Get("Origin")
375+
if origin == "" {
376+
// Empty Origin header is ok (e.g., for backend-originated requests or mobile apps)
377+
return nil
376378
}
377379

378-
if !utilities.IsRedirectURLValid(s.config, referer) {
380+
if !utilities.IsRedirectURLValid(s.config, origin) {
379381
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "unauthorized request origin")
380382
}
381383

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package oauthserver
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/gobwas/glob"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
"github.com/supabase/auth/internal/conf"
12+
"github.com/supabase/auth/internal/hooks/v0hooks"
13+
"github.com/supabase/auth/internal/storage/test"
14+
"github.com/supabase/auth/internal/tokens"
15+
)
16+
17+
func TestValidateRequestOrigin(t *testing.T) {
18+
// Setup test configuration
19+
globalConfig, err := conf.LoadGlobal(oauthServerTestConfig)
20+
require.NoError(t, err)
21+
22+
// Set up test site URL for validation
23+
globalConfig.SiteURL = "https://example.com"
24+
globalConfig.URIAllowList = []string{
25+
"http://localhost:3000",
26+
"https://app.example.com",
27+
}
28+
29+
// Set up URIAllowListMap manually for testing
30+
globalConfig.URIAllowListMap = make(map[string]glob.Glob)
31+
for _, uri := range globalConfig.URIAllowList {
32+
g := glob.MustCompile(uri, '.', '/')
33+
globalConfig.URIAllowListMap[uri] = g
34+
}
35+
36+
conn, err := test.SetupDBConnection(globalConfig)
37+
require.NoError(t, err)
38+
defer conn.Close()
39+
40+
hooksMgr := &v0hooks.Manager{}
41+
tokenService := tokens.NewService(globalConfig, hooksMgr)
42+
server := NewServer(globalConfig, conn, tokenService)
43+
44+
tests := []struct {
45+
name string
46+
originHeader string
47+
expectError bool
48+
errorMessage string
49+
}{
50+
{
51+
name: "Empty Origin header should be allowed",
52+
originHeader: "",
53+
expectError: false,
54+
},
55+
{
56+
name: "Valid Origin matching site URL should be allowed",
57+
originHeader: "https://example.com",
58+
expectError: false,
59+
},
60+
{
61+
name: "Valid Origin with different path should be allowed",
62+
originHeader: "https://example.com/some/path",
63+
expectError: false,
64+
},
65+
{
66+
name: "Valid Origin matching allow list should be allowed",
67+
originHeader: "https://app.example.com",
68+
expectError: false,
69+
},
70+
{
71+
name: "Valid Origin with localhost should be allowed",
72+
originHeader: "http://localhost:3000",
73+
expectError: false,
74+
},
75+
{
76+
name: "Invalid Origin should be rejected",
77+
originHeader: "https://malicious.com",
78+
expectError: true,
79+
errorMessage: "unauthorized request origin",
80+
},
81+
{
82+
name: "Invalid Origin with IP address should be rejected",
83+
originHeader: "https://192.168.1.1",
84+
expectError: true,
85+
errorMessage: "unauthorized request origin",
86+
},
87+
{
88+
name: "Valid loopback IP should be allowed",
89+
originHeader: "http://127.0.0.1:3000",
90+
expectError: false,
91+
},
92+
{
93+
name: "Invalid Origin format should be rejected",
94+
originHeader: "not-a-valid-url",
95+
expectError: true,
96+
errorMessage: "unauthorized request origin",
97+
},
98+
}
99+
100+
for _, tt := range tests {
101+
t.Run(tt.name, func(t *testing.T) {
102+
// Create a test request
103+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
104+
105+
// Set Origin header if provided
106+
if tt.originHeader != "" {
107+
req.Header.Set("Origin", tt.originHeader)
108+
}
109+
110+
// Call validateRequestOrigin
111+
err := server.validateRequestOrigin(req)
112+
113+
if tt.expectError {
114+
assert.Error(t, err)
115+
if tt.errorMessage != "" {
116+
assert.Contains(t, err.Error(), tt.errorMessage)
117+
}
118+
} else {
119+
assert.NoError(t, err)
120+
}
121+
})
122+
}
123+
}
124+
125+
func TestValidateRequestOriginEdgeCases(t *testing.T) {
126+
globalConfig, err := conf.LoadGlobal(oauthServerTestConfig)
127+
require.NoError(t, err)
128+
129+
globalConfig.SiteURL = "https://example.com"
130+
131+
conn, err := test.SetupDBConnection(globalConfig)
132+
require.NoError(t, err)
133+
defer conn.Close()
134+
135+
hooksMgr := &v0hooks.Manager{}
136+
tokenService := tokens.NewService(globalConfig, hooksMgr)
137+
server := NewServer(globalConfig, conn, tokenService)
138+
139+
t.Run("Origin with different port should be allowed (hostname matching)", func(t *testing.T) {
140+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
141+
req.Header.Set("Origin", "https://example.com:8080")
142+
143+
// Should pass because hostname matches (IsRedirectURLValid allows different ports)
144+
err := server.validateRequestOrigin(req)
145+
assert.NoError(t, err)
146+
})
147+
148+
t.Run("Case sensitivity in Origin header", func(t *testing.T) {
149+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
150+
req.Header.Set("Origin", "https://EXAMPLE.COM")
151+
152+
// Should fail because hostname comparison is case-sensitive in URL parsing
153+
err := server.validateRequestOrigin(req)
154+
assert.Error(t, err)
155+
})
156+
157+
t.Run("Origin with trailing slash should be handled", func(t *testing.T) {
158+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
159+
req.Header.Set("Origin", "https://example.com/")
160+
161+
// Should pass - URL parsing should handle trailing slash correctly
162+
err := server.validateRequestOrigin(req)
163+
assert.NoError(t, err)
164+
})
165+
166+
t.Run("Multiple Origin headers uses first one", func(t *testing.T) {
167+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
168+
// Add multiple Origin headers
169+
req.Header.Add("Origin", "https://example.com") // First header (valid)
170+
req.Header.Add("Origin", "https://malicious.com") // Second header (invalid)
171+
172+
// Go's http.Header.Get() returns only the first header value
173+
// So this should pass because first Origin is valid
174+
err := server.validateRequestOrigin(req)
175+
assert.NoError(t, err)
176+
})
177+
178+
t.Run("Comma-separated origins in single header should fail", func(t *testing.T) {
179+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
180+
// Manually create comma-separated Origin header (malformed)
181+
req.Header.Set("Origin", "https://example.com,https://malicious.com")
182+
183+
// This should fail because comma-separated origins is not a valid Origin header format
184+
err := server.validateRequestOrigin(req)
185+
assert.Error(t, err)
186+
})
187+
}

0 commit comments

Comments
 (0)