Skip to content

Commit 98a1ebf

Browse files
committed
feat: introduce v2 refresh token algorithm
1 parent 0184ec2 commit 98a1ebf

13 files changed

+534
-61
lines changed

internal/api/token_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,11 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
435435

436436
// ensure that the 4 refresh tokens are setup correctly
437437
for i, refreshToken := range refreshTokens {
438-
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
438+
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
439439
require.NoError(ts.T(), err)
440440

441+
token := anyToken.(*models.RefreshToken)
442+
441443
if i == len(refreshTokens)-1 {
442444
require.False(ts.T(), token.Revoked)
443445
} else {
@@ -470,9 +472,10 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
470472

471473
// ensure that the refresh tokens are marked as revoked in the database
472474
for _, refreshToken := range refreshTokens {
473-
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
475+
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
474476
require.NoError(ts.T(), err)
475477

478+
token := anyToken.(*models.RefreshToken)
476479
require.True(ts.T(), token.Revoked)
477480
}
478481

internal/conf/configuration.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,10 @@ func (c *DatabaseEncryptionConfiguration) Validate() error {
723723

724724
type SecurityConfiguration struct {
725725
Captcha CaptchaConfiguration `json:"captcha"`
726+
RefreshTokenAlgorithmVersion int `json:"refresh_token_algorithm_version" split_words:"true"`
726727
RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"`
727728
RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"`
729+
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
728730
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
729731
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`
730732

internal/crypto/crypto.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func GenerateOtp(digits int) string {
2929

3030
return otp
3131
}
32+
3233
func GenerateTokenHash(emailOrPhone, otp string) string {
3334
return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp)))
3435
}

internal/crypto/refresh_tokens.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package crypto
2+
3+
import (
4+
"crypto/hmac"
5+
"crypto/rand"
6+
"crypto/sha256"
7+
"crypto/subtle"
8+
"encoding/base64"
9+
"encoding/binary"
10+
"errors"
11+
12+
"github.com/gofrs/uuid"
13+
)
14+
15+
func GenerateRefreshTokenHmacKey() []byte {
16+
key := make([]byte, 32)
17+
must(rand.Read(key))
18+
19+
return key
20+
}
21+
22+
const refreshTokenChecksumLength = 4
23+
const refreshTokenSignatureLength = 16
24+
const minRefreshTokenLength = 1 + 16 + 1 + refreshTokenSignatureLength + refreshTokenChecksumLength
25+
const maxRefreshTokenLength = minRefreshTokenLength + 8
26+
27+
// RefreshToken is an object that encodes a cryptographically authenticated
28+
// (signed) message containing a version, session ID and monotonically
29+
// increasing non-negative counter.
30+
//
31+
// The signature is a truncated (first 128 bits) of HMAC-SHA-256, which saves
32+
// on encoded length without sacrificing security. The checksum of 4 bytes at
33+
// the end is to lessen the load on the server with invalid strings (those that
34+
// are not likely to be a proper refresh token).
35+
type RefreshToken struct {
36+
Raw []byte
37+
38+
Version byte
39+
SessionID uuid.UUID
40+
Counter int64
41+
Signature []byte
42+
}
43+
44+
func (RefreshToken) TableName() string {
45+
panic("crypto.RefreshToken is not meant to be saved in the database")
46+
}
47+
48+
func (r *RefreshToken) CheckSignature(hmacSha256Key []byte) bool {
49+
bytes := r.Raw[:len(r.Raw)-refreshTokenSignatureLength-refreshTokenChecksumLength]
50+
51+
h := hmac.New(sha256.New, hmacSha256Key)
52+
h.Write(bytes)
53+
signature := h.Sum(nil)[:refreshTokenSignatureLength]
54+
55+
return hmac.Equal(signature, r.Signature)
56+
}
57+
58+
func (r *RefreshToken) Encode(hmacSha256Key []byte) string {
59+
result := make([]byte, 0, maxRefreshTokenLength)
60+
61+
result = append(result, 0)
62+
result = append(result, r.SessionID.Bytes()...)
63+
result = binary.AppendUvarint(result, uint64(r.Counter))
64+
65+
// Note on truncating the HMAC-SHA-256 output:
66+
// This does not impact security as the brute-force space is 2^128 and
67+
// the collision space is 2^64, both unattainable in practice.
68+
69+
h := hmac.New(sha256.New, hmacSha256Key)
70+
h.Write(result)
71+
signature := h.Sum(nil)[:refreshTokenSignatureLength]
72+
73+
result = append(result, signature...)
74+
75+
checksum := sha256.Sum256(result)
76+
result = append(result, checksum[:refreshTokenChecksumLength]...)
77+
78+
r.Version = 0
79+
r.Raw = result
80+
r.Signature = signature
81+
82+
return base64.RawURLEncoding.EncodeToString(result)
83+
}
84+
85+
var (
86+
ErrRefreshTokenLength = errors.New("crypto: refresh token length is not valid")
87+
ErrRefreshTokenUnknownVersion = errors.New("crypto: refresh token version is not 0")
88+
ErrRefreshTokenChecksumInvalid = errors.New("crypto: refresh token checksum is not valid")
89+
ErrRefreshTokenCounterInvalid = errors.New("crypto: refresh token's counter is not valid")
90+
)
91+
92+
func ParseRefreshToken(token string) (*RefreshToken, error) {
93+
bytes, err := base64.RawURLEncoding.DecodeString(token)
94+
if err != nil {
95+
return nil, err
96+
}
97+
98+
if len(bytes) < minRefreshTokenLength {
99+
return nil, ErrRefreshTokenLength
100+
}
101+
102+
if bytes[0] != 0 {
103+
return nil, ErrRefreshTokenUnknownVersion
104+
}
105+
106+
parseFrom := bytes[1 : len(bytes)-refreshTokenChecksumLength]
107+
108+
checksum256 := sha256.Sum256(bytes[:len(bytes)-refreshTokenChecksumLength])
109+
if subtle.ConstantTimeCompare(checksum256[:refreshTokenChecksumLength], bytes[len(bytes)-refreshTokenChecksumLength:len(bytes)]) != 1 {
110+
return nil, ErrRefreshTokenChecksumInvalid
111+
}
112+
113+
sessionID, err := uuid.FromBytes(parseFrom[0:16])
114+
if err != nil {
115+
return nil, err
116+
}
117+
118+
parseFrom = parseFrom[16:]
119+
120+
counter, counterBytes := binary.Uvarint(parseFrom)
121+
if counterBytes <= 0 {
122+
return nil, ErrRefreshTokenCounterInvalid
123+
}
124+
125+
parseFrom = parseFrom[counterBytes:]
126+
127+
if len(parseFrom) != 16 {
128+
return nil, ErrRefreshTokenLength
129+
}
130+
131+
signature := parseFrom
132+
133+
return &RefreshToken{
134+
Raw: bytes,
135+
136+
Version: 0,
137+
SessionID: sessionID,
138+
Counter: int64(counter),
139+
Signature: signature,
140+
}, nil
141+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package crypto
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/base64"
6+
"fmt"
7+
"strings"
8+
"testing"
9+
10+
"github.com/gofrs/uuid"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestRefreshTokenParse(t *testing.T) {
15+
negativeExamples := []struct {
16+
value []byte
17+
error error
18+
}{
19+
{
20+
value: make([]byte, minRefreshTokenLength-1),
21+
error: ErrRefreshTokenLength,
22+
},
23+
{
24+
value: make([]byte, minRefreshTokenLength),
25+
error: ErrRefreshTokenChecksumInvalid,
26+
},
27+
{
28+
value: func() []byte {
29+
b := make([]byte, minRefreshTokenLength)
30+
b[0] = 1
31+
return b
32+
}(),
33+
error: ErrRefreshTokenUnknownVersion,
34+
},
35+
{
36+
value: func() []byte {
37+
b := make([]byte, minRefreshTokenLength)
38+
for i := 1 + 16; i < len(b); i += 1 {
39+
b[i] = 0xFF
40+
}
41+
42+
checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength])
43+
copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength])
44+
return b
45+
}(),
46+
error: ErrRefreshTokenCounterInvalid,
47+
},
48+
{
49+
value: func() []byte {
50+
b := make([]byte, minRefreshTokenLength)
51+
b[1+16] = 0xFF
52+
b[1+16+1] = 0
53+
54+
checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength])
55+
copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength])
56+
return b
57+
}(),
58+
error: ErrRefreshTokenLength,
59+
},
60+
}
61+
62+
for i, example := range negativeExamples {
63+
t.Run(fmt.Sprintf("negative example %d", i), func(t *testing.T) {
64+
rt, err := ParseRefreshToken(base64.RawURLEncoding.EncodeToString(example.value))
65+
require.Nil(t, rt)
66+
require.Error(t, err)
67+
require.Equal(t, err, example.error)
68+
})
69+
}
70+
71+
rt, err := ParseRefreshToken(strings.Repeat("!", (4*minRefreshTokenLength)/3))
72+
require.Nil(t, rt)
73+
require.Error(t, err)
74+
75+
original := &RefreshToken{
76+
SessionID: uuid.Must(uuid.NewV4()),
77+
Counter: 9223372036854775807,
78+
}
79+
80+
parsed, err := ParseRefreshToken(original.Encode(make([]byte, 32)))
81+
require.Nil(t, err)
82+
require.Equal(t, original.SessionID.String(), parsed.SessionID.String())
83+
require.Equal(t, original.Counter, parsed.Counter)
84+
require.Equal(t, original.Raw, parsed.Raw)
85+
require.Equal(t, original.Signature, parsed.Signature)
86+
}

internal/models/refresh_token.go

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package models
22

33
import (
44
"database/sql"
5+
"encoding/base64"
56
"net/http"
67
"time"
78

@@ -14,15 +15,6 @@ import (
1415
"github.com/supabase/auth/internal/utilities"
1516
)
1617

17-
type TODO struct {
18-
SessionID string `json:"s"`
19-
ID int64 `json:"t"`
20-
}
21-
22-
func (TODO) TableName() string {
23-
panic("not a DB model")
24-
}
25-
2618
// RefreshToken is the database model for refresh tokens.
2719
type RefreshToken struct {
2820
ID int64 `db:"id"`
@@ -127,6 +119,50 @@ func FindTokenBySessionID(tx *storage.Connection, sessionId *uuid.UUID) (*Refres
127119
return refreshToken, nil
128120
}
129121

122+
func (s *Session) ApplyGrantParams(params *GrantParams) {
123+
s.FactorID = params.FactorID
124+
125+
if params.SessionNotAfter != nil {
126+
s.NotAfter = params.SessionNotAfter
127+
}
128+
129+
if params.UserAgent != "" {
130+
s.UserAgent = &params.UserAgent
131+
}
132+
133+
if params.IP != "" {
134+
s.IP = &params.IP
135+
}
136+
137+
if params.SessionTag != nil && *params.SessionTag != "" {
138+
s.Tag = params.SessionTag
139+
}
140+
141+
if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil {
142+
s.OAuthClientID = params.OAuthClientID
143+
}
144+
}
145+
146+
func (s *Session) SetupRefreshTokenData(dbEncryption conf.DatabaseEncryptionConfiguration) error {
147+
hmacKey := base64.RawURLEncoding.EncodeToString(crypto.GenerateRefreshTokenHmacKey())
148+
149+
if dbEncryption.Encrypt {
150+
es, err := crypto.NewEncryptedString(s.ID.String(), []byte(hmacKey), dbEncryption.EncryptionKeyID, dbEncryption.EncryptionKey)
151+
if err != nil {
152+
return err
153+
}
154+
155+
hmacKey = es.String()
156+
}
157+
158+
counter := int64(0)
159+
160+
s.RefreshTokenHmacKey = &hmacKey
161+
s.RefreshTokenCounter = &counter
162+
163+
return nil
164+
}
165+
130166
func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken, params *GrantParams) (*RefreshToken, error) {
131167
token := &RefreshToken{
132168
UserID: user.ID,
@@ -144,25 +180,7 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
144180
return nil, errors.Wrap(err, "error instantiating new session object")
145181
}
146182

147-
if params.SessionNotAfter != nil {
148-
session.NotAfter = params.SessionNotAfter
149-
}
150-
151-
if params.UserAgent != "" {
152-
session.UserAgent = &params.UserAgent
153-
}
154-
155-
if params.IP != "" {
156-
session.IP = &params.IP
157-
}
158-
159-
if params.SessionTag != nil && *params.SessionTag != "" {
160-
session.Tag = params.SessionTag
161-
}
162-
163-
if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil {
164-
session.OAuthClientID = params.OAuthClientID
165-
}
183+
session.ApplyGrantParams(params)
166184

167185
if err := tx.Create(session); err != nil {
168186
return nil, errors.Wrap(err, "error creating new session")

0 commit comments

Comments
 (0)