Skip to content

Commit 8ac87f5

Browse files
committed
golink: do not set HSTS headers when serving non-FQDN origins
Inspect the `Host` header to ensure that we do not return HSTS headers for short domains which can lead to some clients pinning short domains to endpoints with invalid certificates. Signed-off-by: Patrick O'Doherty <[email protected]>
1 parent 7917da1 commit 8ac87f5

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

golink.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,17 @@ func redirectHandler(hostname string) http.Handler {
334334
}
335335

336336
// HSTS wraps the provided handler and sets Strict-Transport-Security header on
337-
// all responses.
337+
// responses. It inspects the Host header to ensure we do not specify HSTS
338+
// response on non fully qualified domain name origins.
338339
func HSTS(h http.Handler) http.Handler {
339340
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
340-
w.Header().Set("Strict-Transport-Security", "max-age=31536000")
341+
host, found := r.Header["Host"]
342+
if found {
343+
host := host[0]
344+
if strings.Contains(host, ".") {
345+
w.Header().Set("Strict-Transport-Security", "max-age=31536000")
346+
}
347+
}
341348
h.ServeHTTP(w, r)
342349
})
343350
}

golink_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,41 @@ func TestResolveLink(t *testing.T) {
524524
})
525525
}
526526
}
527+
528+
func TestNoHSTSShortDomain(t *testing.T) {
529+
var err error
530+
db, err = NewSQLiteDB(":memory:")
531+
if err != nil {
532+
t.Fatal(err)
533+
}
534+
db.Save(&Link{Short: "foobar", Long: "http://foobar/"})
535+
536+
tests := []struct {
537+
host string
538+
expectHsts bool
539+
}{
540+
{
541+
host: "go",
542+
expectHsts: false,
543+
},
544+
{
545+
host: "go.prawn-universe.ts.net",
546+
expectHsts: true,
547+
},
548+
}
549+
for _, tt := range tests {
550+
name := "HSTS: " + tt.host
551+
t.Run(name, func(t *testing.T) {
552+
r := httptest.NewRequest("GET", "/foobar", nil)
553+
r.Header.Add("Host", tt.host)
554+
555+
w := httptest.NewRecorder()
556+
HSTS(serveHandler()).ServeHTTP(w, r)
557+
558+
_, found := w.Header()["Strict-Transport-Security"]
559+
if found != tt.expectHsts {
560+
t.Errorf("HSTS expectation: domain %s want: %t got: %t", tt.host, tt.expectHsts, found)
561+
}
562+
})
563+
}
564+
}

0 commit comments

Comments
 (0)