Skip to content

Commit bf4ad09

Browse files
committed
Merge branch 'userroles-claims' into claims-emailpassword-changes
2 parents 3e21129 + b34a68c commit bf4ad09

File tree

8 files changed

+90
-74
lines changed

8 files changed

+90
-74
lines changed

recipe/emailverification/emailverificationClaim.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@ func NewEmailVerificationClaim() (claims.TypeSessionClaim, evclaims.TypeEmailVer
3636

3737
evClaim, booleanClaimValidators := claims.BooleanClaim("st-ev", fetchValue, nil)
3838

39+
getValueFromPayload := func(payload map[string]interface{}, userContext supertokens.UserContext) interface{} {
40+
if value, ok := evClaim.GetValueFromPayload(payload, userContext).(map[string]interface{}); ok {
41+
return value["v"]
42+
}
43+
return nil
44+
}
45+
46+
getLastRefetchTime := func(payload map[string]interface{}, userContext supertokens.UserContext) *int64 {
47+
if value, ok := evClaim.GetValueFromPayload(payload, userContext).(map[string]interface{}); ok {
48+
val := value["t"].(int64)
49+
return &val
50+
}
51+
return nil
52+
}
53+
3954
validators := evclaims.TypeEmailVerificationClaimValidators{
4055
BooleanClaimValidators: booleanClaimValidators,
4156
IsVerified: func(refetchTimeOnFalseInSeconds *int64) claims.SessionClaimValidator {
@@ -46,8 +61,8 @@ func NewEmailVerificationClaim() (claims.TypeSessionClaim, evclaims.TypeEmailVer
4661

4762
claimValidator := booleanClaimValidators.HasValue(true, nil, nil)
4863
claimValidator.ShouldRefetch = func(payload map[string]interface{}, userContext supertokens.UserContext) bool {
49-
value := evClaim.GetValueFromPayload(payload, userContext)
50-
return value == nil || (value == false && *evClaim.GetLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*refetchTimeOnFalseInSeconds*1000)
64+
value := getValueFromPayload(payload, userContext)
65+
return value == nil || (value == false && *getLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*refetchTimeOnFalseInSeconds*1000)
5166
}
5267
return claimValidator
5368
},

recipe/session/claims/claims.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ type TypeSessionClaim struct {
3333
RemoveFromPayloadByMerge_internal func(payload map[string]interface{}, userContext supertokens.UserContext) map[string]interface{}
3434
RemoveFromPayload func(payload map[string]interface{}, userContext supertokens.UserContext) map[string]interface{}
3535
GetValueFromPayload func(payload map[string]interface{}, userContext supertokens.UserContext) interface{}
36-
GetLastRefetchTime func(payload map[string]interface{}, userContext supertokens.UserContext) *int64
3736
Build func(userId string, userContext supertokens.UserContext) (map[string]interface{}, error)
3837
}
3938

recipe/session/claims/primitiveArrayClaim.go

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,20 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
3434

3535
sessionClaim.GetValueFromPayload = func(payload map[string]interface{}, userContext supertokens.UserContext) interface{} {
3636
if value, ok := payload[sessionClaim.Key].(map[string]interface{}); ok {
37+
return value
38+
}
39+
return nil
40+
}
41+
42+
getValueFromPayload := func(payload map[string]interface{}, userContext supertokens.UserContext) interface{} {
43+
if value, ok := sessionClaim.GetValueFromPayload(payload, userContext).(map[string]interface{}); ok {
3744
return value["v"]
3845
}
3946
return nil
4047
}
4148

42-
sessionClaim.GetLastRefetchTime = func(payload map[string]interface{}, userContext supertokens.UserContext) *int64 {
43-
if value, ok := payload[sessionClaim.Key].(map[string]interface{}); ok {
49+
getLastRefetchTime := func(payload map[string]interface{}, userContext supertokens.UserContext) *int64 {
50+
if value, ok := sessionClaim.GetValueFromPayload(payload, userContext).(map[string]interface{}); ok {
4451
val := value["t"].(int64)
4552
return &val
4653
}
@@ -60,17 +67,17 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
6067
ID: claimId,
6168
Claim: &sessionClaim,
6269
ShouldRefetch: func(payload map[string]interface{}, userContext supertokens.UserContext) bool {
63-
claimVal, ok := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
70+
claimVal, ok := getValueFromPayload(payload, userContext).([]interface{})
6471
if !ok || claimVal == nil {
6572
return true
6673
}
6774
if maxAgeInSeconds != nil {
68-
return *sessionClaim.GetLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
75+
return *getLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
6976
}
7077
return false
7178
},
7279
Validate: func(payload map[string]interface{}, userContext supertokens.UserContext) ClaimValidationResult {
73-
claimVal := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
80+
claimVal := getValueFromPayload(payload, userContext).([]interface{})
7481

7582
if claimVal == nil {
7683
return ClaimValidationResult{
@@ -82,7 +89,7 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
8289
},
8390
}
8491
}
85-
ageInSeconds := (time.Now().UnixMilli() - *sessionClaim.GetLastRefetchTime(payload, userContext)) / 1000
92+
ageInSeconds := (time.Now().UnixMilli() - *getLastRefetchTime(payload, userContext)) / 1000
8693
if maxAgeInSeconds != nil && ageInSeconds > *maxAgeInSeconds {
8794
return ClaimValidationResult{
8895
IsValid: false,
@@ -121,17 +128,17 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
121128
ID: claimId,
122129
Claim: &sessionClaim,
123130
ShouldRefetch: func(payload map[string]interface{}, userContext supertokens.UserContext) bool {
124-
val, ok := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
131+
val, ok := getValueFromPayload(payload, userContext).([]interface{})
125132
if !ok || val == nil {
126133
return true
127134
}
128135
if maxAgeInSeconds != nil {
129-
return *sessionClaim.GetLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
136+
return *getLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
130137
}
131138
return false
132139
},
133140
Validate: func(payload map[string]interface{}, userContext supertokens.UserContext) ClaimValidationResult {
134-
claimVal := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
141+
claimVal := getValueFromPayload(payload, userContext).([]interface{})
135142

136143
if claimVal == nil {
137144
return ClaimValidationResult{
@@ -143,7 +150,7 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
143150
},
144151
}
145152
}
146-
ageInSeconds := (time.Now().UnixMilli() - *sessionClaim.GetLastRefetchTime(payload, userContext)) / 1000
153+
ageInSeconds := (time.Now().UnixMilli() - *getLastRefetchTime(payload, userContext)) / 1000
147154
if maxAgeInSeconds != nil && ageInSeconds > *maxAgeInSeconds {
148155
return ClaimValidationResult{
149156
IsValid: false,
@@ -182,17 +189,17 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
182189
ID: claimId,
183190
Claim: &sessionClaim,
184191
ShouldRefetch: func(payload map[string]interface{}, userContext supertokens.UserContext) bool {
185-
val, ok := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
192+
val, ok := getValueFromPayload(payload, userContext).([]interface{})
186193
if !ok || val == nil {
187194
return true
188195
}
189196
if maxAgeInSeconds != nil {
190-
return *sessionClaim.GetLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
197+
return *getLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
191198
}
192199
return false
193200
},
194201
Validate: func(payload map[string]interface{}, userContext supertokens.UserContext) ClaimValidationResult {
195-
claimVal := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
202+
claimVal := getValueFromPayload(payload, userContext).([]interface{})
196203

197204
if claimVal == nil {
198205
return ClaimValidationResult{
@@ -204,7 +211,7 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
204211
},
205212
}
206213
}
207-
ageInSeconds := (time.Now().UnixMilli() - *sessionClaim.GetLastRefetchTime(payload, userContext)) / 1000
214+
ageInSeconds := (time.Now().UnixMilli() - *getLastRefetchTime(payload, userContext)) / 1000
208215
if maxAgeInSeconds != nil && ageInSeconds > *maxAgeInSeconds {
209216
return ClaimValidationResult{
210217
IsValid: false,
@@ -244,17 +251,17 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
244251
ID: claimId,
245252
Claim: &sessionClaim,
246253
ShouldRefetch: func(payload map[string]interface{}, userContext supertokens.UserContext) bool {
247-
val, ok := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
254+
val, ok := getValueFromPayload(payload, userContext).([]interface{})
248255
if !ok || val == nil {
249256
return true
250257
}
251258
if maxAgeInSeconds != nil {
252-
return *sessionClaim.GetLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
259+
return *getLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
253260
}
254261
return false
255262
},
256263
Validate: func(payload map[string]interface{}, userContext supertokens.UserContext) ClaimValidationResult {
257-
claimVal := sessionClaim.GetValueFromPayload(payload, userContext).([]interface{})
264+
claimVal := getValueFromPayload(payload, userContext).([]interface{})
258265

259266
if claimVal == nil {
260267
return ClaimValidationResult{
@@ -266,7 +273,7 @@ func PrimitiveArrayClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInS
266273
},
267274
}
268275
}
269-
ageInSeconds := (time.Now().UnixMilli() - *sessionClaim.GetLastRefetchTime(payload, userContext)) / 1000
276+
ageInSeconds := (time.Now().UnixMilli() - *getLastRefetchTime(payload, userContext)) / 1000
270277
if maxAgeInSeconds != nil && ageInSeconds > *maxAgeInSeconds {
271278
return ClaimValidationResult{
272279
IsValid: false,

recipe/session/claims/primitiveClaim.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,20 @@ func PrimitiveClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInSecond
3434

3535
sessionClaim.GetValueFromPayload = func(payload map[string]interface{}, userContext supertokens.UserContext) interface{} {
3636
if value, ok := payload[sessionClaim.Key].(map[string]interface{}); ok {
37+
return value
38+
}
39+
return nil
40+
}
41+
42+
getValueFromPayload := func(payload map[string]interface{}, userContext supertokens.UserContext) interface{} {
43+
if value, ok := sessionClaim.GetValueFromPayload(payload, userContext).(map[string]interface{}); ok {
3744
return value["v"]
3845
}
3946
return nil
4047
}
4148

42-
sessionClaim.GetLastRefetchTime = func(payload map[string]interface{}, userContext supertokens.UserContext) *int64 {
43-
if value, ok := payload[sessionClaim.Key].(map[string]interface{}); ok {
49+
getLastRefetchTime := func(payload map[string]interface{}, userContext supertokens.UserContext) *int64 {
50+
if value, ok := sessionClaim.GetValueFromPayload(payload, userContext).(map[string]interface{}); ok {
4451
val := value["t"].(int64)
4552
return &val
4653
}
@@ -60,14 +67,14 @@ func PrimitiveClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInSecond
6067
ID: validatorId,
6168
Claim: &sessionClaim,
6269
ShouldRefetch: func(payload map[string]interface{}, userContext supertokens.UserContext) bool {
63-
val := sessionClaim.GetValueFromPayload(payload, userContext)
70+
val := getValueFromPayload(payload, userContext)
6471
if val == nil {
6572
return true
6673
}
67-
return maxAgeInSeconds != nil && *sessionClaim.GetLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
74+
return maxAgeInSeconds != nil && *getLastRefetchTime(payload, userContext) < time.Now().UnixMilli()-*maxAgeInSeconds*1000
6875
},
6976
Validate: func(payload map[string]interface{}, userContext supertokens.UserContext) ClaimValidationResult {
70-
claimVal := sessionClaim.GetValueFromPayload(payload, userContext)
77+
claimVal := getValueFromPayload(payload, userContext)
7178

7279
if claimVal == nil {
7380
return ClaimValidationResult{
@@ -79,7 +86,7 @@ func PrimitiveClaim(key string, fetchValue FetchValueFunc, defaultMaxAgeInSecond
7986
},
8087
}
8188
}
82-
ageInSeconds := (time.Now().UnixMilli() - *sessionClaim.GetLastRefetchTime(payload, userContext)) / 1000
89+
ageInSeconds := (time.Now().UnixMilli() - *getLastRefetchTime(payload, userContext)) / 1000
8390
if maxAgeInSeconds != nil && ageInSeconds > *maxAgeInSeconds {
8491
return ClaimValidationResult{
8592
IsValid: false,

recipe/userroles/claims.go

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@ package userroles
22

33
import (
44
"github.com/supertokens/supertokens-golang/recipe/session/claims"
5-
urclaims "github.com/supertokens/supertokens-golang/recipe/userroles/claims"
5+
"github.com/supertokens/supertokens-golang/recipe/userroles/userrolesclaims"
66
"github.com/supertokens/supertokens-golang/supertokens"
77
)
88

99
func init() {
10-
urclaims.UserRoleClaim = NewUserRoleClaim()
11-
urclaims.PermissionClaim = NewPermissionClaim()
10+
// automatically called when this package is imported
11+
userrolesclaims.UserRoleClaim, userrolesclaims.UserRoleClaimValidators = NewUserRoleClaim()
12+
userrolesclaims.PermissionClaim, userrolesclaims.PermissionClaimValidators = NewPermissionClaim()
1213
}
1314

14-
func NewUserRoleClaim() *urclaims.TypeUserRoleClaim {
15+
func NewUserRoleClaim() (claims.TypeSessionClaim, userrolesclaims.TypeUserRoleClaimValidators) {
1516
fetchValue := func(userId string, userContext supertokens.UserContext) (interface{}, error) {
1617
recipe, err := getRecipeInstanceOrThrowError()
1718
if err != nil {
@@ -29,16 +30,13 @@ func NewUserRoleClaim() *urclaims.TypeUserRoleClaim {
2930
return rolesArray, nil
3031
}
3132

32-
primitiveArrayClaim := claims.PrimitiveArrayClaim("st-role", fetchValue, nil)
33-
return &urclaims.TypeUserRoleClaim{
34-
TypePrimitiveArrayClaim: primitiveArrayClaim,
35-
Validators: &urclaims.TypeUserRoleClaimValidators{
36-
PrimitiveArrayClaimValidators: primitiveArrayClaim.Validators,
37-
},
33+
userRoleClaim, primitiveArrayClaimValidators := claims.PrimitiveArrayClaim("st-role", fetchValue, nil)
34+
return userRoleClaim, userrolesclaims.TypeUserRoleClaimValidators{
35+
PrimitiveArrayClaimValidators: primitiveArrayClaimValidators,
3836
}
3937
}
4038

41-
func NewPermissionClaim() *urclaims.TypePermissionClaim {
39+
func NewPermissionClaim() (claims.TypeSessionClaim, userrolesclaims.TypePermissionClaimValidators) {
4240
fetchValue := func(userId string, userContext supertokens.UserContext) (interface{}, error) {
4341
recipe, err := getRecipeInstanceOrThrowError()
4442
if err != nil {
@@ -69,11 +67,8 @@ func NewPermissionClaim() *urclaims.TypePermissionClaim {
6967
return result, nil
7068
}
7169

72-
primitiveArrayClaim := claims.PrimitiveArrayClaim("st-perm", fetchValue, nil)
73-
return &urclaims.TypePermissionClaim{
74-
TypePrimitiveArrayClaim: primitiveArrayClaim,
75-
Validators: &urclaims.TypePermissionClaimValidators{
76-
PrimitiveArrayClaimValidators: primitiveArrayClaim.Validators,
77-
},
70+
permissionClaim, primitiveArrayClaimValidators := claims.PrimitiveArrayClaim("st-perm", fetchValue, nil)
71+
return permissionClaim, userrolesclaims.TypePermissionClaimValidators{
72+
PrimitiveArrayClaimValidators: primitiveArrayClaimValidators,
7873
}
7974
}

recipe/userroles/claims/claims.go

Lines changed: 0 additions & 25 deletions
This file was deleted.

recipe/userroles/recipe.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
"net/http"
2121

2222
"github.com/supertokens/supertokens-golang/recipe/session"
23-
"github.com/supertokens/supertokens-golang/recipe/userroles/claims"
23+
"github.com/supertokens/supertokens-golang/recipe/userroles/userrolesclaims"
2424
"github.com/supertokens/supertokens-golang/recipe/userroles/userrolesmodels"
2525
"github.com/supertokens/supertokens-golang/supertokens"
2626
)
@@ -69,19 +69,20 @@ func recipeInit(config *userrolesmodels.TypeInput) supertokens.Recipe {
6969
}
7070
singletonInstance = &recipe
7171

72-
supertokens.AddPostInitCallback(func() {
72+
supertokens.AddPostInitCallback(func() error {
7373
sessionRecipe, err := session.GetRecipeInstanceOrThrowError()
7474
if err != nil {
75-
return
75+
return err
7676
}
7777

7878
if !config.SkipAddingRolesToAccessToken {
79-
sessionRecipe.AddClaimFromOtherRecipe(claims.UserRoleClaim.TypeSessionClaim)
79+
sessionRecipe.AddClaimFromOtherRecipe(userrolesclaims.UserRoleClaim)
8080
}
8181

8282
if !config.SkipAddingPermissionsToAccessToken {
83-
sessionRecipe.AddClaimFromOtherRecipe(claims.PermissionClaim.TypeSessionClaim)
83+
sessionRecipe.AddClaimFromOtherRecipe(userrolesclaims.PermissionClaim)
8484
}
85+
return nil
8586
})
8687

8788
return &singletonInstance.RecipeModule, nil
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package userrolesclaims
2+
3+
import "github.com/supertokens/supertokens-golang/recipe/session/claims"
4+
5+
type TypeUserRoleClaimValidators struct {
6+
claims.PrimitiveArrayClaimValidators
7+
}
8+
9+
var UserRoleClaim claims.TypeSessionClaim
10+
var UserRoleClaimValidators TypeUserRoleClaimValidators
11+
12+
type TypePermissionClaimValidators struct {
13+
claims.PrimitiveArrayClaimValidators
14+
}
15+
16+
var PermissionClaim claims.TypeSessionClaim
17+
var PermissionClaimValidators TypePermissionClaimValidators

0 commit comments

Comments
 (0)