Skip to content

Commit 24c21a1

Browse files
committed
add tests for invalid tokens
1 parent f9c322c commit 24c21a1

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

internal/tokens/service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection,
355355

356356
if counterDifference < 0 {
357357
// refresh token was not issued by this server
358-
apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid Refresh Token: Not Issued By This Server").WithInternalMessage("Refresh token for session %s has a counter that's ahead %d of the database state", session.ID.String(), counterDifference)
358+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid Refresh Token: Not Issued By This Server").WithInternalMessage("Refresh token for session %s has a counter that's ahead %d of the database state", session.ID.String(), -counterDifference)
359359
} else if counterDifference == 0 || config.Security.RefreshTokenAllowReuse {
360360
// normal refresh token use
361361
counter := *session.RefreshTokenCounter + 1

internal/tokens/service_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,3 +684,84 @@ func (ts *RefreshTokenV2Suite) TestDBEncryption() {
684684
require.True(ts.T(), strings.Contains(encryptedStrings[0], "\"key_id\":\"B\""))
685685
require.True(ts.T(), strings.Contains(encryptedStrings[1], "\"key_id\":\"A\""))
686686
}
687+
688+
func (ts *RefreshTokenV2Suite) TestInvalidRefreshTokens() {
689+
config := ts.config()
690+
require.Equal(ts.T(), 2, config.Security.RefreshTokenAlgorithmVersion)
691+
692+
config.Security.RefreshTokenRotationEnabled = false
693+
config.Security.RefreshTokenReuseInterval = 1
694+
config.Security.RefreshTokenAllowReuse = false
695+
696+
clock := time.Now()
697+
698+
srv := NewService(config, &panicHookManager{})
699+
srv.SetTimeFunc(func() time.Time {
700+
return clock
701+
})
702+
703+
req, err := http.NewRequest("POST", "https://example.com/", nil)
704+
require.NoError(ts.T(), err)
705+
706+
req = req.WithContext(context.Background())
707+
responseHeaders := make(http.Header)
708+
709+
at, err := srv.IssueRefreshToken(
710+
req,
711+
responseHeaders,
712+
ts.Conn,
713+
ts.User,
714+
models.PasswordGrant,
715+
models.GrantParams{},
716+
)
717+
require.NoError(ts.T(), err)
718+
require.NotNil(ts.T(), at)
719+
720+
prt, err := crypto.ParseRefreshToken(at.RefreshToken)
721+
require.NoError(ts.T(), err)
722+
723+
session, err := models.FindSessionByID(ts.Conn, prt.SessionID, false)
724+
require.NoError(ts.T(), err)
725+
726+
key, _, err := session.GetRefreshTokenHmacKey(config.Security.DBEncryption)
727+
require.NoError(ts.T(), err)
728+
729+
// tamper with counter
730+
prt.Counter += 1
731+
tamperedRefreshToken := prt.Encode(key)
732+
733+
responseHeaders = make(http.Header)
734+
nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{
735+
RefreshToken: tamperedRefreshToken,
736+
})
737+
require.Error(ts.T(), err)
738+
require.Nil(ts.T(), nrt)
739+
740+
require.Equal(ts.T(), prt.SessionID.String(), responseHeaders.Get("sb-auth-session-id"))
741+
742+
// tamper with signature
743+
prt.Counter = 0
744+
tamperedRefreshToken = prt.Encode(make([]byte, 32))
745+
746+
responseHeaders = make(http.Header)
747+
nrt, err = srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{
748+
RefreshToken: tamperedRefreshToken,
749+
})
750+
require.Error(ts.T(), err)
751+
require.Nil(ts.T(), nrt)
752+
753+
require.Equal(ts.T(), "", responseHeaders.Get("sb-auth-session-id"))
754+
755+
// remove the session
756+
err = models.LogoutSession(ts.Conn, prt.SessionID)
757+
require.NoError(ts.T(), err)
758+
759+
responseHeaders = make(http.Header)
760+
nrt, err = srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{
761+
RefreshToken: at.RefreshToken,
762+
})
763+
require.Error(ts.T(), err)
764+
require.Nil(ts.T(), nrt)
765+
766+
require.Equal(ts.T(), "", responseHeaders.Get("sb-auth-session-id"))
767+
}

0 commit comments

Comments
 (0)