Skip to content

Add support for DefaultAzureCredentials #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/entraid_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ jobs:
AZURE_CERT: ${{secrets.AZURE_CERT}}
AZURE_PRIVATE_KEY: ${{secrets.AZURE_PRIVATE_KEY}}
AZURE_REDIS_SCOPES: ${{secrets.AZURE_REDIS_SCOPES}}
AZURE_TENANT_ID: ${{secrets.AZURE_TENANT_ID}}
5 changes: 5 additions & 0 deletions entraid/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
<artifactId>msal4j</artifactId>
<version>1.17.2</version>
</dependency>
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-identity</artifactId>
<version>1.15.3</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2024, Redis Ltd. and Contributors All rights reserved. Licensed under the MIT License.
*/
package redis.clients.authentication.entraid;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Set;
import java.util.function.Supplier;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.DefaultAzureCredential;
import redis.clients.authentication.core.IdentityProvider;
import redis.clients.authentication.core.Token;

public final class AzureIdentityProvider implements IdentityProvider {

private Supplier<AccessToken> accessTokenSupplier;

public AzureIdentityProvider(DefaultAzureCredential defaultAzureCredential, Set<String> scopes,
int timeout) {
TokenRequestContext ctx = new TokenRequestContext()
.setScopes(new ArrayList<String>(scopes));
accessTokenSupplier = () -> defaultAzureCredential.getToken(ctx)
.block(Duration.ofMillis(timeout));
}

@Override
public Token requestToken() {
return new JWToken(accessTokenSupplier.get().getToken());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2024, Redis Ltd. and Contributors
* All rights reserved.
*
* Licensed under the MIT License.
*/
package redis.clients.authentication.entraid;

import java.util.Set;
import java.util.function.Supplier;

import com.azure.identity.DefaultAzureCredential;

import redis.clients.authentication.core.IdentityProvider;
import redis.clients.authentication.core.IdentityProviderConfig;

public final class AzureIdentityProviderConfig implements IdentityProviderConfig {

private final Supplier<IdentityProvider> providerSupplier;

public AzureIdentityProviderConfig(DefaultAzureCredential defaultAzureCredential, Set<String> scopes, int timeout) {
providerSupplier = () -> new AzureIdentityProvider(defaultAzureCredential, scopes, timeout);
}

@Override
public IdentityProvider getProvider() {
return providerSupplier.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2024, Redis Ltd. and Contributors All rights reserved. Licensed under the MIT License.
*/
package redis.clients.authentication.entraid;

import java.util.Collections;
import java.util.Set;

import com.azure.identity.DefaultAzureCredential;

import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.authentication.core.TokenManagerConfig;

public class AzureTokenAuthConfigBuilder
extends TokenAuthConfig.Builder<AzureTokenAuthConfigBuilder> implements AutoCloseable {
public static final float DEFAULT_EXPIRATION_REFRESH_RATIO = 0.75F;
public static final int DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 2 * 60 * 1000;
public static final int DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS = 1000;
public static final int DEFAULT_MAX_ATTEMPTS_TO_RETRY = 5;
public static final int DEFAULT_DELAY_IN_MS_TO_RETRY = 100;
public static final Set<String> DEFAULT_SCOPES = Collections.singleton("https://redis.azure.com/.default");;

private DefaultAzureCredential defaultAzureCredential;
private Set<String> scopes = DEFAULT_SCOPES;
private int tokenRequestExecTimeoutInMs = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS;

public AzureTokenAuthConfigBuilder() {
this.expirationRefreshRatio(DEFAULT_EXPIRATION_REFRESH_RATIO)
.lowerRefreshBoundMillis(DEFAULT_LOWER_REFRESH_BOUND_MILLIS)
.tokenRequestExecTimeoutInMs(DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS)
.maxAttemptsToRetry(DEFAULT_MAX_ATTEMPTS_TO_RETRY)
.delayInMsToRetry(DEFAULT_DELAY_IN_MS_TO_RETRY);
}

public AzureTokenAuthConfigBuilder defaultAzureCredential(
DefaultAzureCredential defaultAzureCredential) {
this.defaultAzureCredential = defaultAzureCredential;
return this;
}

public AzureTokenAuthConfigBuilder scopes(Set<String> scopes) {
this.scopes = scopes;
return this;
}

@Override
public AzureTokenAuthConfigBuilder tokenRequestExecTimeoutInMs(
int tokenRequestExecTimeoutInMs) {
super.tokenRequestExecTimeoutInMs(tokenRequestExecTimeoutInMs);
this.tokenRequestExecTimeoutInMs = tokenRequestExecTimeoutInMs;
return this;
}

public TokenAuthConfig build() {
super.identityProviderConfig(new AzureIdentityProviderConfig(defaultAzureCredential, scopes,
tokenRequestExecTimeoutInMs));
return super.build();
}

@Override
public void close() throws Exception {
defaultAzureCredential = null;
scopes = null;
}

public static AzureTokenAuthConfigBuilder builder() {
return new AzureTokenAuthConfigBuilder();
}

public static AzureTokenAuthConfigBuilder from(AzureTokenAuthConfigBuilder sample) {
TokenAuthConfig tokenAuthConfig = TokenAuthConfig.Builder.from(sample).build();
TokenManagerConfig tokenManagerConfig = tokenAuthConfig.getTokenManagerConfig();

AzureTokenAuthConfigBuilder builder = (AzureTokenAuthConfigBuilder) new AzureTokenAuthConfigBuilder()
.expirationRefreshRatio(tokenManagerConfig.getExpirationRefreshRatio())
.lowerRefreshBoundMillis(tokenManagerConfig.getLowerRefreshBoundMillis())
.tokenRequestExecTimeoutInMs(tokenManagerConfig.getTokenRequestExecTimeoutInMs())
.maxAttemptsToRetry(tokenManagerConfig.getRetryPolicy().getMaxAttempts())
.delayInMsToRetry(tokenManagerConfig.getRetryPolicy().getdelayInMs())
.identityProviderConfig(tokenAuthConfig.getIdentityProviderConfig());
builder.scopes = sample.scopes;
return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2024, Redis Ltd. and Contributors All rights reserved. Licensed under the MIT License.
*/
package redis.clients.authentication;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import org.junit.Test;
import com.azure.identity.DefaultAzureCredential;
import com.azure.identity.DefaultAzureCredentialBuilder;

import redis.clients.authentication.core.Token;
import redis.clients.authentication.entraid.AzureIdentityProvider;
import redis.clients.authentication.entraid.AzureTokenAuthConfigBuilder;

public class AzureIdentityProviderIntegrationTests {

@Test
public void requestTokenWithDefaultAzureCredential() {
// ensure environment variables are set
String client_id = System.getenv(TestContext.AZURE_CLIENT_ID);
assertNotNull(client_id);
assertFalse(client_id.isEmpty());
String clientSecret = System.getenv(TestContext.AZURE_CLIENT_SECRET);
assertNotNull(clientSecret);
assertFalse(clientSecret.isEmpty());
String tenantId = System.getenv("AZURE_TENANT_ID");
assertNotNull(tenantId);
assertFalse(tenantId.isEmpty());

DefaultAzureCredential defaultAzureCredential = new DefaultAzureCredentialBuilder().build();
Token token = new AzureIdentityProvider(defaultAzureCredential,
AzureTokenAuthConfigBuilder.DEFAULT_SCOPES, 2000).requestToken();
assertNotNull(token.getValue());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2024, Redis Ltd. and Contributors
* All rights reserved.
*
* Licensed under the MIT License.
*/
package redis.clients.authentication;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.time.OffsetDateTime;
import java.util.Date;
import java.util.Set;

import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedConstruction;

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.DefaultAzureCredential;

import reactor.core.publisher.Mono;
import redis.clients.authentication.entraid.AzureIdentityProvider;
import redis.clients.authentication.entraid.AzureIdentityProviderConfig;
import redis.clients.authentication.entraid.AzureTokenAuthConfigBuilder;

public class AzureIdentityProviderUnitTests {
@Test
public void testAzureTokenAuthConfigBuilder() {
DefaultAzureCredential mockCredential = mock(DefaultAzureCredential.class);
Set<String> scopes = AzureTokenAuthConfigBuilder.DEFAULT_SCOPES;
int timeout = 2000;

try (MockedConstruction<AzureIdentityProviderConfig> mockedConstructor = mockConstruction(
AzureIdentityProviderConfig.class,
(mock, context) -> {
assertEquals(mockCredential, context.arguments().get(0));
assertEquals(scopes, context.arguments().get(1));
assertEquals(timeout, context.arguments().get(2));
})) {
AzureTokenAuthConfigBuilder.builder().defaultAzureCredential(mockCredential).scopes(scopes)
.tokenRequestExecTimeoutInMs(timeout).build();
}
}

public void testAzureIdentityProviderConfig() {
DefaultAzureCredential mockCredential = mock(DefaultAzureCredential.class);
Set<String> scopes = AzureTokenAuthConfigBuilder.DEFAULT_SCOPES;
int timeout = 2000;

try (MockedConstruction<AzureIdentityProvider> mockedConstructor = mockConstruction(
AzureIdentityProvider.class,
(mock, context) -> {
assertEquals(mockCredential, context.arguments().get(0));
assertEquals(scopes, context.arguments().get(1));
assertEquals(timeout, context.arguments().get(2));
})) {
new AzureIdentityProviderConfig(mockCredential, scopes, timeout).getProvider();
}
}

@Test
public void testRequestWithMockCredential() {
String token = JWT.create().withExpiresAt(new Date(System.currentTimeMillis()
- 1000))
.withClaim("oid", "user1").sign(Algorithm.none());

AccessToken t = new AccessToken(token, OffsetDateTime.now());
Mono<AccessToken> monoToken = Mono.just(t);
DefaultAzureCredential mockCredential = mock(DefaultAzureCredential.class);
when(mockCredential.getToken(any(TokenRequestContext.class))).thenReturn(monoToken);
new AzureIdentityProviderConfig(mockCredential,
AzureTokenAuthConfigBuilder.DEFAULT_SCOPES, 0).getProvider().requestToken();

ArgumentCaptor<TokenRequestContext> argument = ArgumentCaptor.forClass(TokenRequestContext.class);

verify(mockCredential, atLeast(1)).getToken(argument.capture());
AzureTokenAuthConfigBuilder.DEFAULT_SCOPES
.forEach((item) -> assertTrue(argument.getValue().getScopes().contains(item)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

public class TestContext {

private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID";
private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY";
private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET";
private static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY";
private static final String AZURE_CERT = "AZURE_CERT";
private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES";
private static final String AZURE_USER_ASSIGNED_MANAGED_IDENTITY_CLIENT_ID = "AZURE_USER_ASSIGNED_MANAGED_IDENTITY_CLIENT_ID";
public static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID";
public static final String AZURE_AUTHORITY = "AZURE_AUTHORITY";
public static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET";
public static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY";
public static final String AZURE_CERT = "AZURE_CERT";
public static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES";
public static final String AZURE_USER_ASSIGNED_MANAGED_IDENTITY_CLIENT_ID = "AZURE_USER_ASSIGNED_MANAGED_IDENTITY_CLIENT_ID";

private String clientId;
private String authority;
Expand Down
Loading