Skip to content

Commit ab05e6c

Browse files
committed
- introduce fastfailover using objectMaker injection into connectionFactory
1 parent 397f437 commit ab05e6c

File tree

5 files changed

+190
-16
lines changed

5 files changed

+190
-16
lines changed

src/main/java/redis/clients/jedis/Connection.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,10 @@ public void disconnect() {
288288
}
289289
}
290290

291+
public void forceDisconnect() throws IOException {
292+
socket.close();
293+
}
294+
291295
public boolean isConnected() {
292296
return socket != null && socket.isBound() && !socket.isClosed() && socket.isConnected()
293297
&& !socket.isInputShutdown() && !socket.isOutputShutdown();

src/main/java/redis/clients/jedis/ConnectionFactory.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import org.slf4j.LoggerFactory;
88

99
import java.util.function.Supplier;
10+
import java.util.function.UnaryOperator;
1011

1112
import redis.clients.jedis.annots.Experimental;
1213
import redis.clients.jedis.authentication.AuthXManager;
@@ -21,12 +22,15 @@
2122
*/
2223
public class ConnectionFactory implements PooledObjectFactory<Connection> {
2324

25+
public interface MakerInjector extends UnaryOperator<Supplier<Connection>> {
26+
};
27+
2428
private static final Logger logger = LoggerFactory.getLogger(ConnectionFactory.class);
2529

2630
private final JedisSocketFactory jedisSocketFactory;
2731
private final JedisClientConfig clientConfig;
2832
private final Cache clientSideCache;
29-
private final Supplier<Connection> objectMaker;
33+
private Supplier<Connection> objectMaker;
3034

3135
private final AuthXEventListener authXEventListener;
3236

@@ -73,6 +77,10 @@ private Supplier<Connection> connectionSupplier() {
7377
: () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache);
7478
}
7579

80+
public void injectMaker(MakerInjector injector) {
81+
this.objectMaker = injector.apply(objectMaker);
82+
}
83+
7684
@Override
7785
public void activateObject(PooledObject<Connection> pooledConnection) throws Exception {
7886
// what to do ??
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package redis.clients.jedis.mcf;
2+
3+
import java.util.Set;
4+
import java.util.concurrent.CompletableFuture;
5+
import java.util.concurrent.ConcurrentHashMap;
6+
import java.util.concurrent.CountDownLatch;
7+
import java.util.concurrent.ExecutorService;
8+
import java.util.concurrent.Executors;
9+
import java.util.concurrent.TimeUnit;
10+
import java.util.function.Supplier;
11+
12+
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
15+
16+
import redis.clients.jedis.Connection;
17+
import redis.clients.jedis.ConnectionFactory;
18+
import redis.clients.jedis.ConnectionPool;
19+
import redis.clients.jedis.HostAndPort;
20+
import redis.clients.jedis.JedisClientConfig;
21+
import redis.clients.jedis.exceptions.JedisConnectionException;
22+
23+
public class TrackingConnectionPool extends ConnectionPool {
24+
private static final Logger log = LoggerFactory.getLogger(TrackingConnectionPool.class);
25+
26+
private final Set<Connection> allCreatedObjects = ConcurrentHashMap.newKeySet();
27+
private volatile boolean forcingDisconnect;
28+
private ConnectionFactory factory;
29+
30+
// Executor for running connection creation subtasks
31+
private final ExecutorService connectionCreationExecutor = Executors.newCachedThreadPool(r -> {
32+
Thread t = new Thread(r, "connection-creator");
33+
t.setDaemon(true);
34+
return t;
35+
});
36+
37+
// Simple latch for external unblocking of all connection creation threads
38+
private volatile CountDownLatch unblockLatch = new CountDownLatch(1);
39+
40+
public TrackingConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
41+
GenericObjectPoolConfig<Connection> poolConfig) {
42+
super(hostAndPort, clientConfig, poolConfig);
43+
this.factory = (ConnectionFactory) this.getFactory();
44+
factory.injectMaker(this::injector);
45+
}
46+
47+
private Supplier<Connection> injector(Supplier<Connection> supplier) {
48+
return () -> make(supplier);
49+
}
50+
51+
private Connection make(Supplier<Connection> supplier) {
52+
// Create CompletableFutures for both the connection task and unblock signal
53+
CompletableFuture<Connection> connectionFuture = CompletableFuture.supplyAsync(() -> supplier.get(),
54+
connectionCreationExecutor);
55+
56+
CompletableFuture<Void> unblockFuture = CompletableFuture.runAsync(() -> {
57+
try {
58+
unblockLatch.await();
59+
} catch (InterruptedException e) {
60+
Thread.currentThread().interrupt();
61+
}
62+
}, connectionCreationExecutor);
63+
64+
try {
65+
// Wait for whichever completes first
66+
CompletableFuture.anyOf(connectionFuture, unblockFuture).join();
67+
68+
if (connectionFuture.isDone() && !connectionFuture.isCompletedExceptionally()) {
69+
return connectionFuture.join();
70+
} else {
71+
connectionFuture.cancel(true);
72+
throw new JedisConnectionException("Connection creation was cancelled due to forced disconnect!");
73+
}
74+
} catch (JedisConnectionException e) {
75+
connectionFuture.cancel(true);
76+
unblockFuture.cancel(true);
77+
throw e;
78+
} catch (Exception e) {
79+
connectionFuture.cancel(true);
80+
unblockFuture.cancel(true);
81+
throw new JedisConnectionException("Connection creation failed", e);
82+
}
83+
}
84+
85+
public TrackingConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) {
86+
super(hostAndPort, clientConfig);
87+
this.factory = (ConnectionFactory) this.getFactory();
88+
}
89+
90+
@Override
91+
public Connection getResource() {
92+
if (forcingDisconnect) {
93+
throw new JedisConnectionException("Forced disconnect in progress!");
94+
}
95+
96+
Connection conn = super.getResource();
97+
allCreatedObjects.add(conn);
98+
return conn;
99+
}
100+
101+
@Override
102+
public void returnResource(final Connection resource) {
103+
super.returnResource(resource);
104+
allCreatedObjects.remove(resource);
105+
}
106+
107+
@Override
108+
public void returnBrokenResource(final Connection resource) {
109+
if (forcingDisconnect) {
110+
super.returnResource(resource);
111+
} else {
112+
super.returnBrokenResource(resource);
113+
}
114+
allCreatedObjects.remove(resource);
115+
}
116+
117+
public void forceDisconnect() {
118+
this.forcingDisconnect = true;
119+
120+
// First, unblock any pending connection creation
121+
unblockConnectionCreation();
122+
123+
this.clear();
124+
for (Connection connection : allCreatedObjects) {
125+
try {
126+
connection.forceDisconnect();
127+
} catch (Exception e) {
128+
log.warn("Error while force disconnecting connection: " + connection.toIdentityString());
129+
}
130+
}
131+
this.clear();
132+
this.forcingDisconnect = false;
133+
}
134+
135+
/**
136+
* Externally unblock ALL waiting connection creation threads.
137+
*/
138+
public void unblockConnectionCreation() {
139+
CountDownLatch oldLatch = unblockLatch;
140+
unblockLatch = new CountDownLatch(1); // Reset first
141+
oldLatch.countDown(); // Then signal old one
142+
log.info("Externally unblocked waiting connection creation threads");
143+
}
144+
145+
@Override
146+
public void close() {
147+
// Shutdown the connection creation executor
148+
connectionCreationExecutor.shutdown();
149+
try {
150+
if (!connectionCreationExecutor.awaitTermination(1, TimeUnit.SECONDS)) {
151+
connectionCreationExecutor.shutdownNow();
152+
}
153+
} catch (InterruptedException e) {
154+
connectionCreationExecutor.shutdownNow();
155+
Thread.currentThread().interrupt();
156+
}
157+
158+
super.close();
159+
}
160+
}

src/main/java/redis/clients/jedis/providers/MultiClusterPooledConnectionProvider.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import redis.clients.jedis.mcf.HealthStatusManager;
4242
import redis.clients.jedis.mcf.StatusTracker;
4343
import redis.clients.jedis.mcf.SwitchReason;
44+
import redis.clients.jedis.mcf.TrackingConnectionPool;
4445
import redis.clients.jedis.MultiClusterClientConfig.StrategySupplier;
4546

4647
import redis.clients.jedis.util.Pool;
@@ -292,11 +293,11 @@ private void addClusterInternal(MultiClusterClientConfig multiClusterClientConfi
292293
circuitBreakerEventPublisher.onFailureRateExceeded(event -> log.error(String.valueOf(event)));
293294
circuitBreakerEventPublisher.onSlowCallRateExceeded(event -> log.error(String.valueOf(event)));
294295

295-
ConnectionPool pool;
296+
TrackingConnectionPool pool;
296297
if (poolConfig != null) {
297-
pool = new ConnectionPool(config.getHostAndPort(), config.getJedisClientConfig(), poolConfig);
298+
pool = new TrackingConnectionPool(config.getHostAndPort(), config.getJedisClientConfig(), poolConfig);
298299
} else {
299-
pool = new ConnectionPool(config.getHostAndPort(), config.getJedisClientConfig());
300+
pool = new TrackingConnectionPool(config.getHostAndPort(), config.getJedisClientConfig());
300301
}
301302
Cluster cluster = new Cluster(pool, retry, circuitBreaker, config.getWeight(), multiClusterClientConfig);
302303
multiClusterMap.put(config.getHostAndPort(), cluster);
@@ -546,6 +547,7 @@ private boolean setActiveCluster(Cluster cluster, boolean validateConnection) {
546547
oldCluster.circuitBreaker.getName());
547548
oldCluster.forceDisconnect();
548549
log.info("Disconnected all active connections in old cluster: {}", oldCluster.circuitBreaker.getName());
550+
549551
}
550552
return switched;
551553

@@ -630,7 +632,7 @@ public List<Class<? extends Throwable>> getFallbackExceptionList() {
630632

631633
public static class Cluster {
632634

633-
private final ConnectionPool connectionPool;
635+
private final TrackingConnectionPool connectionPool;
634636
private final Retry retry;
635637
private final CircuitBreaker circuitBreaker;
636638
private final float weight;
@@ -643,7 +645,7 @@ public static class Cluster {
643645
private volatile long gracePeriodEndsAt = 0;
644646
private final Logger log = LoggerFactory.getLogger(getClass());
645647

646-
private Cluster(ConnectionPool connectionPool, Retry retry, CircuitBreaker circuitBreaker, float weight,
648+
private Cluster(TrackingConnectionPool connectionPool, Retry retry, CircuitBreaker circuitBreaker, float weight,
647649
MultiClusterClientConfig multiClusterClientConfig) {
648650
this.connectionPool = connectionPool;
649651
this.retry = retry;
@@ -741,8 +743,7 @@ public boolean isFailbackSupported() {
741743
}
742744

743745
public void forceDisconnect() {
744-
log.info("Forcing disconnect of all active connections in old cluster: {}", circuitBreaker.getName());
745-
// TODO: disconnect all active connections here
746+
connectionPool.forceDisconnect();
746747
}
747748

748749
public void close() {

src/test/java/redis/clients/jedis/mcf/PeriodicFailbackTest.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static org.junit.jupiter.api.Assertions.*;
44
import static org.mockito.Mockito.*;
5+
import static redis.clients.jedis.providers.MultiClusterPooledConnectionProviderHelper.onHealthStatusChange;
56

67
import org.junit.jupiter.api.BeforeEach;
78
import org.junit.jupiter.api.Test;
@@ -92,7 +93,7 @@ void testPeriodicFailbackCheckWithHealthyCluster() throws InterruptedException {
9293
assertEquals(provider.getCluster(endpoint2), provider.getCluster());
9394

9495
// Make cluster2 unhealthy to force failover to cluster1
95-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint2, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
96+
onHealthStatusChange(provider, endpoint2, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
9697

9798
// Should now be on cluster1 (cluster2 is in grace period)
9899
assertEquals(provider.getCluster(endpoint1), provider.getCluster());
@@ -101,7 +102,7 @@ void testPeriodicFailbackCheckWithHealthyCluster() throws InterruptedException {
101102
assertTrue(provider.getCluster(endpoint2).isInGracePeriod());
102103

103104
// Make cluster2 healthy again (but it's still in grace period)
104-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint2, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
105+
onHealthStatusChange(provider, endpoint2, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
105106

106107
// Trigger periodic check immediately - should still be on cluster1
107108
MultiClusterPooledConnectionProviderHelper.periodicFailbackCheck(provider);
@@ -137,13 +138,13 @@ void testPeriodicFailbackCheckWithFailbackDisabled() throws InterruptedException
137138
assertEquals(provider.getCluster(endpoint2), provider.getCluster());
138139

139140
// Make cluster2 unhealthy to force failover to cluster1
140-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint2, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
141+
onHealthStatusChange(provider, endpoint2, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
141142

142143
// Should now be on cluster1
143144
assertEquals(provider.getCluster(endpoint1), provider.getCluster());
144145

145146
// Make cluster2 healthy again
146-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint2, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
147+
onHealthStatusChange(provider, endpoint2, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
147148

148149
// Wait for stability period
149150
Thread.sleep(100);
@@ -181,20 +182,20 @@ void testPeriodicFailbackCheckSelectsHighestWeightCluster() throws InterruptedEx
181182
assertEquals(provider.getCluster(endpoint3), provider.getCluster());
182183

183184
// Make cluster3 unhealthy to force failover to cluster2 (next highest weight)
184-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint3, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
185+
onHealthStatusChange(provider, endpoint3, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
185186

186187
// Should now be on cluster2 (weight 2.0f, higher than cluster1's 1.0f)
187188
assertEquals(provider.getCluster(endpoint2), provider.getCluster());
188189

189190
// Make cluster2 unhealthy to force failover to cluster1
190-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint2, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
191+
onHealthStatusChange(provider, endpoint2, HealthStatus.HEALTHY, HealthStatus.UNHEALTHY);
191192

192193
// Should now be on cluster1 (only healthy cluster left)
193194
assertEquals(provider.getCluster(endpoint1), provider.getCluster());
194195

195196
// Make cluster2 and cluster3 healthy again
196-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint2, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
197-
MultiClusterPooledConnectionProviderHelper.onHealthStatusChange(provider, endpoint3, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
197+
onHealthStatusChange(provider, endpoint2, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
198+
onHealthStatusChange(provider, endpoint3, HealthStatus.UNHEALTHY, HealthStatus.HEALTHY);
198199

199200
// Wait for grace period to expire
200201
Thread.sleep(150);

0 commit comments

Comments
 (0)