1
+ package com .redis .om .spring ;
2
+
3
+ import com .azure .core .credential .AccessToken ;
4
+ import com .azure .core .credential .TokenCredential ;
5
+ import com .azure .core .credential .TokenRequestContext ;
6
+ import com .azure .core .util .CoreUtils ;
7
+ import com .azure .identity .DefaultAzureCredential ;
8
+ import com .azure .identity .DefaultAzureCredentialBuilder ;
9
+ import org .apache .commons .logging .Log ;
10
+ import org .apache .commons .logging .LogFactory ;
11
+ import org .springframework .beans .factory .annotation .Value ;
12
+ import org .springframework .boot .autoconfigure .condition .ConditionalOnProperty ;
13
+ import org .springframework .boot .context .properties .EnableConfigurationProperties ;
14
+ import org .springframework .context .annotation .Bean ;
15
+ import org .springframework .context .annotation .Configuration ;
16
+ import org .springframework .data .redis .connection .RedisStandaloneConfiguration ;
17
+ import org .springframework .data .redis .connection .jedis .JedisConnectionFactory ;
18
+ import redis .clients .jedis .Jedis ;
19
+
20
+ import java .time .Duration ;
21
+ import java .time .OffsetDateTime ;
22
+ import java .util .Timer ;
23
+ import java .util .TimerTask ;
24
+ import java .util .concurrent .ThreadLocalRandom ;
25
+
26
+ @ Configuration (proxyBeanMethods = false )
27
+ @ EnableConfigurationProperties ({ RedisOMProperties .class })
28
+ @ ConditionalOnProperty (name = "redis.om.spring.authentication.entra-id.enabled" , havingValue = "true" , matchIfMissing = false )
29
+ public class EntraIDConfiguration {
30
+
31
+ private static final Log logger = LogFactory .getLog (EntraIDConfiguration .class );
32
+
33
+ @ Value ("${spring.data.redis.host}" )
34
+ private String host ;
35
+ @ Value ("${spring.data.redis.port}" )
36
+ private int port ;
37
+ @ Value ("${redis.om.spring.authentication.entra-id.enabled}" )
38
+ private String clientType ;
39
+
40
+ public EntraIDConfiguration () {
41
+ logger .info ("EntraIDConfiguration initialized" );
42
+ logger .info ("Redis host: " + host );
43
+ logger .info ("Redis port: " + port );
44
+ logger .info ("Redis client type: " + clientType );
45
+ }
46
+
47
+ @ Bean
48
+ public JedisConnectionFactory jedisConnectionFactory () {
49
+ logger .info ("Creating JedisConnectionFactory for Entra ID authentication" );
50
+
51
+ // Create DefaultAzureCredential
52
+ DefaultAzureCredential defaultAzureCredential = new DefaultAzureCredentialBuilder ().build ();
53
+ TokenRequestContext trc = new TokenRequestContext ().addScopes ("https://redis.azure.com/.default" );
54
+ TokenRefreshCache tokenRefreshCache = new TokenRefreshCache (defaultAzureCredential , trc );
55
+ AccessToken accessToken = tokenRefreshCache .getAccessToken ();
56
+
57
+ boolean useSsl = true ;
58
+ String token = accessToken .getToken ();
59
+ logger .trace ("Token obtained successfully: \n " + token );
60
+ String username = extractUsernameFromToken (token );
61
+ logger .debug ("Username extracted from token: " + username );
62
+
63
+ JedisConnectionFactory jedisConnectionFactory = new JedisConnectionFactory (getRedisStandaloneConfiguration (username , token , useSsl ));
64
+ jedisConnectionFactory .setConvertPipelineAndTxResults (false );
65
+ jedisConnectionFactory .setUseSsl (useSsl );
66
+ logger .info ("JedisConnectionFactory for EntraID created successfully" );
67
+ return jedisConnectionFactory ;
68
+ }
69
+
70
+ private RedisStandaloneConfiguration getRedisStandaloneConfiguration (String username , String token , boolean useSsl ) {
71
+ RedisStandaloneConfiguration redisStandaloneConfiguration = new RedisStandaloneConfiguration ();
72
+ redisStandaloneConfiguration .setHostName (host );
73
+ redisStandaloneConfiguration .setPort (port );
74
+ redisStandaloneConfiguration .setUsername (username );
75
+ redisStandaloneConfiguration .setPassword (token );
76
+ return redisStandaloneConfiguration ;
77
+ }
78
+
79
+ private String extractUsernameFromToken (String token ) {
80
+ // The token is a JWT, and the username is in the "sub" claim
81
+ String [] parts = token .split ("\\ ." );
82
+ if (parts .length != 3 ) {
83
+ throw new IllegalArgumentException ("Invalid JWT token" );
84
+ }
85
+
86
+ String payload = new String (java .util .Base64 .getUrlDecoder ().decode (parts [1 ]));
87
+ com .google .gson .JsonObject jsonObject = com .google .gson .JsonParser .parseString (payload ).getAsJsonObject ();
88
+ return jsonObject .get ("sub" ).getAsString ();
89
+ }
90
+
91
+ /**
92
+ * The token cache to store and proactively refresh the access token.
93
+ */
94
+ private class TokenRefreshCache {
95
+ private final TokenCredential tokenCredential ;
96
+ private final TokenRequestContext tokenRequestContext ;
97
+ private final Timer timer ;
98
+ private volatile AccessToken accessToken ;
99
+ private final Duration maxRefreshOffset = Duration .ofMinutes (5 );
100
+ private final Duration baseRefreshOffset = Duration .ofMinutes (2 );
101
+ private Jedis jedisInstanceToAuthenticate ;
102
+ private String username ;
103
+
104
+ /**
105
+ * Creates an instance of TokenRefreshCache
106
+ * @param tokenCredential the token credential to be used for authentication.
107
+ * @param tokenRequestContext the token request context to be used for authentication.
108
+ */
109
+ public TokenRefreshCache (TokenCredential tokenCredential , TokenRequestContext tokenRequestContext ) {
110
+ this .tokenCredential = tokenCredential ;
111
+ this .tokenRequestContext = tokenRequestContext ;
112
+ this .timer = new Timer ();
113
+ }
114
+
115
+ /**
116
+ * Gets the cached access token.
117
+ * @return the AccessToken
118
+ */
119
+ public AccessToken getAccessToken () {
120
+ if (accessToken != null ) {
121
+ return accessToken ;
122
+ } else {
123
+ TokenRefreshTask tokenRefreshTask = new TokenRefreshTask ();
124
+ accessToken = tokenCredential .getToken (tokenRequestContext ).block ();
125
+ timer .schedule (tokenRefreshTask , getTokenRefreshDelay ());
126
+ return accessToken ;
127
+ }
128
+ }
129
+
130
+ private class TokenRefreshTask extends TimerTask {
131
+ // Add your task here
132
+ public void run () {
133
+ accessToken = tokenCredential .getToken (tokenRequestContext ).block ();
134
+ username = extractUsernameFromToken (accessToken .getToken ());
135
+ System .out .println ("Refreshed Token with Expiry: " + accessToken .getExpiresAt ().toEpochSecond ());
136
+
137
+ if (jedisInstanceToAuthenticate != null && !CoreUtils .isNullOrEmpty (username )) {
138
+ jedisInstanceToAuthenticate .auth (username , accessToken .getToken ());
139
+ System .out .println ("Refreshed Jedis Connection with fresh access token, token expires at : "
140
+ + accessToken .getExpiresAt ().toEpochSecond ());
141
+ }
142
+ timer .schedule (new TokenRefreshTask (), getTokenRefreshDelay ());
143
+ }
144
+ }
145
+
146
+ private long getTokenRefreshDelay () {
147
+ return ((accessToken .getExpiresAt ()
148
+ .minusSeconds (ThreadLocalRandom .current ().nextLong (baseRefreshOffset .getSeconds (), maxRefreshOffset .getSeconds ()))
149
+ .toEpochSecond () - OffsetDateTime .now ().toEpochSecond ()) * 1000 );
150
+ }
151
+
152
+ /**
153
+ * Sets the Jedis to proactively authenticate before token expiry.
154
+ * @param jedisInstanceToAuthenticate the instance to authenticate
155
+ * @return the updated instance
156
+ */
157
+ public TokenRefreshCache setJedisInstanceToAuthenticate (Jedis jedisInstanceToAuthenticate ) {
158
+ this .jedisInstanceToAuthenticate = jedisInstanceToAuthenticate ;
159
+ return this ;
160
+ }
161
+ }
162
+ }
0 commit comments