Skip to content

Commit f7ceed0

Browse files
Luigi Fugarobsbodden
authored andcommitted
feat: Add Microsoft Entra ID authentication support for Redis and Azure OpenAI
- Implement Entra ID authentication mechanism for Redis connections - Add dynamic support for Azure OpenAI authentication via Entra ID - Create EntraIDConfiguration for automatic JedisConnectionFactory creation - Add Azure Identity dependency for Entra ID integration - Refactor Azure client properties under a unified structure - Improve OpenAI configuration with response timeout setting - Centralize Spring AI properties and embedding model creation - Update Spring Boot auto-configuration - Preserve TDigest implementation compatibility
1 parent 7c92fd9 commit f7ceed0

14 files changed

+381
-83
lines changed

redis-om-spring/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
<groupId>com.redis.om</groupId>
99
<artifactId>redis-om-spring</artifactId>
10-
<version>0.9.12-SNAPSHOT</version>
10+
<version>0.9.13-SNAPSHOT</version>
1111
<packaging>jar</packaging>
1212

1313
<name>redis-om-spring</name>
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package com.redis.om.spring;
2+
3+
import com.azure.core.credential.AccessToken;
4+
import com.azure.core.credential.TokenCredential;
5+
import com.azure.core.credential.TokenRequestContext;
6+
import com.azure.core.util.CoreUtils;
7+
import com.azure.identity.DefaultAzureCredential;
8+
import com.azure.identity.DefaultAzureCredentialBuilder;
9+
import org.apache.commons.logging.Log;
10+
import org.apache.commons.logging.LogFactory;
11+
import org.springframework.beans.factory.annotation.Value;
12+
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
13+
import org.springframework.boot.context.properties.EnableConfigurationProperties;
14+
import org.springframework.context.annotation.Bean;
15+
import org.springframework.context.annotation.Configuration;
16+
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
17+
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
18+
import redis.clients.jedis.Jedis;
19+
20+
import java.time.Duration;
21+
import java.time.OffsetDateTime;
22+
import java.util.Timer;
23+
import java.util.TimerTask;
24+
import java.util.concurrent.ThreadLocalRandom;
25+
26+
@Configuration(proxyBeanMethods = false)
27+
@EnableConfigurationProperties({ RedisOMProperties.class })
28+
@ConditionalOnProperty(name = "redis.om.spring.authentication.entra-id.enabled", havingValue = "true", matchIfMissing = false)
29+
public class EntraIDConfiguration {
30+
31+
private static final Log logger = LogFactory.getLog(EntraIDConfiguration.class);
32+
33+
@Value("${spring.data.redis.host}")
34+
private String host;
35+
@Value("${spring.data.redis.port}")
36+
private int port;
37+
@Value("${redis.om.spring.authentication.entra-id.enabled}")
38+
private String clientType;
39+
40+
public EntraIDConfiguration() {
41+
logger.info("EntraIDConfiguration initialized");
42+
logger.info("Redis host: " + host);
43+
logger.info("Redis port: " + port);
44+
logger.info("Redis client type: " + clientType);
45+
}
46+
47+
@Bean
48+
public JedisConnectionFactory jedisConnectionFactory() {
49+
logger.info("Creating JedisConnectionFactory for Entra ID authentication");
50+
51+
// Create DefaultAzureCredential
52+
DefaultAzureCredential defaultAzureCredential = new DefaultAzureCredentialBuilder().build();
53+
TokenRequestContext trc = new TokenRequestContext().addScopes("https://redis.azure.com/.default");
54+
TokenRefreshCache tokenRefreshCache = new TokenRefreshCache(defaultAzureCredential, trc);
55+
AccessToken accessToken = tokenRefreshCache.getAccessToken();
56+
57+
boolean useSsl = true;
58+
String token = accessToken.getToken();
59+
logger.trace("Token obtained successfully: \n" + token);
60+
String username = extractUsernameFromToken(token);
61+
logger.debug("Username extracted from token: " + username);
62+
63+
JedisConnectionFactory jedisConnectionFactory = new JedisConnectionFactory(getRedisStandaloneConfiguration(username, token, useSsl));
64+
jedisConnectionFactory.setConvertPipelineAndTxResults(false);
65+
jedisConnectionFactory.setUseSsl(useSsl);
66+
logger.info("JedisConnectionFactory for EntraID created successfully");
67+
return jedisConnectionFactory;
68+
}
69+
70+
private RedisStandaloneConfiguration getRedisStandaloneConfiguration(String username, String token, boolean useSsl) {
71+
RedisStandaloneConfiguration redisStandaloneConfiguration = new RedisStandaloneConfiguration();
72+
redisStandaloneConfiguration.setHostName(host);
73+
redisStandaloneConfiguration.setPort(port);
74+
redisStandaloneConfiguration.setUsername(username);
75+
redisStandaloneConfiguration.setPassword(token);
76+
return redisStandaloneConfiguration;
77+
}
78+
79+
private String extractUsernameFromToken(String token) {
80+
// The token is a JWT, and the username is in the "sub" claim
81+
String[] parts = token.split("\\.");
82+
if (parts.length != 3) {
83+
throw new IllegalArgumentException("Invalid JWT token");
84+
}
85+
86+
String payload = new String(java.util.Base64.getUrlDecoder().decode(parts[1]));
87+
com.google.gson.JsonObject jsonObject = com.google.gson.JsonParser.parseString(payload).getAsJsonObject();
88+
return jsonObject.get("sub").getAsString();
89+
}
90+
91+
/**
92+
* The token cache to store and proactively refresh the access token.
93+
*/
94+
private class TokenRefreshCache {
95+
private final TokenCredential tokenCredential;
96+
private final TokenRequestContext tokenRequestContext;
97+
private final Timer timer;
98+
private volatile AccessToken accessToken;
99+
private final Duration maxRefreshOffset = Duration.ofMinutes(5);
100+
private final Duration baseRefreshOffset = Duration.ofMinutes(2);
101+
private Jedis jedisInstanceToAuthenticate;
102+
private String username;
103+
104+
/**
105+
* Creates an instance of TokenRefreshCache
106+
* @param tokenCredential the token credential to be used for authentication.
107+
* @param tokenRequestContext the token request context to be used for authentication.
108+
*/
109+
public TokenRefreshCache(TokenCredential tokenCredential, TokenRequestContext tokenRequestContext) {
110+
this.tokenCredential = tokenCredential;
111+
this.tokenRequestContext = tokenRequestContext;
112+
this.timer = new Timer();
113+
}
114+
115+
/**
116+
* Gets the cached access token.
117+
* @return the AccessToken
118+
*/
119+
public AccessToken getAccessToken() {
120+
if (accessToken != null) {
121+
return accessToken;
122+
} else {
123+
TokenRefreshTask tokenRefreshTask = new TokenRefreshTask();
124+
accessToken = tokenCredential.getToken(tokenRequestContext).block();
125+
timer.schedule(tokenRefreshTask, getTokenRefreshDelay());
126+
return accessToken;
127+
}
128+
}
129+
130+
private class TokenRefreshTask extends TimerTask {
131+
// Add your task here
132+
public void run() {
133+
accessToken = tokenCredential.getToken(tokenRequestContext).block();
134+
username = extractUsernameFromToken(accessToken.getToken());
135+
System.out.println("Refreshed Token with Expiry: " + accessToken.getExpiresAt().toEpochSecond());
136+
137+
if (jedisInstanceToAuthenticate != null && !CoreUtils.isNullOrEmpty(username)) {
138+
jedisInstanceToAuthenticate.auth(username, accessToken.getToken());
139+
System.out.println("Refreshed Jedis Connection with fresh access token, token expires at : "
140+
+ accessToken.getExpiresAt().toEpochSecond());
141+
}
142+
timer.schedule(new TokenRefreshTask(), getTokenRefreshDelay());
143+
}
144+
}
145+
146+
private long getTokenRefreshDelay() {
147+
return ((accessToken.getExpiresAt()
148+
.minusSeconds(ThreadLocalRandom.current().nextLong(baseRefreshOffset.getSeconds(), maxRefreshOffset.getSeconds()))
149+
.toEpochSecond() - OffsetDateTime.now().toEpochSecond()) * 1000);
150+
}
151+
152+
/**
153+
* Sets the Jedis to proactively authenticate before token expiry.
154+
* @param jedisInstanceToAuthenticate the instance to authenticate
155+
* @return the updated instance
156+
*/
157+
public TokenRefreshCache setJedisInstanceToAuthenticate(Jedis jedisInstanceToAuthenticate) {
158+
this.jedisInstanceToAuthenticate = jedisInstanceToAuthenticate;
159+
return this;
160+
}
161+
}
162+
}

redis-om-spring/src/main/java/com/redis/om/spring/RedisOMAiProperties.java

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ public class RedisOMAiProperties {
1313
private int embeddingBatchSize = 1000;
1414
private final Djl djl = new Djl();
1515
private final OpenAi openAi = new OpenAi();
16-
private final AzureOpenAi azureOpenAi = new AzureOpenAi();
17-
private final AzureEntraId azureEntraId = new AzureEntraId();
16+
private final AzureClients azure = new AzureClients();
1817
private final VertexAi vertexAi = new VertexAi();
1918
private final Aws aws = new Aws();
2019
private final Ollama ollama = new Ollama();
@@ -35,12 +34,8 @@ public OpenAi getOpenAi() {
3534
return openAi;
3635
}
3736

38-
public AzureOpenAi getAzureOpenAi() {
39-
return azureOpenAi;
40-
}
41-
42-
public AzureEntraId getAzureEntraId() {
43-
return azureEntraId;
37+
public AzureClients getAzure() {
38+
return azure;
4439
}
4540

4641
public VertexAi getVertexAi() {
@@ -252,6 +247,27 @@ public void setBaseUrl(String baseUrl) {
252247
}
253248
}
254249

250+
public static class AzureClients {
251+
private AzureOpenAi openAi;
252+
private AzureEntraId entraId;
253+
254+
public AzureOpenAi getOpenAi() {
255+
return openAi;
256+
}
257+
258+
public void setOpenAi(AzureOpenAi openAi) {
259+
this.openAi = openAi;
260+
}
261+
262+
public AzureEntraId getEntraId() {
263+
return entraId;
264+
}
265+
266+
public void setEntraId(AzureEntraId entraId) {
267+
this.entraId = entraId;
268+
}
269+
}
270+
255271
public static class AzureOpenAi {
256272
private String apiKey;
257273
private String endpoint;

redis-om-spring/src/main/java/com/redis/om/spring/RedisOMProperties.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,20 @@
1010
prefix = "redis.om.spring", ignoreInvalidFields = true
1111
)
1212
public class RedisOMProperties {
13-
public static final String ROMS_VERSION = "0.9.11-SNAPSHOT";
13+
public static final String ROMS_VERSION = "0.9.13-SNAPSHOT";
1414
public static final int MAX_SEARCH_RESULTS = 10000;
1515
public static final double DEFAULT_DISTANCE = 0.0005;
1616
public static final Metrics DEFAULT_DISTANCE_METRIC = Metrics.MILES;
1717
// repository properties
1818
private final Repository repository = new Repository();
1919
private final References references = new References();
20+
// Entra ID Authentication
21+
private final Authentication authentication = new Authentication();
22+
23+
public Authentication getAuthentication() {
24+
return authentication;
25+
}
26+
2027

2128
public Repository getRepository() {
2229
return repository;
@@ -26,6 +33,39 @@ public References getReferences() {
2633
return references;
2734
}
2835

36+
public static class Authentication {
37+
private EntraId entraId = new EntraId();
38+
39+
public EntraId getEntraId() {
40+
return entraId;
41+
}
42+
43+
public void setEntraId(EntraId entraId) {
44+
this.entraId = entraId;
45+
}
46+
}
47+
48+
public static class EntraId {
49+
private boolean enabled = false;
50+
private String tenantId;
51+
52+
public boolean isEnabled() {
53+
return enabled;
54+
}
55+
56+
public void setEnabled(boolean enabled) {
57+
this.enabled = enabled;
58+
}
59+
60+
public String getTenantId() {
61+
return tenantId;
62+
}
63+
64+
public void setTenantId(String tenantId) {
65+
this.tenantId = tenantId;
66+
}
67+
}
68+
2969
public static class Repository {
3070
private final Query query = new Query();
3171
private boolean dropAndRecreateIndexOnDeleteAll = false;

redis-om-spring/src/main/java/com/redis/om/spring/client/RedisModulesClient.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import redis.clients.jedis.bloom.commands.TopKFilterCommands;
1818
import redis.clients.jedis.json.commands.RedisJsonCommands;
1919
import redis.clients.jedis.search.RediSearchCommands;
20+
import redis.clients.jedis.bloom.commands.TDigestSketchCommands;
2021

2122
import java.util.Objects;
2223
import java.util.Optional;
@@ -61,6 +62,10 @@ public CuckooFilterCommands clientForCuckoo() {
6162
public TopKFilterCommands clientForTopK() {
6263
return unifiedJedis;
6364
}
65+
66+
public TDigestSketchCommands clientForTDigest() {
67+
return unifiedJedis;
68+
}
6469

6570
private UnifiedJedis getUnifiedJedis() {
6671

redis-om-spring/src/main/java/com/redis/om/spring/ops/RedisModulesOperations.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ public CuckooFilterOperations<K> opsForCuckoFilter() {
3535
public TopKOperations<K> opsForTopK() {
3636
return new TopKOperationsImpl<>(client);
3737
}
38+
39+
public TDigestOperations<K> opsForTDigest() {
40+
return new TDigestOperationsImpl<>(client);
41+
}
3842
}

0 commit comments

Comments
 (0)