Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/main/java/redis/clients/jedis/JedisClusterInfoCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,27 @@ public Map<String, ConnectionPool> getNodes() {
}
}

public Map<String, ConnectionPool> getPrimaryNodes() {
r.lock();
try {
Map<String, ConnectionPool> primaryNodes = new HashMap<>();
Set<ConnectionPool> addedPools = new HashSet<>();

for (int slot = 0; slot < slots.length; slot++) {
ConnectionPool pool = slots[slot];
if (pool != null && addedPools.add(pool)) {
HostAndPort hostAndPort = slotNodes[slot];
if (hostAndPort != null) {
primaryNodes.put(getNodeKey(hostAndPort), pool);
}
}
}
return primaryNodes;
} finally {
r.unlock();
}
}

public List<ConnectionPool> getShuffledNodesPool() {
r.lock();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import redis.clients.jedis.CommandArguments;
import redis.clients.jedis.CommandObject;
import redis.clients.jedis.Connection;
import redis.clients.jedis.ConnectionPool;
Expand All @@ -21,6 +27,13 @@

public class ClusterCommandExecutor implements CommandExecutor {

private static final Set<String> PRIMARY_ONLY_COMMANDS;
static {
PRIMARY_ONLY_COMMANDS = Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
"FUNCTION_DELETE", "FUNCTION_FLUSH", "FUNCTION_LOAD", "FUNCTION_RESTORE", "FUNCTION_KILL"
)));
}

private final Logger log = LoggerFactory.getLogger(getClass());

public final ClusterConnectionProvider provider;
Expand All @@ -45,7 +58,9 @@ public void close() {

@Override
public final <T> T broadcastCommand(CommandObject<T> commandObject) {
Map<String, ConnectionPool> connectionMap = provider.getConnectionMap();
Map<String, ConnectionPool> connectionMap = requiresPrimaryOnly(commandObject)
? provider.getPrimaryConnectionMap()
: provider.getConnectionMap();

boolean isErrored = false;
T reply = null;
Expand Down Expand Up @@ -76,6 +91,45 @@ public final <T> T broadcastCommand(CommandObject<T> commandObject) {
return reply;
}

private boolean requiresPrimaryOnly(CommandObject<?> commandObject) {
try {
String commandName = new String(commandObject.getArguments().getCommand().getRaw());

if ("FUNCTION".equals(commandName)) {
CommandArguments args = commandObject.getArguments();
Iterator<?> iterator = args.iterator();

if (iterator.hasNext()) {
iterator.next();

if (iterator.hasNext()) {
Object subCommandObj = iterator.next();

if (subCommandObj != null) {
try {
java.lang.reflect.Method getRawMethod = subCommandObj.getClass().getMethod("getRaw");
Object rawValue = getRawMethod.invoke(subCommandObj);

if (rawValue instanceof byte[]) {
String subCommand = new String((byte[]) rawValue);
String fullCommand = "FUNCTION_" + subCommand.toUpperCase();
return PRIMARY_ONLY_COMMANDS.contains(fullCommand);
}
} catch (Exception e) {
return false;
}
}
}
}
return false;
}

return PRIMARY_ONLY_COMMANDS.contains(commandName);
} catch (Exception e) {
return false;
}
}

@Override
public final <T> T executeCommand(CommandObject<T> commandObject) {
return doExecuteCommand(commandObject, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ public Map<String, ConnectionPool> getNodes() {
return cache.getNodes();
}

public Map<String, ConnectionPool> getPrimaryNodes() {
return cache.getPrimaryNodes();
}

public HostAndPort getNode(int slot) {
return slot >= 0 ? cache.getSlotNode(slot) : null;
}
Expand Down Expand Up @@ -209,4 +213,8 @@ public Connection getReplicaConnectionFromSlot(int slot) {
public Map<String, ConnectionPool> getConnectionMap() {
return Collections.unmodifiableMap(getNodes());
}

public Map<String, ConnectionPool> getPrimaryConnectionMap() {
return Collections.unmodifiableMap(getPrimaryNodes());
}
}
42 changes: 42 additions & 0 deletions src/test/java/redis/clients/jedis/ClusterCommandExecutorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.never;


import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.LongConsumer;
import org.hamcrest.MatcherAssert;
Expand Down Expand Up @@ -357,4 +362,41 @@ protected void sleep(long sleepMillis) {
inOrder.verifyNoMoreInteractions();
assertEquals(0L, totalSleepMs.get());
}

@Test
public void runFunctionCommandUsesPrimaryOnly() {
ClusterConnectionProvider connectionHandler = mock(ClusterConnectionProvider.class);

Map<String, ConnectionPool> primaryNodes = new HashMap<>();
primaryNodes.put("127.0.0.1:6379", mock(ConnectionPool.class));

Map<String, ConnectionPool> allNodes = new HashMap<>();
allNodes.put("127.0.0.1:6379", mock(ConnectionPool.class));
allNodes.put("127.0.0.1:6380", mock(ConnectionPool.class));

when(connectionHandler.getPrimaryConnectionMap()).thenReturn(primaryNodes);
when(connectionHandler.getConnectionMap()).thenReturn(allNodes);

Connection conn = mock(Connection.class);
for (ConnectionPool pool : primaryNodes.values()) {
when(pool.getResource()).thenReturn(conn);
}

ClusterCommandExecutor executor = new ClusterCommandExecutor(connectionHandler, 10, Duration.ZERO) {
@Override
public <T> T execute(Connection connection, CommandObject<T> commandObject) {
return (T) "mylib";
}
};

CommandObjects commandObjects = new CommandObjects();
CommandObject<String> functionLoadReplaceCommand = commandObjects.functionLoadReplace("script");

String result = executor.broadcastCommand(functionLoadReplaceCommand);

assertEquals("mylib", result);

verify(connectionHandler).getPrimaryConnectionMap();
verify(connectionHandler, never()).getConnectionMap();
}
}
Loading