diff --git a/internal/auth.go b/internal/auth.go index 9b8f0b16..7aa0f28b 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -135,28 +135,43 @@ func returnUrl(r *http.Request) string { // Get oauth redirect uri func redirectUri(r *http.Request) string { - if use, _ := useAuthDomain(r); use { + if use, authHost, _ := useAuthDomain(r); use { p := r.Header.Get("X-Forwarded-Proto") - return fmt.Sprintf("%s://%s%s", p, config.AuthHost, config.Path) + return fmt.Sprintf("%s://%s%s", p, authHost, config.Path) } return fmt.Sprintf("%s%s", redirectBase(r), config.Path) } // Should we use auth host + what it is -func useAuthDomain(r *http.Request) (bool, string) { - if config.AuthHost == "" { - return false, "" +func useAuthDomain(r *http.Request) (bool, string, string) { + if len(config.AuthHosts) == 0 { + return false, "", "" } // Does the request match a given cookie domain? reqMatch, reqHost := matchCookieDomains(r.Host) // Do any of the auth hosts match a cookie domain? - authMatch, authHost := matchCookieDomains(config.AuthHost) + authMatch, authHost := matchAuthHosts(reqHost) // We need both to match the same domain - return reqMatch && authMatch && reqHost == authHost, reqHost + return reqMatch && authMatch, authHost, reqHost +} + +// Return matching auth host domain if exists +func matchAuthHosts(domain string) (bool, string) { + // Remove port + p := strings.Split(domain, ":") + + for _, d := range config.AuthHosts { + // Subdomain match? + if len(d) >= len(domain) && d[len(d)-len(domain):] == domain { + return true, d + } + } + + return false, p[0] } // Cookie methods @@ -287,7 +302,7 @@ func cookieDomain(r *http.Request) string { // Cookie domain func csrfCookieDomain(r *http.Request) string { var host string - if use, domain := useAuthDomain(r); use { + if use, _, domain := useAuthDomain(r); use { host = domain } else { host = r.Host diff --git a/internal/auth_test.go b/internal/auth_test.go index 74e8d2f2..3915f6df 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -215,7 +215,7 @@ func TestRedirectUri(t *testing.T) { // With Auth URL but no matching cookie domain // - will not use auth host // - config.AuthHost = "auth.example.com" + config.AuthHosts = CommaSeparatedList{"auth.example.com"} uri, err = url.Parse(redirectUri(r)) assert.Nil(err) @@ -226,7 +226,7 @@ func TestRedirectUri(t *testing.T) { // // With correct Auth URL + cookie domain // - config.AuthHost = "auth.example.com" + config.AuthHosts = CommaSeparatedList{"auth.example.com"} config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} // Check url @@ -243,7 +243,7 @@ func TestRedirectUri(t *testing.T) { r = httptest.NewRequest("GET", "https://another.com/hello", nil) r.Header.Add("X-Forwarded-Proto", "https") - config.AuthHost = "auth.example.com" + config.AuthHosts = CommaSeparatedList{"auth.example.com"} config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} // Check url @@ -252,6 +252,71 @@ func TestRedirectUri(t *testing.T) { assert.Equal("https", uri.Scheme) assert.Equal("another.com", uri.Host) assert.Equal("/_oauth", uri.Path) + + // + // With correct Auth URL + cookie domain, multiple AuthHosts and cookie domains + // - will use matching authHost + // + config.AuthHosts = CommaSeparatedList{"auth.example.com", "auth.another.com"} + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com"), *NewCookieDomain("another.com")} + + uri, err = url.Parse(redirectUri(r)) + assert.Nil(err) + assert.Equal("https", uri.Scheme) + assert.Equal("auth.another.com", uri.Host) + assert.Equal("/_oauth", uri.Path) + + // + // With correct Auth URL + no cookie domains + // - will not use authHost + // + config.AuthHosts = CommaSeparatedList{"auth.example.com", "auth.another.com"} + config.CookieDomains = []CookieDomain{} + + uri, err = url.Parse(redirectUri(r)) + assert.Nil(err) + assert.Equal("https", uri.Scheme) + assert.Equal("another.com", uri.Host) + assert.Equal("/_oauth", uri.Path) + + // + // With correct Auth URL + no matching cookie domains + // - will not use authHost + // + config.AuthHosts = CommaSeparatedList{"auth.example.com", "auth.another.com"} + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com"), *NewCookieDomain("another.example")} + + uri, err = url.Parse(redirectUri(r)) + assert.Nil(err) + assert.Equal("https", uri.Scheme) + assert.Equal("another.com", uri.Host) + assert.Equal("/_oauth", uri.Path) + + // + // With no matching Auth Host + matching cookie domains + // - will not use authHost + // + config.AuthHosts = CommaSeparatedList{"auth.example.com", "auth.another.example"} + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com"), *NewCookieDomain("another.com")} + + uri, err = url.Parse(redirectUri(r)) + assert.Nil(err) + assert.Equal("https", uri.Scheme) + assert.Equal("another.com", uri.Host) + assert.Equal("/_oauth", uri.Path) + + // + // With no matching Auth Host + no matching cookie domains + // - will not use authHost + // + config.AuthHosts = CommaSeparatedList{"auth.example.com", "auth.another.example"} + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com"), *NewCookieDomain("another.example")} + + uri, err = url.Parse(redirectUri(r)) + assert.Nil(err) + assert.Equal("https", uri.Scheme) + assert.Equal("another.com", uri.Host) + assert.Equal("/_oauth", uri.Path) } func TestAuthMakeCookie(t *testing.T) { @@ -298,7 +363,7 @@ func TestAuthMakeCSRFCookie(t *testing.T) { assert.Equal("app.example.com", c.Domain) // With cookie domain and auth url - config.AuthHost = "auth.example.com" + config.AuthHosts = CommaSeparatedList{"auth.example.com"} config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} c = MakeCSRFCookie(r, "12333378901234567890123456789012") assert.Equal("_forward_auth_csrf_123333", c.Name) diff --git a/internal/config.go b/internal/config.go index 840fb6dc..9634e4a9 100644 --- a/internal/config.go +++ b/internal/config.go @@ -24,7 +24,7 @@ type Config struct { LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"` LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"` - AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"` + AuthHosts CommaSeparatedList `long:"auth-host" env:"AUTH_HOST" env-delim:"," description:"Single host to use when returning from 3rd party auth"` Config func(s string) error `long:"config" env:"CONFIG" description:"Path to config file" json:"-"` CookieDomains []CookieDomain `long:"cookie-domain" env:"COOKIE_DOMAIN" env-delim:"," description:"Domain to set auth cookie on, can be set multiple times"` InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"` diff --git a/internal/config_test.go b/internal/config_test.go index 27b8fdc8..a719033e 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -24,7 +24,7 @@ func TestConfigDefaults(t *testing.T) { assert.Equal("warn", c.LogLevel) assert.Equal("text", c.LogFormat) - assert.Equal("", c.AuthHost) + assert.Len(c.AuthHosts, 0) assert.Len(c.CookieDomains, 0) assert.False(c.InsecureCookie) assert.Equal("_forward_auth", c.CookieName) @@ -197,7 +197,7 @@ func TestConfigFileBackwardsCompatability(t *testing.T) { require.Nil(t, err) assert.Equal("/two", c.Path, "variable in legacy config file should be read") - assert.Equal("auth.legacy.com", c.AuthHost, "variable in legacy config file should be read") + assert.Equal(CommaSeparatedList{"auth.legacy.com"}, c.AuthHosts, "variable in legacy config file should be read") } func TestConfigParseEnvironment(t *testing.T) {