Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/main/java/com/github/tonivade/resp/RespServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ public static class Builder implements Recoverable {
private String host = DEFAULT_HOST;
private int port = DEFAULT_PORT;
private CommandSuite commands = new CommandSuite();
private boolean parallelExecution = false;

public Builder host(String host) {
this.host = host;
Expand All @@ -254,8 +255,19 @@ public Builder commands(CommandSuite commands) {
return this;
}

/**
* Enables parallel command execution on Netty I/O threads, bypassing the
* single-thread RxJava scheduler. State uses ConcurrentHashMap for
* thread safety. Best for stateless, thread-safe commands where
* maximum throughput and lowest latency are required.
*/
public Builder parallelExecution() {
this.parallelExecution = true;
return this;
}

public RespServer build() {
return new RespServer(new RespServerContext(host, port, commands));
return new RespServer(new RespServerContext(host, port, commands, SessionListener.nullListener(), parallelExecution));
}
}
}
60 changes: 46 additions & 14 deletions src/main/java/com/github/tonivade/resp/RespServerContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static com.github.tonivade.resp.SessionListener.nullListener;
import static java.util.concurrent.Executors.newSingleThreadExecutor;

import java.util.HashMap;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
Expand All @@ -33,14 +34,10 @@ public class RespServerContext implements ServerContext {

private static final Logger LOGGER = LoggerFactory.getLogger(RespServerContext.class);

private final StateHolder state = new StateHolder();
private final StateHolder state;
private final ConcurrentHashMap<String, Session> clients = new ConcurrentHashMap<>();
private final Scheduler scheduler = Schedulers.from(newSingleThreadExecutor(runnable -> {
Thread thread = new Thread(runnable);
thread.setName(RESP_SERVER);
thread.setDaemon(true);
return thread;
}));
private final Scheduler scheduler;
private final boolean parallelExecution;

private final String host;
private final int port;
Expand All @@ -52,10 +49,24 @@ public RespServerContext(String host, int port, CommandSuite commands) {
}

public RespServerContext(String host, int port, CommandSuite commands, SessionListener sessionListener) {
this(host, port, commands, sessionListener, false);
}

public RespServerContext(String host, int port, CommandSuite commands,
SessionListener sessionListener, boolean parallelExecution) {
this.host = checkNonEmpty(host);
this.port = checkRange(port, 1024, 65535);
this.commands = checkNonNull(commands);
this.sessionListener = checkNonNull(sessionListener);
this.parallelExecution = parallelExecution;
if (parallelExecution) {
this.state = new StateHolder(new ConcurrentHashMap<>());
this.scheduler = null;
} else {
this.state = new StateHolder(new HashMap<>());
this.scheduler = Schedulers.from(
newSingleThreadExecutor(runnable -> newDaemonThread(runnable, RESP_SERVER)));
}
}

public void start() {
Expand All @@ -64,7 +75,9 @@ public void start() {

public void stop() {
clear();
scheduler.shutdown();
if (scheduler != null) {
scheduler.shutdown();
}
}

@Override
Expand Down Expand Up @@ -114,12 +127,21 @@ void processCommand(Request request) {
LOGGER.debug("received command: {}", request);

var command = getCommand(request.getCommand());
try {
enqueue(Observable.fromCallable(() -> executeCommand(command, request)))
.subscribe(response -> processResponse(request, response),
ex -> LOGGER.error("error executing command: " + request, ex));
} catch (RuntimeException ex) {
LOGGER.error("error executing command: " + request, ex);
if (parallelExecution) {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the condition scheduler == null

try {
RedisToken response = executeCommand(command, request);
processResponse(request, response);
} catch (RuntimeException ex) {
LOGGER.error("error executing command: " + request, ex);
}
} else {
try {
enqueue(Observable.fromCallable(() -> executeCommand(command, request)))
.subscribe(response -> processResponse(request, response),
ex -> LOGGER.error("error executing command: " + request, ex));
} catch (RuntimeException ex) {
LOGGER.error("error executing command: " + request, ex);
}
}
}

Expand All @@ -143,6 +165,9 @@ protected RedisToken executeCommand(RespCommand command, Request request) {
}

protected <T> Observable<T> enqueue(Observable<T> observable) {
if (scheduler == null) {
return observable;
}
return observable.subscribeOn(scheduler);
}

Expand All @@ -157,4 +182,11 @@ private void clear() {
clients.clear();
state.clear();
}

private static Thread newDaemonThread(Runnable runnable, String name) {
Thread thread = new Thread(runnable);
thread.setName(name);
thread.setDaemon(true);
return thread;
}
}
10 changes: 9 additions & 1 deletion src/main/java/com/github/tonivade/resp/StateHolder.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@

public class StateHolder {

private final Map<String, Object> state = new HashMap<>();
private final Map<String, Object> state;

public StateHolder() {
this(new HashMap<>());
}

public StateHolder(Map<String, Object> state) {
this.state = state;
}

@SuppressWarnings("unchecked")
public <T> Optional<T> getValue(String key) {
Expand Down
28 changes: 28 additions & 0 deletions src/test/java/com/github/tonivade/resp/RespServerContextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,34 @@ public void requireCallback() {
assertThrows(IllegalArgumentException.class, () -> new RespServerContext(HOST, PORT, null));
}

@Test
public void processCommandParallelExecution() {
var parallelContext = new RespServerContext(HOST, PORT, commands,
SessionListener.nullListener(), true);
Request request = new DefaultRequest(parallelContext, session, safeString("test"), Collections.emptyList());
when(commands.getCommand(request.getCommand())).thenReturn(respCommand);
when(respCommand.execute(request)).thenReturn(nullString());

parallelContext.processCommand(request);

verify(respCommand, timeout(1000)).execute(request);
verify(session, timeout(1000)).publish(nullString());
}

@Test
public void processCommandParallelExecutionException() {
var parallelContext = new RespServerContext(HOST, PORT, commands,
SessionListener.nullListener(), true);
Request request = new DefaultRequest(parallelContext, session, safeString("test"), Collections.emptyList());
when(commands.getCommand(request.getCommand())).thenReturn(respCommand);
doThrow(RuntimeException.class).when(respCommand).execute(request);

parallelContext.processCommand(request);

verify(respCommand, timeout(1000)).execute(request);
verify(session, timeout(1000).atLeast(0)).publish(any());
}

private Request newRequest(String command) {
return new DefaultRequest(serverContext, session, safeString(command), Collections.emptyList());
}
Expand Down