Skip to content

Commit d23329c

Browse files
committed
Handle protocol version mismatch in Netty IO handler
1 parent 12331e4 commit d23329c

File tree

10 files changed

+139
-17
lines changed

10 files changed

+139
-17
lines changed

pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@
779779
<include>src/main/java/com/rabbitmq/client/observation/**/*.java</include>
780780
<include>src/test/java/com/rabbitmq/client/test/functional/MicrometerObservationCollectorMetrics.java</include>
781781
<include>src/test/java/com/rabbitmq/client/test/NettyTest.java</include>
782+
<include>src/test/java/com/rabbitmq/client/test/ProtocolVersionMismatch.java</include>
782783
</includes>
783784
<googleJavaFormat>
784785
<version>${google-java-format.version}</version>

src/main/java/com/rabbitmq/client/ConnectionFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ public ConnectionFactory setExceptionHandler(ExceptionHandler exceptionHandler)
813813
public boolean isSSL() {
814814
return getSocketFactory() instanceof SSLSocketFactory
815815
|| sslContextFactory != null
816-
|| this.netty().isTls();
816+
|| this.nettyConf.isTls();
817817
}
818818

819819
/**

src/main/java/com/rabbitmq/client/impl/AMQConnection.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ public void start()
493493

494494
// We can now respond to errors having finished tailoring the connection
495495
this._inConnectionNegotiation = false;
496+
this._frameHandler.finishConnectionNegotiation();
496497
}
497498

498499
protected ChannelManager instantiateChannelManager(int channelMax, ThreadFactory threadFactory) {

src/main/java/com/rabbitmq/client/impl/Frame.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ public static Frame readFrom(DataInputStream is, int maxPayloadSize) throws IOEx
9090
try {
9191
type = is.readUnsignedByte();
9292
} catch (SocketTimeoutException ste) {
93-
// System.err.println("Timed out waiting for a frame.");
9493
return null; // failed
9594
}
9695

src/main/java/com/rabbitmq/client/impl/FrameHandler.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ default boolean internalHearbeat() {
5454

5555
void initialize(AMQConnection connection);
5656

57+
default void finishConnectionNegotiation() {
58+
59+
}
60+
5761
/**
5862
* Read a {@link Frame} from the underlying data connection.
5963
* @return an incoming Frame, or null if there is none
@@ -77,4 +81,5 @@ default boolean internalHearbeat() {
7781

7882
/** Close the underlying data connection (complaint not permitted). */
7983
void close();
84+
8085
}

src/main/java/com/rabbitmq/client/impl/NettyFrameHandlerFactory.java

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
import io.netty.handler.timeout.IdleStateEvent;
4545
import io.netty.handler.timeout.IdleStateHandler;
4646
import io.netty.handler.timeout.ReadTimeoutHandler;
47+
import java.io.ByteArrayInputStream;
48+
import java.io.DataInputStream;
4749
import java.io.IOException;
4850
import java.net.InetAddress;
4951
import java.net.InetSocketAddress;
@@ -137,6 +139,8 @@ private static final class NettyFrameHandler implements FrameHandler {
137139
LengthFieldBasedFrameDecoder.class.getSimpleName();
138140
private static final String HANDLER_READ_TIMEOUT = ReadTimeoutHandler.class.getSimpleName();
139141
private static final String HANDLER_IDLE_STATE = IdleStateHandler.class.getSimpleName();
142+
private static final String HANDLER_PROTOCOL_VERSION_MISMATCH =
143+
ProtocolVersionMismatchHandler.class.getSimpleName();
140144
private static final byte[] HEADER =
141145
new byte[] {
142146
'A', 'M', 'Q', 'P', 0, AMQP.PROTOCOL.MAJOR, AMQP.PROTOCOL.MINOR, AMQP.PROTOCOL.REVISION
@@ -193,6 +197,8 @@ public void initChannel(SocketChannel ch) {
193197
HANDLER_FLUSH_CONSOLIDATION,
194198
new FlushConsolidationHandler(
195199
FlushConsolidationHandler.DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES, true));
200+
ch.pipeline()
201+
.addLast(HANDLER_PROTOCOL_VERSION_MISMATCH, new ProtocolVersionMismatchHandler());
196202
ch.pipeline()
197203
.addLast(
198204
HANDLER_FRAME_DECODER,
@@ -284,6 +290,11 @@ public void initialize(AMQConnection connection) {
284290
this.handler.connection = connection;
285291
}
286292

293+
@Override
294+
public void finishConnectionNegotiation() {
295+
maybeRemoveHandler(HANDLER_PROTOCOL_VERSION_MISMATCH);
296+
}
297+
287298
@Override
288299
public Frame readFrame() {
289300
throw new UnsupportedOperationException();
@@ -412,11 +423,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
412423
if (noProblem
413424
&& (!this.connection.isRunning() || this.connection.hasBrokerInitiatedShutdown())) {
414425
// looks like the frame was Close-Ok or Close
415-
ctx.executor()
416-
.submit(
417-
() -> {
418-
this.connection.doFinalShutdown();
419-
});
426+
ctx.executor().submit(() -> this.connection.doFinalShutdown());
420427
}
421428
} finally {
422429
m.release();
@@ -503,4 +510,23 @@ private CountDownLatch writableLatch() {
503510
return this.writableLatch.get();
504511
}
505512
}
513+
514+
private static final class ProtocolVersionMismatchHandler extends ChannelInboundHandlerAdapter {
515+
516+
@Override
517+
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
518+
ByteBuf b = (ByteBuf) msg;
519+
if (b.readByte() == 'A') {
520+
// likely an AMQP header that indicates a protocol version mismatch
521+
// the header is small, we read everything in memory and use the Frame class
522+
int toRead = Math.min(b.readableBytes(), NettyFrameHandler.HEADER.length - 1);
523+
byte[] header = new byte[toRead];
524+
b.readBytes(header);
525+
Frame.protocolVersionMismatch(new DataInputStream(new ByteArrayInputStream(header)));
526+
} else {
527+
b.readerIndex(0);
528+
ctx.fireChannelRead(msg);
529+
}
530+
}
531+
}
506532
}

src/test/java/com/rabbitmq/client/test/ClientTestSuite.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@
7575
AMQConnectionRefreshCredentialsTest.class,
7676
ValueWriterTest.class,
7777
BlockedConnectionTest.class,
78-
NettyTest.class
78+
NettyTest.class,
79+
ProtocolVersionMismatch.class
7980
})
8081
public class ClientTestSuite {
8182

src/test/java/com/rabbitmq/client/test/FrameTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ private void checkWrittenChunks(int totalFrameSize, AccumulatorWritableByteChann
7272

7373
private static class AccumulatorWritableByteChannel implements WritableByteChannel {
7474

75-
List<byte[]> chunks = new ArrayList<byte[]>();
75+
List<byte[]> chunks = new ArrayList<>();
7676

7777
Random random = new Random();
7878

7979
@Override
80-
public int write(ByteBuffer src) throws IOException {
80+
public int write(ByteBuffer src) {
8181
int remaining = src.remaining();
8282
if(remaining > 0) {
8383
int toRead = random.nextInt(remaining) + 1;
@@ -88,7 +88,6 @@ public int write(ByteBuffer src) throws IOException {
8888
} else {
8989
return remaining;
9090
}
91-
9291
}
9392

9493
@Override
@@ -97,9 +96,7 @@ public boolean isOpen() {
9796
}
9897

9998
@Override
100-
public void close() throws IOException {
101-
102-
}
99+
public void close() { }
103100
}
104101

105102
public static void drain(WritableByteChannel channel, ByteBuffer buffer) throws IOException {
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (c) 2025 Broadcom. All Rights Reserved.
2+
// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
3+
//
4+
// This software, the RabbitMQ Java client library, is triple-licensed under the
5+
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
6+
// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see
7+
// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL,
8+
// please see LICENSE-APACHE2.
9+
//
10+
// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
11+
// either express or implied. See the LICENSE file for specific language governing
12+
// rights and limitations of this software.
13+
//
14+
// If you have any questions regarding licensing, please contact us at
15+
16+
package com.rabbitmq.client.test;
17+
18+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
19+
20+
import com.rabbitmq.client.ConnectionFactory;
21+
import com.rabbitmq.client.MalformedFrameException;
22+
import io.netty.bootstrap.ServerBootstrap;
23+
import io.netty.buffer.ByteBuf;
24+
import io.netty.channel.ChannelFutureListener;
25+
import io.netty.channel.ChannelHandlerContext;
26+
import io.netty.channel.ChannelInboundHandlerAdapter;
27+
import io.netty.channel.ChannelInitializer;
28+
import io.netty.channel.EventLoopGroup;
29+
import io.netty.channel.MultiThreadIoEventLoopGroup;
30+
import io.netty.channel.nio.NioIoHandler;
31+
import io.netty.channel.socket.SocketChannel;
32+
import io.netty.channel.socket.nio.NioServerSocketChannel;
33+
import io.netty.util.ReferenceCountUtil;
34+
import java.util.concurrent.TimeUnit;
35+
import org.junit.jupiter.params.ParameterizedTest;
36+
import org.junit.jupiter.params.provider.MethodSource;
37+
38+
public class ProtocolVersionMismatch {
39+
40+
@ParameterizedTest
41+
@MethodSource("com.rabbitmq.client.test.TestUtils#ioLayers")
42+
void connectionShouldFailWithProtocolVersionMismatch(String ioLayer) throws Exception {
43+
int port = TestUtils.randomNetworkPort();
44+
try (SimpleServer ignored = new SimpleServer(port)) {
45+
ConnectionFactory cf = TestUtils.connectionFactory();
46+
TestUtils.setIoLayer(cf, ioLayer);
47+
cf.setPort(port);
48+
assertThatThrownBy(cf::newConnection).hasRootCauseInstanceOf(MalformedFrameException.class);
49+
}
50+
}
51+
52+
private static class SimpleServer implements AutoCloseable {
53+
54+
private final EventLoopGroup elp =
55+
new MultiThreadIoEventLoopGroup(1, NioIoHandler.newFactory());
56+
57+
private static final byte[] AMQP_HEADER = new byte[] {'A', 'M', 'Q', 'P', 0, 1, 0, 0};
58+
59+
private SimpleServer(int port) throws InterruptedException {
60+
ServerBootstrap b = new ServerBootstrap();
61+
b.group(elp);
62+
b.channel(NioServerSocketChannel.class);
63+
b.childHandler(
64+
new ChannelInitializer<SocketChannel>() {
65+
@Override
66+
protected void initChannel(SocketChannel ch) {
67+
ch.pipeline()
68+
.addLast(
69+
new ChannelInboundHandlerAdapter() {
70+
@Override
71+
public void channelRead(ChannelHandlerContext ctx, Object msg) {
72+
// discard the data
73+
ReferenceCountUtil.release(msg);
74+
ByteBuf b = ctx.alloc().buffer(AMQP_HEADER.length);
75+
b.writeBytes(AMQP_HEADER);
76+
ctx.channel()
77+
.writeAndFlush(b)
78+
.addListener(ChannelFutureListener.CLOSE);
79+
}
80+
});
81+
}
82+
})
83+
.validate();
84+
b.bind(port).sync();
85+
}
86+
87+
@Override
88+
public void close() {
89+
this.elp.shutdownGracefully(0, 0, TimeUnit.SECONDS);
90+
}
91+
}
92+
}

src/test/java/com/rabbitmq/client/test/TestUtils.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ public class TestUtils {
6262
public static ConnectionFactory connectionFactory() {
6363
ConnectionFactory connectionFactory = new ConnectionFactory();
6464
setIoLayer(connectionFactory);
65-
if (isNetty()) {
66-
connectionFactory.netty().eventLoopGroup(EVENT_LOOP_GROUP.get());
67-
}
6865
return connectionFactory;
6966
}
7067

@@ -98,6 +95,9 @@ private static boolean isNetty(String layer) {
9895

9996
public static void setIoLayer(ConnectionFactory cf) {
10097
setIoLayer(cf, IO_LAYER);
98+
if (isNetty()) {
99+
cf.netty().eventLoopGroup(EVENT_LOOP_GROUP.get());
100+
}
101101
}
102102

103103
public static void setIoLayer(ConnectionFactory cf, String layer) {

0 commit comments

Comments
 (0)