Skip to content

Commit 87f7d84

Browse files
committed
add deadlock test
Kindly contributed by @bobymicroby
1 parent 217097d commit 87f7d84

File tree

4 files changed

+306
-3
lines changed

4 files changed

+306
-3
lines changed

.golangci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
version: "2"
2+
run:
3+
tests: false
24
linters:
35
disable:
46
- depguard

.testcoverage.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ threshold:
1515

1616
# (optional; default 0)
1717
# Minimum coverage percentage required for each package.
18-
package: 90
18+
package: 85
1919

2020
# (optional; default 0)
2121
# Minimum overall project coverage percentage required.
22-
total: 85
22+
total: 90
2323

2424
# Holds regexp rules which will override thresholds for matched files or packages
2525
# using their paths.

manager/defaults.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions {
7474
// The default token parser is used to parse the raw token and return a Token object.
7575
func defaultIdentityProviderResponseParserOr(idpResponseParser shared.IdentityProviderResponseParser) shared.IdentityProviderResponseParser {
7676
if idpResponseParser == nil {
77-
return &defaultIdentityProviderResponseParser{}
77+
return entraidIdentityProviderResponseParser
7878
}
7979
return idpResponseParser
8080
}

manager/token_manager_test.go

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"os"
88
"reflect"
99
"runtime"
10+
"strings"
1011
"sync"
1112
"sync/atomic"
1213
"testing"
@@ -1373,3 +1374,303 @@ func BenchmarkTokenManager_durationToRenewal(b *testing.B) {
13731374
tm.durationToRenewal()
13741375
}
13751376
}
1377+
1378+
// TestConcurrentTokenManagerOperations tests concurrent operations on the TokenManager
1379+
// to verify there are no deadlocks or race conditions in the implementation.
1380+
func TestConcurrentTokenManagerOperations(t *testing.T) {
1381+
t.Parallel()
1382+
1383+
// Create a mock identity provider that returns predictable tokens
1384+
mockIdp := &concurrentMockIdentityProvider{
1385+
tokenCounter: 0,
1386+
}
1387+
1388+
// Create token manager with the mock provider
1389+
options := TokenManagerOptions{
1390+
ExpirationRefreshRatio: 0.7,
1391+
LowerRefreshBoundMs: 100,
1392+
}
1393+
tm, err := NewTokenManager(mockIdp, options)
1394+
assert.NoError(t, err)
1395+
assert.NotNil(t, tm)
1396+
1397+
// Number of concurrent operations to perform
1398+
const numConcurrentOps = 50
1399+
const numGoroutines = 1000
1400+
1401+
// Channels to track received tokens and errors
1402+
tokenCh := make(chan *token.Token, numConcurrentOps*numGoroutines)
1403+
errorCh := make(chan error, numConcurrentOps*numGoroutines)
1404+
1405+
// Channel to signal completion of all operations
1406+
doneCh := make(chan struct{})
1407+
1408+
// Track closers for cleanup
1409+
var closers sync.Map
1410+
1411+
// Start multiple goroutines that will concurrently interact with the token manager
1412+
var wg sync.WaitGroup
1413+
wg.Add(numGoroutines)
1414+
1415+
for i := 0; i < numGoroutines; i++ {
1416+
go func(routineID int) {
1417+
defer wg.Done()
1418+
1419+
for j := 0; j < numConcurrentOps; j++ {
1420+
// Create a listener for this operation
1421+
listener := &concurrentTestTokenListener{
1422+
onNextFunc: func(t *token.Token) {
1423+
select {
1424+
case tokenCh <- t:
1425+
default:
1426+
// Channel full, ignore
1427+
}
1428+
},
1429+
onErrorFunc: func(err error) {
1430+
select {
1431+
case errorCh <- err:
1432+
default:
1433+
// Channel full, ignore
1434+
}
1435+
},
1436+
}
1437+
1438+
// Choose operation based on a pattern
1439+
// Using modulo for a deterministic pattern that exercises all operations
1440+
opType := j % 3
1441+
1442+
// t.Logf("Goroutine %d, Operation %d: Performing operation type %d", routineID, j, opType)
1443+
1444+
switch opType {
1445+
case 0:
1446+
// Start the token manager with a new listener
1447+
// t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j)
1448+
closeFunc, err := tm.Start(listener)
1449+
1450+
if err != nil {
1451+
if err != ErrTokenManagerAlreadyStarted {
1452+
// t.Logf("Goroutine %d, Operation %d: Start failed with error: %v", routineID, j, err)
1453+
select {
1454+
case errorCh <- fmt.Errorf("failed to start token manager: %w", err):
1455+
default:
1456+
t.Fatalf("Goroutine %d, Operation %d: Failed to start token manager: %v", routineID, j, err)
1457+
}
1458+
}
1459+
continue
1460+
}
1461+
1462+
// t.Logf("Goroutine %d, Operation %d: Successfully started token manager", routineID, j)
1463+
// Store the closer for later cleanup
1464+
closerKey := fmt.Sprintf("closer-%d-%d", routineID, j)
1465+
closers.Store(closerKey, closeFunc)
1466+
1467+
// Simulate some work
1468+
time.Sleep(time.Duration(500-rand.Intn(400)) * time.Millisecond)
1469+
1470+
case 1:
1471+
// Get current token
1472+
//t.Logf("Goroutine %d, Operation %d: Getting token", routineID, j)
1473+
token, err := tm.GetToken(false)
1474+
if err != nil {
1475+
//t.Logf("Goroutine %d, Operation %d: GetToken failed with error: %v", routineID, j, err)
1476+
select {
1477+
case errorCh <- fmt.Errorf("failed to get token: %w", err):
1478+
default:
1479+
t.Fatalf("Goroutine %d, Operation %d: Failed to get token: %v", routineID, j, err)
1480+
}
1481+
} else if token != nil {
1482+
//t.Logf("Goroutine %d, Operation %d: Successfully got token, expires: %v", routineID, j, token.ExpirationOn())
1483+
select {
1484+
case tokenCh <- token:
1485+
default:
1486+
// Channel full, ignore
1487+
}
1488+
}
1489+
1490+
case 2:
1491+
// Close a previously created token manager listener
1492+
// This simulates multiple subscriptions being created and destroyed
1493+
//t.Logf("Goroutine %d, Operation %d: Attempting to close a token manager", routineID, j)
1494+
closedAny := false
1495+
1496+
closers.Range(func(key, value interface{}) bool {
1497+
if j%10 > 7 { // Only close some of the time based on a pattern
1498+
closedAny = true
1499+
//t.Logf("Goroutine %d, Operation %d: Closing token manager with key %v", routineID, j, key)
1500+
1501+
closeFunc := value.(CloseFunc)
1502+
if err := closeFunc(); err != nil {
1503+
if err != ErrTokenManagerAlreadyClosed {
1504+
// t.Logf("Goroutine %d, Operation %d: Close failed with error: %v", routineID, j, err)
1505+
select {
1506+
case errorCh <- fmt.Errorf("failed to close token manager: %w", err):
1507+
default:
1508+
t.Fatalf("Goroutine %d, Operation %d: Failed to close token manager: %v", routineID, j, err)
1509+
}
1510+
} else {
1511+
//t.Logf("Goroutine %d, Operation %d: TokenManager was already closed", routineID, j)
1512+
}
1513+
} else {
1514+
// t.Logf("Goroutine %d, Operation %d: Successfully closed token manager", routineID, j)
1515+
}
1516+
1517+
closers.Delete(key)
1518+
return false // stop after finding one to close
1519+
}
1520+
return true
1521+
})
1522+
1523+
if !closedAny {
1524+
//t.Logf("Goroutine %d, Operation %d: No token manager to close or condition not met", routineID, j)
1525+
}
1526+
}
1527+
}
1528+
}(i)
1529+
}
1530+
1531+
// Wait for all operations to complete or timeout
1532+
go func() {
1533+
wg.Wait()
1534+
close(doneCh)
1535+
}()
1536+
1537+
// Use a timeout to detect deadlocks
1538+
select {
1539+
case <-doneCh:
1540+
// All operations completed successfully
1541+
t.Log("All concurrent operations completed successfully")
1542+
case <-time.After(30 * time.Second):
1543+
t.Fatal("test timed out, possible deadlock detected")
1544+
}
1545+
1546+
// Count operations by type
1547+
var startCount, getTokenCount, closeCount int32
1548+
1549+
// Collect all ops from goroutines
1550+
for i := 0; i < numGoroutines; i++ {
1551+
for j := 0; j < numConcurrentOps; j++ {
1552+
opType := j % 3
1553+
switch opType {
1554+
case 0:
1555+
atomic.AddInt32(&startCount, 1)
1556+
case 1:
1557+
atomic.AddInt32(&getTokenCount, 1)
1558+
case 2:
1559+
atomic.AddInt32(&closeCount, 1)
1560+
}
1561+
}
1562+
}
1563+
1564+
// Clean up any remaining closers
1565+
closers.Range(func(key, value interface{}) bool {
1566+
closeFunc := value.(CloseFunc)
1567+
_ = closeFunc() // Ignore errors during cleanup
1568+
return true
1569+
})
1570+
1571+
// Close channels to avoid goroutine leaks
1572+
close(tokenCh)
1573+
close(errorCh)
1574+
1575+
// Count tokens and check their validity
1576+
var tokens []*token.Token
1577+
for t := range tokenCh {
1578+
tokens = append(tokens, t)
1579+
}
1580+
1581+
// Collect and categorize errors
1582+
var startErrors, getTokenErrors, closeErrors, otherErrors []error
1583+
for err := range errorCh {
1584+
errStr := err.Error()
1585+
if strings.Contains(errStr, "failed to start token manager") {
1586+
startErrors = append(startErrors, err)
1587+
} else if strings.Contains(errStr, "failed to get token") {
1588+
getTokenErrors = append(getTokenErrors, err)
1589+
} else if strings.Contains(errStr, "failed to close token manager") {
1590+
closeErrors = append(closeErrors, err)
1591+
} else {
1592+
otherErrors = append(otherErrors, err)
1593+
t.Fatalf("Unexpected error during concurrent operations: %v", err)
1594+
}
1595+
}
1596+
1597+
totalOps := startCount + getTokenCount + closeCount
1598+
expectedOps := int32(numGoroutines * numConcurrentOps)
1599+
1600+
// Report operation counts
1601+
t.Logf("Concurrent test summary:")
1602+
t.Logf("- Total operations executed: %d (expected: %d)", totalOps, expectedOps)
1603+
t.Logf("- Start operations: %d (with %d errors)", startCount, len(startErrors))
1604+
t.Logf("- GetToken operations: %d (with %d errors, %d successful)",
1605+
getTokenCount, len(getTokenErrors), len(tokens))
1606+
t.Logf("- Close operations: %d (with %d errors)", closeCount, len(closeErrors))
1607+
1608+
// Some errors are expected due to concurrent operations
1609+
// but we should have received tokens successfully
1610+
assert.Equal(t, expectedOps, totalOps, "All operations should be accounted for")
1611+
assert.True(t, len(tokens) > 0, "Should have received tokens")
1612+
1613+
// Verify the token manager still works after all the concurrent operations
1614+
finalListener := &concurrentTestTokenListener{
1615+
onNextFunc: func(t *token.Token) {
1616+
// Just verify we get a token - don't use assert within this callback
1617+
if t == nil {
1618+
panic("Final token should not be nil")
1619+
}
1620+
},
1621+
onErrorFunc: func(err error) {
1622+
t.Errorf("Unexpected error in final listener: %v", err)
1623+
},
1624+
}
1625+
1626+
closeFunc, err := tm.Start(finalListener)
1627+
if err != nil && err != ErrTokenManagerAlreadyStarted {
1628+
t.Fatalf("Failed to start token manager after concurrent operations: %v", err)
1629+
}
1630+
if closeFunc != nil {
1631+
defer closeFunc()
1632+
}
1633+
1634+
// Get token one more time to verify everything still works
1635+
finalToken, err := tm.GetToken(true)
1636+
assert.NoError(t, err, "Should be able to get token after concurrent operations")
1637+
assert.NotNil(t, finalToken, "Final token should not be nil")
1638+
}
1639+
1640+
// concurrentTestTokenListener is a test implementation of TokenListener for concurrent tests
1641+
type concurrentTestTokenListener struct {
1642+
onNextFunc func(*token.Token)
1643+
onErrorFunc func(error)
1644+
}
1645+
1646+
func (l *concurrentTestTokenListener) OnTokenNext(t *token.Token) {
1647+
if l.onNextFunc != nil {
1648+
l.onNextFunc(t)
1649+
}
1650+
}
1651+
1652+
func (l *concurrentTestTokenListener) OnTokenError(err error) {
1653+
if l.onErrorFunc != nil {
1654+
l.onErrorFunc(err)
1655+
}
1656+
}
1657+
1658+
// concurrentMockIdentityProvider is a mock implementation of shared.IdentityProvider for concurrent tests
1659+
type concurrentMockIdentityProvider struct {
1660+
tokenCounter int
1661+
mutex sync.Mutex
1662+
}
1663+
1664+
func (m *concurrentMockIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) {
1665+
m.mutex.Lock()
1666+
defer m.mutex.Unlock()
1667+
1668+
m.tokenCounter++
1669+
1670+
// Use the existing test JWT token which is already properly formatted
1671+
resp, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, testJWTToken)
1672+
if err != nil {
1673+
return nil, err
1674+
}
1675+
return resp, nil
1676+
}

0 commit comments

Comments
 (0)