Skip to content

Commit 9381dc2

Browse files
authored
use go context instead of garilla
* refactor(keys): use subst in sql * project_vault_token * refactor(project): add project service * refactor(api): use helpers context * feat(api): use builtin context * feat: remove gorilla/context * feat(project): add unit tests for project service * test: fix test
1 parent 7f301a7 commit 9381dc2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+823
-477
lines changed

api/api_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package api
22

33
import (
44
"github.com/semaphoreui/semaphore/util"
5-
//_ "github.com/snikch/goodman/hooks"
6-
//_ "github.com/snikch/goodman/transaction"
75
"net/http"
86
"net/http/httptest"
97
"testing"
@@ -17,7 +15,7 @@ func TestApiPing(t *testing.T) {
1715
req, _ := http.NewRequest("GET", "/api/ping", nil)
1816
rr := httptest.NewRecorder()
1917

20-
r := Route()
18+
r := Route(nil, nil)
2119

2220
r.ServeHTTP(rr, req)
2321

api/apps.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7-
"github.com/gorilla/context"
87
"github.com/semaphoreui/semaphore/api/helpers"
98
"github.com/semaphoreui/semaphore/db"
109
"github.com/semaphoreui/semaphore/util"
@@ -80,7 +79,7 @@ func appMiddleware(next http.Handler) http.Handler {
8079
return
8180
}
8281

83-
context.Set(r, "app_id", appID)
82+
r = helpers.SetContextValue(r, "app_id", appID)
8483
next.ServeHTTP(w, r)
8584
})
8685
}
@@ -110,7 +109,7 @@ func getApps(w http.ResponseWriter, r *http.Request) {
110109
}
111110

112111
func getApp(w http.ResponseWriter, r *http.Request) {
113-
appID := context.Get(r, "app_id").(string)
112+
appID := helpers.GetFromContext(r, "app_id").(string)
114113

115114
app, ok := util.Config.Apps[appID]
116115
if !ok {
@@ -122,7 +121,7 @@ func getApp(w http.ResponseWriter, r *http.Request) {
122121
}
123122

124123
func deleteApp(w http.ResponseWriter, r *http.Request) {
125-
appID := context.Get(r, "app_id").(string)
124+
appID := helpers.GetFromContext(r, "app_id").(string)
126125

127126
store := helpers.Store(r)
128127

@@ -161,7 +160,7 @@ func setAppOption(store db.Store, appID string, field string, val any) error {
161160
}
162161

163162
func setApp(w http.ResponseWriter, r *http.Request) {
164-
appID := context.Get(r, "app_id").(string)
163+
appID := helpers.GetFromContext(r, "app_id").(string)
165164

166165
store := helpers.Store(r)
167166

@@ -206,7 +205,7 @@ func setApp(w http.ResponseWriter, r *http.Request) {
206205
}
207206

208207
func setAppActive(w http.ResponseWriter, r *http.Request) {
209-
appID := context.Get(r, "app_id").(string)
208+
appID := helpers.GetFromContext(r, "app_id").(string)
210209

211210
store := helpers.Store(r)
212211

api/auth.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package api
22

33
import (
44
"errors"
5-
"github.com/gorilla/context"
65
"github.com/pquerna/otp"
76
"github.com/semaphoreui/semaphore/api/helpers"
87
"github.com/semaphoreui/semaphore/db"
@@ -207,9 +206,11 @@ func verifySession(w http.ResponseWriter, r *http.Request) {
207206
}
208207
}
209208

210-
func authenticationHandler(w http.ResponseWriter, r *http.Request) bool {
209+
func authenticationHandler(w http.ResponseWriter, r *http.Request) (ok bool, req *http.Request) {
211210
var userID int
212211

212+
req = r
213+
213214
authHeader := strings.ToLower(r.Header.Get("authorization"))
214215

215216
if len(authHeader) > 0 && strings.Contains(authHeader, "bearer") {
@@ -221,29 +222,29 @@ func authenticationHandler(w http.ResponseWriter, r *http.Request) bool {
221222
}
222223

223224
w.WriteHeader(http.StatusUnauthorized)
224-
return false
225+
return
225226
}
226227

227228
userID = token.UserID
228229
} else {
229-
session, ok := getSession(r)
230+
session, found := getSession(r)
230231

231-
if !ok {
232+
if !found {
232233
w.WriteHeader(http.StatusUnauthorized)
233-
return false
234+
return
234235
}
235236

236237
if !session.IsVerified() {
237238
helpers.WriteErrorStatus(w, "TOTP_REQUIRED", http.StatusUnauthorized)
238-
return false
239+
return
239240
}
240241

241242
userID = session.UserID
242243

243244
if err := helpers.Store(r).TouchSession(userID, session.ID); err != nil {
244245
log.Error(err)
245246
w.WriteHeader(http.StatusUnauthorized)
246-
return false
247+
return
247248
}
248249
}
249250

@@ -254,17 +255,18 @@ func authenticationHandler(w http.ResponseWriter, r *http.Request) bool {
254255
log.Error(err)
255256
}
256257
w.WriteHeader(http.StatusUnauthorized)
257-
return false
258+
return
258259
}
259260

260-
context.Set(r, "user", &user)
261-
return true
261+
ok = true
262+
req = helpers.SetContextValue(r, "user", &user)
263+
return
262264
}
263265

264266
// nolint: gocyclo
265267
func authentication(next http.Handler) http.Handler {
266268
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
267-
ok := authenticationHandler(w, r)
269+
ok, r := authenticationHandler(w, r)
268270
if ok {
269271
next.ServeHTTP(w, r)
270272
}
@@ -279,7 +281,7 @@ func authenticationWithStore(next http.Handler) http.Handler {
279281
var ok bool
280282

281283
db.StoreSession(store, r.URL.String(), func() {
282-
ok = authenticationHandler(w, r)
284+
ok, r = authenticationHandler(w, r)
283285
})
284286

285287
if ok {
@@ -290,7 +292,7 @@ func authenticationWithStore(next http.Handler) http.Handler {
290292

291293
func adminMiddleware(next http.Handler) http.Handler {
292294
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
293-
user := context.Get(r, "user").(*db.User)
295+
user := helpers.GetFromContext(r, "user").(*db.User)
294296

295297
if !user.Admin {
296298
w.WriteHeader(http.StatusForbidden)

api/cache.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package api
22

33
import (
4-
"github.com/gorilla/context"
54
"github.com/semaphoreui/semaphore/api/helpers"
65
"github.com/semaphoreui/semaphore/db"
76
"github.com/semaphoreui/semaphore/util"
@@ -10,7 +9,7 @@ import (
109
)
1110

1211
func clearCache(w http.ResponseWriter, r *http.Request) {
13-
currentUser := context.Get(r, "user").(*db.User)
12+
currentUser := helpers.GetFromContext(r, "user").(*db.User)
1413

1514
if !currentUser.Admin {
1615
helpers.WriteJSON(w, http.StatusForbidden, map[string]string{

api/debug/gc.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@ package debug
33
import (
44
"net/http"
55
"runtime"
6-
7-
"github.com/gorilla/context"
86
)
97

108
func GC(w http.ResponseWriter, r *http.Request) {
11-
context.Purge(600)
129
runtime.GC()
1310
w.WriteHeader(http.StatusNoContent)
1411
}

api/events.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ import (
44
"github.com/semaphoreui/semaphore/api/helpers"
55
"github.com/semaphoreui/semaphore/db"
66
"net/http"
7-
8-
"github.com/gorilla/context"
97
)
108

119
// nolint: gocyclo
1210
func getEvents(w http.ResponseWriter, r *http.Request, limit int) {
13-
user := context.Get(r, "user").(*db.User)
14-
projectObj, exists := context.GetOk(r, "project")
11+
user := helpers.GetFromContext(r, "user").(*db.User)
12+
projectObj, exists := helpers.GetOkFromContext(r, "project")
1513

1614
var err error
1715
var events []db.Event

api/helpers/context.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
11
package helpers
22

33
import (
4+
"context"
45
"net/http"
56

6-
"github.com/gorilla/context"
77
"github.com/semaphoreui/semaphore/db"
88
)
99

10+
func GetFromContext(r *http.Request, key string) any {
11+
return r.Context().Value(key)
12+
}
13+
14+
func GetOkFromContext(r *http.Request, key string) (res any, ok bool) {
15+
res = r.Context().Value(key)
16+
return res, res != nil
17+
}
18+
19+
func SetContextValue(r *http.Request, key string, value any) *http.Request {
20+
ctx := r.Context()
21+
ctx = context.WithValue(ctx, key, value)
22+
return r.WithContext(ctx)
23+
}
24+
1025
func UserFromContext(r *http.Request) *db.User {
11-
return context.Get(r, "user").(*db.User)
26+
return GetFromContext(r, "user").(*db.User)
1227
}

0 commit comments

Comments
 (0)