Skip to content

Commit 70aa132

Browse files
committed
feat: introduce v2 refresh token algorithm
1 parent aa0ac5b commit 70aa132

File tree

12 files changed

+579
-82
lines changed

12 files changed

+579
-82
lines changed

internal/api/token_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,11 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
435435

436436
// ensure that the 4 refresh tokens are setup correctly
437437
for i, refreshToken := range refreshTokens {
438-
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
438+
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
439439
require.NoError(ts.T(), err)
440440

441+
token := anyToken.(*models.RefreshToken)
442+
441443
if i == len(refreshTokens)-1 {
442444
require.False(ts.T(), token.Revoked)
443445
} else {
@@ -470,9 +472,10 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
470472

471473
// ensure that the refresh tokens are marked as revoked in the database
472474
for _, refreshToken := range refreshTokens {
473-
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
475+
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
474476
require.NoError(ts.T(), err)
475477

478+
token := anyToken.(*models.RefreshToken)
476479
require.True(ts.T(), token.Revoked)
477480
}
478481

internal/conf/configuration.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,10 @@ func (c *DatabaseEncryptionConfiguration) Validate() error {
723723

724724
type SecurityConfiguration struct {
725725
Captcha CaptchaConfiguration `json:"captcha"`
726+
RefreshTokenAlgorithmVersion int `json:"refresh_token_algorithm_version" split_words:"true"`
726727
RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"`
727728
RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"`
729+
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
728730
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
729731
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`
730732

internal/crypto/crypto.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func GenerateOtp(digits int) string {
2929

3030
return otp
3131
}
32+
3233
func GenerateTokenHash(emailOrPhone, otp string) string {
3334
return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp)))
3435
}

internal/crypto/refresh_tokens.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package crypto
2+
3+
import (
4+
"crypto/hmac"
5+
"crypto/rand"
6+
"crypto/sha256"
7+
"crypto/subtle"
8+
"encoding/base64"
9+
"encoding/binary"
10+
"errors"
11+
12+
"github.com/gofrs/uuid"
13+
)
14+
15+
func GenerateRefreshTokenHmacKey() []byte {
16+
key := make([]byte, 32)
17+
must(rand.Read(key))
18+
19+
return key
20+
}
21+
22+
const refreshTokenChecksumLength = 4
23+
const refreshTokenSignatureLength = 16
24+
const minRefreshTokenLength = 1 + 16 + 1 + refreshTokenSignatureLength + refreshTokenChecksumLength
25+
const maxRefreshTokenLength = minRefreshTokenLength + 8
26+
27+
// RefreshToken is an object that encodes a cryptographically authenticated
28+
// (signed) message containing a version, session ID and monotonically
29+
// increasing non-negative counter.
30+
//
31+
// The signature is a truncated (first 128 bits) of HMAC-SHA-256, which saves
32+
// on encoded length without sacrificing security. The checksum of 4 bytes at
33+
// the end is to lessen the load on the server with invalid strings (those that
34+
// are not likely to be a proper refresh token).
35+
type RefreshToken struct {
36+
Raw []byte
37+
38+
Version byte
39+
SessionID uuid.UUID
40+
Counter int64
41+
Signature []byte
42+
}
43+
44+
func (RefreshToken) TableName() string {
45+
panic("crypto.RefreshToken is not meant to be saved in the database")
46+
}
47+
48+
func (r *RefreshToken) CheckSignature(hmacSha256Key []byte) bool {
49+
bytes := r.Raw[:len(r.Raw)-refreshTokenSignatureLength-refreshTokenChecksumLength]
50+
51+
h := hmac.New(sha256.New, hmacSha256Key)
52+
h.Write(bytes)
53+
signature := h.Sum(nil)[:refreshTokenSignatureLength]
54+
55+
return hmac.Equal(signature, r.Signature)
56+
}
57+
58+
func (r *RefreshToken) Encode(hmacSha256Key []byte) string {
59+
result := make([]byte, 0, maxRefreshTokenLength)
60+
61+
result = append(result, 0)
62+
result = append(result, r.SessionID.Bytes()...)
63+
result = binary.AppendUvarint(result, uint64(r.Counter))
64+
65+
// Note on truncating the HMAC-SHA-256 output:
66+
// This does not impact security as the brute-force space is 2^128 and
67+
// the collision space is 2^64, both unattainable in practice.
68+
69+
h := hmac.New(sha256.New, hmacSha256Key)
70+
h.Write(result)
71+
signature := h.Sum(nil)[:refreshTokenSignatureLength]
72+
73+
result = append(result, signature...)
74+
75+
checksum := sha256.Sum256(result)
76+
result = append(result, checksum[:refreshTokenChecksumLength]...)
77+
78+
r.Version = 0
79+
r.Raw = result
80+
r.Signature = signature
81+
82+
return base64.RawURLEncoding.EncodeToString(result)
83+
}
84+
85+
var (
86+
ErrRefreshTokenLength = errors.New("crypto: refresh token length is not valid")
87+
ErrRefreshTokenUnknownVersion = errors.New("crypto: refresh token version is not 0")
88+
ErrRefreshTokenChecksumInvalid = errors.New("crypto: refresh token checksum is not valid")
89+
ErrRefreshTokenCounterInvalid = errors.New("crypto: refresh token's counter is not valid")
90+
)
91+
92+
func ParseRefreshToken(token string) (*RefreshToken, error) {
93+
bytes, err := base64.RawURLEncoding.DecodeString(token)
94+
if err != nil {
95+
return nil, err
96+
}
97+
98+
if len(bytes) < minRefreshTokenLength {
99+
return nil, ErrRefreshTokenLength
100+
}
101+
102+
if bytes[0] != 0 {
103+
return nil, ErrRefreshTokenUnknownVersion
104+
}
105+
106+
parseFrom := bytes[1 : len(bytes)-refreshTokenChecksumLength]
107+
108+
checksum256 := sha256.Sum256(bytes[:len(bytes)-refreshTokenChecksumLength])
109+
if subtle.ConstantTimeCompare(checksum256[:refreshTokenChecksumLength], bytes[len(bytes)-refreshTokenChecksumLength:len(bytes)]) != 1 {
110+
return nil, ErrRefreshTokenChecksumInvalid
111+
}
112+
113+
sessionID, err := uuid.FromBytes(parseFrom[0:16])
114+
if err != nil {
115+
return nil, err
116+
}
117+
118+
parseFrom = parseFrom[16:]
119+
120+
counter, counterBytes := binary.Uvarint(parseFrom)
121+
if counterBytes <= 0 {
122+
return nil, ErrRefreshTokenCounterInvalid
123+
}
124+
125+
parseFrom = parseFrom[counterBytes:]
126+
127+
if len(parseFrom) != 16 {
128+
return nil, ErrRefreshTokenLength
129+
}
130+
131+
signature := parseFrom
132+
133+
return &RefreshToken{
134+
Raw: bytes,
135+
136+
Version: 0,
137+
SessionID: sessionID,
138+
Counter: int64(counter),
139+
Signature: signature,
140+
}, nil
141+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package crypto
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/base64"
6+
"fmt"
7+
"strings"
8+
"testing"
9+
10+
"github.com/gofrs/uuid"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestRefreshTokenParse(t *testing.T) {
15+
negativeExamples := []struct {
16+
value []byte
17+
error error
18+
}{
19+
{
20+
value: make([]byte, minRefreshTokenLength-1),
21+
error: ErrRefreshTokenLength,
22+
},
23+
{
24+
value: make([]byte, minRefreshTokenLength),
25+
error: ErrRefreshTokenChecksumInvalid,
26+
},
27+
{
28+
value: func() []byte {
29+
b := make([]byte, minRefreshTokenLength)
30+
b[0] = 1
31+
return b
32+
}(),
33+
error: ErrRefreshTokenUnknownVersion,
34+
},
35+
{
36+
value: func() []byte {
37+
b := make([]byte, minRefreshTokenLength)
38+
for i := 1 + 16; i < len(b); i += 1 {
39+
b[i] = 0xFF
40+
}
41+
42+
checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength])
43+
copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength])
44+
return b
45+
}(),
46+
error: ErrRefreshTokenCounterInvalid,
47+
},
48+
{
49+
value: func() []byte {
50+
b := make([]byte, minRefreshTokenLength)
51+
b[1+16] = 0xFF
52+
b[1+16+1] = 0
53+
54+
checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength])
55+
copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength])
56+
return b
57+
}(),
58+
error: ErrRefreshTokenLength,
59+
},
60+
}
61+
62+
for i, example := range negativeExamples {
63+
t.Run(fmt.Sprintf("negative example %d", i), func(t *testing.T) {
64+
rt, err := ParseRefreshToken(base64.RawURLEncoding.EncodeToString(example.value))
65+
require.Nil(t, rt)
66+
require.Error(t, err)
67+
require.Equal(t, err, example.error)
68+
})
69+
}
70+
71+
rt, err := ParseRefreshToken(strings.Repeat("!", (4*minRefreshTokenLength)/3))
72+
require.Nil(t, rt)
73+
require.Error(t, err)
74+
75+
original := &RefreshToken{
76+
SessionID: uuid.Must(uuid.NewV4()),
77+
Counter: 9223372036854775807,
78+
}
79+
80+
parsed, err := ParseRefreshToken(original.Encode(make([]byte, 32)))
81+
require.Nil(t, err)
82+
require.Equal(t, original.SessionID.String(), parsed.SessionID.String())
83+
require.Equal(t, original.Counter, parsed.Counter)
84+
require.Equal(t, original.Raw, parsed.Raw)
85+
require.Equal(t, original.Signature, parsed.Signature)
86+
}

internal/models/refresh_token.go

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package models
22

33
import (
44
"database/sql"
5+
"encoding/base64"
56
"net/http"
67
"time"
78

@@ -118,6 +119,50 @@ func FindTokenBySessionID(tx *storage.Connection, sessionId *uuid.UUID) (*Refres
118119
return refreshToken, nil
119120
}
120121

122+
func (s *Session) ApplyGrantParams(params *GrantParams) {
123+
s.FactorID = params.FactorID
124+
125+
if params.SessionNotAfter != nil {
126+
s.NotAfter = params.SessionNotAfter
127+
}
128+
129+
if params.UserAgent != "" {
130+
s.UserAgent = &params.UserAgent
131+
}
132+
133+
if params.IP != "" {
134+
s.IP = &params.IP
135+
}
136+
137+
if params.SessionTag != nil && *params.SessionTag != "" {
138+
s.Tag = params.SessionTag
139+
}
140+
141+
if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil {
142+
s.OAuthClientID = params.OAuthClientID
143+
}
144+
}
145+
146+
func (s *Session) SetupRefreshTokenData(dbEncryption conf.DatabaseEncryptionConfiguration) error {
147+
hmacKey := base64.RawURLEncoding.EncodeToString(crypto.GenerateRefreshTokenHmacKey())
148+
149+
if dbEncryption.Encrypt {
150+
es, err := crypto.NewEncryptedString(s.ID.String(), []byte(hmacKey), dbEncryption.EncryptionKeyID, dbEncryption.EncryptionKey)
151+
if err != nil {
152+
return err
153+
}
154+
155+
hmacKey = es.String()
156+
}
157+
158+
counter := int64(0)
159+
160+
s.RefreshTokenHmacKey = &hmacKey
161+
s.RefreshTokenCounter = &counter
162+
163+
return nil
164+
}
165+
121166
func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken, params *GrantParams) (*RefreshToken, error) {
122167
token := &RefreshToken{
123168
UserID: user.ID,
@@ -135,25 +180,7 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
135180
return nil, errors.Wrap(err, "error instantiating new session object")
136181
}
137182

138-
if params.SessionNotAfter != nil {
139-
session.NotAfter = params.SessionNotAfter
140-
}
141-
142-
if params.UserAgent != "" {
143-
session.UserAgent = &params.UserAgent
144-
}
145-
146-
if params.IP != "" {
147-
session.IP = &params.IP
148-
}
149-
150-
if params.SessionTag != nil && *params.SessionTag != "" {
151-
session.Tag = params.SessionTag
152-
}
153-
154-
if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil {
155-
session.OAuthClientID = params.OAuthClientID
156-
}
183+
session.ApplyGrantParams(params)
157184

158185
if err := tx.Create(session); err != nil {
159186
return nil, errors.Wrap(err, "error creating new session")

internal/models/refresh_token_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ func (ts *RefreshTokenTestSuite) TestGrantRefreshTokenSwap() {
5454
s, err := GrantRefreshTokenSwap(ts.config.AuditLog, &http.Request{}, ts.db, u, r)
5555
require.NoError(ts.T(), err)
5656

57-
_, nr, _, err := FindUserWithRefreshToken(ts.db, r.Token, false)
57+
_, anyNR, _, err := FindUserWithRefreshToken(ts.db, ts.config.Security.DBEncryption, r.Token, false)
5858
require.NoError(ts.T(), err)
5959

60+
nr := anyNR.(*RefreshToken)
61+
6062
require.Equal(ts.T(), r.ID, nr.ID)
6163
require.True(ts.T(), nr.Revoked, "expected old token to be revoked")
6264

@@ -69,9 +71,11 @@ func (ts *RefreshTokenTestSuite) TestLogout() {
6971
r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{})
7072
require.NoError(ts.T(), err)
7173

74+
var anyR any
75+
7276
require.NoError(ts.T(), Logout(ts.db, u.ID))
73-
u, r, _, err = FindUserWithRefreshToken(ts.db, r.Token, false)
74-
require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, r)
77+
u, anyR, _, err = FindUserWithRefreshToken(ts.db, ts.config.Security.DBEncryption, r.Token, false)
78+
require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, anyR)
7579

7680
require.True(ts.T(), IsNotFoundError(err), "expected NotFoundError")
7781
}

0 commit comments

Comments
 (0)