Skip to content

Commit 27cf88a

Browse files
committed
fix: oninvalidclaim handler
1 parent 13f304b commit 27cf88a

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

recipe/session/claims/claims.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ type ClaimValidationResult struct {
5050
}
5151

5252
type ClaimValidationError struct {
53-
ID string
54-
Reason interface{} // This can be nil, add checks when used
53+
ID string `json:"id"`
54+
Reason interface{} `json:"reason"` // This can be nil, add checks when used
5555
}

recipe/session/sessmodels/models.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ type TypeInput struct {
9191
CookieSecure *bool
9292
CookieSameSite *string
9393
SessionExpiredStatusCode *int
94+
InvalidClaimStatusCode *int
9495
CookieDomain *string
9596
AntiCsrf *string
9697
Override *OverrideStruct
@@ -113,6 +114,7 @@ type OverrideStruct struct {
113114
type ErrorHandlers struct {
114115
OnUnauthorised func(message string, req *http.Request, res http.ResponseWriter) error
115116
OnTokenTheftDetected func(sessionHandle string, userID string, req *http.Request, res http.ResponseWriter) error
117+
OnInvalidClaim func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error
116118
}
117119

118120
type TypeNormalisedInput struct {
@@ -121,6 +123,7 @@ type TypeNormalisedInput struct {
121123
CookieSameSite string
122124
CookieSecure bool
123125
SessionExpiredStatusCode int
126+
InvalidClaimStatusCode int
124127
AntiCsrf string
125128
Override OverrideStruct
126129
ErrorHandlers NormalisedErrorHandlers
@@ -154,6 +157,7 @@ type NormalisedErrorHandlers struct {
154157
OnUnauthorised func(message string, req *http.Request, res http.ResponseWriter) error
155158
OnTryRefreshToken func(message string, req *http.Request, res http.ResponseWriter) error
156159
OnTokenTheftDetected func(sessionHandle string, userID string, req *http.Request, res http.ResponseWriter) error
160+
OnInvalidClaim func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error
157161
}
158162

159163
type SessionContainer struct {

recipe/session/utils.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ func validateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config
8888
sessionExpiredStatusCode = *config.SessionExpiredStatusCode
8989
}
9090

91+
invalidClaimStatusCode := 403
92+
if config != nil && config.InvalidClaimStatusCode != nil {
93+
invalidClaimStatusCode = *config.InvalidClaimStatusCode
94+
}
95+
9196
if config != nil && config.AntiCsrf != nil {
9297
if *config.AntiCsrf != antiCSRF_NONE && *config.AntiCsrf != antiCSRF_VIA_CUSTOM_HEADER && *config.AntiCsrf != antiCSRF_VIA_TOKEN {
9398
return sessmodels.TypeNormalisedInput{}, errors.New("antiCsrf config must be one of 'NONE' or 'VIA_CUSTOM_HEADER' or 'VIA_TOKEN'")
@@ -127,6 +132,13 @@ func validateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config
127132
}
128133
return sendUnauthorisedResponse(*recipeInstance, message, req, res)
129134
},
135+
OnInvalidClaim: func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error {
136+
recipeInstance, err := getRecipeInstanceOrThrowError()
137+
if err != nil {
138+
return err
139+
}
140+
return sendInvalidClaimResponse(*recipeInstance, validationErrors, req, res)
141+
},
130142
}
131143

132144
if config != nil && config.ErrorHandlers != nil {
@@ -136,6 +148,9 @@ func validateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config
136148
if config.ErrorHandlers.OnUnauthorised != nil {
137149
errorHandlers.OnUnauthorised = config.ErrorHandlers.OnUnauthorised
138150
}
151+
if config.ErrorHandlers.OnInvalidClaim != nil {
152+
errorHandlers.OnInvalidClaim = config.ErrorHandlers.OnInvalidClaim
153+
}
139154
}
140155

141156
IsAnIPAPIDomain, err := supertokens.IsAnIPAddress(topLevelAPIDomain)
@@ -178,6 +193,7 @@ func validateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config
178193
CookieSameSite: cookieSameSite,
179194
CookieSecure: cookieSecure,
180195
SessionExpiredStatusCode: sessionExpiredStatusCode,
196+
InvalidClaimStatusCode: invalidClaimStatusCode,
181197
AntiCsrf: antiCsrf,
182198
ErrorHandlers: errorHandlers,
183199
Jwt: Jwt,
@@ -297,6 +313,13 @@ func sendUnauthorisedResponse(recipeInstance Recipe, _ string, _ *http.Request,
297313
return supertokens.SendNon200Response(response, "unauthorised", recipeInstance.Config.SessionExpiredStatusCode)
298314
}
299315

316+
func sendInvalidClaimResponse(recipeInstance Recipe, claimValidationErrors []claims.ClaimValidationError, _ *http.Request, response http.ResponseWriter) error {
317+
return supertokens.SendNon200ResponseWithPayload(response, map[string]interface{}{
318+
"message": "invalid claim",
319+
"claimValidationErrors": claimValidationErrors,
320+
}, recipeInstance.Config.InvalidClaimStatusCode)
321+
}
322+
300323
func sendTokenTheftDetectedResponse(recipeInstance Recipe, sessionHandle string, _ string, _ *http.Request, response http.ResponseWriter) error {
301324
_, err := (*recipeInstance.RecipeImpl.RevokeSession)(sessionHandle, &map[string]interface{}{})
302325
if err != nil {

supertokens/utils.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,28 @@ func SendNon200Response(res http.ResponseWriter, message string, statusCode int)
181181
return nil
182182
}
183183

184+
func SendNon200ResponseWithPayload(res http.ResponseWriter, payload map[string]interface{}, statusCode int) error {
185+
dw := MakeDoneWriter(res)
186+
if !dw.IsDone() {
187+
if statusCode < 300 {
188+
return errors.New("calling SendNon200ResponseWithPayload with status code < 300")
189+
}
190+
191+
LogDebugMessage("Sending response to client with status code: " + strconv.Itoa(statusCode))
192+
193+
res.Header().Set("Content-Type", "application/json; charset=utf-8")
194+
res.WriteHeader(statusCode)
195+
196+
bytes, err := json.Marshal(payload)
197+
if err != nil {
198+
return err
199+
} else {
200+
res.Write(bytes)
201+
}
202+
}
203+
return nil
204+
}
205+
184206
func ReadFromRequest(r *http.Request) ([]byte, error) {
185207
f := r.Body
186208
buf, err := ioutil.ReadAll(f)

0 commit comments

Comments
 (0)