Skip to content

Commit f387795

Browse files
authored
refactor(manager): small refactors around the manager and token logic (#10)
* fix(manager): Improve channel closure handling - Updated IsClosed function to accurately check if a channel is closed without consuming data unless necessary. - Safely close the closedChan only if it is not already closed to avoid potential panics. * fix(manager): Add panic recovery in token manager goroutine - Implemented panic recovery in the Start method of entraidTokenManager to prevent crashes and ensure listener is notified of errors. * fix(token): Enhance token creation logic and documentation - Updated the New function to return nil if expiresOn is zero to prevent panic. - Added logic to set receivedAt to the current time and recalculate TTL if receivedAt is zero. - Improved documentation to clarify the responsibilities of the caller regarding token validity and behavior when parameters are zero. * chore(token): remove some unnecessary comments * test(manager): change test delta to 5ms
1 parent ab35b6a commit f387795

File tree

5 files changed

+53
-24
lines changed

5 files changed

+53
-24
lines changed

internal/utils.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
package internal
22

33
// IsClosed checks if a channel is closed.
4+
// Returns true only if the channel is actually closed, not just if it has data available.
45
//
5-
// NOTE: It returns true if the channel is closed as well
6-
// as if the channel is not empty. Used internally
7-
// to check if the channel is closed.
6+
// WARNING: This function will consume one value from the channel if it has pending data.
7+
// Use with caution on channels where consuming data might cause issues.
88
func IsClosed(ch <-chan struct{}) bool {
99
select {
10-
case <-ch:
11-
return true
10+
case _, ok := <-ch:
11+
// If ok is false, the channel is closed
12+
// If ok is true, the channel had data (which we just consumed)
13+
return !ok
1214
default:
15+
// Channel is open but has no data available
16+
return false
1317
}
14-
15-
return false
1618
}

manager/entraid_manager.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) {
148148
e.listener = listener
149149

150150
go func(listener TokenListener, closed <-chan struct{}) {
151+
// Add panic recovery to prevent crashes
152+
defer func() {
153+
if r := recover(); r != nil {
154+
// Attempt to notify listener of panic, but don't panic again if that fails
155+
func() {
156+
defer func() { _ = recover() }()
157+
listener.OnError(fmt.Errorf("token manager goroutine panic: %v", r))
158+
}()
159+
}
160+
}()
151161
maxDelay := e.retryOptions.MaxDelay
152162
initialDelay := e.retryOptions.InitialDelay
153163

@@ -223,6 +233,7 @@ func (e *entraidTokenManager) stop() (err error) {
223233
err = fmt.Errorf("failed to stop token manager: %s", r)
224234
}
225235
}()
236+
226237
if e.ctxCancel != nil {
227238
e.ctxCancel()
228239
}
@@ -232,7 +243,11 @@ func (e *entraidTokenManager) stop() (err error) {
232243
}
233244

234245
e.listener = nil
235-
close(e.closedChan)
246+
247+
// Safely close the channel - only close if not already closed
248+
if !internal.IsClosed(e.closedChan) {
249+
close(e.closedChan)
250+
}
236251

237252
return nil
238253
}

manager/entraid_manager_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"github.com/stretchr/testify/assert"
99
)
1010

11+
const testDurationDelta = float64(5 * time.Millisecond)
12+
1113
func TestDurationToRenewal(t *testing.T) {
1214
tests := []struct {
1315
name string
@@ -236,7 +238,7 @@ func TestDurationToRenewal(t *testing.T) {
236238
}
237239

238240
duration := manager.durationToRenewal(tt.token)
239-
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond),
241+
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), testDurationDelta,
240242
"%s: expected %v, got %v", tt.name, tt.expectedDuration, duration)
241243
})
242244
}
@@ -415,7 +417,7 @@ func TestDurationToRenewalMillisecondPrecision(t *testing.T) {
415417
}
416418

417419
duration := manager.durationToRenewal(tt.token)
418-
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), float64(time.Millisecond),
420+
assert.InDelta(t, float64(tt.expectedDuration), float64(duration), testDurationDelta,
419421
"%s: expected %v, got %v", tt.name, tt.expectedDuration, duration)
420422
})
421423
}
@@ -453,8 +455,8 @@ func TestDurationToRenewalConcurrent(t *testing.T) {
453455
if i == 0 {
454456
firstResult = result
455457
} else {
456-
// All results should be within 10ms of each other
457-
assert.InDelta(t, firstResult.Milliseconds(), result.Milliseconds(), 10)
458+
// All results should be within 5ms of each other
459+
assert.InDelta(t, firstResult.Milliseconds(), result.Milliseconds(), 5)
458460
}
459461
}
460462
}

token/token.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,22 @@ import (
1010
var _ auth.Credentials = (*Token)(nil)
1111

1212
// New creates a new token with the specified username, password, raw token, expiration time, received at time, and time to live.
13-
// NOTE: This won't do any validation on the token, expiresOn, receivedAt, or ttl. It will simply create a new token instance.
14-
// The caller is responsible for ensuring the token is valid.
13+
// NOTE: The caller is responsible for ensuring the token is valid.
14+
// If the token is invalid, the behavior is undefined.
15+
// - if expiresOn is zero, New returns nil
16+
// - if receivedAt is zero, it will be set to the current time and TTL will be recalculated
1517
// Expiration time and TTL are used to determine when the token should be refreshed.
1618
// TTL is in milliseconds.
1719
// receivedAt + ttl should be within a millisecond of expiresOn
1820
func New(username, password, rawToken string, expiresOn, receivedAt time.Time, ttl int64) *Token {
21+
if expiresOn.IsZero() {
22+
return nil
23+
}
24+
if receivedAt.IsZero() {
25+
receivedAt = time.Now()
26+
ttl = expiresOn.Sub(receivedAt).Milliseconds()
27+
}
28+
1929
return &Token{
2030
username: username,
2131
password: password,
@@ -28,6 +38,10 @@ func New(username, password, rawToken string, expiresOn, receivedAt time.Time, t
2838

2939
// Token represents parsed authentication token used to access the Redis server.
3040
// It implements the auth.Credentials interface.
41+
//
42+
// WARNING: Use New() to create a new token.
43+
// Creating a token with Token{} is invalid and will undefined behavior in the TokenManager.
44+
// The zero value of Token is not valid.
3145
type Token struct {
3246
// username is the username of the user.
3347
username string
@@ -60,11 +74,6 @@ func (t *Token) RawToken() string {
6074

6175
// ReceivedAt returns the time when the token was received.
6276
func (t *Token) ReceivedAt() time.Time {
63-
if t.receivedAt.IsZero() {
64-
// set it to now, recalculate ttl
65-
t.receivedAt = time.Now()
66-
t.ttl = t.expiresOn.Sub(t.receivedAt).Milliseconds()
67-
}
6877
return t.receivedAt
6978
}
7079

token/token_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,15 @@ func TestCopyToken(t *testing.T) {
9494
assert.NotEqual(t, token.expiresOn, copiedToken.expiresOn)
9595

9696
// copy nil
97-
copiedToken = copyToken(nil)
98-
assert.Nil(t, copiedToken)
97+
nilToken := copyToken(nil)
98+
assert.Nil(t, nilToken)
9999
// copy empty token
100-
copiedToken = copyToken(&Token{})
101-
assert.NotNil(t, copiedToken)
100+
emptyToken := copyToken(&Token{})
101+
assert.Nil(t, emptyToken)
102102
anotherCopy := copiedToken.Copy()
103103
anotherCopy.rawToken = "changed"
104104
assert.NotEqual(t, copiedToken, anotherCopy)
105+
assert.NotEqual(t, copiedToken.rawToken, anotherCopy.rawToken)
105106
}
106107

107108
func TestTokenReceivedAt(t *testing.T) {
@@ -124,7 +125,7 @@ func TestTokenReceivedAt(t *testing.T) {
124125
// Check if the copied token is a new instance
125126
assert.NotNil(t, tcopiedToken)
126127

127-
emptyRecievedAt := &Token{}
128+
emptyRecievedAt := New("username", "password", "rawToken", time.Now(), time.Time{}, time.Hour.Milliseconds())
128129
assert.True(t, emptyRecievedAt.ReceivedAt().After(time.Now().Add(-1*time.Hour)))
129130
assert.True(t, emptyRecievedAt.ReceivedAt().Before(time.Now().Add(1*time.Hour)))
130131
}

0 commit comments

Comments
 (0)