Skip to content

Commit 60b659a

Browse files
committed
feat: introduce v2 refresh token algorithm
1 parent 9a8d0df commit 60b659a

File tree

12 files changed

+593
-82
lines changed

12 files changed

+593
-82
lines changed

internal/api/token_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,11 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
435435

436436
// ensure that the 4 refresh tokens are setup correctly
437437
for i, refreshToken := range refreshTokens {
438-
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
438+
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
439439
require.NoError(ts.T(), err)
440440

441+
token := anyToken.(*models.RefreshToken)
442+
441443
if i == len(refreshTokens)-1 {
442444
require.False(ts.T(), token.Revoked)
443445
} else {
@@ -470,9 +472,10 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
470472

471473
// ensure that the refresh tokens are marked as revoked in the database
472474
for _, refreshToken := range refreshTokens {
473-
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
475+
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
474476
require.NoError(ts.T(), err)
475477

478+
token := anyToken.(*models.RefreshToken)
476479
require.True(ts.T(), token.Revoked)
477480
}
478481

internal/conf/configuration.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,10 @@ func (c *DatabaseEncryptionConfiguration) Validate() error {
723723

724724
type SecurityConfiguration struct {
725725
Captcha CaptchaConfiguration `json:"captcha"`
726+
RefreshTokenAlgorithmVersion int `json:"refresh_token_algorithm_version" split_words:"true"`
726727
RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"`
727728
RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"`
729+
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
728730
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
729731
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`
730732

internal/crypto/crypto.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func GenerateOtp(digits int) string {
2929

3030
return otp
3131
}
32+
3233
func GenerateTokenHash(emailOrPhone, otp string) string {
3334
return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp)))
3435
}

internal/crypto/refresh_tokens.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package crypto
2+
3+
import (
4+
"crypto/hmac"
5+
"crypto/rand"
6+
"crypto/sha256"
7+
"crypto/subtle"
8+
"encoding/base64"
9+
"encoding/binary"
10+
"errors"
11+
"math"
12+
13+
"github.com/gofrs/uuid"
14+
)
15+
16+
func GenerateRefreshTokenHmacKey() []byte {
17+
key := make([]byte, 32)
18+
must(rand.Read(key))
19+
20+
return key
21+
}
22+
23+
const refreshTokenChecksumLength = 4
24+
const refreshTokenSignatureLength = 16
25+
const minRefreshTokenLength = 1 + 16 + 1 + refreshTokenSignatureLength + refreshTokenChecksumLength
26+
const maxRefreshTokenLength = minRefreshTokenLength + 8
27+
28+
// RefreshToken is an object that encodes a cryptographically authenticated
29+
// (signed) message containing a version, session ID and monotonically
30+
// increasing non-negative counter.
31+
//
32+
// The signature is a truncated (first 128 bits) of HMAC-SHA-256, which saves
33+
// on encoded length without sacrificing security. The checksum of 4 bytes at
34+
// the end is to lessen the load on the server with invalid strings (those that
35+
// are not likely to be a proper refresh token).
36+
type RefreshToken struct {
37+
Raw []byte
38+
39+
Version byte
40+
SessionID uuid.UUID
41+
Counter int64
42+
Signature []byte
43+
}
44+
45+
func (RefreshToken) TableName() string {
46+
panic("crypto.RefreshToken is not meant to be saved in the database")
47+
}
48+
49+
func (r *RefreshToken) CheckSignature(hmacSha256Key []byte) bool {
50+
bytes := r.Raw[:len(r.Raw)-refreshTokenSignatureLength-refreshTokenChecksumLength]
51+
52+
h := hmac.New(sha256.New, hmacSha256Key)
53+
h.Write(bytes)
54+
signature := h.Sum(nil)[:refreshTokenSignatureLength]
55+
56+
return hmac.Equal(signature, r.Signature)
57+
}
58+
59+
func (r *RefreshToken) Encode(hmacSha256Key []byte) string {
60+
result := make([]byte, 0, maxRefreshTokenLength)
61+
62+
result = append(result, 0)
63+
result = append(result, r.SessionID.Bytes()...)
64+
result = binary.AppendUvarint(result, uint64(r.Counter))
65+
66+
// Note on truncating the HMAC-SHA-256 output:
67+
// This does not impact security as the brute-force space is 2^128 and
68+
// the collision space is 2^64, both unattainable in practice.
69+
70+
h := hmac.New(sha256.New, hmacSha256Key)
71+
h.Write(result)
72+
signature := h.Sum(nil)[:refreshTokenSignatureLength]
73+
74+
result = append(result, signature...)
75+
76+
checksum := sha256.Sum256(result)
77+
result = append(result, checksum[:refreshTokenChecksumLength]...)
78+
79+
r.Version = 0
80+
r.Raw = result
81+
r.Signature = signature
82+
83+
return base64.RawURLEncoding.EncodeToString(result)
84+
}
85+
86+
var (
87+
ErrRefreshTokenLength = errors.New("crypto: refresh token length is not valid")
88+
ErrRefreshTokenUnknownVersion = errors.New("crypto: refresh token version is not 0")
89+
ErrRefreshTokenChecksumInvalid = errors.New("crypto: refresh token checksum is not valid")
90+
ErrRefreshTokenCounterInvalid = errors.New("crypto: refresh token's counter is not valid")
91+
)
92+
93+
func ParseRefreshToken(token string) (*RefreshToken, error) {
94+
bytes, err := base64.RawURLEncoding.DecodeString(token)
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
if len(bytes) < minRefreshTokenLength {
100+
return nil, ErrRefreshTokenLength
101+
}
102+
103+
if bytes[0] != 0 {
104+
return nil, ErrRefreshTokenUnknownVersion
105+
}
106+
107+
parseFrom := bytes[1 : len(bytes)-refreshTokenChecksumLength]
108+
109+
checksum256 := sha256.Sum256(bytes[:len(bytes)-refreshTokenChecksumLength])
110+
if subtle.ConstantTimeCompare(checksum256[:refreshTokenChecksumLength], bytes[len(bytes)-refreshTokenChecksumLength:]) != 1 {
111+
return nil, ErrRefreshTokenChecksumInvalid
112+
}
113+
114+
sessionID, err := uuid.FromBytes(parseFrom[0:16])
115+
if err != nil {
116+
return nil, err
117+
}
118+
119+
parseFrom = parseFrom[16:]
120+
121+
counter, counterBytes := binary.Uvarint(parseFrom)
122+
if counterBytes <= 0 {
123+
return nil, ErrRefreshTokenCounterInvalid
124+
}
125+
126+
if counter > math.MaxInt64 {
127+
return nil, ErrRefreshTokenCounterInvalid
128+
}
129+
130+
parseFrom = parseFrom[counterBytes:]
131+
132+
if len(parseFrom) != 16 {
133+
return nil, ErrRefreshTokenLength
134+
}
135+
136+
signature := parseFrom
137+
138+
return &RefreshToken{
139+
Raw: bytes,
140+
141+
Version: 0,
142+
SessionID: sessionID,
143+
Counter: int64(counter),
144+
Signature: signature,
145+
}, nil
146+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package crypto
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/base64"
6+
"fmt"
7+
"strings"
8+
"testing"
9+
10+
"github.com/gofrs/uuid"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestRefreshTokenParse(t *testing.T) {
15+
negativeExamples := []struct {
16+
value []byte
17+
error error
18+
}{
19+
{
20+
value: make([]byte, minRefreshTokenLength-1),
21+
error: ErrRefreshTokenLength,
22+
},
23+
{
24+
value: make([]byte, minRefreshTokenLength),
25+
error: ErrRefreshTokenChecksumInvalid,
26+
},
27+
{
28+
value: func() []byte {
29+
b := make([]byte, minRefreshTokenLength)
30+
b[0] = 1
31+
return b
32+
}(),
33+
error: ErrRefreshTokenUnknownVersion,
34+
},
35+
{
36+
value: func() []byte {
37+
b := make([]byte, minRefreshTokenLength)
38+
for i := 1 + 16; i < len(b); i += 1 {
39+
b[i] = 0xFF
40+
}
41+
42+
checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength])
43+
copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength])
44+
return b
45+
}(),
46+
error: ErrRefreshTokenCounterInvalid,
47+
},
48+
{
49+
value: func() []byte {
50+
b := make([]byte, minRefreshTokenLength)
51+
b[1+16] = 0xFF
52+
b[1+16+1] = 0
53+
54+
checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength])
55+
copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength])
56+
return b
57+
}(),
58+
error: ErrRefreshTokenLength,
59+
},
60+
}
61+
62+
for i, example := range negativeExamples {
63+
t.Run(fmt.Sprintf("negative example %d", i), func(t *testing.T) {
64+
rt, err := ParseRefreshToken(base64.RawURLEncoding.EncodeToString(example.value))
65+
require.Nil(t, rt)
66+
require.Error(t, err)
67+
require.Equal(t, err, example.error)
68+
})
69+
}
70+
71+
rt, err := ParseRefreshToken(strings.Repeat("!", (4*minRefreshTokenLength)/3))
72+
require.Nil(t, rt)
73+
require.Error(t, err)
74+
75+
original := &RefreshToken{
76+
SessionID: uuid.Must(uuid.NewV4()),
77+
Counter: 9223372036854775807,
78+
}
79+
80+
parsed, err := ParseRefreshToken(original.Encode(make([]byte, 32)))
81+
require.Nil(t, err)
82+
require.Equal(t, original.SessionID.String(), parsed.SessionID.String())
83+
require.Equal(t, original.Counter, parsed.Counter)
84+
require.Equal(t, original.Raw, parsed.Raw)
85+
require.Equal(t, original.Signature, parsed.Signature)
86+
}

internal/models/refresh_token.go

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package models
22

33
import (
44
"database/sql"
5+
"encoding/base64"
56
"net/http"
67
"time"
78

@@ -118,6 +119,50 @@ func FindTokenBySessionID(tx *storage.Connection, sessionId *uuid.UUID) (*Refres
118119
return refreshToken, nil
119120
}
120121

122+
func (s *Session) ApplyGrantParams(params *GrantParams) {
123+
s.FactorID = params.FactorID
124+
125+
if params.SessionNotAfter != nil {
126+
s.NotAfter = params.SessionNotAfter
127+
}
128+
129+
if params.UserAgent != "" {
130+
s.UserAgent = &params.UserAgent
131+
}
132+
133+
if params.IP != "" {
134+
s.IP = &params.IP
135+
}
136+
137+
if params.SessionTag != nil && *params.SessionTag != "" {
138+
s.Tag = params.SessionTag
139+
}
140+
141+
if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil {
142+
s.OAuthClientID = params.OAuthClientID
143+
}
144+
}
145+
146+
func (s *Session) SetupRefreshTokenData(dbEncryption conf.DatabaseEncryptionConfiguration) error {
147+
hmacKey := base64.RawURLEncoding.EncodeToString(crypto.GenerateRefreshTokenHmacKey())
148+
149+
if dbEncryption.Encrypt {
150+
es, err := crypto.NewEncryptedString(s.ID.String(), []byte(hmacKey), dbEncryption.EncryptionKeyID, dbEncryption.EncryptionKey)
151+
if err != nil {
152+
return err
153+
}
154+
155+
hmacKey = es.String()
156+
}
157+
158+
counter := int64(0)
159+
160+
s.RefreshTokenHmacKey = &hmacKey
161+
s.RefreshTokenCounter = &counter
162+
163+
return nil
164+
}
165+
121166
func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken, params *GrantParams) (*RefreshToken, error) {
122167
token := &RefreshToken{
123168
UserID: user.ID,
@@ -135,25 +180,7 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
135180
return nil, errors.Wrap(err, "error instantiating new session object")
136181
}
137182

138-
if params.SessionNotAfter != nil {
139-
session.NotAfter = params.SessionNotAfter
140-
}
141-
142-
if params.UserAgent != "" {
143-
session.UserAgent = &params.UserAgent
144-
}
145-
146-
if params.IP != "" {
147-
session.IP = &params.IP
148-
}
149-
150-
if params.SessionTag != nil && *params.SessionTag != "" {
151-
session.Tag = params.SessionTag
152-
}
153-
154-
if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil {
155-
session.OAuthClientID = params.OAuthClientID
156-
}
183+
session.ApplyGrantParams(params)
157184

158185
if err := tx.Create(session); err != nil {
159186
return nil, errors.Wrap(err, "error creating new session")

internal/models/refresh_token_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ func (ts *RefreshTokenTestSuite) TestGrantRefreshTokenSwap() {
5454
s, err := GrantRefreshTokenSwap(ts.config.AuditLog, &http.Request{}, ts.db, u, r)
5555
require.NoError(ts.T(), err)
5656

57-
_, nr, _, err := FindUserWithRefreshToken(ts.db, r.Token, false)
57+
_, anyNR, _, err := FindUserWithRefreshToken(ts.db, ts.config.Security.DBEncryption, r.Token, false)
5858
require.NoError(ts.T(), err)
5959

60+
nr := anyNR.(*RefreshToken)
61+
6062
require.Equal(ts.T(), r.ID, nr.ID)
6163
require.True(ts.T(), nr.Revoked, "expected old token to be revoked")
6264

@@ -69,9 +71,11 @@ func (ts *RefreshTokenTestSuite) TestLogout() {
6971
r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{})
7072
require.NoError(ts.T(), err)
7173

74+
var anyR any
75+
7276
require.NoError(ts.T(), Logout(ts.db, u.ID))
73-
u, r, _, err = FindUserWithRefreshToken(ts.db, r.Token, false)
74-
require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, r)
77+
u, anyR, _, err = FindUserWithRefreshToken(ts.db, ts.config.Security.DBEncryption, r.Token, false)
78+
require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, anyR)
7579

7680
require.True(ts.T(), IsNotFoundError(err), "expected NotFoundError")
7781
}

0 commit comments

Comments
 (0)