Skip to content

Commit c147080

Browse files
authored
Merge branch '0.9' into circleci-fix
2 parents c9b8da8 + de336fb commit c147080

File tree

6 files changed

+134
-15
lines changed

6 files changed

+134
-15
lines changed

recipe/session/middleware.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ import (
2626
func VerifySessionHelper(recipeInstance Recipe, options *sessmodels.VerifySessionOptions, otherHandler http.HandlerFunc) http.HandlerFunc {
2727
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2828
dw := supertokens.MakeDoneWriter(w)
29+
userContext := supertokens.MakeDefaultUserContextFromAPI(r)
2930
session, err := (*recipeInstance.APIImpl.VerifySession)(options, sessmodels.APIOptions{
3031
Config: recipeInstance.Config,
3132
OtherHandler: otherHandler,
3233
Req: r,
3334
Res: dw,
3435
RecipeID: recipeInstance.RecipeModule.GetRecipeID(),
3536
RecipeImplementation: recipeInstance.RecipeImpl,
36-
}, &map[string]interface{}{})
37+
}, userContext)
3738
if err != nil {
3839
err = supertokens.ErrorHandler(err, r, dw)
3940
if err != nil {

recipe/thirdparty/recipe.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
148148
}
149149

150150
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
151-
return false, err
151+
return false, nil
152152
}
153153

154154
func (r *Recipe) getEmailForUserId(userID string, userContext supertokens.UserContext) (evmodels.TypeEmailInfo, error) {

recipe/thirdpartyemailpassword/recipe.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWrit
205205
return handleError, err
206206
}
207207
}
208-
return false, err
208+
return false, nil
209209
}
210210

211211
func ResetForTest() {

recipe/thirdpartypasswordless/recipe.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWrit
213213
return handleError, err
214214
}
215215
}
216-
return false, err
216+
return false, nil
217217
}
218218

219219
func ResetForTest() {

recipe/userroles/claims_test.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,12 @@ func TestShouldValidatePermissions(t *testing.T) {
313313
invalidClaimErr := err.(sessErrors.InvalidClaimError)
314314
assert.Equal(t, 1, len(invalidClaimErr.InvalidClaims))
315315
assert.Equal(t, "st-perm", invalidClaimErr.InvalidClaims[0].ID)
316-
assert.Equal(t, map[string]interface{}{
317-
"actualValue": []interface{}{"a", "b"},
318-
"expectedToInclude": "nope",
319-
"message": "wrong value",
320-
}, invalidClaimErr.InvalidClaims[0].Reason)
316+
reason := invalidClaimErr.InvalidClaims[0].Reason.(map[string]interface{})
317+
assert.Equal(t, "wrong value", reason["message"])
318+
assert.Equal(t, "nope", reason["expectedToInclude"])
319+
assert.Equal(t, 2, len(reason["actualValue"].([]interface{})))
320+
assert.Contains(t, reason["actualValue"], "a")
321+
assert.Contains(t, reason["actualValue"], "b")
321322
}
322323

323324
func TestShouldValidatePermissionsAfterRefetching(t *testing.T) {
@@ -370,9 +371,10 @@ func TestShouldValidatePermissionsAfterRefetching(t *testing.T) {
370371
invalidClaimErr := err.(sessErrors.InvalidClaimError)
371372
assert.Equal(t, 1, len(invalidClaimErr.InvalidClaims))
372373
assert.Equal(t, "st-perm", invalidClaimErr.InvalidClaims[0].ID)
373-
assert.Equal(t, map[string]interface{}{
374-
"actualValue": []interface{}{"a", "b"},
375-
"expectedToInclude": "nope",
376-
"message": "wrong value",
377-
}, invalidClaimErr.InvalidClaims[0].Reason)
374+
reason := invalidClaimErr.InvalidClaims[0].Reason.(map[string]interface{})
375+
assert.Equal(t, "wrong value", reason["message"])
376+
assert.Equal(t, "nope", reason["expectedToInclude"])
377+
assert.Equal(t, 2, len(reason["actualValue"].([]interface{})))
378+
assert.Contains(t, reason["actualValue"], "a")
379+
assert.Contains(t, reason["actualValue"], "b")
378380
}

test/auth-react-server/main.go

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,22 @@ import (
2828
"github.com/supertokens/supertokens-golang/recipe/emailpassword"
2929
"github.com/supertokens/supertokens-golang/recipe/emailpassword/epmodels"
3030
"github.com/supertokens/supertokens-golang/recipe/emailverification"
31+
"github.com/supertokens/supertokens-golang/recipe/emailverification/evclaims"
3132
"github.com/supertokens/supertokens-golang/recipe/emailverification/evmodels"
3233
"github.com/supertokens/supertokens-golang/recipe/jwt"
3334
"github.com/supertokens/supertokens-golang/recipe/passwordless"
3435
"github.com/supertokens/supertokens-golang/recipe/passwordless/plessmodels"
3536
"github.com/supertokens/supertokens-golang/recipe/session"
37+
"github.com/supertokens/supertokens-golang/recipe/session/claims"
3638
"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
3739
"github.com/supertokens/supertokens-golang/recipe/thirdparty"
3840
"github.com/supertokens/supertokens-golang/recipe/thirdparty/tpmodels"
3941
"github.com/supertokens/supertokens-golang/recipe/thirdpartyemailpassword"
4042
"github.com/supertokens/supertokens-golang/recipe/thirdpartyemailpassword/tpepmodels"
4143
"github.com/supertokens/supertokens-golang/recipe/thirdpartypasswordless"
4244
"github.com/supertokens/supertokens-golang/recipe/thirdpartypasswordless/tplmodels"
45+
"github.com/supertokens/supertokens-golang/recipe/userroles"
46+
"github.com/supertokens/supertokens-golang/recipe/userroles/userrolesclaims"
4347
"github.com/supertokens/supertokens-golang/supertokens"
4448
)
4549

@@ -86,6 +90,7 @@ func callSTInit(passwordlessConfig *plessmodels.TypeInput) {
8690
thirdparty.ResetForTest()
8791
thirdpartyemailpassword.ResetForTest()
8892
thirdpartypasswordless.ResetForTest()
93+
userroles.ResetForTest()
8994

9095
if passwordlessConfig == nil {
9196
passwordlessConfig = &plessmodels.TypeInput{
@@ -525,6 +530,7 @@ func callSTInit(passwordlessConfig *plessmodels.TypeInput) {
525530
},
526531
},
527532
}),
533+
userroles.Init(nil),
528534
},
529535
})
530536

@@ -556,9 +562,119 @@ func callSTInit(passwordlessConfig *plessmodels.TypeInput) {
556562
rw.WriteHeader(200)
557563
rw.Header().Add("content-type", "application/json")
558564
bytes, _ := json.Marshal(map[string]interface{}{
559-
"available": []string{"passwordless", "thirdpartypasswordless", "generalerror"},
565+
"available": []string{"passwordless", "thirdpartypasswordless", "generalerror", "userroles"},
560566
})
561567
rw.Write(bytes)
568+
569+
} else if r.URL.Path == "/unverifyEmail" && r.Method == "GET" {
570+
session.VerifySession(nil, func(w http.ResponseWriter, r *http.Request) {
571+
sessionContainer := session.GetSessionFromRequestContext(r.Context())
572+
emailverification.UnverifyEmail(sessionContainer.GetUserID(), nil)
573+
sessionContainer.FetchAndSetClaim(evclaims.EmailVerificationClaim)
574+
rw.Header().Add("content-type", "application/json")
575+
rw.WriteHeader(200)
576+
rw.Write([]byte("{\"status\": \"OK\"}"))
577+
}).ServeHTTP(rw, r)
578+
579+
} else if r.URL.Path == "/setRole" && r.Method == "POST" {
580+
session.VerifySession(nil, func(w http.ResponseWriter, r *http.Request) {
581+
sessionContainer := session.GetSessionFromRequestContext(r.Context())
582+
bodyBytes, err := ioutil.ReadAll(r.Body)
583+
if err != nil {
584+
return
585+
}
586+
var body map[string]interface{}
587+
err = json.Unmarshal(bodyBytes, &body)
588+
if err != nil {
589+
return
590+
}
591+
role := body["role"].(string)
592+
permissions := body["permissions"].([]interface{})
593+
permissionsStr := make([]string, len(permissions))
594+
for i, p := range permissions {
595+
permissionsStr[i] = p.(string)
596+
}
597+
_, err = userroles.CreateNewRoleOrAddPermissions(role, permissionsStr, &map[string]interface{}{})
598+
if err != nil {
599+
return
600+
}
601+
_, err = userroles.AddRoleToUser(sessionContainer.GetUserID(), role, &map[string]interface{}{})
602+
if err != nil {
603+
return
604+
}
605+
err = sessionContainer.FetchAndSetClaim(userrolesclaims.UserRoleClaim)
606+
if err != nil {
607+
return
608+
}
609+
err = sessionContainer.FetchAndSetClaim(userrolesclaims.PermissionClaim)
610+
if err != nil {
611+
return
612+
}
613+
rw.Header().Add("content-type", "application/json")
614+
rw.WriteHeader(200)
615+
rw.Write([]byte("{\"status\": \"OK\"}"))
616+
}).ServeHTTP(rw, r)
617+
618+
} else if r.URL.Path == "/checkRole" && r.Method == "POST" {
619+
session.VerifySession(&sessmodels.VerifySessionOptions{
620+
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
621+
req := (*userContext)["_default"].(map[string]interface{})["request"].(*http.Request)
622+
bodyBytes, err := ioutil.ReadAll(req.Body)
623+
if err != nil {
624+
return nil, err
625+
}
626+
var body map[string]interface{}
627+
err = json.Unmarshal(bodyBytes, &body)
628+
if err != nil {
629+
return nil, err
630+
}
631+
632+
getValidator := func(validator claims.PrimitiveArrayClaimValidators, validatorStr string, args []interface{}) claims.SessionClaimValidator {
633+
var maxAge *int64 = nil
634+
var id *string = nil
635+
if len(args) > 1 {
636+
maxAgeFloat := args[1].(float64)
637+
maxAgeInt := int64(maxAgeFloat)
638+
maxAge = &maxAgeInt
639+
}
640+
if len(args) > 2 {
641+
idStr := args[2].(string)
642+
id = &idStr
643+
}
644+
645+
switch validatorStr {
646+
case "includes":
647+
return validator.Includes(args[0].(string), maxAge, id)
648+
case "excludes":
649+
return validator.Excludes(args[0].(string), maxAge, id)
650+
case "includesAll":
651+
return validator.IncludesAll(args[0].([]interface{}), maxAge, id)
652+
case "excludesAll":
653+
return validator.ExcludesAll(args[0].([]interface{}), maxAge, id)
654+
}
655+
656+
return claims.SessionClaimValidator{}
657+
}
658+
659+
if role, ok := body["role"].(map[string]interface{}); ok {
660+
validatorStr := role["validator"].(string)
661+
args := role["args"].([]interface{})
662+
globalClaimValidators = append(globalClaimValidators, getValidator(userrolesclaims.UserRoleClaimValidators, validatorStr, args))
663+
}
664+
665+
if permission, ok := body["permission"].(map[string]interface{}); ok {
666+
validatorStr := permission["validator"].(string)
667+
args := permission["args"].([]interface{})
668+
globalClaimValidators = append(globalClaimValidators, getValidator(userrolesclaims.PermissionClaimValidators, validatorStr, args))
669+
}
670+
671+
return globalClaimValidators, nil
672+
},
673+
}, func(w http.ResponseWriter, r *http.Request) {
674+
rw.Header().Add("content-type", "application/json")
675+
rw.WriteHeader(200)
676+
rw.Write([]byte("{\"status\": \"OK\"}"))
677+
}).ServeHTTP(rw, r)
562678
}
563679
}))
564680

0 commit comments

Comments
 (0)