|
4 | 4 | package golink
|
5 | 5 |
|
6 | 6 | import (
|
| 7 | + "errors" |
| 8 | + "net/http" |
| 9 | + "net/http/httptest" |
| 10 | + "net/url" |
| 11 | + "strings" |
7 | 12 | "testing"
|
8 | 13 | "time"
|
9 | 14 | )
|
10 | 15 |
|
| 16 | +func init() { |
| 17 | + // tests always need golink to be run in dev mode |
| 18 | + *dev = ":8080" |
| 19 | +} |
| 20 | + |
| 21 | +func TestServeGo(t *testing.T) { |
| 22 | + var err error |
| 23 | + db, err = NewSQLiteDB(":memory:") |
| 24 | + if err != nil { |
| 25 | + t.Fatal(err) |
| 26 | + } |
| 27 | + db.Save(&Link{Short: "who", Long: "http://who/"}) |
| 28 | + db.Save(&Link{Short: "me", Long: "/who/{{.User}}"}) |
| 29 | + db.Save(&Link{Short: "invalid-var", Long: "/who/{{.Invalid}}"}) |
| 30 | + |
| 31 | + tests := []struct { |
| 32 | + name string |
| 33 | + link string |
| 34 | + currentUser func(*http.Request) (string, error) |
| 35 | + wantStatus int |
| 36 | + wantLink string |
| 37 | + }{ |
| 38 | + { |
| 39 | + name: "simple link", |
| 40 | + link: "/who", |
| 41 | + wantStatus: http.StatusFound, |
| 42 | + wantLink: "http://who/", |
| 43 | + }, |
| 44 | + { |
| 45 | + name: "simple link, anonymous request", |
| 46 | + link: "/who", |
| 47 | + currentUser: func(*http.Request) (string, error) { return "", nil }, |
| 48 | + wantStatus: http.StatusFound, |
| 49 | + wantLink: "http://who/", |
| 50 | + }, |
| 51 | + { |
| 52 | + name: "user link", |
| 53 | + link: "/me", |
| 54 | + wantStatus: http.StatusFound, |
| 55 | + wantLink: "/who/[email protected]", |
| 56 | + }, |
| 57 | + { |
| 58 | + name: "unknown link", |
| 59 | + link: "/does-not-exist", |
| 60 | + wantStatus: http.StatusNotFound, |
| 61 | + }, |
| 62 | + { |
| 63 | + name: "unknown variable", |
| 64 | + link: "/invalid-var", |
| 65 | + wantStatus: http.StatusInternalServerError, |
| 66 | + }, |
| 67 | + } |
| 68 | + |
| 69 | + for _, tt := range tests { |
| 70 | + t.Run(tt.name, func(t *testing.T) { |
| 71 | + if tt.currentUser != nil { |
| 72 | + oldCurrentUser := currentUser |
| 73 | + currentUser = tt.currentUser |
| 74 | + t.Cleanup(func() { |
| 75 | + currentUser = oldCurrentUser |
| 76 | + }) |
| 77 | + } |
| 78 | + |
| 79 | + r := httptest.NewRequest("GET", tt.link, nil) |
| 80 | + w := httptest.NewRecorder() |
| 81 | + serveGo(w, r) |
| 82 | + |
| 83 | + if w.Code != tt.wantStatus { |
| 84 | + t.Errorf("serveGo(%q) = %d; want %d", tt.link, w.Code, tt.wantStatus) |
| 85 | + } |
| 86 | + if gotLink := w.Header().Get("Location"); gotLink != tt.wantLink { |
| 87 | + t.Errorf("serveGo(%q) = %q; want %q", tt.link, gotLink, tt.wantLink) |
| 88 | + } |
| 89 | + }) |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +func TestServeSave(t *testing.T) { |
| 94 | + var err error |
| 95 | + db, err = NewSQLiteDB(":memory:") |
| 96 | + if err != nil { |
| 97 | + t.Fatal(err) |
| 98 | + } |
| 99 | + |
| 100 | + tests := []struct { |
| 101 | + name string |
| 102 | + short string |
| 103 | + long string |
| 104 | + currentUser func(*http.Request) (string, error) |
| 105 | + wantStatus int |
| 106 | + }{ |
| 107 | + { |
| 108 | + name: "missing short", |
| 109 | + short: "", |
| 110 | + long: "http://who/", |
| 111 | + wantStatus: http.StatusBadRequest, |
| 112 | + }, |
| 113 | + { |
| 114 | + name: "missing long", |
| 115 | + short: "", |
| 116 | + long: "http://who/", |
| 117 | + wantStatus: http.StatusBadRequest, |
| 118 | + }, |
| 119 | + { |
| 120 | + name: "save simple link", |
| 121 | + short: "who", |
| 122 | + long: "http://who/", |
| 123 | + wantStatus: http.StatusOK, |
| 124 | + }, |
| 125 | + { |
| 126 | + name: "disallow editing another's link", |
| 127 | + short: "who", |
| 128 | + long: "http://who/", |
| 129 | + currentUser: func( *http. Request) ( string, error) { return "[email protected]", nil }, |
| 130 | + wantStatus: http.StatusForbidden, |
| 131 | + }, |
| 132 | + { |
| 133 | + name: "disallow unknown users", |
| 134 | + short: "who2", |
| 135 | + long: "http://who/", |
| 136 | + currentUser: func(*http.Request) (string, error) { return "", errors.New("") }, |
| 137 | + wantStatus: http.StatusInternalServerError, |
| 138 | + }, |
| 139 | + } |
| 140 | + |
| 141 | + for _, tt := range tests { |
| 142 | + t.Run(tt.name, func(t *testing.T) { |
| 143 | + if tt.currentUser != nil { |
| 144 | + oldCurrentUser := currentUser |
| 145 | + currentUser = tt.currentUser |
| 146 | + t.Cleanup(func() { |
| 147 | + currentUser = oldCurrentUser |
| 148 | + }) |
| 149 | + } |
| 150 | + |
| 151 | + r := httptest.NewRequest("POST", "/", strings.NewReader(url.Values{ |
| 152 | + "short": {tt.short}, |
| 153 | + "long": {tt.long}, |
| 154 | + }.Encode())) |
| 155 | + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 156 | + w := httptest.NewRecorder() |
| 157 | + serveSave(w, r) |
| 158 | + |
| 159 | + if w.Code != tt.wantStatus { |
| 160 | + t.Errorf("serveSave(%q, %q) = %d; want %d", tt.short, tt.long, w.Code, tt.wantStatus) |
| 161 | + } |
| 162 | + }) |
| 163 | + } |
| 164 | +} |
| 165 | + |
11 | 166 | func TestExpandLink(t *testing.T) {
|
12 | 167 | tests := []struct {
|
13 | 168 | name string // test name
|
14 | 169 | long string // long URL for golink
|
15 | 170 | now time.Time // current time
|
16 | 171 | user string // current user resolving link
|
17 | 172 | remainder string // remainder of URL path after golink name
|
| 173 | + wantErr bool // whether we expect an error |
18 | 174 | want string // expected redirect URL
|
19 | 175 | }{
|
20 | 176 | {
|
@@ -52,6 +208,11 @@ func TestExpandLink(t *testing.T) {
|
52 | 208 |
|
53 | 209 | want: "http://host.com/[email protected]",
|
54 | 210 | },
|
| 211 | + { |
| 212 | + name: "unknown-field", |
| 213 | + long: `http://host.com/{{.Foo}}`, |
| 214 | + wantErr: true, |
| 215 | + }, |
55 | 216 | {
|
56 | 217 | name: "template-no-path",
|
57 | 218 | long: "https://calendar.google.com/{{with .Path}}calendar/embed?mode=week&src={{.}}@tailscale.com{{end}}",
|
@@ -85,8 +246,8 @@ func TestExpandLink(t *testing.T) {
|
85 | 246 | for _, tt := range tests {
|
86 | 247 | t.Run(tt.name, func(t *testing.T) {
|
87 | 248 | got, err := expandLink(tt.long, expandEnv{Now: tt.now, Path: tt.remainder, User: tt.user})
|
88 |
| - if err != nil { |
89 |
| - t.Fatalf("expandLink(%q): %v", tt.long, err) |
| 249 | + if (err != nil) != tt.wantErr { |
| 250 | + t.Fatalf("expandLink(%q) returned error %v; want %v", tt.long, err, tt.wantErr) |
90 | 251 | }
|
91 | 252 | if got != tt.want {
|
92 | 253 | t.Errorf("expandLink(%q) = %q; want %q", tt.long, got, tt.want)
|
|
0 commit comments