Skip to content

Commit 31b1ce1

Browse files
committed
add tests for serve methods
Also check error returned from tmpl.Execute. Refactor currentUser to make the logic a little simpler, and make it a package var for easier testing. Signed-off-by: Will Norris <[email protected]>
1 parent b9fdc2d commit 31b1ce1

File tree

2 files changed

+177
-14
lines changed

2 files changed

+177
-14
lines changed

golink.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ func serveGo(w http.ResponseWriter, r *http.Request) {
346346

347347
link, err := db.Load(short)
348348
if errors.Is(err, fs.ErrNotExist) {
349+
w.WriteHeader(http.StatusNotFound)
349350
serveHome(w, short)
350351
return
351352
}
@@ -467,7 +468,9 @@ func expandLink(long string, env expandEnv) (string, error) {
467468
return "", err
468469
}
469470
buf := new(bytes.Buffer)
470-
tmpl.Execute(buf, env)
471+
if err := tmpl.Execute(buf, env); err != nil {
472+
return "", err
473+
}
471474
long = buf.String()
472475

473476
_, err = url.Parse(long)
@@ -479,19 +482,18 @@ func expandLink(long string, env expandEnv) (string, error) {
479482

480483
func devMode() bool { return *dev != "" }
481484

482-
func currentUser(r *http.Request) (string, error) {
483-
login := ""
485+
// currentUser returns the Tailscale user associated with the request.
486+
// In most cases, this will be the user that owns the device that made the request.
487+
// For tagged devices, the value "tagged-devices" is returned.
488+
var currentUser = func(r *http.Request) (string, error) {
484489
if devMode() {
485-
486-
} else {
487-
res, err := localClient.WhoIs(r.Context(), r.RemoteAddr)
488-
if err != nil {
489-
return "", err
490-
}
491-
login = res.UserProfile.LoginName
490+
return "[email protected]", nil
492491
}
493-
return login, nil
494-
492+
whois, err := localClient.WhoIs(r.Context(), r.RemoteAddr)
493+
if err != nil {
494+
return "", err
495+
}
496+
return whois.UserProfile.LoginName, nil
495497
}
496498

497499
// userExists returns whether a user exists with the specified login in the current tailnet.

golink_test.go

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,173 @@
44
package golink
55

66
import (
7+
"errors"
8+
"net/http"
9+
"net/http/httptest"
10+
"net/url"
11+
"strings"
712
"testing"
813
"time"
914
)
1015

16+
func init() {
17+
// tests always need golink to be run in dev mode
18+
*dev = ":8080"
19+
}
20+
21+
func TestServeGo(t *testing.T) {
22+
var err error
23+
db, err = NewSQLiteDB(":memory:")
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
db.Save(&Link{Short: "who", Long: "http://who/"})
28+
db.Save(&Link{Short: "me", Long: "/who/{{.User}}"})
29+
db.Save(&Link{Short: "invalid-var", Long: "/who/{{.Invalid}}"})
30+
31+
tests := []struct {
32+
name string
33+
link string
34+
currentUser func(*http.Request) (string, error)
35+
wantStatus int
36+
wantLink string
37+
}{
38+
{
39+
name: "simple link",
40+
link: "/who",
41+
wantStatus: http.StatusFound,
42+
wantLink: "http://who/",
43+
},
44+
{
45+
name: "simple link, anonymous request",
46+
link: "/who",
47+
currentUser: func(*http.Request) (string, error) { return "", nil },
48+
wantStatus: http.StatusFound,
49+
wantLink: "http://who/",
50+
},
51+
{
52+
name: "user link",
53+
link: "/me",
54+
wantStatus: http.StatusFound,
55+
wantLink: "/who/[email protected]",
56+
},
57+
{
58+
name: "unknown link",
59+
link: "/does-not-exist",
60+
wantStatus: http.StatusNotFound,
61+
},
62+
{
63+
name: "unknown variable",
64+
link: "/invalid-var",
65+
wantStatus: http.StatusInternalServerError,
66+
},
67+
}
68+
69+
for _, tt := range tests {
70+
t.Run(tt.name, func(t *testing.T) {
71+
if tt.currentUser != nil {
72+
oldCurrentUser := currentUser
73+
currentUser = tt.currentUser
74+
t.Cleanup(func() {
75+
currentUser = oldCurrentUser
76+
})
77+
}
78+
79+
r := httptest.NewRequest("GET", tt.link, nil)
80+
w := httptest.NewRecorder()
81+
serveGo(w, r)
82+
83+
if w.Code != tt.wantStatus {
84+
t.Errorf("serveGo(%q) = %d; want %d", tt.link, w.Code, tt.wantStatus)
85+
}
86+
if gotLink := w.Header().Get("Location"); gotLink != tt.wantLink {
87+
t.Errorf("serveGo(%q) = %q; want %q", tt.link, gotLink, tt.wantLink)
88+
}
89+
})
90+
}
91+
}
92+
93+
func TestServeSave(t *testing.T) {
94+
var err error
95+
db, err = NewSQLiteDB(":memory:")
96+
if err != nil {
97+
t.Fatal(err)
98+
}
99+
100+
tests := []struct {
101+
name string
102+
short string
103+
long string
104+
currentUser func(*http.Request) (string, error)
105+
wantStatus int
106+
}{
107+
{
108+
name: "missing short",
109+
short: "",
110+
long: "http://who/",
111+
wantStatus: http.StatusBadRequest,
112+
},
113+
{
114+
name: "missing long",
115+
short: "",
116+
long: "http://who/",
117+
wantStatus: http.StatusBadRequest,
118+
},
119+
{
120+
name: "save simple link",
121+
short: "who",
122+
long: "http://who/",
123+
wantStatus: http.StatusOK,
124+
},
125+
{
126+
name: "disallow editing another's link",
127+
short: "who",
128+
long: "http://who/",
129+
currentUser: func(*http.Request) (string, error) { return "[email protected]", nil },
130+
wantStatus: http.StatusForbidden,
131+
},
132+
{
133+
name: "disallow unknown users",
134+
short: "who2",
135+
long: "http://who/",
136+
currentUser: func(*http.Request) (string, error) { return "", errors.New("") },
137+
wantStatus: http.StatusInternalServerError,
138+
},
139+
}
140+
141+
for _, tt := range tests {
142+
t.Run(tt.name, func(t *testing.T) {
143+
if tt.currentUser != nil {
144+
oldCurrentUser := currentUser
145+
currentUser = tt.currentUser
146+
t.Cleanup(func() {
147+
currentUser = oldCurrentUser
148+
})
149+
}
150+
151+
r := httptest.NewRequest("POST", "/", strings.NewReader(url.Values{
152+
"short": {tt.short},
153+
"long": {tt.long},
154+
}.Encode()))
155+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
156+
w := httptest.NewRecorder()
157+
serveSave(w, r)
158+
159+
if w.Code != tt.wantStatus {
160+
t.Errorf("serveSave(%q, %q) = %d; want %d", tt.short, tt.long, w.Code, tt.wantStatus)
161+
}
162+
})
163+
}
164+
}
165+
11166
func TestExpandLink(t *testing.T) {
12167
tests := []struct {
13168
name string // test name
14169
long string // long URL for golink
15170
now time.Time // current time
16171
user string // current user resolving link
17172
remainder string // remainder of URL path after golink name
173+
wantErr bool // whether we expect an error
18174
want string // expected redirect URL
19175
}{
20176
{
@@ -52,6 +208,11 @@ func TestExpandLink(t *testing.T) {
52208
53209
want: "http://host.com/[email protected]",
54210
},
211+
{
212+
name: "unknown-field",
213+
long: `http://host.com/{{.Foo}}`,
214+
wantErr: true,
215+
},
55216
{
56217
name: "template-no-path",
57218
long: "https://calendar.google.com/{{with .Path}}calendar/embed?mode=week&src={{.}}@tailscale.com{{end}}",
@@ -85,8 +246,8 @@ func TestExpandLink(t *testing.T) {
85246
for _, tt := range tests {
86247
t.Run(tt.name, func(t *testing.T) {
87248
got, err := expandLink(tt.long, expandEnv{Now: tt.now, Path: tt.remainder, User: tt.user})
88-
if err != nil {
89-
t.Fatalf("expandLink(%q): %v", tt.long, err)
249+
if (err != nil) != tt.wantErr {
250+
t.Fatalf("expandLink(%q) returned error %v; want %v", tt.long, err, tt.wantErr)
90251
}
91252
if got != tt.want {
92253
t.Errorf("expandLink(%q) = %q; want %q", tt.long, got, tt.want)

0 commit comments

Comments
 (0)