Skip to content

Commit 7d71a80

Browse files
Merge pull request #292 from rabbitmq/rabbitmq-java-client-290
Add option to ensure RPC reply is for the current request
2 parents dace543 + 0963ee3 commit 7d71a80

File tree

7 files changed

+210
-17
lines changed

7 files changed

+210
-17
lines changed

src/main/java/com/rabbitmq/client/ConnectionFactory.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ public class ConnectionFactory implements Cloneable {
127127
*/
128128
private int channelRpcTimeout = DEFAULT_CHANNEL_RPC_TIMEOUT;
129129

130+
/**
131+
* Whether or not channels check the reply type of an RPC call.
132+
* Default is false.
133+
* @since 4.2.0
134+
*/
135+
private boolean channelShouldCheckRpcResponseType = false;
136+
130137
/** @return the default host to use for connections */
131138
public String getHost() {
132139
return host;
@@ -958,6 +965,7 @@ public ConnectionParams params(ExecutorService consumerWorkServiceExecutor) {
958965
result.setShutdownExecutor(shutdownExecutor);
959966
result.setHeartbeatExecutor(heartbeatExecutor);
960967
result.setChannelRpcTimeout(channelRpcTimeout);
968+
result.setChannelShouldCheckRpcResponseType(channelShouldCheckRpcResponseType);
961969
return result;
962970
}
963971

@@ -1126,4 +1134,19 @@ public void setChannelRpcTimeout(int channelRpcTimeout) {
11261134
public int getChannelRpcTimeout() {
11271135
return channelRpcTimeout;
11281136
}
1137+
1138+
/**
1139+
* When set to true, channels will check the response type (e.g. queue.declare
1140+
* expects a queue.declare-ok response) of RPC calls
1141+
* and ignore those that do not match.
1142+
* Default is false.
1143+
* @param channelShouldCheckRpcResponseType
1144+
*/
1145+
public void setChannelShouldCheckRpcResponseType(boolean channelShouldCheckRpcResponseType) {
1146+
this.channelShouldCheckRpcResponseType = channelShouldCheckRpcResponseType;
1147+
}
1148+
1149+
public boolean isChannelShouldCheckRpcResponseType() {
1150+
return channelShouldCheckRpcResponseType;
1151+
}
11291152
}

src/main/java/com/rabbitmq/client/impl/AMQChannel.java

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
import java.util.concurrent.TimeoutException;
2121

2222
import com.rabbitmq.client.*;
23+
import com.rabbitmq.client.AMQP.Basic;
24+
import com.rabbitmq.client.AMQP.Confirm;
25+
import com.rabbitmq.client.AMQP.Exchange;
26+
import com.rabbitmq.client.AMQP.Queue;
27+
import com.rabbitmq.client.AMQP.Tx;
2328
import com.rabbitmq.client.Method;
2429
import com.rabbitmq.utility.BlockingValueOrException;
2530
import org.slf4j.Logger;
@@ -66,6 +71,8 @@ public abstract class AMQChannel extends ShutdownNotifierComponent {
6671
/** Timeout for RPC calls */
6772
protected final int _rpcTimeout;
6873

74+
private final boolean _checkRpcResponseType;
75+
6976
/**
7077
* Construct a channel on the given connection, with the given channel number.
7178
* @param connection the underlying connection for this channel
@@ -78,6 +85,7 @@ public AMQChannel(AMQConnection connection, int channelNumber) {
7885
throw new IllegalArgumentException("Continuation timeout on RPC calls cannot be less than 0");
7986
}
8087
this._rpcTimeout = connection.getChannelRpcTimeout();
88+
this._checkRpcResponseType = connection.willCheckRpcResponseType();
8189
}
8290

8391
/**
@@ -153,8 +161,19 @@ public void handleCompleteInboundCommand(AMQCommand command) throws IOException
153161
// waiting RPC continuation.
154162
if (!processAsync(command)) {
155163
// The filter decided not to handle/consume the command,
156-
// so it must be some reply to an earlier RPC.
157-
RpcContinuation nextOutstandingRpc = nextOutstandingRpc();
164+
// so it must be a response to an earlier RPC.
165+
if (_checkRpcResponseType) {
166+
synchronized (_channelMutex) {
167+
// check if this reply command is intended for the current waiting request before calling nextOutstandingRpc()
168+
if (!_activeRpc.canHandleReply(command)) {
169+
// this reply command is not intended for the current waiting request
170+
// most likely a previous request timed out and this command is the reply for that.
171+
// Throw this reply command away so we don't stop the current request from waiting for its reply
172+
return;
173+
}
174+
}
175+
}
176+
final RpcContinuation nextOutstandingRpc = nextOutstandingRpc();
158177
// the outstanding RPC can be null when calling Channel#asyncRpc
159178
if(nextOutstandingRpc != null) {
160179
nextOutstandingRpc.handleCommand(command);
@@ -229,7 +248,7 @@ public AMQCommand rpc(Method m, int timeout)
229248
private AMQCommand privateRpc(Method m)
230249
throws IOException, ShutdownSignalException
231250
{
232-
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation();
251+
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation(m);
233252
rpc(m, k);
234253
// At this point, the request method has been sent, and we
235254
// should wait for the reply to arrive.
@@ -266,7 +285,7 @@ protected ChannelContinuationTimeoutException wrapTimeoutException(final Method
266285

267286
private AMQCommand privateRpc(Method m, int timeout)
268287
throws IOException, ShutdownSignalException, TimeoutException {
269-
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation();
288+
SimpleBlockingRpcContinuation k = new SimpleBlockingRpcContinuation(m);
270289
rpc(m, k);
271290

272291
try {
@@ -384,13 +403,25 @@ public AMQConnection getConnection() {
384403

385404
public interface RpcContinuation {
386405
void handleCommand(AMQCommand command);
406+
/** @return true if the reply command can be handled for this request */
407+
boolean canHandleReply(AMQCommand command);
387408
void handleShutdownSignal(ShutdownSignalException signal);
388409
}
389410

390411
public static abstract class BlockingRpcContinuation<T> implements RpcContinuation {
391412
public final BlockingValueOrException<T, ShutdownSignalException> _blocker =
392413
new BlockingValueOrException<T, ShutdownSignalException>();
393414

415+
protected final Method request;
416+
417+
public BlockingRpcContinuation() {
418+
request = null;
419+
}
420+
421+
public BlockingRpcContinuation(final Method request) {
422+
this.request = request;
423+
}
424+
394425
@Override
395426
public void handleCommand(AMQCommand command) {
396427
_blocker.setValue(transformReply(command));
@@ -412,12 +443,79 @@ public T getReply(int timeout)
412443
return _blocker.uninterruptibleGetValue(timeout);
413444
}
414445

446+
@Override
447+
public boolean canHandleReply(AMQCommand command) {
448+
// make a best effort attempt to ensure the reply was intended for this rpc request
449+
// Ideally each rpc request would tag an id on it that could be returned and referenced on its reply.
450+
// But because that would be a very large undertaking to add passively this logic at least protects against ClassCastExceptions
451+
if (request != null) {
452+
final Method reply = command.getMethod();
453+
if (request instanceof Basic.Qos) {
454+
return reply instanceof Basic.QosOk;
455+
} else if (request instanceof Basic.Get) {
456+
return reply instanceof Basic.GetOk || reply instanceof Basic.GetEmpty;
457+
} else if (request instanceof Basic.Consume) {
458+
if (!(reply instanceof Basic.ConsumeOk))
459+
return false;
460+
// can also check the consumer tags match here. handle case where request consumer tag is empty and server-generated.
461+
final String consumerTag = ((Basic.Consume)request).getConsumerTag();
462+
return consumerTag == null || consumerTag.equals("") || consumerTag.equals(((Basic.ConsumeOk)reply).getConsumerTag());
463+
} else if (request instanceof Basic.Cancel) {
464+
if (!(reply instanceof Basic.CancelOk))
465+
return false;
466+
// can also check the consumer tags match here
467+
return ((Basic.Cancel)request).getConsumerTag().equals(((Basic.CancelOk)reply).getConsumerTag());
468+
} else if (request instanceof Basic.Recover) {
469+
return reply instanceof Basic.RecoverOk;
470+
} else if (request instanceof Exchange.Declare) {
471+
return reply instanceof Exchange.DeclareOk;
472+
} else if (request instanceof Exchange.Delete) {
473+
return reply instanceof Exchange.DeleteOk;
474+
} else if (request instanceof Exchange.Bind) {
475+
return reply instanceof Exchange.BindOk;
476+
} else if (request instanceof Exchange.Unbind) {
477+
return reply instanceof Exchange.UnbindOk;
478+
} else if (request instanceof Queue.Declare) {
479+
// we cannot check the queue name, as the server can strip some characters
480+
// see QueueLifecycle test and https://github.com/rabbitmq/rabbitmq-server/issues/710
481+
return reply instanceof Queue.DeclareOk;
482+
} else if (request instanceof Queue.Delete) {
483+
return reply instanceof Queue.DeleteOk;
484+
} else if (request instanceof Queue.Bind) {
485+
return reply instanceof Queue.BindOk;
486+
} else if (request instanceof Queue.Unbind) {
487+
return reply instanceof Queue.UnbindOk;
488+
} else if (request instanceof Queue.Purge) {
489+
return reply instanceof Queue.PurgeOk;
490+
} else if (request instanceof Tx.Select) {
491+
return reply instanceof Tx.SelectOk;
492+
} else if (request instanceof Tx.Commit) {
493+
return reply instanceof Tx.CommitOk;
494+
} else if (request instanceof Tx.Rollback) {
495+
return reply instanceof Tx.RollbackOk;
496+
} else if (request instanceof Confirm.Select) {
497+
return reply instanceof Confirm.SelectOk;
498+
}
499+
}
500+
// for passivity default to true
501+
return true;
502+
}
503+
415504
public abstract T transformReply(AMQCommand command);
416505
}
417506

418507
public static class SimpleBlockingRpcContinuation
419508
extends BlockingRpcContinuation<AMQCommand>
420509
{
510+
511+
public SimpleBlockingRpcContinuation() {
512+
super();
513+
}
514+
515+
public SimpleBlockingRpcContinuation(final Method method) {
516+
super(method);
517+
}
518+
421519
@Override
422520
public AMQCommand transformReply(AMQCommand command) {
423521
return command;

src/main/java/com/rabbitmq/client/impl/AMQConnection.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ public static Map<String, Object> defaultClientProperties() {
134134
private final Collection<BlockedListener> blockedListeners = new CopyOnWriteArrayList<BlockedListener>();
135135
protected final MetricsCollector metricsCollector;
136136
private final int channelRpcTimeout;
137+
private final boolean channelShouldCheckRpcResponseType;
137138

138139
/* State modified after start - all volatile */
139140

@@ -229,6 +230,7 @@ public AMQConnection(ConnectionParams params, FrameHandler frameHandler, Metrics
229230
throw new IllegalArgumentException("Continuation timeout on RPC calls cannot be less than 0");
230231
}
231232
this.channelRpcTimeout = params.getChannelRpcTimeout();
233+
this.channelShouldCheckRpcResponseType = params.channelShouldCheckRpcResponseType();
232234

233235
this._channel0 = new AMQChannel(this, 0) {
234236
@Override public boolean processAsync(Command c) throws IOException {
@@ -1056,4 +1058,8 @@ public void setId(String id) {
10561058
public int getChannelRpcTimeout() {
10571059
return channelRpcTimeout;
10581060
}
1061+
1062+
public boolean willCheckRpcResponseType() {
1063+
return channelShouldCheckRpcResponseType;
1064+
}
10591065
}

src/main/java/com/rabbitmq/client/impl/ChannelN.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,15 @@ public String basicConsume(String queue, final boolean autoAck, String consumerT
12291229
final Consumer callback)
12301230
throws IOException
12311231
{
1232-
BlockingRpcContinuation<String> k = new BlockingRpcContinuation<String>() {
1232+
final Method m = new Basic.Consume.Builder()
1233+
.queue(queue)
1234+
.consumerTag(consumerTag)
1235+
.noLocal(noLocal)
1236+
.noAck(autoAck)
1237+
.exclusive(exclusive)
1238+
.arguments(arguments)
1239+
.build();
1240+
BlockingRpcContinuation<String> k = new BlockingRpcContinuation<String>(m) {
12331241
@Override
12341242
public String transformReply(AMQCommand replyCommand) {
12351243
String actualConsumerTag = ((Basic.ConsumeOk) replyCommand.getMethod()).getConsumerTag();
@@ -1243,14 +1251,7 @@ public String transformReply(AMQCommand replyCommand) {
12431251
}
12441252
};
12451253

1246-
final Method m = new Basic.Consume.Builder()
1247-
.queue(queue)
1248-
.consumerTag(consumerTag)
1249-
.noLocal(noLocal)
1250-
.noAck(autoAck)
1251-
.exclusive(exclusive)
1252-
.arguments(arguments)
1253-
.build();
1254+
12541255
rpc(m, k);
12551256

12561257
try {
@@ -1276,7 +1277,9 @@ public void basicCancel(final String consumerTag)
12761277
final Consumer originalConsumer = _consumers.get(consumerTag);
12771278
if (originalConsumer == null)
12781279
throw new IOException("Unknown consumerTag");
1279-
BlockingRpcContinuation<Consumer> k = new BlockingRpcContinuation<Consumer>() {
1280+
1281+
final Method m = new Basic.Cancel(consumerTag, false);
1282+
BlockingRpcContinuation<Consumer> k = new BlockingRpcContinuation<Consumer>(m) {
12801283
@Override
12811284
public Consumer transformReply(AMQCommand replyCommand) {
12821285
if (!(replyCommand.getMethod() instanceof Basic.CancelOk))
@@ -1287,7 +1290,7 @@ public Consumer transformReply(AMQCommand replyCommand) {
12871290
}
12881291
};
12891292

1290-
final Method m = new Basic.Cancel(consumerTag, false);
1293+
12911294
rpc(m, k);
12921295

12931296
try {

src/main/java/com/rabbitmq/client/impl/ConnectionParams.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public class ConnectionParams {
4040
private long networkRecoveryInterval;
4141
private boolean topologyRecovery;
4242
private int channelRpcTimeout;
43+
private boolean channelShouldCheckRpcResponseType;
4344

4445
private ExceptionHandler exceptionHandler;
4546
private ThreadFactory threadFactory;
@@ -114,6 +115,10 @@ public int getChannelRpcTimeout() {
114115
return channelRpcTimeout;
115116
}
116117

118+
public boolean channelShouldCheckRpcResponseType() {
119+
return channelShouldCheckRpcResponseType;
120+
}
121+
117122
public void setUsername(String username) {
118123
this.username = username;
119124
}
@@ -189,4 +194,8 @@ public void setHeartbeatExecutor(ScheduledExecutorService heartbeatExecutor) {
189194
public void setChannelRpcTimeout(int channelRpcTimeout) {
190195
this.channelRpcTimeout = channelRpcTimeout;
191196
}
197+
198+
public void setChannelShouldCheckRpcResponseType(boolean channelShouldCheckRpcResponseType) {
199+
this.channelShouldCheckRpcResponseType = channelShouldCheckRpcResponseType;
200+
}
192201
}

src/main/java/com/rabbitmq/client/impl/nio/HeaderWriteRequest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import java.io.DataOutputStream;
2121
import java.io.IOException;
22-
import java.nio.ByteBuffer;
23-
import java.nio.channels.WritableByteChannel;
2422

2523
/**
2624
*

src/test/java/com/rabbitmq/client/test/AMQChannelTest.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,62 @@ public Void call() throws Exception {
106106
assertThat(rpcResponse.getMethod(), is(response));
107107
}
108108

109+
@Test
110+
public void testRpcTimeoutReplyComesDuringNexRpc() throws Exception {
111+
int rpcTimeout = 100;
112+
AMQConnection connection = mock(AMQConnection.class);
113+
when(connection.getChannelRpcTimeout()).thenReturn(rpcTimeout);
114+
when(connection.willCheckRpcResponseType()).thenReturn(Boolean.TRUE);
115+
116+
final DummyAmqChannel channel = new DummyAmqChannel(connection, 1);
117+
Method method = new AMQImpl.Queue.Declare.Builder()
118+
.queue("123")
119+
.durable(false)
120+
.exclusive(true)
121+
.autoDelete(true)
122+
.arguments(null)
123+
.build();
124+
125+
try {
126+
channel.rpc(method);
127+
fail("Should time out and throw an exception");
128+
} catch(final ChannelContinuationTimeoutException e) {
129+
// OK
130+
assertThat((DummyAmqChannel) e.getChannel(), is(channel));
131+
assertThat(e.getChannelNumber(), is(channel.getChannelNumber()));
132+
assertThat(e.getMethod(), is(method));
133+
assertNull("outstanding RPC should have been cleaned", channel.nextOutstandingRpc());
134+
}
135+
136+
// now do a basic.consume request and have the queue.declareok returned instead
137+
method = new AMQImpl.Basic.Consume.Builder()
138+
.queue("123")
139+
.consumerTag("")
140+
.arguments(null)
141+
.build();
142+
143+
final Method response1 = new AMQImpl.Queue.DeclareOk.Builder()
144+
.queue("123")
145+
.consumerCount(0)
146+
.messageCount(0).build();
147+
148+
final Method response2 = new AMQImpl.Basic.ConsumeOk.Builder()
149+
.consumerTag("456").build();
150+
151+
scheduler.schedule(new Callable<Void>() {
152+
@Override
153+
public Void call() throws Exception {
154+
channel.handleCompleteInboundCommand(new AMQCommand(response1));
155+
Thread.sleep(10);
156+
channel.handleCompleteInboundCommand(new AMQCommand(response2));
157+
return null;
158+
}
159+
}, (long) (rpcTimeout / 2.0), TimeUnit.MILLISECONDS);
160+
161+
AMQCommand rpcResponse = channel.rpc(method);
162+
assertThat(rpcResponse.getMethod(), is(response2));
163+
}
164+
109165
static class DummyAmqChannel extends AMQChannel {
110166

111167
public DummyAmqChannel(AMQConnection connection, int channelNumber) {

0 commit comments

Comments
 (0)