|
7 | 7 | "os"
|
8 | 8 | "reflect"
|
9 | 9 | "runtime"
|
| 10 | + "strings" |
10 | 11 | "sync"
|
11 | 12 | "sync/atomic"
|
12 | 13 | "testing"
|
@@ -1373,3 +1374,303 @@ func BenchmarkTokenManager_durationToRenewal(b *testing.B) {
|
1373 | 1374 | tm.durationToRenewal()
|
1374 | 1375 | }
|
1375 | 1376 | }
|
| 1377 | + |
| 1378 | +// TestConcurrentTokenManagerOperations tests concurrent operations on the TokenManager |
| 1379 | +// to verify there are no deadlocks or race conditions in the implementation. |
| 1380 | +func TestConcurrentTokenManagerOperations(t *testing.T) { |
| 1381 | + t.Parallel() |
| 1382 | + |
| 1383 | + // Create a mock identity provider that returns predictable tokens |
| 1384 | + mockIdp := &concurrentMockIdentityProvider{ |
| 1385 | + tokenCounter: 0, |
| 1386 | + } |
| 1387 | + |
| 1388 | + // Create token manager with the mock provider |
| 1389 | + options := TokenManagerOptions{ |
| 1390 | + ExpirationRefreshRatio: 0.7, |
| 1391 | + LowerRefreshBoundMs: 100, |
| 1392 | + } |
| 1393 | + tm, err := NewTokenManager(mockIdp, options) |
| 1394 | + assert.NoError(t, err) |
| 1395 | + assert.NotNil(t, tm) |
| 1396 | + |
| 1397 | + // Number of concurrent operations to perform |
| 1398 | + const numConcurrentOps = 50 |
| 1399 | + const numGoroutines = 1000 |
| 1400 | + |
| 1401 | + // Channels to track received tokens and errors |
| 1402 | + tokenCh := make(chan *token.Token, numConcurrentOps*numGoroutines) |
| 1403 | + errorCh := make(chan error, numConcurrentOps*numGoroutines) |
| 1404 | + |
| 1405 | + // Channel to signal completion of all operations |
| 1406 | + doneCh := make(chan struct{}) |
| 1407 | + |
| 1408 | + // Track closers for cleanup |
| 1409 | + var closers sync.Map |
| 1410 | + |
| 1411 | + // Start multiple goroutines that will concurrently interact with the token manager |
| 1412 | + var wg sync.WaitGroup |
| 1413 | + wg.Add(numGoroutines) |
| 1414 | + |
| 1415 | + for i := 0; i < numGoroutines; i++ { |
| 1416 | + go func(routineID int) { |
| 1417 | + defer wg.Done() |
| 1418 | + |
| 1419 | + for j := 0; j < numConcurrentOps; j++ { |
| 1420 | + // Create a listener for this operation |
| 1421 | + listener := &concurrentTestTokenListener{ |
| 1422 | + onNextFunc: func(t *token.Token) { |
| 1423 | + select { |
| 1424 | + case tokenCh <- t: |
| 1425 | + default: |
| 1426 | + // Channel full, ignore |
| 1427 | + } |
| 1428 | + }, |
| 1429 | + onErrorFunc: func(err error) { |
| 1430 | + select { |
| 1431 | + case errorCh <- err: |
| 1432 | + default: |
| 1433 | + // Channel full, ignore |
| 1434 | + } |
| 1435 | + }, |
| 1436 | + } |
| 1437 | + |
| 1438 | + // Choose operation based on a pattern |
| 1439 | + // Using modulo for a deterministic pattern that exercises all operations |
| 1440 | + opType := j % 3 |
| 1441 | + |
| 1442 | + // t.Logf("Goroutine %d, Operation %d: Performing operation type %d", routineID, j, opType) |
| 1443 | + |
| 1444 | + switch opType { |
| 1445 | + case 0: |
| 1446 | + // Start the token manager with a new listener |
| 1447 | + // t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j) |
| 1448 | + closeFunc, err := tm.Start(listener) |
| 1449 | + |
| 1450 | + if err != nil { |
| 1451 | + if err != ErrTokenManagerAlreadyStarted { |
| 1452 | + // t.Logf("Goroutine %d, Operation %d: Start failed with error: %v", routineID, j, err) |
| 1453 | + select { |
| 1454 | + case errorCh <- fmt.Errorf("failed to start token manager: %w", err): |
| 1455 | + default: |
| 1456 | + t.Fatalf("Goroutine %d, Operation %d: Failed to start token manager: %v", routineID, j, err) |
| 1457 | + } |
| 1458 | + } |
| 1459 | + continue |
| 1460 | + } |
| 1461 | + |
| 1462 | + // t.Logf("Goroutine %d, Operation %d: Successfully started token manager", routineID, j) |
| 1463 | + // Store the closer for later cleanup |
| 1464 | + closerKey := fmt.Sprintf("closer-%d-%d", routineID, j) |
| 1465 | + closers.Store(closerKey, closeFunc) |
| 1466 | + |
| 1467 | + // Simulate some work |
| 1468 | + time.Sleep(time.Duration(500-rand.Intn(400)) * time.Millisecond) |
| 1469 | + |
| 1470 | + case 1: |
| 1471 | + // Get current token |
| 1472 | + //t.Logf("Goroutine %d, Operation %d: Getting token", routineID, j) |
| 1473 | + token, err := tm.GetToken(false) |
| 1474 | + if err != nil { |
| 1475 | + //t.Logf("Goroutine %d, Operation %d: GetToken failed with error: %v", routineID, j, err) |
| 1476 | + select { |
| 1477 | + case errorCh <- fmt.Errorf("failed to get token: %w", err): |
| 1478 | + default: |
| 1479 | + t.Fatalf("Goroutine %d, Operation %d: Failed to get token: %v", routineID, j, err) |
| 1480 | + } |
| 1481 | + } else if token != nil { |
| 1482 | + //t.Logf("Goroutine %d, Operation %d: Successfully got token, expires: %v", routineID, j, token.ExpirationOn()) |
| 1483 | + select { |
| 1484 | + case tokenCh <- token: |
| 1485 | + default: |
| 1486 | + // Channel full, ignore |
| 1487 | + } |
| 1488 | + } |
| 1489 | + |
| 1490 | + case 2: |
| 1491 | + // Close a previously created token manager listener |
| 1492 | + // This simulates multiple subscriptions being created and destroyed |
| 1493 | + //t.Logf("Goroutine %d, Operation %d: Attempting to close a token manager", routineID, j) |
| 1494 | + closedAny := false |
| 1495 | + |
| 1496 | + closers.Range(func(key, value interface{}) bool { |
| 1497 | + if j%10 > 7 { // Only close some of the time based on a pattern |
| 1498 | + closedAny = true |
| 1499 | + //t.Logf("Goroutine %d, Operation %d: Closing token manager with key %v", routineID, j, key) |
| 1500 | + |
| 1501 | + closeFunc := value.(CloseFunc) |
| 1502 | + if err := closeFunc(); err != nil { |
| 1503 | + if err != ErrTokenManagerAlreadyClosed { |
| 1504 | + // t.Logf("Goroutine %d, Operation %d: Close failed with error: %v", routineID, j, err) |
| 1505 | + select { |
| 1506 | + case errorCh <- fmt.Errorf("failed to close token manager: %w", err): |
| 1507 | + default: |
| 1508 | + t.Fatalf("Goroutine %d, Operation %d: Failed to close token manager: %v", routineID, j, err) |
| 1509 | + } |
| 1510 | + } else { |
| 1511 | + //t.Logf("Goroutine %d, Operation %d: TokenManager was already closed", routineID, j) |
| 1512 | + } |
| 1513 | + } else { |
| 1514 | + // t.Logf("Goroutine %d, Operation %d: Successfully closed token manager", routineID, j) |
| 1515 | + } |
| 1516 | + |
| 1517 | + closers.Delete(key) |
| 1518 | + return false // stop after finding one to close |
| 1519 | + } |
| 1520 | + return true |
| 1521 | + }) |
| 1522 | + |
| 1523 | + if !closedAny { |
| 1524 | + //t.Logf("Goroutine %d, Operation %d: No token manager to close or condition not met", routineID, j) |
| 1525 | + } |
| 1526 | + } |
| 1527 | + } |
| 1528 | + }(i) |
| 1529 | + } |
| 1530 | + |
| 1531 | + // Wait for all operations to complete or timeout |
| 1532 | + go func() { |
| 1533 | + wg.Wait() |
| 1534 | + close(doneCh) |
| 1535 | + }() |
| 1536 | + |
| 1537 | + // Use a timeout to detect deadlocks |
| 1538 | + select { |
| 1539 | + case <-doneCh: |
| 1540 | + // All operations completed successfully |
| 1541 | + t.Log("All concurrent operations completed successfully") |
| 1542 | + case <-time.After(30 * time.Second): |
| 1543 | + t.Fatal("test timed out, possible deadlock detected") |
| 1544 | + } |
| 1545 | + |
| 1546 | + // Count operations by type |
| 1547 | + var startCount, getTokenCount, closeCount int32 |
| 1548 | + |
| 1549 | + // Collect all ops from goroutines |
| 1550 | + for i := 0; i < numGoroutines; i++ { |
| 1551 | + for j := 0; j < numConcurrentOps; j++ { |
| 1552 | + opType := j % 3 |
| 1553 | + switch opType { |
| 1554 | + case 0: |
| 1555 | + atomic.AddInt32(&startCount, 1) |
| 1556 | + case 1: |
| 1557 | + atomic.AddInt32(&getTokenCount, 1) |
| 1558 | + case 2: |
| 1559 | + atomic.AddInt32(&closeCount, 1) |
| 1560 | + } |
| 1561 | + } |
| 1562 | + } |
| 1563 | + |
| 1564 | + // Clean up any remaining closers |
| 1565 | + closers.Range(func(key, value interface{}) bool { |
| 1566 | + closeFunc := value.(CloseFunc) |
| 1567 | + _ = closeFunc() // Ignore errors during cleanup |
| 1568 | + return true |
| 1569 | + }) |
| 1570 | + |
| 1571 | + // Close channels to avoid goroutine leaks |
| 1572 | + close(tokenCh) |
| 1573 | + close(errorCh) |
| 1574 | + |
| 1575 | + // Count tokens and check their validity |
| 1576 | + var tokens []*token.Token |
| 1577 | + for t := range tokenCh { |
| 1578 | + tokens = append(tokens, t) |
| 1579 | + } |
| 1580 | + |
| 1581 | + // Collect and categorize errors |
| 1582 | + var startErrors, getTokenErrors, closeErrors, otherErrors []error |
| 1583 | + for err := range errorCh { |
| 1584 | + errStr := err.Error() |
| 1585 | + if strings.Contains(errStr, "failed to start token manager") { |
| 1586 | + startErrors = append(startErrors, err) |
| 1587 | + } else if strings.Contains(errStr, "failed to get token") { |
| 1588 | + getTokenErrors = append(getTokenErrors, err) |
| 1589 | + } else if strings.Contains(errStr, "failed to close token manager") { |
| 1590 | + closeErrors = append(closeErrors, err) |
| 1591 | + } else { |
| 1592 | + otherErrors = append(otherErrors, err) |
| 1593 | + t.Fatalf("Unexpected error during concurrent operations: %v", err) |
| 1594 | + } |
| 1595 | + } |
| 1596 | + |
| 1597 | + totalOps := startCount + getTokenCount + closeCount |
| 1598 | + expectedOps := int32(numGoroutines * numConcurrentOps) |
| 1599 | + |
| 1600 | + // Report operation counts |
| 1601 | + t.Logf("Concurrent test summary:") |
| 1602 | + t.Logf("- Total operations executed: %d (expected: %d)", totalOps, expectedOps) |
| 1603 | + t.Logf("- Start operations: %d (with %d errors)", startCount, len(startErrors)) |
| 1604 | + t.Logf("- GetToken operations: %d (with %d errors, %d successful)", |
| 1605 | + getTokenCount, len(getTokenErrors), len(tokens)) |
| 1606 | + t.Logf("- Close operations: %d (with %d errors)", closeCount, len(closeErrors)) |
| 1607 | + |
| 1608 | + // Some errors are expected due to concurrent operations |
| 1609 | + // but we should have received tokens successfully |
| 1610 | + assert.Equal(t, expectedOps, totalOps, "All operations should be accounted for") |
| 1611 | + assert.True(t, len(tokens) > 0, "Should have received tokens") |
| 1612 | + |
| 1613 | + // Verify the token manager still works after all the concurrent operations |
| 1614 | + finalListener := &concurrentTestTokenListener{ |
| 1615 | + onNextFunc: func(t *token.Token) { |
| 1616 | + // Just verify we get a token - don't use assert within this callback |
| 1617 | + if t == nil { |
| 1618 | + panic("Final token should not be nil") |
| 1619 | + } |
| 1620 | + }, |
| 1621 | + onErrorFunc: func(err error) { |
| 1622 | + t.Errorf("Unexpected error in final listener: %v", err) |
| 1623 | + }, |
| 1624 | + } |
| 1625 | + |
| 1626 | + closeFunc, err := tm.Start(finalListener) |
| 1627 | + if err != nil && err != ErrTokenManagerAlreadyStarted { |
| 1628 | + t.Fatalf("Failed to start token manager after concurrent operations: %v", err) |
| 1629 | + } |
| 1630 | + if closeFunc != nil { |
| 1631 | + defer closeFunc() |
| 1632 | + } |
| 1633 | + |
| 1634 | + // Get token one more time to verify everything still works |
| 1635 | + finalToken, err := tm.GetToken(true) |
| 1636 | + assert.NoError(t, err, "Should be able to get token after concurrent operations") |
| 1637 | + assert.NotNil(t, finalToken, "Final token should not be nil") |
| 1638 | +} |
| 1639 | + |
| 1640 | +// concurrentTestTokenListener is a test implementation of TokenListener for concurrent tests |
| 1641 | +type concurrentTestTokenListener struct { |
| 1642 | + onNextFunc func(*token.Token) |
| 1643 | + onErrorFunc func(error) |
| 1644 | +} |
| 1645 | + |
| 1646 | +func (l *concurrentTestTokenListener) OnTokenNext(t *token.Token) { |
| 1647 | + if l.onNextFunc != nil { |
| 1648 | + l.onNextFunc(t) |
| 1649 | + } |
| 1650 | +} |
| 1651 | + |
| 1652 | +func (l *concurrentTestTokenListener) OnTokenError(err error) { |
| 1653 | + if l.onErrorFunc != nil { |
| 1654 | + l.onErrorFunc(err) |
| 1655 | + } |
| 1656 | +} |
| 1657 | + |
| 1658 | +// concurrentMockIdentityProvider is a mock implementation of shared.IdentityProvider for concurrent tests |
| 1659 | +type concurrentMockIdentityProvider struct { |
| 1660 | + tokenCounter int |
| 1661 | + mutex sync.Mutex |
| 1662 | +} |
| 1663 | + |
| 1664 | +func (m *concurrentMockIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { |
| 1665 | + m.mutex.Lock() |
| 1666 | + defer m.mutex.Unlock() |
| 1667 | + |
| 1668 | + m.tokenCounter++ |
| 1669 | + |
| 1670 | + // Use the existing test JWT token which is already properly formatted |
| 1671 | + resp, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, testJWTToken) |
| 1672 | + if err != nil { |
| 1673 | + return nil, err |
| 1674 | + } |
| 1675 | + return resp, nil |
| 1676 | +} |
0 commit comments