Skip to content

Commit 64548be

Browse files
Merge pull request #172 from supertokens/session-container-to-pointer
fix: Session container to pointer
2 parents 3d0d227 + eba4d9d commit 64548be

26 files changed

+270
-211
lines changed

examples/with-labstack-echo/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func verifySession(options *sessmodels.VerifySessionOptions) echo.MiddlewareFunc
144144
}
145145

146146
func sessioninfo(c echo.Context) error {
147-
sessionContainer := c.Get("session").(*sessmodels.SessionContainer)
147+
sessionContainer := c.Get("session").(sessmodels.SessionContainer)
148148

149149
if sessionContainer == nil {
150150
return errors.New("no session found")

recipe/emailpassword/emailExistsAndVerificationCheck_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ func TestThatTheHandlePostEmailVerificationCallBackIsCalledOnSuccessFullVerifica
931931
Override: &evmodels.OverrideStruct{
932932
APIs: func(originalImplementation evmodels.APIInterface) evmodels.APIInterface {
933933
originalVerifyEmailPost := *originalImplementation.VerifyEmailPOST
934-
*originalImplementation.VerifyEmailPOST = func(token string, sessionContainer *sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
934+
*originalImplementation.VerifyEmailPOST = func(token string, sessionContainer sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
935935
res, err := originalVerifyEmailPost(token, sessionContainer, options, userContext)
936936
if err != nil {
937937
log.Fatal(err.Error())
@@ -1191,7 +1191,7 @@ func TestTheEmailVerifyAPIwithValidInputOverridingAPIs(t *testing.T) {
11911191
Override: &evmodels.OverrideStruct{
11921192
APIs: func(originalImplementation evmodels.APIInterface) evmodels.APIInterface {
11931193
originalVerifyEmailPost := *originalImplementation.VerifyEmailPOST
1194-
*originalImplementation.VerifyEmailPOST = func(token string, sessionContainer *sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
1194+
*originalImplementation.VerifyEmailPOST = func(token string, sessionContainer sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
11951195
res, err := originalVerifyEmailPost(token, sessionContainer, options, userContext)
11961196
if err != nil {
11971197
log.Fatal(err.Error())
@@ -1397,7 +1397,7 @@ func TestTheEmailVerifyAPIwithValidInputThrowsErrorOnSuchOverriding(t *testing.T
13971397
Override: &evmodels.OverrideStruct{
13981398
APIs: func(originalImplementation evmodels.APIInterface) evmodels.APIInterface {
13991399
originalVerifyEmailPost := *originalImplementation.VerifyEmailPOST
1400-
*originalImplementation.VerifyEmailPOST = func(token string, sessionContainer *sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
1400+
*originalImplementation.VerifyEmailPOST = func(token string, sessionContainer sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
14011401
res, err := originalVerifyEmailPost(token, sessionContainer, options, userContext)
14021402
if err != nil {
14031403
log.Fatal(err.Error())

recipe/emailverification/api/emailverify.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func EmailVerify(apiImplementation evmodels.APIInterface, options evmodels.APIOp
4242
options.Req, options.Res,
4343
&sessmodels.VerifySessionOptions{
4444
SessionRequired: &sessionRequired,
45-
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer *sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
45+
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
4646
validators := []claims.SessionClaimValidator{}
4747
return validators, nil
4848
},
@@ -100,7 +100,7 @@ func EmailVerify(apiImplementation evmodels.APIInterface, options evmodels.APIOp
100100
options.Req,
101101
options.Res,
102102
&sessmodels.VerifySessionOptions{
103-
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer *sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
103+
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
104104
validators := []claims.SessionClaimValidator{}
105105
return validators, nil
106106
},

recipe/emailverification/api/generateEmailVerifyToken.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func GenerateEmailVerifyToken(apiImplementation evmodels.APIInterface, options e
3535
sessionContainer, err := session.GetSessionWithContext(
3636
options.Req, options.Res,
3737
&sessmodels.VerifySessionOptions{
38-
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer *sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
38+
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
3939
validators := []claims.SessionClaimValidator{}
4040
return validators, nil
4141
},

recipe/emailverification/api/implementation.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727
)
2828

2929
func MakeAPIImplementation() evmodels.APIInterface {
30-
verifyEmailPOST := func(token string, sessionContainer *sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
30+
verifyEmailPOST := func(token string, sessionContainer sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.VerifyEmailPOSTResponse, error) {
3131
resp, err := (*options.RecipeImplementation.VerifyEmailUsingToken)(token, userContext)
3232
if err != nil {
3333
return evmodels.VerifyEmailPOSTResponse{}, err
@@ -49,7 +49,7 @@ func MakeAPIImplementation() evmodels.APIInterface {
4949
}
5050
}
5151

52-
isEmailVerifiedGET := func(sessionContainer *sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.IsEmailVerifiedGETResponse, error) {
52+
isEmailVerifiedGET := func(sessionContainer sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.IsEmailVerifiedGETResponse, error) {
5353
if sessionContainer == nil {
5454
return evmodels.IsEmailVerifiedGETResponse{}, supertokens.BadInputError{Msg: "Session is undefined. Should not come here."}
5555
}
@@ -70,7 +70,7 @@ func MakeAPIImplementation() evmodels.APIInterface {
7070
}, nil
7171
}
7272

73-
generateEmailVerifyTokenPOST := func(sessionContainer *sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.GenerateEmailVerifyTokenPOSTResponse, error) {
73+
generateEmailVerifyTokenPOST := func(sessionContainer sessmodels.SessionContainer, options evmodels.APIOptions, userContext supertokens.UserContext) (evmodels.GenerateEmailVerifyTokenPOSTResponse, error) {
7474
if sessionContainer == nil {
7575
return evmodels.GenerateEmailVerifyTokenPOSTResponse{}, supertokens.BadInputError{Msg: "Session is undefined. Should not come here."}
7676
}

recipe/emailverification/evmodels/apiInterface.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ type APIOptions struct {
3636
}
3737

3838
type APIInterface struct {
39-
VerifyEmailPOST *func(token string, sessionContainer *sessmodels.SessionContainer, options APIOptions, userContext supertokens.UserContext) (VerifyEmailPOSTResponse, error)
40-
IsEmailVerifiedGET *func(sessionContainer *sessmodels.SessionContainer, options APIOptions, userContext supertokens.UserContext) (IsEmailVerifiedGETResponse, error)
41-
GenerateEmailVerifyTokenPOST *func(sessionContainer *sessmodels.SessionContainer, options APIOptions, userContext supertokens.UserContext) (GenerateEmailVerifyTokenPOSTResponse, error)
39+
VerifyEmailPOST *func(token string, sessionContainer sessmodels.SessionContainer, options APIOptions, userContext supertokens.UserContext) (VerifyEmailPOSTResponse, error)
40+
IsEmailVerifiedGET *func(sessionContainer sessmodels.SessionContainer, options APIOptions, userContext supertokens.UserContext) (IsEmailVerifiedGETResponse, error)
41+
GenerateEmailVerifyTokenPOST *func(sessionContainer sessmodels.SessionContainer, options APIOptions, userContext supertokens.UserContext) (GenerateEmailVerifyTokenPOSTResponse, error)
4242
}
4343

4444
type VerifyEmailPOSTResponse struct {

recipe/session/api/implementation.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func MakeAPIImplementation() sessmodels.APIInterface {
2828
return (*options.RecipeImplementation.RefreshSession)(options.Req, options.Res, userContext)
2929
}
3030

31-
verifySession := func(verifySessionOptions *sessmodels.VerifySessionOptions, options sessmodels.APIOptions, userContext supertokens.UserContext) (*sessmodels.SessionContainer, error) {
31+
verifySession := func(verifySessionOptions *sessmodels.VerifySessionOptions, options sessmodels.APIOptions, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
3232
method := options.Req.Method
3333
if method == http.MethodOptions || method == http.MethodTrace {
3434
return nil, nil
@@ -42,7 +42,7 @@ func MakeAPIImplementation() sessmodels.APIInterface {
4242
refreshTokenPath := options.Config.RefreshTokenPath
4343
if incomingPath.Equals(refreshTokenPath) && method == http.MethodPost {
4444
session, err := (*options.RecipeImplementation.RefreshSession)(options.Req, options.Res, userContext)
45-
return &session, err
45+
return session, err
4646
} else {
4747
sessionContainer, err := (*options.RecipeImplementation.GetSession)(options.Req, options.Res, verifySessionOptions, userContext)
4848
if err != nil {
@@ -53,7 +53,7 @@ func MakeAPIImplementation() sessmodels.APIInterface {
5353
return nil, nil
5454
}
5555

56-
var overrideGlobalClaimValidators func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer *sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) = nil
56+
var overrideGlobalClaimValidators func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) = nil
5757
if verifySessionOptions != nil {
5858
overrideGlobalClaimValidators = verifySessionOptions.OverrideGlobalClaimValidators
5959
}
@@ -81,7 +81,7 @@ func MakeAPIImplementation() sessmodels.APIInterface {
8181
}
8282
}
8383

84-
signOutPOST := func(sessionContainer *sessmodels.SessionContainer, options sessmodels.APIOptions, userContext supertokens.UserContext) (sessmodels.SignOutPOSTResponse, error) {
84+
signOutPOST := func(sessionContainer sessmodels.SessionContainer, options sessmodels.APIOptions, userContext supertokens.UserContext) (sessmodels.SignOutPOSTResponse, error) {
8585
if sessionContainer != nil {
8686
err := sessionContainer.RevokeSessionWithContext(userContext)
8787
if err != nil {

recipe/session/api/signout.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func SignOutAPI(apiImplementation sessmodels.APIInterface, options sessmodels.AP
3232
False := false
3333
sessionContainer, err := (*options.RecipeImplementation.GetSession)(options.Req, options.Res, &sessmodels.VerifySessionOptions{
3434
SessionRequired: &False,
35-
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer *sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
35+
OverrideGlobalClaimValidators: func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error) {
3636
return []claims.SessionClaimValidator{}, nil
3737
},
3838
}, userContext)

recipe/session/claimsWithJWT_test.go

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
"github.com/golang-jwt/jwt/v4"
99
"github.com/stretchr/testify/assert"
10+
"github.com/supertokens/supertokens-golang/recipe/session/claims"
1011
"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
1112
"github.com/supertokens/supertokens-golang/supertokens"
1213
"github.com/supertokens/supertokens-golang/test/unittesting"
@@ -37,7 +38,7 @@ func TestJWTShouldCreateRightAccessTokenPayloadWithClaims(t *testing.T) {
3738
claim, _ := TrueClaim()
3839
accessTokenPayload, err := claim.Build(userID, accessTokenPayload, userContext)
3940
if err != nil {
40-
return sessmodels.SessionContainer{}, err
41+
return nil, err
4142
}
4243
return oCreateNewSession(res, userID, accessTokenPayload, sessionData, userContext)
4344
}
@@ -79,10 +80,93 @@ func TestJWTShouldCreateRightAccessTokenPayloadWithClaims(t *testing.T) {
7980
sessInfo, err := GetSessionInformation(sessionContainer.GetHandle())
8081
assert.NoError(t, err)
8182
jwtPayloadStr := sessInfo.AccessTokenPayload["jwt"].(string)
82-
token, _ := jwt.Parse(jwtPayloadStr, func(t *jwt.Token) (interface{}, error) {
83-
return nil, nil
84-
})
85-
jwtPayload := token.Claims.(jwt.MapClaims)
83+
jwtPayload := jwt.MapClaims{}
84+
85+
_, _, err = (&jwt.Parser{}).ParseUnverified(jwtPayloadStr, jwtPayload)
86+
assert.NoError(t, err)
87+
8688
assert.Equal(t, true, jwtPayload["st-true"].(map[string]interface{})["v"])
8789
assert.Equal(t, "rope", jwtPayload["sub"])
8890
}
91+
92+
func TestAssertClaimsWithPayloadWithJWTAndCallRightUpdateAccessTokenPayload(t *testing.T) {
93+
configValue := supertokens.TypeInput{
94+
Supertokens: &supertokens.ConnectionInfo{
95+
ConnectionURI: "http://localhost:8080",
96+
},
97+
AppInfo: supertokens.AppInfo{
98+
AppName: "SuperTokens",
99+
WebsiteDomain: "supertokens.io",
100+
APIDomain: "api.supertokens.io",
101+
},
102+
RecipeList: []supertokens.Recipe{
103+
Init(&sessmodels.TypeInput{
104+
Jwt: &sessmodels.JWTInputConfig{Enable: true},
105+
}),
106+
},
107+
}
108+
BeforeEach()
109+
unittesting.StartUpST("localhost", "8080")
110+
defer AfterEach()
111+
err := supertokens.Init(configValue)
112+
if err != nil {
113+
t.Error(err.Error())
114+
}
115+
116+
mux := http.NewServeMux()
117+
var sessionContainer sessmodels.SessionContainer
118+
accessTokenPayload := map[string]interface{}{
119+
"hello": "world",
120+
}
121+
122+
mux.HandleFunc("/create", func(rw http.ResponseWriter, r *http.Request) {
123+
var err error
124+
sessionContainer, err = CreateNewSession(rw, "rope", accessTokenPayload, map[string]interface{}{})
125+
assert.NoError(t, err)
126+
})
127+
128+
testServer := httptest.NewServer(supertokens.Middleware(mux))
129+
defer func() {
130+
testServer.Close()
131+
}()
132+
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/create", nil)
133+
assert.NoError(t, err)
134+
res, err := http.DefaultClient.Do(req)
135+
assert.NoError(t, err)
136+
assert.Equal(t, 200, res.StatusCode)
137+
138+
validateCallCount := 0
139+
var validationPayload map[string]interface{}
140+
141+
validate := func(payload map[string]interface{}, userContext supertokens.UserContext) claims.ClaimValidationResult {
142+
validateCallCount += 1
143+
144+
validationPayload = payload
145+
146+
return claims.ClaimValidationResult{
147+
IsValid: true,
148+
}
149+
}
150+
151+
_, validators := StubClaimWithRefetch(validate)
152+
err = sessionContainer.AssertClaims([]claims.SessionClaimValidator{
153+
validators.Stub(),
154+
})
155+
assert.NoError(t, err)
156+
assert.Equal(t, 1, validateCallCount)
157+
assert.Equal(t, "world", validationPayload["hello"])
158+
assert.NotNil(t, validationPayload, "st-stub")
159+
assert.Equal(t, "stub", validationPayload["st-stub"].(map[string]interface{})["v"])
160+
161+
// Check if claim was updated in jwt
162+
sessInfo, err := GetSessionInformation(sessionContainer.GetHandle())
163+
assert.NoError(t, err)
164+
jwtPayloadStr := sessInfo.AccessTokenPayload["jwt"].(string)
165+
jwtPayload := jwt.MapClaims{}
166+
167+
_, _, err = (&jwt.Parser{}).ParseUnverified(jwtPayloadStr, jwtPayload)
168+
assert.NoError(t, err)
169+
170+
assert.Equal(t, "stub", jwtPayload["st-stub"].(map[string]interface{})["v"])
171+
assert.Equal(t, "rope", jwtPayload["sub"])
172+
}

recipe/session/claimsutils_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,32 @@ func StubClaim(validate func(payload map[string]interface{}, userContext superto
5050
}
5151
}
5252

53+
func StubClaimWithRefetch(validate func(payload map[string]interface{}, userContext supertokens.UserContext) claims.ClaimValidationResult) (*claims.TypeSessionClaim, StubValidator) {
54+
claim, validators := claims.PrimitiveClaim(
55+
"st-stub",
56+
func(userId string, userContext supertokens.UserContext) (interface{}, error) {
57+
return "stub", nil
58+
},
59+
nil,
60+
)
61+
62+
return claim, StubValidator{
63+
PrimitiveClaimValidators: validators,
64+
Stub: func() claims.SessionClaimValidator {
65+
return claims.SessionClaimValidator{
66+
ID: claim.Key,
67+
Claim: claim,
68+
Validate: func(payload map[string]interface{}, userContext supertokens.UserContext) claims.ClaimValidationResult {
69+
return validate(payload, userContext)
70+
},
71+
ShouldRefetch: func(payload map[string]interface{}, userContext supertokens.UserContext) bool {
72+
return true
73+
},
74+
}
75+
},
76+
}
77+
}
78+
5379
type StubValidator struct {
5480
claims.PrimitiveClaimValidators
5581
Stub func() claims.SessionClaimValidator

0 commit comments

Comments
 (0)