diff --git a/internal/utils.go b/internal/utils.go index 46b3842..32b8912 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -1,16 +1,18 @@ package internal // IsClosed checks if a channel is closed. +// Returns true only if the channel is actually closed, not just if it has data available. // -// NOTE: It returns true if the channel is closed as well -// as if the channel is not empty. Used internally -// to check if the channel is closed. +// WARNING: This function will consume one value from the channel if it has pending data. +// Use with caution on channels where consuming data might cause issues. func IsClosed(ch <-chan struct{}) bool { select { - case <-ch: - return true + case _, ok := <-ch: + // If ok is false, the channel is closed + // If ok is true, the channel had data (which we just consumed) + return !ok default: + // Channel is open but has no data available + return false } - - return false } diff --git a/manager/entraid_manager.go b/manager/entraid_manager.go index 3d633d2..8accc70 100644 --- a/manager/entraid_manager.go +++ b/manager/entraid_manager.go @@ -148,6 +148,16 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { e.listener = listener go func(listener TokenListener, closed <-chan struct{}) { + // Add panic recovery to prevent crashes + defer func() { + if r := recover(); r != nil { + // Attempt to notify listener of panic, but don't panic again if that fails + func() { + defer func() { _ = recover() }() + listener.OnError(fmt.Errorf("token manager goroutine panic: %v", r)) + }() + } + }() maxDelay := e.retryOptions.MaxDelay initialDelay := e.retryOptions.InitialDelay @@ -223,6 +233,7 @@ func (e *entraidTokenManager) stop() (err error) { err = fmt.Errorf("failed to stop token manager: %s", r) } }() + if e.ctxCancel != nil { e.ctxCancel() } @@ -232,7 +243,11 @@ func (e *entraidTokenManager) stop() (err error) { } e.listener = nil - close(e.closedChan) + + // Safely close the channel - only close if not already closed + if !internal.IsClosed(e.closedChan) { + close(e.closedChan) + } return nil } diff --git a/manager/entraid_manager_test.go b/manager/entraid_manager_test.go index 5cb7132..7ac8987 100644 --- a/manager/entraid_manager_test.go +++ b/manager/entraid_manager_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" ) +const testDurationDelta = float64(5 * time.Millisecond) + func TestDurationToRenewal(t *testing.T) { tests := []struct { name string @@ -236,7 +238,7 @@ func TestDurationToRenewal(t *testing.T) { } duration := manager.durationToRenewal(tt.token) - assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond), + assert.InDelta(t, float64(tt.expectedDuration), float64(duration), testDurationDelta, "%s: expected %v, got %v", tt.name, tt.expectedDuration, duration) }) } @@ -415,7 +417,7 @@ func TestDurationToRenewalMillisecondPrecision(t *testing.T) { } duration := manager.durationToRenewal(tt.token) - assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond), + assert.InDelta(t, float64(tt.expectedDuration), float64(duration), testDurationDelta, "%s: expected %v, got %v", tt.name, tt.expectedDuration, duration) }) } @@ -453,8 +455,8 @@ func TestDurationToRenewalConcurrent(t *testing.T) { if i == 0 { firstResult = result } else { - // All results should be within 10ms of each other - assert.InDelta(t, firstResult.Milliseconds(), result.Milliseconds(), 10) + // All results should be within 5ms of each other + assert.InDelta(t, firstResult.Milliseconds(), result.Milliseconds(), 5) } } } diff --git a/token/token.go b/token/token.go index 016ad2f..8e59a44 100644 --- a/token/token.go +++ b/token/token.go @@ -10,12 +10,22 @@ import ( var _ auth.Credentials = (*Token)(nil) // New creates a new token with the specified username, password, raw token, expiration time, received at time, and time to live. -// NOTE: This won't do any validation on the token, expiresOn, receivedAt, or ttl. It will simply create a new token instance. -// The caller is responsible for ensuring the token is valid. +// NOTE: The caller is responsible for ensuring the token is valid. +// If the token is invalid, the behavior is undefined. +// - if expiresOn is zero, New returns nil +// - if receivedAt is zero, it will be set to the current time and TTL will be recalculated // Expiration time and TTL are used to determine when the token should be refreshed. // TTL is in milliseconds. // receivedAt + ttl should be within a millisecond of expiresOn func New(username, password, rawToken string, expiresOn, receivedAt time.Time, ttl int64) *Token { + if expiresOn.IsZero() { + return nil + } + if receivedAt.IsZero() { + receivedAt = time.Now() + ttl = expiresOn.Sub(receivedAt).Milliseconds() + } + return &Token{ username: username, password: password, @@ -28,6 +38,10 @@ func New(username, password, rawToken string, expiresOn, receivedAt time.Time, t // Token represents parsed authentication token used to access the Redis server. // It implements the auth.Credentials interface. +// +// WARNING: Use New() to create a new token. +// Creating a token with Token{} is invalid and will undefined behavior in the TokenManager. +// The zero value of Token is not valid. type Token struct { // username is the username of the user. username string @@ -60,11 +74,6 @@ func (t *Token) RawToken() string { // ReceivedAt returns the time when the token was received. func (t *Token) ReceivedAt() time.Time { - if t.receivedAt.IsZero() { - // set it to now, recalculate ttl - t.receivedAt = time.Now() - t.ttl = t.expiresOn.Sub(t.receivedAt).Milliseconds() - } return t.receivedAt } diff --git a/token/token_test.go b/token/token_test.go index 58134f5..54b1c99 100644 --- a/token/token_test.go +++ b/token/token_test.go @@ -94,14 +94,15 @@ func TestCopyToken(t *testing.T) { assert.NotEqual(t, token.expiresOn, copiedToken.expiresOn) // copy nil - copiedToken = copyToken(nil) - assert.Nil(t, copiedToken) + nilToken := copyToken(nil) + assert.Nil(t, nilToken) // copy empty token - copiedToken = copyToken(&Token{}) - assert.NotNil(t, copiedToken) + emptyToken := copyToken(&Token{}) + assert.Nil(t, emptyToken) anotherCopy := copiedToken.Copy() anotherCopy.rawToken = "changed" assert.NotEqual(t, copiedToken, anotherCopy) + assert.NotEqual(t, copiedToken.rawToken, anotherCopy.rawToken) } func TestTokenReceivedAt(t *testing.T) { @@ -124,7 +125,7 @@ func TestTokenReceivedAt(t *testing.T) { // Check if the copied token is a new instance assert.NotNil(t, tcopiedToken) - emptyRecievedAt := &Token{} + emptyRecievedAt := New("username", "password", "rawToken", time.Now(), time.Time{}, time.Hour.Milliseconds()) assert.True(t, emptyRecievedAt.ReceivedAt().After(time.Now().Add(-1*time.Hour))) assert.True(t, emptyRecievedAt.ReceivedAt().Before(time.Now().Add(1*time.Hour))) }