diff --git a/go.mod b/go.mod index 74b2be0..f35987e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.24.0 require ( github.com/google/go-cmp v0.6.0 - golang.org/x/net v0.38.0 modernc.org/sqlite v1.19.4 tailscale.com v1.82.5 ) @@ -77,6 +76,7 @@ require ( golang.org/x/crypto v0.36.0 // indirect golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac // indirect golang.org/x/mod v0.23.0 // indirect + golang.org/x/net v0.38.0 // indirect golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/term v0.30.0 // indirect diff --git a/golink.go b/golink.go index 0b98f66..cf8b943 100644 --- a/golink.go +++ b/golink.go @@ -10,7 +10,6 @@ import ( "context" "crypto/rand" "embed" - "encoding/base64" "encoding/json" "errors" "flag" @@ -30,7 +29,6 @@ import ( texttemplate "text/template" "time" - "golang.org/x/net/xsrftoken" "tailscale.com/client/tailscale" "tailscale.com/hostinfo" "tailscale.com/ipn" @@ -41,17 +39,8 @@ import ( const ( defaultHostname = "go" - - // Used as a placeholder short name for generating the XSRF defense token, - // when creating new links. - newShortName = ".new" - - // If the caller sends this header set to a non-empty value, we will allow - // them to make the call even without an XSRF token. JavaScript in browser - // cannot set this header, per the [Fetch Spec]. - // - // [Fetch Spec]: https://fetch.spec.whatwg.org - secHeaderName = "Sec-Golink" + secFetchSite = "Sec-Fetch-Site" + secGolink = "Sec-Golink" ) var ( @@ -211,6 +200,8 @@ out: fqdn := strings.TrimSuffix(status.Self.DNSName, ".") httpHandler := serveHandler() + httpHandler = EnforceSecFetchSiteOrSecGolink(httpHandler) + if enableTLS { httpsHandler := HSTS(httpHandler) httpHandler = redirectHandler(fqdn) @@ -275,7 +266,6 @@ type homeData struct { Short string Long string Clicks []visitData - XSRF string ReadOnly bool } @@ -283,11 +273,8 @@ type homeData struct { type deleteData struct { Short string Long string - XSRF string } -var xsrfKey string - func init() { homeTmpl = newTemplate("base.html", "home.html") detailTmpl = newTemplate("base.html", "detail.html") @@ -299,7 +286,6 @@ func init() { b := make([]byte, 24) rand.Read(b) - xsrfKey = base64.StdEncoding.EncodeToString(b) } var tmplFuncs = template.FuncMap{ @@ -416,6 +402,34 @@ func HSTS(h http.Handler) http.Handler { }) } +// EnforceSecFetchSiteOrSecGolink is a Cross-Site Request Forgery protection +// middleware that validates the Sec-Fetch-Site header for non-idempotent +// requests. It requires clients to send Sec-Fetch-Site set to "same-origin". +// +// It alternatively allows for clients to send the header "Sec-Golink" set to +// any value to maintain compatibility with clients developed against earlier +// versions of golink that relied on xsrf token based CSRF protection. +func EnforceSecFetchSiteOrSecGolink(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "GET", "HEAD", "OPTIONS": // allow idempotent methods + h.ServeHTTP(w, r) + return + } + + // Check for Sec-Fetch-Site header set to "same-origin" + // or Sec-Golink header set to any value for backwards compatibility. + sameOrigin := r.Header.Get(secFetchSite) == "same-origin" + secGolink := r.Header.Get(secGolink) != "" + if sameOrigin || secGolink { + h.ServeHTTP(w, r) + return + } + + http.Error(w, "invalid non `Sec-Fetch-Site: same-origin` request", http.StatusBadRequest) + }) +} + // serverHandler returns the main http.Handler for serving all requests. func serveHandler() http.Handler { mux := http.NewServeMux() @@ -476,16 +490,10 @@ func serveHome(w http.ResponseWriter, r *http.Request, short string) { } } - cu, err := currentUser(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } homeTmpl.Execute(w, homeData{ Short: short, Long: long, Clicks: clicks, - XSRF: xsrftoken.Generate(xsrfKey, cu.login, newShortName), ReadOnly: *readonly, }) } @@ -597,7 +605,6 @@ type detailData struct { // Editable indicates whether the current user can edit the link. Editable bool Link *Link - XSRF string } func serveDetail(w http.ResponseWriter, r *http.Request) { @@ -641,7 +648,6 @@ func serveDetail(w http.ResponseWriter, r *http.Request) { data := detailData{ Link: link, Editable: canEdit, - XSRF: xsrftoken.Generate(xsrfKey, cu.login, link.Short), } if canEdit && !ownerExists { data.Link.Owner = cu.login @@ -829,16 +835,6 @@ func serveDelete(w http.ResponseWriter, r *http.Request) { return } - // Deletion by CLI has never worked because it has always required the XSRF - // token. (Refer to commit c7ac33d04c33743606f6224009a5c73aa0b8dec0.) If we - // want to enable deletion via CLI and to honor allowUnknownUsers for - // deletion, we could change the below to a call to isRequestAuthorized. For - // now, always require the XSRF token, thus maintaining the status quo. - if !xsrftoken.Valid(r.PostFormValue("xsrf"), xsrfKey, cu.login, link.Short) { - http.Error(w, "invalid XSRF token", http.StatusBadRequest) - return - } - if err := db.Delete(short); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -848,7 +844,6 @@ func serveDelete(w http.ResponseWriter, r *http.Request) { deleteTmpl.Execute(w, deleteData{ Short: link.Short, Long: link.Long, - XSRF: xsrftoken.Generate(xsrfKey, cu.login, newShortName), }) } @@ -891,18 +886,6 @@ func serveSave(w http.ResponseWriter, r *http.Request) { return } - // short name to use for XSRF token. - // For new link creation, the special newShortName value is used. - tokenShortName := newShortName - if link != nil { - tokenShortName = link.Short - } - - if !isRequestAuthorized(r, cu, tokenShortName) { - http.Error(w, "invalid XSRF token", http.StatusBadRequest) - return - } - // allow transferring ownership to valid users. If empty, set owner to current user. owner := r.FormValue("owner") if owner != "" { @@ -1077,14 +1060,3 @@ func resolveLink(link *url.URL) (*url.URL, error) { } return dst, err } - -func isRequestAuthorized(r *http.Request, u user, short string) bool { - if *allowUnknownUsers { - return true - } - if r.Header.Get(secHeaderName) != "" { - return true - } - - return xsrftoken.Valid(r.PostFormValue("xsrf"), xsrfKey, u.login, short) -} diff --git a/golink_test.go b/golink_test.go index 5842112..b71c752 100644 --- a/golink_test.go +++ b/golink_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "golang.org/x/net/xsrftoken" "tailscale.com/tstest" "tailscale.com/types/ptr" "tailscale.com/util/must" @@ -23,6 +22,75 @@ func init() { *dev = ":8080" } +func TestEnforceSecFetchSiteOrSecGolink(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := EnforceSecFetchSiteOrSecGolink(mux) + + tests := []struct { + name string + method string + withSameOrigin bool + withCrossSite bool + expectSuccess bool + withSecGolink bool + }{ + { + name: "GET without header succeeds", + method: http.MethodGet, + withSameOrigin: false, + expectSuccess: true, + }, + { + name: "POST without header fails", + method: http.MethodPost, + }, + { + name: "POST with same-origin header succeeds", + method: http.MethodPost, + withSameOrigin: true, + expectSuccess: true, + }, + { + name: "POST with cross-site header fails", + method: http.MethodPost, + withCrossSite: true, + }, + { + name: "POST with sec-golink header succeeds", + method: http.MethodPost, + withSecGolink: true, + expectSuccess: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(tt.method, "/", nil) + if tt.withSameOrigin { + r.Header.Set(secFetchSite, "same-origin") + } + if tt.withCrossSite { + r.Header.Set(secFetchSite, "cross-site") + } + if tt.withSecGolink { + r.Header.Set(secGolink, "true") + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != http.StatusOK && tt.expectSuccess { + t.Errorf("expected status OK, got %d", w.Code) + } else if w.Code == http.StatusOK && !tt.expectSuccess { + t.Errorf("expected non-OK status, got %d", w.Code) + } + }) + } +} + func TestServeGo(t *testing.T) { var err error db, err = NewSQLiteDB(":memory:") @@ -158,17 +226,9 @@ func TestServeSave(t *testing.T) { } db.Save(&Link{Short: "link-owned-by-tagged-devices", Long: "/before", Owner: "tagged-devices"}) - fooXSRF := func(short string) string { - return xsrftoken.Generate(xsrfKey, "foo@example.com", short) - } - barXSRF := func(short string) string { - return xsrftoken.Generate(xsrfKey, "bar@example.com", short) - } - tests := []struct { name string short string - xsrf string long string allowUnknownUsers bool currentUser func(*http.Request) (user, error) @@ -189,14 +249,12 @@ func TestServeSave(t *testing.T) { { name: "save simple link", short: "who", - xsrf: fooXSRF(newShortName), long: "http://who/", wantStatus: http.StatusOK, }, { name: "disallow editing another's link", short: "who", - xsrf: barXSRF("who"), long: "http://who/", currentUser: func(*http.Request) (user, error) { return user{login: "bar@example.com"}, nil }, wantStatus: http.StatusForbidden, @@ -204,7 +262,6 @@ func TestServeSave(t *testing.T) { { name: "allow editing link owned by tagged-devices", short: "link-owned-by-tagged-devices", - xsrf: barXSRF("link-owned-by-tagged-devices"), long: "/after", currentUser: func(*http.Request) (user, error) { return user{login: "bar@example.com"}, nil }, wantStatus: http.StatusOK, @@ -212,7 +269,6 @@ func TestServeSave(t *testing.T) { { name: "admins can edit any link", short: "who", - xsrf: barXSRF("who"), long: "http://who/", currentUser: func(*http.Request) (user, error) { return user{login: "bar@example.com", isAdmin: true}, nil }, wantStatus: http.StatusOK, @@ -220,7 +276,6 @@ func TestServeSave(t *testing.T) { { name: "disallow unknown users", short: "who2", - xsrf: fooXSRF("who2"), long: "http://who/", currentUser: func(*http.Request) (user, error) { return user{}, errors.New("") }, wantStatus: http.StatusInternalServerError, @@ -233,13 +288,6 @@ func TestServeSave(t *testing.T) { currentUser: func(*http.Request) (user, error) { return user{}, nil }, wantStatus: http.StatusOK, }, - { - name: "invalid xsrf", - short: "goat", - xsrf: fooXSRF("sheep"), - long: "https://goat.example.com/goat.php?goat=true", - wantStatus: http.StatusBadRequest, - }, } for _, tt := range tests { @@ -259,7 +307,6 @@ func TestServeSave(t *testing.T) { r := httptest.NewRequest("POST", "/", strings.NewReader(url.Values{ "short": {tt.short}, "long": {tt.long}, - "xsrf": {tt.xsrf}, }.Encode())) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() @@ -282,14 +329,9 @@ func TestServeDelete(t *testing.T) { db.Save(&Link{Short: "foo", Owner: "foo@example.com"}) db.Save(&Link{Short: "link-owned-by-tagged-devices", Long: "/before", Owner: "tagged-devices"}) - xsrf := func(short string) string { - return xsrftoken.Generate(xsrfKey, "foo@example.com", short) - } - tests := []struct { name string short string - xsrf string currentUser func(*http.Request) (user, error) wantStatus int }{ @@ -311,28 +353,14 @@ func TestServeDelete(t *testing.T) { { name: "allow deleting link owned by tagged-devices", short: "link-owned-by-tagged-devices", - xsrf: xsrf("link-owned-by-tagged-devices"), wantStatus: http.StatusOK, }, { name: "admin can delete unowned link", short: "a", currentUser: func(*http.Request) (user, error) { return user{login: "foo@example.com", isAdmin: true}, nil }, - xsrf: xsrf("a"), wantStatus: http.StatusOK, }, - { - name: "invalid xsrf", - short: "foo", - xsrf: xsrf("invalid"), - wantStatus: http.StatusBadRequest, - }, - { - name: "valid xsrf", - short: "foo", - xsrf: xsrf("foo"), - wantStatus: http.StatusOK, - }, } for _, tt := range tests { @@ -345,9 +373,7 @@ func TestServeDelete(t *testing.T) { }) } - r := httptest.NewRequest("POST", "/.delete/"+tt.short, strings.NewReader(url.Values{ - "xsrf": {tt.xsrf}, - }.Encode())) + r := httptest.NewRequest("POST", "/.delete/"+tt.short, nil) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() serveDelete(w, r)