Skip to content

Commit ee60a45

Browse files
committed
- support full customization of different MSAL application types and advanced configurations with EntraIDTokenAuthConfigBuilder
- add more unit tests
1 parent 0b1095f commit ee60a45

File tree

4 files changed

+230
-22
lines changed

4 files changed

+230
-22
lines changed

entraid/src/main/java/redis/clients/authentication/entraid/EntraIDIdentityProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ public EntraIDIdentityProvider(ManagedIdentityInfo info, Set<String> scopes) {
4848
resultSupplier = () -> supplierForManagedIdentityApp(app, params);
4949
}
5050

51+
public EntraIDIdentityProvider(Supplier<IAuthenticationResult> customEntraIdAppSupplier) {
52+
this.resultSupplier = customEntraIdAppSupplier;
53+
}
54+
5155
private IClientCredential getClientCredential(ServicePrincipalInfo servicePrincipalInfo) {
5256
switch (servicePrincipalInfo.getAccessWith()) {
5357
case WithSecret:
Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
11
package redis.clients.authentication.entraid;
22

33
import java.util.Set;
4+
import java.util.function.Supplier;
5+
6+
import com.microsoft.aad.msal4j.IAuthenticationResult;
47

58
import redis.clients.authentication.core.IdentityProvider;
69
import redis.clients.authentication.core.IdentityProviderConfig;
710

811
public final class EntraIDIdentityProviderConfig implements IdentityProviderConfig, AutoCloseable {
912

10-
private ServicePrincipalInfo servicePrincipalInfo;
11-
private Set<String> scopes;
12-
private ManagedIdentityInfo managedIdentityInfo;
13+
private Supplier<IdentityProvider> providerSupplier;
14+
15+
public EntraIDIdentityProviderConfig(ServicePrincipalInfo info, Set<String> scopes) {
16+
providerSupplier = () -> new EntraIDIdentityProvider(info, scopes);
17+
}
18+
19+
public EntraIDIdentityProviderConfig(ManagedIdentityInfo info, Set<String> scopes) {
20+
providerSupplier = () -> new EntraIDIdentityProvider(info, scopes);
21+
}
1322

14-
public EntraIDIdentityProviderConfig(ServicePrincipalInfo servicePrincipalInfo,
15-
ManagedIdentityInfo info, Set<String> scopes) {
16-
this.servicePrincipalInfo = servicePrincipalInfo;
17-
this.scopes = scopes;
18-
this.managedIdentityInfo = info;
23+
public EntraIDIdentityProviderConfig(
24+
Supplier<IAuthenticationResult> customEntraIdAppSupplier) {
25+
providerSupplier = () -> new EntraIDIdentityProvider(customEntraIdAppSupplier);
1926
}
2027

2128
@Override
2229
public IdentityProvider getProvider() {
23-
IdentityProvider identityProvider = null;
24-
if (managedIdentityInfo != null) {
25-
identityProvider = new EntraIDIdentityProvider(managedIdentityInfo, scopes);
26-
} else {
27-
identityProvider = new EntraIDIdentityProvider(servicePrincipalInfo, scopes);
28-
}
30+
IdentityProvider identityProvider = providerSupplier.get();
2931
clear();
3032
return identityProvider;
3133
}
@@ -36,8 +38,6 @@ public void close() throws Exception {
3638
}
3739

3840
private void clear() {
39-
servicePrincipalInfo = null;
40-
managedIdentityInfo = null;
41-
scopes = null;
41+
providerSupplier = null;
4242
}
4343
}

entraid/src/main/java/redis/clients/authentication/entraid/EntraIDTokenAuthConfigBuilder.java

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import java.security.PrivateKey;
44
import java.security.cert.X509Certificate;
55
import java.util.Set;
6+
import java.util.function.Supplier;
7+
8+
import com.microsoft.aad.msal4j.IAuthenticationResult;
69

710
import redis.clients.authentication.core.TokenAuthConfig;
811
import redis.clients.authentication.entraid.ManagedIdentityInfo.UserManagedIdentityType;
@@ -24,6 +27,7 @@ public class EntraIDTokenAuthConfigBuilder extends TokenAuthConfig.Builder
2427
private Set<String> scopes;
2528
private ServicePrincipalAccess accessWith;
2629
private ManagedIdentityInfo mii;
30+
Supplier<IAuthenticationResult> customEntraIdAuthenticationSupplier;
2731

2832
public EntraIDTokenAuthConfigBuilder() {
2933
this.expirationRefreshRatio(DEFAULT_EXPIRATION_REFRESH_RATIO)
@@ -66,6 +70,12 @@ public EntraIDTokenAuthConfigBuilder userAssignedManagedIdentity(
6670
return this;
6771
}
6872

73+
public EntraIDTokenAuthConfigBuilder customEntraIdAuthenticationSupplier(
74+
Supplier<IAuthenticationResult> customEntraIdAuthenticationSupplier) {
75+
76+
return this;
77+
}
78+
6979
public EntraIDTokenAuthConfigBuilder scopes(Set<String> scopes) {
7080
this.scopes = scopes;
7181
return this;
@@ -85,11 +95,22 @@ public TokenAuthConfig build() {
8595
}
8696
if (spi != null && mii != null) {
8797
throw new RedisEntraIDException(
88-
"Cannot have both ServicePrincipal and ManagedIdentity");
98+
"Cannot have both ServicePrincipal and ManagedIdentity!");
99+
}
100+
if (this.customEntraIdAuthenticationSupplier != null && (spi != null || mii != null)) {
101+
throw new RedisEntraIDException(
102+
"Cannot have both customEntraIdAuthenticationSupplier and ServicePrincipal/ManagedIdentity!");
103+
}
104+
if (spi != null) {
105+
super.identityProviderConfig(new EntraIDIdentityProviderConfig(spi, scopes));
106+
}
107+
if (mii != null) {
108+
super.identityProviderConfig(new EntraIDIdentityProviderConfig(mii, scopes));
109+
}
110+
if (customEntraIdAuthenticationSupplier != null) {
111+
super.identityProviderConfig(
112+
new EntraIDIdentityProviderConfig(customEntraIdAuthenticationSupplier));
89113
}
90-
EntraIDIdentityProviderConfig idProviderConfig = new EntraIDIdentityProviderConfig(spi, mii,
91-
scopes);
92-
super.identityProviderConfig(idProviderConfig);
93114
return super.build();
94115
}
95116

@@ -101,6 +122,7 @@ public void close() throws Exception {
101122
cert = null;
102123
authority = null;
103124
scopes = null;
125+
customEntraIdAuthenticationSupplier = null;
104126
}
105127

106128
public static EntraIDTokenAuthConfigBuilder builder() {

entraid/src/test/java/redis/clients/authentication/RedisEntraIDUnitTests.java

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
import static org.hamcrest.Matchers.both;
1717
import static org.hamcrest.MatcherAssert.assertThat;
1818

19+
import java.time.Duration;
20+
import java.util.ArrayList;
1921
import java.util.Collections;
2022
import java.util.Date;
23+
import java.util.List;
2124
import java.util.Set;
2225
import java.util.concurrent.CountDownLatch;
2326
import java.util.concurrent.TimeoutException;
@@ -31,6 +34,7 @@
3134

3235
import com.auth0.jwt.JWT;
3336
import com.auth0.jwt.algorithms.Algorithm;
37+
3438
import redis.clients.authentication.core.IdentityProvider;
3539
import redis.clients.authentication.core.IdentityProviderConfig;
3640
import redis.clients.authentication.core.SimpleToken;
@@ -280,6 +284,16 @@ public void onError(Exception e) {
280284
.until(() -> numberOfTokens.get(), is(2));
281285
}
282286

287+
// T.2.2
288+
289+
// Test that the Redis client is not blocked/interrupted during token renewal.
290+
@Test
291+
public void renewalDuringOperationsTest() {
292+
// set the stage with consecutive get/set operations with unique keys which takes at least for 2000 ms with a jedispooled instace,
293+
// configure token manager to renew token every 500ms
294+
// wait till all operations are completed and verify that token was renewed at least 3 times after initial token acquisition
295+
}
296+
283297
// T.2.2
284298
// Ensure the system propagates error during renewal back to the user
285299
@Test
@@ -306,6 +320,7 @@ public void onError(Exception e) {
306320

307321
Awaitility.await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(TWO_SECONDS)
308322
.until(() -> numberOfErrors.get(), is(1));
323+
309324
}
310325

311326
// T.2.3
@@ -330,6 +345,7 @@ public void onTokenRenewed(Token token) {
330345
timeDiff.set((int) (token.getExpiresAt() - lastToken.getExpiresAt()));
331346
}
332347
lastToken = token;
348+
333349
}
334350

335351
@Override
@@ -350,7 +366,89 @@ public void onError(Exception e) {
350366
// T.2.3
351367
// Verify behavior with edge case renewal timing configurations (e.g., very low or high percentages).
352368
@Test
353-
public void edgeCaseRenewalTimingTest() {
369+
public void highPercentage_edgeCaseRenewalTimingTest() {
370+
List<Token> tokens = new ArrayList<Token>();
371+
int validDurationInMs = 1000;
372+
373+
IdentityProvider identityProvider = () -> new SimpleToken(TOKEN_VALUE,
374+
System.currentTimeMillis() + validDurationInMs, System.currentTimeMillis(),
375+
Collections.singletonMap("oid", TOKEN_OID));
376+
377+
TokenManagerConfig tokenManagerConfig = new TokenManagerConfig(0.99F, 0,
378+
TOKEN_REQUEST_EXEC_TIMEOUT,
379+
new TokenManagerConfig.RetryPolicy(RETRY_POLICY_MAX_ATTEMPTS, RETRY_POLICY_DELAY));
380+
381+
TokenManager tokenManager = new TokenManager(identityProvider, tokenManagerConfig);
382+
TokenListener listener = new TokenListener() {
383+
384+
@Override
385+
public void onTokenRenewed(Token token) {
386+
tokens.add(token);
387+
}
388+
389+
@Override
390+
public void onError(Exception e) {
391+
}
392+
};
393+
394+
tokenManager.start(listener, false);
395+
396+
Awaitility.await().pollInterval(Duration.ofMillis(10)).atMost(Durations.TWO_SECONDS)
397+
.until(() -> tokens.size(), is(2));
398+
399+
Token initialToken = tokens.get(0);
400+
Token secondToken = tokens.get(1);
401+
Long renewalWindowStart = initialToken.getReceivedAt()
402+
+ (long) (validDurationInMs * tokenManagerConfig.getExpirationRefreshRatio());
403+
Long renewalWindowEnd = initialToken.getExpiresAt();
404+
assertThat((Long) secondToken.getReceivedAt(),
405+
both(greaterThanOrEqualTo(renewalWindowStart))
406+
.and(lessThanOrEqualTo(renewalWindowEnd)));
407+
408+
}
409+
410+
// T.2.3
411+
// Verify behavior with edge case renewal timing configurations (e.g., very low or high percentages).
412+
@Test
413+
public void lowPercentage_edgeCaseRenewalTimingTest() {
414+
List<Token> tokens = new ArrayList<Token>();
415+
int validDurationInMs = 1000;
416+
417+
IdentityProvider identityProvider = () -> new SimpleToken(TOKEN_VALUE,
418+
System.currentTimeMillis() + validDurationInMs, System.currentTimeMillis(),
419+
Collections.singletonMap("oid", TOKEN_OID));
420+
421+
TokenManagerConfig tokenManagerConfig = new TokenManagerConfig(0.01F, 0,
422+
TOKEN_REQUEST_EXEC_TIMEOUT,
423+
new TokenManagerConfig.RetryPolicy(RETRY_POLICY_MAX_ATTEMPTS, RETRY_POLICY_DELAY));
424+
425+
TokenManager tokenManager = new TokenManager(identityProvider, tokenManagerConfig);
426+
TokenListener listener = new TokenListener() {
427+
428+
@Override
429+
public void onTokenRenewed(Token token) {
430+
tokens.add(token);
431+
}
432+
433+
@Override
434+
public void onError(Exception e) {
435+
}
436+
};
437+
438+
tokenManager.start(listener, false);
439+
440+
Awaitility.await().pollInterval(ONE_MILLISECOND).atMost(Durations.TWO_SECONDS)
441+
.until(() -> tokens.size(), is(2));
442+
443+
Token initialToken = tokens.get(0);
444+
Token secondToken = tokens.get(1);
445+
Long renewalWindowStart = initialToken.getReceivedAt()
446+
+ (long) (validDurationInMs * tokenManagerConfig.getExpirationRefreshRatio());
447+
Long renewalWindowEnd = initialToken.getExpiresAt();
448+
assertThat((Long) secondToken.getReceivedAt(),
449+
both(greaterThanOrEqualTo(renewalWindowStart))
450+
.and(lessThanOrEqualTo(renewalWindowEnd)));
451+
354452
}
355453

356454
// T.2.4
@@ -363,6 +461,7 @@ public void expiredTokenCheckTest() {
363461

364462
token = JWT.create().withExpiresAt(new Date(System.currentTimeMillis() + 1000))
365463
.withClaim("oid", "user1").sign(Algorithm.none());
464+
366465
assertFalse(new JWToken(token).isExpired());
367466
}
368467

@@ -382,6 +481,13 @@ public void tokenParserTest() {
382481
lessThanOrEqualTo((Long) 10L));
383482
}
384483

484+
// T.2.5
485+
// Ensure that token objects are immutable and cannot be modified after creation.
486+
@Test
487+
public void tokenImmutabilityTest() {
488+
// ???
489+
}
490+
385491
// T.3.1
386492
// Verify that the most recent valid token is correctly cached and that the cache is initially empty
387493
@Test
@@ -399,6 +505,82 @@ public void tokenCachingTest() {
399505
assertNotNull(tokenManager.getCurrentToken());
400506
}
401507

508+
// T.3.1
509+
// Ensure the token cache is updated when a new token is acquired or renewed.
510+
@Test
511+
public void cacheUpdateOnRenewalTest() {
512+
513+
AtomicInteger numberOfTokens = new AtomicInteger(0);
514+
IdentityProvider identityProvider = () -> {
515+
return new SimpleToken("" + numberOfTokens.incrementAndGet(),
516+
System.currentTimeMillis() + 500, System.currentTimeMillis(),
517+
Collections.singletonMap("oid", "user1"));
518+
};
519+
TokenManager tokenManager = new TokenManager(identityProvider, tokenManagerConfig);
520+
assertNull(tokenManager.getCurrentToken());
521+
tokenManager.start(mock(TokenListener.class), true);
522+
assertNotNull(tokenManager.getCurrentToken());
523+
assertEquals("1", tokenManager.getCurrentToken().getValue());
524+
Awaitility.await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(TWO_SECONDS)
525+
.until(() -> tokenManager.getCurrentToken().getValue(), is("2"));
526+
527+
}
528+
529+
// T.3.2
530+
// Verify that all existing connections can be re-authenticated when a new token is received.
531+
@Test
532+
public void allConnectionsReauthTest() {
533+
534+
}
535+
536+
// T.3.2
537+
// Test system behavior when some connections fail to re-authenticate during bulk authentication. e.g when a network partition occurs for 1 or more of them
538+
@Test
539+
public void partialReauthFailureTest() {
540+
541+
}
542+
543+
// T.3.3
544+
// Test authentication of a single connection using the current valid token.
545+
@Test
546+
public void singleConnectionAuthTest() {
547+
548+
}
549+
550+
// T.3.3
551+
// Verify behavior when attempting to authenticate a single connection with an expired token.
552+
@Test
553+
public void connectionAuthWithExpiredTokenTest() {
554+
555+
}
556+
557+
// T.3.4
558+
// Verify handling of reconnection and re-authentication after a network partition. (use cached token)
559+
@Test
560+
public void networkPartitionEvictionTest() {
561+
562+
}
563+
564+
// T.4.1
565+
// Verify that token renewal timing can be configured correctly.
566+
@Test
567+
public void renewalTimingConfigTest() {
568+
569+
}
570+
571+
// T.4.2
572+
// Verify that Azure AD-specific parameters can be configured correctly.
573+
@Test
574+
public void azureADConfigTest() {
575+
576+
}
577+
578+
// T.4.2
579+
// Test configuration of custom identity provider parameters.
580+
@Test
581+
public void customProviderConfigTest() {
582+
}
583+
402584
private void delay(long durationInMs) {
403585
try {
404586
Thread.sleep(durationInMs);

0 commit comments

Comments
 (0)