Skip to content

Commit c97ebf3

Browse files
ShahParth12ParthAtStrato-Cloudio
authored andcommitted
feat: RequireRole function now accepts multiple roles (#143)
Co-authored-by: Parth Shah <parth@strato-cloud.io>
1 parent 4ac2042 commit c97ebf3

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

middlewares/auth.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,12 @@ func UserRequestContext(next http.Handler) http.Handler {
398398
})
399399
}
400400

401-
// RequireRole is a middleware that checks if the user has the required role
402-
// If the user does not have the role, it returns a 403 Forbidden response
403-
// If the user is not authenticated, it returns a 401 Unauthorized response
404-
// Usage: router.Use(RequireRole("admin"))
405-
// If the user has the role, it calls the next handler in the chain
406-
func RequireRole(roleName string) func(http.Handler) http.Handler {
401+
// RequireRole is a middleware that checks if the user has at least one of the required roles.
402+
// If the user does not have any of the roles, it returns a 403 Forbidden response.
403+
// If the user is not authenticated, it returns a 401 Unauthorized response.
404+
// Usage: router.Use(RequireRole("admin")) or router.Use(RequireRole("admin", "editor"))
405+
// If the user has any of the roles, it calls the next handler in the chain
406+
func RequireRole(roleNames ...string) func(http.Handler) http.Handler {
407407
return func(next http.Handler) http.Handler {
408408
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
409409
claims := getValidatedClaims(r.Context())
@@ -414,11 +414,13 @@ func RequireRole(roleName string) func(http.Handler) http.Handler {
414414
return
415415
}
416416
claimsConfig := getClaimsConfig(claims.ClaimsConfig)
417-
roles := getRoles(claims.CustomClaims, claimsConfig)
418-
for _, role := range roles {
419-
if role == roleName {
420-
next.ServeHTTP(rw, r)
421-
return
417+
userRoles := getRoles(claims.CustomClaims, claimsConfig)
418+
for _, userRole := range userRoles {
419+
for _, allowedRole := range roleNames {
420+
if userRole == allowedRole {
421+
next.ServeHTTP(rw, r)
422+
return
423+
}
422424
}
423425
}
424426
rw.Header().Set("Content-Type", "application/json")

0 commit comments

Comments
 (0)