diff --git a/credentials_provider_test.go b/credentials_provider_test.go index a63826c..efe7aed 100644 --- a/credentials_provider_test.go +++ b/credentials_provider_test.go @@ -117,7 +117,7 @@ func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { "test-token", time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ), } @@ -159,7 +159,7 @@ func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { "test-token", time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ), } @@ -219,7 +219,7 @@ func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { "initial-token", time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ), } @@ -253,7 +253,7 @@ func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { "updated-token", time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) tm.lock.Unlock() @@ -329,14 +329,14 @@ func TestCredentialsProviderSubscribe(t *testing.T) { rawTokenString, time.Now().Add(tokenExpiration), time.Now(), - int64(tokenExpiration), + tokenExpiration.Milliseconds(), ) listener := &mockCredentialsListener{ LastTokenCh: make(chan string, 1), LastErrCh: make(chan error, 1), } - mtm := &mockTokenManager{done: make(chan struct{})} + mtm := &mockTokenManager{done: make(chan struct{}), lock: &sync.Mutex{}} // Set the token manager factory in the options options := opts options.tokenManagerFactory = mockTokenManagerFactory(mtm) @@ -386,9 +386,9 @@ func TestCredentialsProviderSubscribe(t *testing.T) { rawTokenString, time.Now().Add(tokenExpiration), time.Now(), - int64(tokenExpiration), + tokenExpiration.Milliseconds(), ) - mtm := &mockTokenManager{done: make(chan struct{})} + mtm := &mockTokenManager{done: make(chan struct{}), lock: &sync.Mutex{}} // Set the token manager factory in the options options := opts options.tokenManagerFactory = mockTokenManagerFactory(mtm) @@ -459,7 +459,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { t.Run("concurrent subscribe and get token error ", func(t *testing.T) { t.Parallel() - mtm := &mockTokenManager{done: make(chan struct{})} + mtm := &mockTokenManager{done: make(chan struct{}), lock: &sync.Mutex{}} // Set the token manager factory in the options options := opts options.tokenManagerFactory = mockTokenManagerFactory(mtm) @@ -514,7 +514,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { rawTokenString, time.Now().Add(tokenExpiration), time.Now(), - int64(tokenExpiration), + tokenExpiration.Milliseconds(), ) // Set the token manager factory in the options options := opts diff --git a/entraid_test.go b/entraid_test.go index 98f2fc8..1e1dfa7 100644 --- a/entraid_test.go +++ b/entraid_test.go @@ -48,7 +48,7 @@ func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { rawTokenString, time.Now().Add(tokenExpiration), time.Now(), - int64(tokenExpiration.Seconds()), + tokenExpiration.Milliseconds(), ) } return m.token, m.err @@ -136,7 +136,7 @@ type mockTokenManager struct { done chan struct{} options manager.TokenManagerOptions listener manager.TokenListener - lock sync.Mutex + lock *sync.Mutex } func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { diff --git a/examples/custom_idp/go.mod b/examples/custom_idp/go.mod index 6609e22..59d008f 100644 --- a/examples/custom_idp/go.mod +++ b/examples/custom_idp/go.mod @@ -3,23 +3,23 @@ module custom_example go 1.23.4 require ( - github.com/redis/go-redis-entraid v1.0.0 + github.com/redis/go-redis-entraid v1.0.1 github.com/redis/go-redis/v9 v9.9.0 ) require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect - github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect - golang.org/x/crypto v0.33.0 // indirect - golang.org/x/net v0.35.0 // indirect - golang.org/x/sys v0.30.0 // indirect - golang.org/x/text v0.22.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect ) diff --git a/examples/custom_idp/go.sum b/examples/custom_idp/go.sum index 1b19682..d2e3864 100644 --- a/examples/custom_idp/go.sum +++ b/examples/custom_idp/go.sum @@ -1,15 +1,15 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0/go.mod h1:XCW7KnZet0Opnr7HccfUw1PLc4CjHqpcaxW8DHklNkQ= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1 h1:iw4+KCeCoieuKodp1d5YhAa1TU/GgogCbw8RbGvsfLA= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1/go.mod h1:AP8cDnDTGIVvayqKAhwzpcAyTJosXpvLYNmVFJb98x8= -github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.2.3 h1:BAUsn6/icUFtvUalVwCO0+hSF7qgU9DwwcEfCvtILtw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.2.3/go.mod h1:QlAsNp4gk9zLD2wiZIvIuv699ynpZ2Tq2ZBp+6MrSEw= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.0 h1:j8BorDEigD8UFOSZQiSqAMOOleyQOOQPnUAwV+Ls1gA= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.0/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= -github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 h1:8BKxhZZLX/WosEeoCvWysmKUscfa9v8LIPEEU0JjE2o= -github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -20,36 +20,34 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/keybase/go-keychain v0.0.0-20231219164618-57a3676c3af6 h1:IsMZxCuZqKuao2vNdfD82fjjgPLfyHLpR41Z88viRWs= -github.com/keybase/go-keychain v0.0.0-20231219164618-57a3676c3af6/go.mod h1:3VeWNIJaW+O5xpRQbPp0Ybqu1vJd/pm7s2F473HRrkw= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis-entraid v0.0.0-20250415111332-9d087bc29c12 h1:H5ZfgueBAxs2eAvXtCMEbT2/fLQz/wxW5Ds4c0uzl50= -github.com/redis/go-redis-entraid v0.0.0-20250415111332-9d087bc29c12/go.mod h1:uXKLxCMUAu1VKgWdt8gWc4PWCygiL2pAI5XpnRSVc0w= -github.com/redis/go-redis-entraid v1.0.0/go.mod h1:b+YPtHM3oFJ74Y2eFHPuz1Cp59kUL0fwdiARp27VW8Q= -github.com/redis/go-redis/v9 v9.5.3-0.20250415103233-40a89c56cc52 h1:jRx2gINoJsGKxi/RYXCq1VneAAYes9JxUp13xH2oU2g= -github.com/redis/go-redis/v9 v9.5.3-0.20250415103233-40a89c56cc52/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/redis/go-redis-entraid v1.0.1 h1:Q2gxpSRFLn+KyZuPrF7zDUCQ9iISoUxqzaCjxPqJKQI= +github.com/redis/go-redis-entraid v1.0.1/go.mod h1:OS6s3V1DdSRzOJEIjpK38/w4chZpl/Sy+1pzby+6nEk= +github.com/redis/go-redis/v9 v9.9.0 h1:URbPQ4xVQSQhZ27WMQVmZSo3uT3pL+4IdHVcYq2nVfM= github.com/redis/go-redis/v9 v9.9.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= -golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/manager/defaults.go b/manager/defaults.go index 6b09cf2..8c05bc4 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -3,7 +3,6 @@ package manager import ( "errors" "fmt" - "math" "net" "os" "time" @@ -177,6 +176,6 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden rawToken, expiresOn, now, - int64(math.Ceil(time.Until(expiresOn).Seconds())), + time.Until(expiresOn).Milliseconds(), ), nil } diff --git a/manager/entraid_manager.go b/manager/entraid_manager.go index 8a7473e..3d633d2 100644 --- a/manager/entraid_manager.go +++ b/manager/entraid_manager.go @@ -11,6 +11,8 @@ import ( "github.com/redis/go-redis-entraid/token" ) +const RefreshRationPrecision = 10000 + // entraidTokenManager is a struct that implements the TokenManager interface. type entraidTokenManager struct { // idp is the identity provider used to obtain the token. @@ -20,7 +22,7 @@ type entraidTokenManager struct { token *token.Token // tokenRWLock is a read-write lock used to protect the token from concurrent access. - tokenRWLock sync.RWMutex + tokenRWLock *sync.RWMutex // identityProviderResponseParser is the parser used to parse the response from the identity provider. // It`s ParseResponse method will be called to parse the response and return the token. @@ -40,7 +42,7 @@ type entraidTokenManager struct { listener TokenListener // lock locks the listener to prevent concurrent access. - lock sync.Mutex + lock *sync.Mutex // expirationRefreshRatio is the ratio of the token expiration time to refresh the token. // It is used to determine when to refresh the token. @@ -221,12 +223,14 @@ func (e *entraidTokenManager) stop() (err error) { err = fmt.Errorf("failed to stop token manager: %s", r) } }() + if e.ctxCancel != nil { + e.ctxCancel() + } if e.closedChan == nil || e.listener == nil { return ErrTokenManagerAlreadyStopped } - e.ctxCancel() e.listener = nil close(e.closedChan) @@ -238,30 +242,63 @@ func (e *entraidTokenManager) stop() (err error) { // If the token is nil, it returns 0. // If the time till expiration is less than the lower bound duration, it returns 0 to renew the token now. func (e *entraidTokenManager) durationToRenewal(t *token.Token) time.Duration { + // Fast path: nil token check if t == nil { return 0 } - expirationRefreshTime := t.ReceivedAt().Add(time.Duration(float64(t.TTL()) * float64(time.Second) * e.expirationRefreshRatio)) - timeTillExpiration := time.Until(t.ExpirationOn()) - now := time.Now().UTC() - if expirationRefreshTime.Before(now) { + // Get current time in milliseconds (UTC) + nowMillis := time.Now().UnixMilli() + + // Get expiration time in milliseconds + expMillis := t.ExpirationOn().UnixMilli() + + // Fast path: token already expired + if expMillis <= nowMillis { + return 0 + } + + // Calculate time until expiration in milliseconds + timeTillExpiration := expMillis - nowMillis + + // Get lower bound in milliseconds + lowerBoundMillis := e.lowerBoundDuration.Milliseconds() + + // Fast path: time until expiration is less than lower bound + if timeTillExpiration <= lowerBoundMillis { return 0 } - // if the timeTillExpiration is less than the lower bound (or 0), return 0 to renew the token NOW - if timeTillExpiration <= e.lowerBoundDuration || timeTillExpiration <= 0 { + // Calculate refresh time using integer math with higher precision + // example tests use 0.001, which would be lost with lower precision + // Example: + // ttlMillis = 10000 + // e.expirationRefreshRatio = 0.001 + // - with int math and 100 precision: 10000 * (0.001*100) = 0ms + // - with int math and 10000 precision: 10000 * (0.001*10000) = 100ms + precision := int64(RefreshRationPrecision) + receivedAtMillis := t.ReceivedAt().UnixMilli() + ttlMillis := t.TTL() // Already in milliseconds + refreshRatioInt := int64(e.expirationRefreshRatio * float64(precision)) + refreshMillis := ttlMillis * refreshRatioInt / precision + refreshTimeMillis := receivedAtMillis + refreshMillis + + // Calculate time until refresh + timeUntilRefresh := refreshTimeMillis - nowMillis + + // Fast path: refresh time is in the past + if timeUntilRefresh <= 0 { return 0 } - // Calculate the time to renew the token based on the expiration refresh ratio - duration := time.Until(expirationRefreshTime) + // Calculate time until lower bound + timeUntilLowerBound := timeTillExpiration - lowerBoundMillis - // if the duration will take us past the lower bound, return the duration to lower bound - if timeTillExpiration-e.lowerBoundDuration < duration { - return timeTillExpiration - e.lowerBoundDuration + // If refresh would occur after lower bound, use time until lower bound + if timeUntilRefresh > timeUntilLowerBound { + return time.Duration(timeUntilLowerBound) * time.Millisecond } - // return the calculated duration - return duration + // Otherwise use time until refresh + return time.Duration(timeUntilRefresh) * time.Millisecond } diff --git a/manager/entraid_manager_test.go b/manager/entraid_manager_test.go new file mode 100644 index 0000000..5cb7132 --- /dev/null +++ b/manager/entraid_manager_test.go @@ -0,0 +1,460 @@ +package manager + +import ( + "testing" + "time" + + "github.com/redis/go-redis-entraid/token" + "github.com/stretchr/testify/assert" +) + +func TestDurationToRenewal(t *testing.T) { + tests := []struct { + name string + token *token.Token + refreshRatio float64 + lowerBoundDuration time.Duration + expectedDuration time.Duration + }{ + { + name: "nil token returns 0", + token: nil, + refreshRatio: 0.75, + lowerBoundDuration: time.Second, + expectedDuration: 0, + }, + { + name: "expired token returns 0", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(-time.Hour), + time.Now().Add(-2*time.Hour), + time.Hour.Milliseconds(), + ), + refreshRatio: 0.75, + lowerBoundDuration: time.Second, + expectedDuration: 0, + }, + { + name: "token with TTL less than lower bound returns 0", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(500*time.Millisecond), + time.Now().Add(-time.Hour), + time.Hour.Milliseconds(), + ), + refreshRatio: 0.75, + lowerBoundDuration: time.Second, + expectedDuration: 0, + }, + { + name: "token with TTL exactly at lower bound returns 0", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Second), + time.Now(), + time.Second.Milliseconds(), + ), + refreshRatio: 0.75, + lowerBoundDuration: time.Second, + expectedDuration: 0, + }, + { + name: "token with refresh time before lower bound", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(1*time.Hour), // expires in 1 hour + time.Now().Add(-1*time.Hour), // received 1 hour ago + (2 * time.Hour).Milliseconds(), // TTL is 2 hours + ), + refreshRatio: 0.75, + lowerBoundDuration: time.Second, + // ReceivedAt is 1 hour in the past, TTL is 2 hours, so refresh is at ReceivedAt + 1.5h (75% of 2h). + // Now is ReceivedAt + 1h, so time until refresh is 30 minutes. + expectedDuration: 30 * time.Minute, + }, + { + name: "token with refresh time before lower bound and large lower bound", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(1*time.Hour), // expires in 1 hour + time.Now().Add(-1*time.Hour), // received 1 hour ago + (2 * time.Hour).Milliseconds(), // TTL is 2 hours + ), + refreshRatio: 0.75, + lowerBoundDuration: 45 * time.Minute, + // ReceivedAt is 1 hour in the past, TTL is 2 hours, so refresh is at ReceivedAt + 1.5h (75% of 2h). + // Now is ReceivedAt + 1h, so time until refresh is 30 minutes. + // But lower bound is 45 minutes, so refresh is scheduled for 15 minutes from now. + expectedDuration: 15 * time.Minute, + }, + { + name: "token with refresh time after lower bound and past ReceivedAt", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(1*time.Hour), // expires in 1 hour + time.Now().Add(-30*time.Minute), // received 30 minutes ago + (90 * time.Minute).Milliseconds(), // TTL is 1.5 hours + ), + refreshRatio: 0.75, + lowerBoundDuration: 10 * time.Minute, + // ReceivedAt is 30 minutes in the past, TTL is 1.5 hours, so refresh is at ReceivedAt + 1.125h (75% of 1.5h). + // Now is ReceivedAt + 0.5h, so time until refresh is 1.125h - 0.5h = 0.625h = 37.5 minutes. + expectedDuration: 37*time.Minute + 30*time.Second, + }, + { + name: "token with refresh time after lower bound", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ), + refreshRatio: 0.75, + lowerBoundDuration: 60 * time.Second, + expectedDuration: 45 * time.Minute, + }, + { + name: "token with refresh ratio 1 and lower bound 10 minutes", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ), + refreshRatio: 1.0, + lowerBoundDuration: 10 * time.Minute, + expectedDuration: 50 * time.Minute, + }, + { + name: "token with zero refresh ratio", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ), + refreshRatio: 0.0, + lowerBoundDuration: time.Second, + expectedDuration: 0, + }, + { + name: "token with negative refresh ratio", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ), + refreshRatio: -0.5, + lowerBoundDuration: time.Second, + expectedDuration: 0, + }, + { + name: "token with very large TTL", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(24*365*time.Hour), + time.Now(), + (24 * 365 * time.Hour).Milliseconds(), + ), + refreshRatio: 0.75, + lowerBoundDuration: time.Hour, + expectedDuration: 24 * 365 * 45 * time.Minute, + }, + { + name: "token with lower bound equal to TTL", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ), + refreshRatio: 0.75, + lowerBoundDuration: time.Hour, + expectedDuration: 0, + }, + { + name: "token with lower bound greater than TTL", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ), + refreshRatio: 0.75, + lowerBoundDuration: 2 * time.Hour, + expectedDuration: 0, + }, + { + name: "token with refresh ratio resulting in zero refresh time", + token: token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Second), + time.Now(), + time.Second.Milliseconds(), + ), + refreshRatio: 0.0001, + lowerBoundDuration: time.Millisecond, + expectedDuration: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &entraidTokenManager{ + expirationRefreshRatio: tt.refreshRatio, + lowerBoundDuration: tt.lowerBoundDuration, + } + + duration := manager.durationToRenewal(tt.token) + assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond), + "%s: expected %v, got %v", tt.name, tt.expectedDuration, duration) + }) + } +} + +func TestDurationToRenewalMillisecondPrecision(t *testing.T) { + now := time.Now() + tests := []struct { + name string + token *token.Token + refreshRatio float64 + lowerBoundDuration time.Duration + expectedDuration time.Duration + }{ + { + name: "exact millisecond TTL", + token: token.New( + "username", + "password", + "rawToken", + now.Add(time.Second), + now, + 1000, // 1 second in milliseconds + ), + refreshRatio: 0.001, + lowerBoundDuration: time.Millisecond, + expectedDuration: time.Millisecond, // 1ms refresh time + }, + { + name: "sub-millisecond TTL", + token: token.New( + "username", + "password", + "rawToken", + now.Add(100*time.Millisecond), + now, + 100, // 100ms + ), + refreshRatio: 0.001, + lowerBoundDuration: time.Millisecond, + expectedDuration: 0, + }, + { + name: "odd millisecond TTL", + token: token.New( + "username", + "password", + "rawToken", + now.Add(123*time.Millisecond), + now, + 123, // 123ms + ), + refreshRatio: 0.001, + lowerBoundDuration: time.Millisecond, + expectedDuration: 0, // 0.123ms rounds to 0ms + }, + { + name: "exact second TTL with millisecond refresh", + token: token.New( + "username", + "password", + "rawToken", + now.Add(time.Second), + now, + 1000, // 1 second in milliseconds + ), + refreshRatio: 0.001, + lowerBoundDuration: time.Millisecond, + expectedDuration: time.Millisecond, // 1ms refresh time + }, + { + name: "high precision refresh ratio", + token: token.New( + "username", + "password", + "rawToken", + now.Add(time.Second), + now, + 1000, // 1 second in milliseconds + ), + refreshRatio: 0.0001, // 0.01% + lowerBoundDuration: time.Millisecond, + expectedDuration: 0, // 0.1ms rounds to 0ms + }, + { + name: "very small TTL with high precision ratio", + token: token.New( + "username", + "password", + "rawToken", + now.Add(10*time.Millisecond), + now, + 10, // 10ms + ), + refreshRatio: 0.0001, // 0.01% + lowerBoundDuration: time.Millisecond, + expectedDuration: 0, // 0.001ms rounds to 0ms + }, + { + name: "large TTL with high precision ratio", + token: token.New( + "username", + "password", + "rawToken", + now.Add(time.Hour), + now, + time.Hour.Milliseconds(), + ), + refreshRatio: 0.0001, // 0.01% + lowerBoundDuration: time.Millisecond, + expectedDuration: 360 * time.Millisecond, // 0.01% of 1 hour = 360ms + }, + { + name: "boundary case: refresh time exactly 1ms", + token: token.New( + "username", + "password", + "rawToken", + now.Add(100*time.Millisecond), + now, + 100, // 100ms + ), + refreshRatio: 0.01, // 1% + lowerBoundDuration: time.Millisecond, + expectedDuration: time.Millisecond, // 1ms refresh time + }, + { + name: "boundary case: refresh time just below 1ms", + token: token.New( + "username", + "password", + "rawToken", + now.Add(100*time.Millisecond), + now, + 100, // 100ms + ), + refreshRatio: 0.009, // 0.9% + lowerBoundDuration: time.Millisecond, + expectedDuration: 0, // 0.9ms rounds to 0ms + }, + { + name: "boundary case: refresh time just above 1ms", + token: token.New( + "username", + "password", + "rawToken", + now.Add(100*time.Millisecond), + now, + 100, // 100ms + ), + refreshRatio: 0.011, // 1.1% + lowerBoundDuration: time.Millisecond, + expectedDuration: time.Millisecond, // 1.1ms rounds to 1ms + }, + { + name: "large TTL with very small refresh ratio", + token: token.New( + "username", + "password", + "rawToken", + now.Add(24*time.Hour), + now, + (24 * time.Hour).Milliseconds(), + ), + refreshRatio: 0.0001, // 0.01% + lowerBoundDuration: time.Millisecond, + expectedDuration: 8*time.Second + 640*time.Millisecond, // 0.01% of 24 hours = 8.64s + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &entraidTokenManager{ + expirationRefreshRatio: tt.refreshRatio, + lowerBoundDuration: tt.lowerBoundDuration, + } + + duration := manager.durationToRenewal(tt.token) + assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond), + "%s: expected %v, got %v", tt.name, tt.expectedDuration, duration) + }) + } +} + +func TestDurationToRenewalConcurrent(t *testing.T) { + manager := &entraidTokenManager{ + expirationRefreshRatio: 0.75, + lowerBoundDuration: time.Second, + } + + token := token.New( + "username", + "password", + "rawToken", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ) + + // Run multiple goroutines to test concurrent access + const goroutines = 10 + results := make(chan time.Duration, goroutines) + + for i := 0; i < goroutines; i++ { + go func() { + results <- manager.durationToRenewal(token) + }() + } + + // Collect results + var firstResult time.Duration + for i := 0; i < goroutines; i++ { + result := <-results + if i == 0 { + firstResult = result + } else { + // All results should be within 10ms of each other + assert.InDelta(t, firstResult.Milliseconds(), result.Milliseconds(), 10) + } + } +} diff --git a/manager/manager_test.go b/manager/manager_test.go index 3f048a9..cac02f7 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -61,7 +61,7 @@ var testTokenValid = token.New( "test", time.Now().Add(time.Hour), time.Now(), - int64(time.Hour.Seconds()), + time.Hour.Milliseconds(), ) func newTestJWTToken(expiresOn time.Time) string { diff --git a/manager/token_manager.go b/manager/token_manager.go index 8dcf55c..be01a52 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -3,6 +3,7 @@ package manager import ( "context" "fmt" + "sync" "time" "github.com/redis/go-redis-entraid/shared" @@ -16,6 +17,8 @@ type TokenManagerOptions struct { // The value should be between 0 and 1. // For example, if the expiration time is 1 hour and the ratio is 0.75, // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) + // Precision is 4 decimal places. + // Closer to 1, the token will be refreshed later. We recommend not going above 0.9. // // default: 0.7 ExpirationRefreshRatio float64 @@ -125,5 +128,7 @@ func NewTokenManager(idp shared.IdentityProvider, options TokenManagerOptions) ( identityProviderResponseParser: options.IdentityProviderResponseParser, retryOptions: options.RetryOptions, requestTimeout: options.RequestTimeout, + tokenRWLock: &sync.RWMutex{}, + lock: &sync.Mutex{}, }, nil } diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index dba0a44..8f1ce2a 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -598,6 +598,171 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { assert.NotNil(t, token1) }) + t.Run("GetToken with cached token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + // First setup the manager with a token + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + + // Get the token once to cache it + token1, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token1) + + // Change the mock to return a different token to verify caching + differentToken := token.New( + "different", + "different", + "different", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ) + mParser = &mockIdentityProviderResponseParser{} + mParser.On("ParseResponse", rawResponse).Return(differentToken, nil) + tm.identityProviderResponseParser = mParser + + // Get the token again, should return the cached token + token2, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token2) + assert.Equal(t, token1, token2) + + // Verify that RequestToken was not called again + idp.AssertNumberOfCalls(t, "RequestToken", 1) + }) + + t.Run("GetToken with force refresh", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + // First setup the manager with a token + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + + // Get the token once to cache it + token1, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token1) + + // Change the mock to return a different token + differentToken := token.New( + "different", + "different", + "different", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ) + mParser = &mockIdentityProviderResponseParser{} + mParser.On("ParseResponse", rawResponse).Return(differentToken, nil) + tm.identityProviderResponseParser = mParser + + // Get the token with force refresh, should get the new token + token2, err := tokenManager.GetToken(true) + assert.NoError(t, err) + assert.NotNil(t, token2) + assert.Equal(t, differentToken, token2) + + // Verify that RequestToken was called again + idp.AssertNumberOfCalls(t, "RequestToken", 2) + }) + + t.Run("GetToken with valid cached token and positive duration", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + ExpirationRefreshRatio: 0.75, + LowerRefreshBound: time.Hour, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + // Create a token that will have a positive duration + validToken := token.New( + "username", + "password", + "rawToken", + time.Now().Add(2*time.Hour), // Expires in 2 hours + time.Now(), + (2 * time.Hour).Milliseconds(), + ) + + // First get a token to cache it + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(validToken, nil) + + // Get the token once to cache it + token1, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token1) + + // Change the mock to return a different token + differentToken := token.New( + "different", + "different", + "different", + time.Now().Add(time.Hour), + time.Now(), + time.Hour.Milliseconds(), + ) + mParser = &mockIdentityProviderResponseParser{} + mParser.On("ParseResponse", rawResponse).Return(differentToken, nil) + tm.identityProviderResponseParser = mParser + + // Get the token again without force refresh + token2, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token2) + assert.Equal(t, token1, token2) // Should return the cached token + + // Verify that RequestToken was not called again + idp.AssertNumberOfCalls(t, "RequestToken", 1) + }) + t.Run("GetToken with parse error", func(t *testing.T) { t.Parallel() idp := &mockIdentityProvider{} @@ -718,6 +883,67 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { assert.Error(t, err) assert.Nil(t, token1) }) + + t.Run("GetToken with token set between checks", func(t *testing.T) { + t.Skip("Flaky test, can cause a race") + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + ExpirationRefreshRatio: 0.5, + LowerRefreshBound: time.Minute, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + validToken := token.New( + "username", + "password", + "rawToken", + time.Now().Add(1*time.Hour), + time.Now(), + (1 * time.Hour).Milliseconds(), + ) + + // Step 1: Acquire the read lock + // This simulates a concurrent GetToken operation + // this should be a write lock since we are actually writing + // but it will block the get token if we acquire the write lock first + tm.tokenRWLock.RLock() + + // Step 2: Start GetToken in a goroutine (it will block on upgrading to write lock) + var token2 *token.Token + var err2 error + getTokenStarted := make(chan struct{}) + getTokenDone := make(chan struct{}) + go func() { + close(getTokenStarted) + token2, err2 = tokenManager.GetToken(false) + close(getTokenDone) + }() + + // Step 3: Wait for GetToken to start and block on write lock + <-getTokenStarted + // Give the goroutine a moment to reach the write lock + time.Sleep(1 * time.Millisecond) + // Step 4: Set the token + tm.token = validToken + // Step 5: Release the read lock so GetToken can proceed + tm.tokenRWLock.RUnlock() + + // Step 6: Wait for GetToken to finish + <-getTokenDone + + // Step 7: Assert the result + assert.NoError(t, err2) + assert.NotNil(t, token2) + assert.Equal(t, validToken, token2) + idp.AssertNotCalled(t, "RequestToken") + }) } func TestEntraidTokenManager_durationToRenewal(t *testing.T) { @@ -805,10 +1031,10 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { token1 := token.New( "test", "test", - "test", + "debug", expiresOn, time.Now(), - int64(time.Until(expiresOn)), + time.Until(expiresOn).Milliseconds(), ) mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once() @@ -820,9 +1046,11 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) + time.Sleep(time.Millisecond) + toRenewal = tm.durationToRenewal(tm.token) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) - <-time.After(toRenewal / 10) + time.Sleep(toRenewal / 10) assert.NotNil(t, tm.listener) assert.NoError(t, stopper()) assert.Nil(t, tm.listener) @@ -830,7 +1058,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { close(tm.closedChan) }) - <-time.After(toRenewal) + time.Sleep(toRenewal) // already stopped assert.Error(t, stopper()) mock.AssertExpectationsForObjects(t, idp, mParser, listener) @@ -892,7 +1120,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { // wait for request token to be called <-done // wait a bit for listener to be notified - <-time.After(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) assert.NoError(t, cancel()) assert.InDelta(t, stop.Sub(start), tm.retryOptions.InitialDelay, float64(200*time.Millisecond)) @@ -944,7 +1172,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.Equal(t, time.Duration(0), toRenewal) assert.True(t, expiresIn > toRenewal) - <-time.After(time.Duration(tm.retryOptions.InitialDelay / 2)) + time.Sleep(tm.retryOptions.InitialDelay / 2) assert.NoError(t, cancel()) assert.Nil(t, tm.listener) assert.Panics(t, func() { @@ -995,10 +1223,12 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) + time.Sleep(time.Millisecond) + toRenewal = tm.durationToRenewal(tm.token) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) - <-time.After(toRenewal + time.Second) + time.Sleep(toRenewal + time.Second) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -1048,9 +1278,11 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) + time.Sleep(time.Millisecond) + toRenewal = tm.durationToRenewal(tm.token) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) - <-time.After(toRenewal + 100*time.Millisecond) + time.Sleep(toRenewal + 100*time.Millisecond) idp.AssertNumberOfCalls(t, "RequestToken", 2) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -1100,9 +1332,11 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) + time.Sleep(time.Millisecond) + toRenewal = tm.durationToRenewal(tm.token) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) - <-time.After(toRenewal + 100*time.Millisecond) + time.Sleep(toRenewal + 100*time.Millisecond) idp.AssertNumberOfCalls(t, "RequestToken", 2) listener.AssertNumberOfCalls(t, "OnError", 1) @@ -1173,6 +1407,8 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.NotNil(t, tm.listener) toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) + time.Sleep(time.Millisecond) + toRenewal = tm.durationToRenewal(tm.token) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) @@ -1249,10 +1485,13 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) + + time.Sleep(time.Millisecond) + toRenewal = tm.durationToRenewal(tm.token) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) - <-time.After(toRenewal + 500*time.Millisecond) + time.Sleep(toRenewal + 500*time.Millisecond) assert.Nil(t, cancel()) select { @@ -1261,7 +1500,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { case <-tm.closedChan: } - <-time.After(50 * time.Millisecond) + time.Sleep(50 * time.Millisecond) // maxAttempts + the initial one idp.AssertNumberOfCalls(t, "RequestToken", 2) @@ -1460,17 +1699,13 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { // Using modulo for a deterministic pattern that exercises all operations opType := j % 3 - // t.Logf("Goroutine %d, Operation %d: Performing operation type %d", routineID, j, opType) - switch opType { case 0: // Start the token manager with a new listener - // t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j) closeFunc, err := tm.Start(listener) if err != nil { if err != ErrTokenManagerAlreadyStarted { - // t.Logf("Goroutine %d, Operation %d: Start failed with error: %v", routineID, j, err) select { case errorCh <- fmt.Errorf("failed to start token manager: %w", err): default: @@ -1480,7 +1715,6 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { continue } - // t.Logf("Goroutine %d, Operation %d: Successfully started token manager", routineID, j) // Store the closer for later cleanup closerKey := fmt.Sprintf("closer-%d-%d", routineID, j) closers.Store(closerKey, closeFunc) @@ -1490,17 +1724,14 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { case 1: // Get current token - //t.Logf("Goroutine %d, Operation %d: Getting token", routineID, j) token, err := tm.GetToken(false) if err != nil { - //t.Logf("Goroutine %d, Operation %d: GetToken failed with error: %v", routineID, j, err) select { case errorCh <- fmt.Errorf("failed to get token: %w", err): default: t.Fatalf("Goroutine %d, Operation %d: Failed to get token: %v", routineID, j, err) } } else if token != nil { - //t.Logf("Goroutine %d, Operation %d: Successfully got token, expires: %v", routineID, j, token.ExpirationOn()) select { case tokenCh <- token: default: @@ -1511,28 +1742,17 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { case 2: // Close a previously created token manager listener // This simulates multiple subscriptions being created and destroyed - //t.Logf("Goroutine %d, Operation %d: Attempting to close a token manager", routineID, j) - closedAny := false - closers.Range(func(key, value interface{}) bool { if j%10 > 7 { // Only close some of the time based on a pattern - closedAny = true - //t.Logf("Goroutine %d, Operation %d: Closing token manager with key %v", routineID, j, key) - closeFunc := value.(StopFunc) if err := closeFunc(); err != nil { if err != ErrTokenManagerAlreadyStopped { - // t.Logf("Goroutine %d, Operation %d: Close failed with error: %v", routineID, j, err) select { case errorCh <- fmt.Errorf("failed to close token manager: %w", err): default: t.Fatalf("Goroutine %d, Operation %d: Failed to close token manager: %v", routineID, j, err) } - } else { - //t.Logf("Goroutine %d, Operation %d: TokenManager was already stopped", routineID, j) } - } else { - // t.Logf("Goroutine %d, Operation %d: Successfully closed token manager", routineID, j) } closers.Delete(key) @@ -1540,10 +1760,6 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { } return true }) - - if !closedAny { - //t.Logf("Goroutine %d, Operation %d: No token manager to close or condition not met", routineID, j) - } } } }(i) diff --git a/providers_test.go b/providers_test.go index b0facab..cd0a0ae 100644 --- a/providers_test.go +++ b/providers_test.go @@ -72,7 +72,7 @@ func TestNewManagedIdentityCredentialsProvider(t *testing.T) { rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Set the token manager factory in the options @@ -157,7 +157,7 @@ func TestNewConfidentialCredentialsProvider(t *testing.T) { rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Set the token manager factory in the options @@ -230,7 +230,7 @@ func TestNewDefaultAzureCredentialsProvider(t *testing.T) { rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Set the token manager factory in the options @@ -290,7 +290,7 @@ func TestCredentialsProviderInterface(t *testing.T) { rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour.Seconds()), + time.Hour.Milliseconds(), ) // Set the token manager factory in the options @@ -326,7 +326,7 @@ func TestCredentialsProviderInterface(t *testing.T) { rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Set the token manager factory in the options @@ -358,7 +358,7 @@ func TestCredentialsProviderInterface(t *testing.T) { rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Set the token manager factory in the options @@ -483,7 +483,7 @@ func TestNewManagedIdentityCredentialsProvider_TokenManagerStartError(t *testing rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Create a mock token manager that returns an error on Start @@ -527,7 +527,7 @@ func TestNewConfidentialCredentialsProvider_TokenManagerStartError(t *testing.T) rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Create a mock token manager that returns an error on Start @@ -567,7 +567,7 @@ func TestNewDefaultAzureCredentialsProvider_TokenManagerStartError(t *testing.T) rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + time.Hour.Milliseconds(), ) // Create a mock token manager that returns an error on Start diff --git a/shared/identity_provider_response_test.go b/shared/identity_provider_response_test.go index 2227506..7150ce2 100644 --- a/shared/identity_provider_response_test.go +++ b/shared/identity_provider_response_test.go @@ -307,7 +307,7 @@ func TestIdentityProvider(t *testing.T) { func TestIdentityProviderResponseParser(t *testing.T) { now := time.Now() expires := now.Add(time.Hour) - testToken := token.New("test-user", "test-password", "test-token", expires, now, int64(time.Hour.Seconds())) + testToken := token.New("test-user", "test-password", "test-token", expires, now, time.Hour.Milliseconds()) tests := []struct { name string diff --git a/token/token.go b/token/token.go index 2e2d7a1..016ad2f 100644 --- a/token/token.go +++ b/token/token.go @@ -11,6 +11,10 @@ var _ auth.Credentials = (*Token)(nil) // New creates a new token with the specified username, password, raw token, expiration time, received at time, and time to live. // NOTE: This won't do any validation on the token, expiresOn, receivedAt, or ttl. It will simply create a new token instance. +// The caller is responsible for ensuring the token is valid. +// Expiration time and TTL are used to determine when the token should be refreshed. +// TTL is in milliseconds. +// receivedAt + ttl should be within a millisecond of expiresOn func New(username, password, rawToken string, expiresOn, receivedAt time.Time, ttl int64) *Token { return &Token{ username: username, @@ -31,11 +35,11 @@ type Token struct { password string // expiresOn is the expiration time of the token. expiresOn time.Time - // ttl is the time to live of the token. + // ttl is the time to live of the token in milliseconds. ttl int64 // rawToken is the authentication token. rawToken string - // receivedAt is the time when the token was received. + // receivedAt is the time when the token was received receivedAt time.Time } @@ -57,7 +61,9 @@ func (t *Token) RawToken() string { // ReceivedAt returns the time when the token was received. func (t *Token) ReceivedAt() time.Time { if t.receivedAt.IsZero() { - return time.Now() + // set it to now, recalculate ttl + t.receivedAt = time.Now() + t.ttl = t.expiresOn.Sub(t.receivedAt).Milliseconds() } return t.receivedAt } diff --git a/token/token_test.go b/token/token_test.go index c845dea..58134f5 100644 --- a/token/token_test.go +++ b/token/token_test.go @@ -12,7 +12,7 @@ func TestNew(t *testing.T) { t.Parallel() expiration := time.Now().Add(1 * time.Hour) receivedAt := time.Now() - ttl := expiration.Unix() - receivedAt.Unix() + ttl := expiration.UnixMilli() - receivedAt.UnixMilli() token := New("username", "password", "rawToken", expiration, receivedAt, ttl) assert.Equal(t, "username", token.username) assert.Equal(t, "password", token.password) @@ -29,7 +29,7 @@ func TestBasicAuth(t *testing.T) { rawToken := fmt.Sprintf("%s:%s", username, password) expiration := time.Now().Add(1 * time.Hour) receivedAt := time.Now() - ttl := expiration.Unix() - receivedAt.Unix() + ttl := expiration.UnixMilli() - receivedAt.UnixMilli() token := New(username, password, rawToken, expiration, receivedAt, ttl) baUsername, baPassword := token.BasicAuth() assert.Equal(t, username, baUsername) @@ -43,7 +43,7 @@ func TestRawCredentials(t *testing.T) { rawToken := fmt.Sprintf("%s:%s", username, password) expiration := time.Now().Add(1 * time.Hour) receivedAt := time.Now() - ttl := expiration.Unix() - receivedAt.Unix() + ttl := expiration.UnixMilli() - receivedAt.UnixMilli() token := New(username, password, rawToken, expiration, receivedAt, ttl) rawCredentials := token.RawCredentials() assert.Equal(t, rawToken, rawCredentials) @@ -58,7 +58,7 @@ func TestExpirationOn(t *testing.T) { rawToken := fmt.Sprintf("%s:%s", username, password) expiration := time.Now().Add(1 * time.Hour) receivedAt := time.Now() - ttl := expiration.Unix() - receivedAt.Unix() + ttl := expiration.UnixMilli() - receivedAt.UnixMilli() token := New(username, password, rawToken, expiration, receivedAt, ttl) expirationOn := token.ExpirationOn() assert.True(t, expirationOn.After(time.Now())) @@ -72,14 +72,14 @@ func TestTokenTTL(t *testing.T) { rawToken := fmt.Sprintf("%s:%s", username, password) expiration := time.Now().Add(1 * time.Hour) receivedAt := time.Now() - ttl := expiration.Unix() - receivedAt.Unix() + ttl := expiration.UnixMilli() - receivedAt.UnixMilli() token := New(username, password, rawToken, expiration, receivedAt, ttl) assert.Equal(t, ttl, token.TTL()) } func TestCopyToken(t *testing.T) { t.Parallel() - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token := New("username", "password", "rawToken", time.Now(), time.Now(), time.Hour.Milliseconds()) copiedToken := copyToken(token) assert.Equal(t, token.username, copiedToken.username) @@ -108,7 +108,7 @@ func TestTokenReceivedAt(t *testing.T) { t.Parallel() // Create a token with a specific receivedAt time receivedAt := time.Now() - token := New("username", "password", "rawToken", time.Now(), receivedAt, 3600) + token := New("username", "password", "rawToken", time.Now(), receivedAt, time.Hour.Milliseconds()) assert.True(t, token.receivedAt.After(time.Now().Add(-1*time.Hour))) assert.True(t, token.receivedAt.Before(time.Now().Add(1*time.Hour))) @@ -133,12 +133,12 @@ func BenchmarkNew(b *testing.B) { now := time.Now() b.ResetTimer() for i := 0; i < b.N; i++ { - New("username", "password", "rawToken", now, now, 3600) + New("username", "password", "rawToken", now, now, time.Hour.Milliseconds()) } } func BenchmarkBasicAuth(b *testing.B) { - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token := New("username", "password", "rawToken", time.Now(), time.Now(), time.Hour.Milliseconds()) b.ResetTimer() for i := 0; i < b.N; i++ { token.BasicAuth() @@ -146,7 +146,7 @@ func BenchmarkBasicAuth(b *testing.B) { } func BenchmarkRawCredentials(b *testing.B) { - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token := New("username", "password", "rawToken", time.Now(), time.Now(), time.Hour.Milliseconds()) b.ResetTimer() for i := 0; i < b.N; i++ { token.RawCredentials() @@ -154,7 +154,7 @@ func BenchmarkRawCredentials(b *testing.B) { } func BenchmarkExpirationOn(b *testing.B) { - token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), 3600) + token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), time.Hour.Milliseconds()) b.ResetTimer() for i := 0; i < b.N; i++ { token.ExpirationOn() @@ -162,7 +162,7 @@ func BenchmarkExpirationOn(b *testing.B) { } func BenchmarkCopyToken(b *testing.B) { - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token := New("username", "password", "rawToken", time.Now(), time.Now(), time.Hour.Milliseconds()) b.ResetTimer() for i := 0; i < b.N; i++ { token.Copy() diff --git a/token_listener_test.go b/token_listener_test.go index 43ebf07..47e6961 100644 --- a/token_listener_test.go +++ b/token_listener_test.go @@ -24,7 +24,7 @@ func TestOnTokenNext(t *testing.T) { listener := tokenListenerFromCP(cp) now := time.Now() - testToken := token.New("test-user", "test-pass", "test-token", now.Add(time.Hour), now, 3600) + testToken := token.New("test-user", "test-pass", "test-token", now.Add(time.Hour), now, time.Hour.Milliseconds()) listener.OnNext(testToken)