@@ -598,6 +598,171 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
598
598
assert .NotNil (t , token1 )
599
599
})
600
600
601
+ t .Run ("GetToken with cached token" , func (t * testing.T ) {
602
+ t .Parallel ()
603
+ idp := & mockIdentityProvider {}
604
+ mParser := & mockIdentityProviderResponseParser {}
605
+ tokenManager , err := NewTokenManager (idp ,
606
+ TokenManagerOptions {
607
+ IdentityProviderResponseParser : mParser ,
608
+ },
609
+ )
610
+ assert .NoError (t , err )
611
+ assert .NotNil (t , tokenManager )
612
+ tm , ok := tokenManager .(* entraidTokenManager )
613
+ assert .True (t , ok )
614
+
615
+ // First setup the manager with a token
616
+ rawResponse := & authResult {
617
+ ResultType : shared .ResponseTypeRawToken ,
618
+ RawTokenVal : "test" ,
619
+ }
620
+
621
+ idp .On ("RequestToken" , mock .Anything ).Return (rawResponse , nil )
622
+ mParser .On ("ParseResponse" , rawResponse ).Return (testTokenValid , nil )
623
+
624
+ // Get the token once to cache it
625
+ token1 , err := tokenManager .GetToken (false )
626
+ assert .NoError (t , err )
627
+ assert .NotNil (t , token1 )
628
+
629
+ // Change the mock to return a different token to verify caching
630
+ differentToken := token .New (
631
+ "different" ,
632
+ "different" ,
633
+ "different" ,
634
+ time .Now ().Add (time .Hour ),
635
+ time .Now (),
636
+ time .Hour .Milliseconds (),
637
+ )
638
+ mParser = & mockIdentityProviderResponseParser {}
639
+ mParser .On ("ParseResponse" , rawResponse ).Return (differentToken , nil )
640
+ tm .identityProviderResponseParser = mParser
641
+
642
+ // Get the token again, should return the cached token
643
+ token2 , err := tokenManager .GetToken (false )
644
+ assert .NoError (t , err )
645
+ assert .NotNil (t , token2 )
646
+ assert .Equal (t , token1 , token2 )
647
+
648
+ // Verify that RequestToken was not called again
649
+ idp .AssertNumberOfCalls (t , "RequestToken" , 1 )
650
+ })
651
+
652
+ t .Run ("GetToken with force refresh" , func (t * testing.T ) {
653
+ t .Parallel ()
654
+ idp := & mockIdentityProvider {}
655
+ mParser := & mockIdentityProviderResponseParser {}
656
+ tokenManager , err := NewTokenManager (idp ,
657
+ TokenManagerOptions {
658
+ IdentityProviderResponseParser : mParser ,
659
+ },
660
+ )
661
+ assert .NoError (t , err )
662
+ assert .NotNil (t , tokenManager )
663
+ tm , ok := tokenManager .(* entraidTokenManager )
664
+ assert .True (t , ok )
665
+
666
+ // First setup the manager with a token
667
+ rawResponse := & authResult {
668
+ ResultType : shared .ResponseTypeRawToken ,
669
+ RawTokenVal : "test" ,
670
+ }
671
+
672
+ idp .On ("RequestToken" , mock .Anything ).Return (rawResponse , nil )
673
+ mParser .On ("ParseResponse" , rawResponse ).Return (testTokenValid , nil )
674
+
675
+ // Get the token once to cache it
676
+ token1 , err := tokenManager .GetToken (false )
677
+ assert .NoError (t , err )
678
+ assert .NotNil (t , token1 )
679
+
680
+ // Change the mock to return a different token
681
+ differentToken := token .New (
682
+ "different" ,
683
+ "different" ,
684
+ "different" ,
685
+ time .Now ().Add (time .Hour ),
686
+ time .Now (),
687
+ time .Hour .Milliseconds (),
688
+ )
689
+ mParser = & mockIdentityProviderResponseParser {}
690
+ mParser .On ("ParseResponse" , rawResponse ).Return (differentToken , nil )
691
+ tm .identityProviderResponseParser = mParser
692
+
693
+ // Get the token with force refresh, should get the new token
694
+ token2 , err := tokenManager .GetToken (true )
695
+ assert .NoError (t , err )
696
+ assert .NotNil (t , token2 )
697
+ assert .Equal (t , differentToken , token2 )
698
+
699
+ // Verify that RequestToken was called again
700
+ idp .AssertNumberOfCalls (t , "RequestToken" , 2 )
701
+ })
702
+
703
+ t .Run ("GetToken with valid cached token and positive duration" , func (t * testing.T ) {
704
+ t .Parallel ()
705
+ idp := & mockIdentityProvider {}
706
+ mParser := & mockIdentityProviderResponseParser {}
707
+ tokenManager , err := NewTokenManager (idp ,
708
+ TokenManagerOptions {
709
+ IdentityProviderResponseParser : mParser ,
710
+ ExpirationRefreshRatio : 0.75 ,
711
+ LowerRefreshBound : time .Hour ,
712
+ },
713
+ )
714
+ assert .NoError (t , err )
715
+ assert .NotNil (t , tokenManager )
716
+ tm , ok := tokenManager .(* entraidTokenManager )
717
+ assert .True (t , ok )
718
+
719
+ // Create a token that will have a positive duration
720
+ validToken := token .New (
721
+ "username" ,
722
+ "password" ,
723
+ "rawToken" ,
724
+ time .Now ().Add (2 * time .Hour ), // Expires in 2 hours
725
+ time .Now (),
726
+ (2 * time .Hour ).Milliseconds (),
727
+ )
728
+
729
+ // First get a token to cache it
730
+ rawResponse := & authResult {
731
+ ResultType : shared .ResponseTypeRawToken ,
732
+ RawTokenVal : "test" ,
733
+ }
734
+
735
+ idp .On ("RequestToken" , mock .Anything ).Return (rawResponse , nil )
736
+ mParser .On ("ParseResponse" , rawResponse ).Return (validToken , nil )
737
+
738
+ // Get the token once to cache it
739
+ token1 , err := tokenManager .GetToken (false )
740
+ assert .NoError (t , err )
741
+ assert .NotNil (t , token1 )
742
+
743
+ // Change the mock to return a different token
744
+ differentToken := token .New (
745
+ "different" ,
746
+ "different" ,
747
+ "different" ,
748
+ time .Now ().Add (time .Hour ),
749
+ time .Now (),
750
+ time .Hour .Milliseconds (),
751
+ )
752
+ mParser = & mockIdentityProviderResponseParser {}
753
+ mParser .On ("ParseResponse" , rawResponse ).Return (differentToken , nil )
754
+ tm .identityProviderResponseParser = mParser
755
+
756
+ // Get the token again without force refresh
757
+ token2 , err := tokenManager .GetToken (false )
758
+ assert .NoError (t , err )
759
+ assert .NotNil (t , token2 )
760
+ assert .Equal (t , token1 , token2 ) // Should return the cached token
761
+
762
+ // Verify that RequestToken was not called again
763
+ idp .AssertNumberOfCalls (t , "RequestToken" , 1 )
764
+ })
765
+
601
766
t .Run ("GetToken with parse error" , func (t * testing.T ) {
602
767
t .Parallel ()
603
768
idp := & mockIdentityProvider {}
@@ -718,6 +883,63 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
718
883
assert .Error (t , err )
719
884
assert .Nil (t , token1 )
720
885
})
886
+
887
+ t .Run ("GetToken with token set between checks" , func (t * testing.T ) {
888
+ idp := & mockIdentityProvider {}
889
+ mParser := & mockIdentityProviderResponseParser {}
890
+ tokenManager , err := NewTokenManager (idp ,
891
+ TokenManagerOptions {
892
+ IdentityProviderResponseParser : mParser ,
893
+ ExpirationRefreshRatio : 0.5 ,
894
+ LowerRefreshBound : time .Minute ,
895
+ },
896
+ )
897
+ assert .NoError (t , err )
898
+ assert .NotNil (t , tokenManager )
899
+ tm , ok := tokenManager .(* entraidTokenManager )
900
+ assert .True (t , ok )
901
+
902
+ validToken := token .New (
903
+ "username" ,
904
+ "password" ,
905
+ "rawToken" ,
906
+ time .Now ().Add (1 * time .Hour ),
907
+ time .Now (),
908
+ (1 * time .Hour ).Milliseconds (),
909
+ )
910
+
911
+ // Step 1: Acquire the read lock
912
+ tm .tokenRWLock .RLock ()
913
+
914
+ // Step 2: Start GetToken in a goroutine (it will block on upgrading to write lock)
915
+ var token2 * token.Token
916
+ var err2 error
917
+ getTokenStarted := make (chan struct {})
918
+ getTokenDone := make (chan struct {})
919
+ go func () {
920
+ close (getTokenStarted )
921
+ token2 , err2 = tokenManager .GetToken (false )
922
+ close (getTokenDone )
923
+ }()
924
+
925
+ // Step 3: Wait for GetToken to start and block on write lock
926
+ <- getTokenStarted
927
+ // Give the goroutine a moment to reach the write lock
928
+ time .Sleep (1 * time .Millisecond )
929
+ // Step 4: Set the token
930
+ tm .token = validToken
931
+ // Step 5: Release the read lock so GetToken can proceed
932
+ tm .tokenRWLock .RUnlock ()
933
+
934
+ // Step 6: Wait for GetToken to finish
935
+ <- getTokenDone
936
+
937
+ // Step 7: Assert the result
938
+ assert .NoError (t , err2 )
939
+ assert .NotNil (t , token2 )
940
+ assert .Equal (t , validToken , token2 )
941
+ idp .AssertNotCalled (t , "RequestToken" )
942
+ })
721
943
}
722
944
723
945
func TestEntraidTokenManager_durationToRenewal (t * testing.T ) {
0 commit comments