Skip to content

Commit 5e6ee97

Browse files
Asynchronous FlightServerChannel to avoid thread contention per request (opensearch-project#19403)
* Async flush of batches at server for Stream Transport Signed-off-by: Rishabh Maurya <[email protected]> * Move BatchTask to record Signed-off-by: Rishabh Maurya <[email protected]> --------- Signed-off-by: Rishabh Maurya <[email protected]>
1 parent 1a7aa0d commit 5e6ee97

File tree

7 files changed

+304
-54
lines changed

7 files changed

+304
-54
lines changed

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerConfig.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import io.netty.channel.nio.NioEventLoopGroup;
3131
import io.netty.channel.socket.nio.NioServerSocketChannel;
3232
import io.netty.channel.socket.nio.NioSocketChannel;
33+
import io.netty.util.NettyRuntime;
3334

3435
/**
3536
* Configuration class for OpenSearch Flight server settings.
@@ -87,6 +88,13 @@ public ServerConfig() {}
8788
Setting.Property.NodeScope
8889
);
8990

91+
static final Setting<Integer> FLIGHT_EVENT_LOOP_THREADS = Setting.intSetting(
92+
"flight.event_loop.threads",
93+
Math.max(1, NettyRuntime.availableProcessors() * 2),
94+
1,
95+
Setting.Property.NodeScope
96+
);
97+
9098
static final Setting<Boolean> ARROW_SSL_ENABLE = Setting.boolSetting(
9199
"flight.ssl.enable",
92100
false, // TODO: get default from security enabled
@@ -112,6 +120,7 @@ public ServerConfig() {}
112120
private static int threadPoolMin;
113121
private static int threadPoolMax;
114122
private static TimeValue keepAlive;
123+
private static int eventLoopThreads;
115124

116125
/**
117126
* Initializes the server configuration with the provided settings.
@@ -134,6 +143,7 @@ public static void init(Settings settings) {
134143
threadPoolMin = FLIGHT_THREAD_POOL_MIN_SIZE.get(settings);
135144
threadPoolMax = FLIGHT_THREAD_POOL_MAX_SIZE.get(settings);
136145
keepAlive = FLIGHT_THREAD_POOL_KEEP_ALIVE.get(settings);
146+
eventLoopThreads = FLIGHT_EVENT_LOOP_THREADS.get(settings);
137147
}
138148

139149
/**
@@ -172,6 +182,15 @@ public static ScalingExecutorBuilder getClientExecutorBuilder() {
172182
return new ScalingExecutorBuilder(FLIGHT_CLIENT_THREAD_POOL_NAME, threadPoolMin, threadPoolMax, keepAlive);
173183
}
174184

185+
/**
186+
* Gets the configured number of event loop threads.
187+
*
188+
* @return The number of event loop threads
189+
*/
190+
public static int getEventLoopThreads() {
191+
return eventLoopThreads;
192+
}
193+
175194
/**
176195
* Returns a list of all settings managed by this configuration class.
177196
*
@@ -184,7 +203,8 @@ public static List<Setting<?>> getSettings() {
184203
ARROW_ENABLE_NULL_CHECK_FOR_GET,
185204
ARROW_ENABLE_DEBUG_ALLOCATOR,
186205
ARROW_ENABLE_UNSAFE_MEMORY_ACCESS,
187-
ARROW_SSL_ENABLE
206+
ARROW_SSL_ENABLE,
207+
FLIGHT_EVENT_LOOP_THREADS
188208
)
189209
);
190210
}

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/ArrowFlightProducer.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
6161
// https://github.com/apache/arrow/issues/38668
6262
executor.execute(() -> {
6363
FlightCallTracker callTracker = statsCollector.createServerCallTracker();
64-
FlightServerChannel channel = new FlightServerChannel(listener, allocator, middleware, callTracker);
64+
FlightServerChannel channel = new FlightServerChannel(
65+
listener,
66+
allocator,
67+
middleware,
68+
callTracker,
69+
flightTransport.getNextFlightExecutor()
70+
);
6571
try {
6672
BytesArray buf = new BytesArray(ticket.getBytes());
6773
callTracker.recordRequestBytes(buf.ramBytesUsed());

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightOutboundHandler.java

Lines changed: 160 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.Version;
2121
import org.opensearch.cluster.node.DiscoveryNode;
2222
import org.opensearch.common.io.stream.BytesStreamOutput;
23+
import org.opensearch.common.util.concurrent.ThreadContext;
2324
import org.opensearch.core.common.bytes.BytesReference;
2425
import org.opensearch.core.transport.TransportResponse;
2526
import org.opensearch.threadpool.ThreadPool;
@@ -89,85 +90,195 @@ public void sendResponse(
8990
);
9091
}
9192

92-
/** This needs to be synchronized for the cases when multiple batches are written concurrently,
93-
* as VectorSchemaRoot is shared across batches **/
94-
public synchronized void sendResponseBatch(
93+
@Override
94+
public void sendErrorResponse(
95+
Version nodeVersion,
96+
Set<String> features,
97+
TcpChannel channel,
98+
long requestId,
99+
String action,
100+
Exception error
101+
) throws IOException {
102+
throw new UnsupportedOperationException(
103+
"sendResponse() is not supported for streaming requests in FlightOutboundHandler; use sendResponseBatch()"
104+
);
105+
}
106+
107+
public void sendResponseBatch(
95108
final Version nodeVersion,
96109
final Set<String> features,
97110
final TcpChannel channel,
111+
final FlightTransportChannel transportChannel,
98112
final long requestId,
99113
final String action,
100114
final TransportResponse response,
101115
final boolean compress,
102116
final boolean isHandshake
103117
) throws IOException {
104-
// TODO add support for compression
118+
ThreadContext.StoredContext storedContext = threadPool.getThreadContext().stashContext();
119+
BatchTask task = new BatchTask(
120+
nodeVersion,
121+
features,
122+
channel,
123+
transportChannel,
124+
requestId,
125+
action,
126+
response,
127+
compress,
128+
isHandshake,
129+
false,
130+
false,
131+
null,
132+
storedContext
133+
);
134+
105135
if (!(channel instanceof FlightServerChannel flightChannel)) {
106-
throw new IllegalStateException("Expected FlightServerChannel, got " + channel.getClass().getName());
136+
messageListener.onResponseSent(requestId, action, new IllegalStateException("Expected FlightServerChannel"));
137+
return;
138+
}
139+
140+
flightChannel.getExecutor().execute(() -> {
141+
try (BatchTask ignored = task) {
142+
processBatchTask(task);
143+
} catch (Exception e) {
144+
messageListener.onResponseSent(requestId, action, e);
145+
}
146+
});
147+
}
148+
149+
private void processBatchTask(BatchTask task) {
150+
task.storedContext().restore();
151+
if (!(task.channel() instanceof FlightServerChannel flightChannel)) {
152+
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel().getClass().getName());
153+
messageListener.onResponseSent(task.requestId(), task.action(), error);
154+
return;
107155
}
156+
108157
try {
109158
try (VectorStreamOutput out = new VectorStreamOutput(flightChannel.getAllocator(), flightChannel.getRoot())) {
110-
response.writeTo(out);
111-
flightChannel.sendBatch(getHeaderBuffer(requestId, nodeVersion, features), out);
112-
messageListener.onResponseSent(requestId, action, response);
159+
task.response().writeTo(out);
160+
flightChannel.sendBatch(getHeaderBuffer(task.requestId(), task.nodeVersion(), task.features()), out);
161+
messageListener.onResponseSent(task.requestId(), task.action(), task.response());
113162
}
114-
} catch (StreamException e) {
115-
messageListener.onResponseSent(requestId, action, e);
116-
// Let StreamException propagate as is - it will be converted to FlightRuntimeException at a higher level
117-
throw e;
118163
} catch (FlightRuntimeException e) {
119-
messageListener.onResponseSent(requestId, action, e);
120-
throw FlightErrorMapper.fromFlightException(e);
164+
messageListener.onResponseSent(task.requestId(), task.action(), FlightErrorMapper.fromFlightException(e));
121165
} catch (Exception e) {
122-
messageListener.onResponseSent(requestId, action, e);
123-
throw e;
166+
messageListener.onResponseSent(task.requestId(), task.action(), e);
124167
}
125168
}
126169

127170
public void completeStream(
128171
final Version nodeVersion,
129172
final Set<String> features,
130173
final TcpChannel channel,
174+
final FlightTransportChannel transportChannel,
131175
final long requestId,
132176
final String action
133177
) {
178+
ThreadContext.StoredContext storedContext = threadPool.getThreadContext().stashContext();
179+
BatchTask completeTask = new BatchTask(
180+
nodeVersion,
181+
features,
182+
channel,
183+
transportChannel,
184+
requestId,
185+
action,
186+
TransportResponse.Empty.INSTANCE,
187+
false,
188+
false,
189+
true,
190+
false,
191+
null,
192+
storedContext
193+
);
194+
134195
if (!(channel instanceof FlightServerChannel flightChannel)) {
135-
throw new IllegalStateException("Expected FlightServerChannel, got " + channel.getClass().getName());
196+
messageListener.onResponseSent(requestId, action, new IllegalStateException("Expected FlightServerChannel"));
197+
return;
136198
}
199+
200+
flightChannel.getExecutor().execute(() -> {
201+
try (BatchTask ignored = completeTask) {
202+
processCompleteTask(completeTask);
203+
} catch (Exception e) {
204+
messageListener.onResponseSent(requestId, action, e);
205+
}
206+
});
207+
}
208+
209+
private void processCompleteTask(BatchTask task) {
210+
task.storedContext().restore();
211+
if (!(task.channel() instanceof FlightServerChannel flightChannel)) {
212+
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel().getClass().getName());
213+
messageListener.onResponseSent(task.requestId(), task.action(), error);
214+
return;
215+
}
216+
137217
try {
138218
flightChannel.completeStream();
139-
messageListener.onResponseSent(requestId, action, TransportResponse.Empty.INSTANCE);
140-
} catch (FlightRuntimeException e) {
141-
messageListener.onResponseSent(requestId, action, e);
142-
throw FlightErrorMapper.fromFlightException(e);
219+
messageListener.onResponseSent(task.requestId(), task.action(), TransportResponse.Empty.INSTANCE);
143220
} catch (Exception e) {
144-
messageListener.onResponseSent(requestId, action, e);
145-
throw e;
221+
messageListener.onResponseSent(task.requestId(), task.action(), e);
146222
}
147223
}
148224

149-
@Override
150225
public void sendErrorResponse(
151226
final Version nodeVersion,
152227
final Set<String> features,
153228
final TcpChannel channel,
229+
final FlightTransportChannel transportChannel,
154230
final long requestId,
155231
final String action,
156232
final Exception error
157-
) throws IOException {
158-
if (!(channel instanceof FlightServerChannel flightServerChannel)) {
159-
throw new IllegalStateException("Expected FlightServerChannel, got " + channel.getClass().getName());
233+
) {
234+
ThreadContext.StoredContext storedContext = threadPool.getThreadContext().stashContext();
235+
BatchTask errorTask = new BatchTask(
236+
nodeVersion,
237+
features,
238+
channel,
239+
transportChannel,
240+
requestId,
241+
action,
242+
null,
243+
false,
244+
false,
245+
false,
246+
true,
247+
error,
248+
storedContext
249+
);
250+
251+
if (!(channel instanceof FlightServerChannel flightChannel)) {
252+
messageListener.onResponseSent(requestId, action, new IllegalStateException("Expected FlightServerChannel"));
253+
return;
254+
}
255+
256+
flightChannel.getExecutor().execute(() -> {
257+
try (BatchTask ignored = errorTask) {
258+
processErrorTask(errorTask);
259+
} catch (Exception e) {
260+
messageListener.onResponseSent(requestId, action, e);
261+
}
262+
});
263+
}
264+
265+
private void processErrorTask(BatchTask task) {
266+
task.storedContext().restore();
267+
if (!(task.channel() instanceof FlightServerChannel flightServerChannel)) {
268+
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel().getClass().getName());
269+
messageListener.onResponseSent(task.requestId(), task.action(), error);
270+
return;
160271
}
272+
161273
try {
162-
Exception flightError = error;
163-
if (error instanceof StreamException) {
164-
flightError = FlightErrorMapper.toFlightException((StreamException) error);
274+
Exception flightError = task.error();
275+
if (task.error() instanceof StreamException) {
276+
flightError = FlightErrorMapper.toFlightException((StreamException) task.error());
165277
}
166-
flightServerChannel.sendError(getHeaderBuffer(requestId, version, features), flightError);
167-
messageListener.onResponseSent(requestId, action, error);
278+
flightServerChannel.sendError(getHeaderBuffer(task.requestId(), task.nodeVersion(), task.features()), flightError);
279+
messageListener.onResponseSent(task.requestId(), task.action(), task.error());
168280
} catch (Exception e) {
169-
messageListener.onResponseSent(requestId, action, e);
170-
throw e;
281+
messageListener.onResponseSent(task.requestId(), task.action(), e);
171282
}
172283
}
173284

@@ -197,4 +308,19 @@ private ByteBuffer getHeaderBuffer(long requestId, Version nodeVersion, Set<Stri
197308
return ByteBuffer.wrap(headerBytes.toBytesRef().bytes);
198309
}
199310
}
311+
312+
record BatchTask(Version nodeVersion, Set<String> features, TcpChannel channel, FlightTransportChannel transportChannel, long requestId,
313+
String action, TransportResponse response, boolean compress, boolean isHandshake, boolean isComplete, boolean isError,
314+
Exception error, ThreadContext.StoredContext storedContext) implements AutoCloseable {
315+
316+
@Override
317+
public void close() {
318+
if (storedContext != null) {
319+
storedContext.close();
320+
}
321+
if ((isComplete || isError) && transportChannel != null) {
322+
transportChannel.releaseChannel(isError);
323+
}
324+
}
325+
}
200326
}

0 commit comments

Comments
 (0)