diff --git a/mqtt-common/src/main/java/io/streamnative/pulsar/handlers/mqtt/common/MQTTCommonInboundHandler.java b/mqtt-common/src/main/java/io/streamnative/pulsar/handlers/mqtt/common/MQTTCommonInboundHandler.java index 807e50f93..34354d1d2 100644 --- a/mqtt-common/src/main/java/io/streamnative/pulsar/handlers/mqtt/common/MQTTCommonInboundHandler.java +++ b/mqtt-common/src/main/java/io/streamnative/pulsar/handlers/mqtt/common/MQTTCommonInboundHandler.java @@ -19,12 +19,18 @@ import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.mqtt.MqttConnectVariableHeader; import io.netty.handler.codec.mqtt.MqttMessage; import io.netty.handler.codec.mqtt.MqttMessageType; +import io.netty.handler.codec.mqtt.MqttVersion; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import io.netty.util.ReferenceCountUtil; import io.streamnative.pulsar.handlers.mqtt.common.adapter.MqttAdapterMessage; +import io.streamnative.pulsar.handlers.mqtt.common.messages.ack.MqttAck; +import io.streamnative.pulsar.handlers.mqtt.common.messages.ack.MqttConnectAck; +import io.streamnative.pulsar.handlers.mqtt.common.messages.ack.MqttDisconnectAck; +import io.streamnative.pulsar.handlers.mqtt.common.messages.codes.mqtt5.Mqtt5DisConnReasonCode; import io.streamnative.pulsar.handlers.mqtt.common.utils.NettyUtils; import java.util.concurrent.ConcurrentHashMap; import lombok.extern.slf4j.Slf4j; @@ -106,6 +112,50 @@ public void channelRead(ChannelHandlerContext ctx, Object message) { default: throw new UnsupportedOperationException("Unknown MessageType: " + messageType); } + } catch (IllegalStateException ex) { + ReferenceCountUtil.safeRelease(mqttMessage); + MqttMessageType mqttMessageType = mqttMessage.fixedHeader().messageType(); + log.warn("Invalid MQTT message state: {}, mqttMessageType:{}", ex.getMessage(), mqttMessageType); + + int protocolVersion = MqttVersion.MQTT_3_1.protocolLevel(); + try { + Connection existingConnection = NettyUtils.getConnection(ctx.channel()); + if (existingConnection != null) { + protocolVersion = existingConnection.getProtocolVersion(); + } + } catch (Exception e) { + } + + if (mqttMessageType == MqttMessageType.CONNECT) { + MqttConnectVariableHeader connectVariableHeader = + (MqttConnectVariableHeader) mqttMessage.variableHeader(); + protocolVersion = connectVariableHeader.version(); + + // For CONNECT message errors, send a CONNACK error response. + MqttMessage errorResponse = MqttConnectAck.errorBuilder().protocolError(protocolVersion); + MqttAdapterMessage errorAdapterMsg = new MqttAdapterMessage( + adapterMsg.getClientId(), + errorResponse, + adapterMsg.fromProxy() + ); + ctx.writeAndFlush(errorAdapterMsg); + } else { + // For other message errors, send a DISCONNECT error response if supported. + MqttAck errorAck = MqttDisconnectAck.errorBuilder(protocolVersion) + .reasonCode(Mqtt5DisConnReasonCode.MALFORMED_PACKET) + .reasonString("Invalid message format: " + ex.getMessage()) + .build(); + + if (errorAck.isProtocolSupported()) { + MqttAdapterMessage errorAdapterMsg = new MqttAdapterMessage( + adapterMsg.getClientId(), + errorAck.getMqttMessage(), + adapterMsg.fromProxy() + ); + ctx.writeAndFlush(errorAdapterMsg); + } + } + ctx.close(); } catch (Throwable ex) { ReferenceCountUtil.safeRelease(mqttMessage); log.error("Exception was caught while processing MQTT message, ", ex); diff --git a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/base/MessageConverTest.java b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/base/MessageConverTest.java index 000780610..83a0867ed 100644 --- a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/base/MessageConverTest.java +++ b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/base/MessageConverTest.java @@ -27,7 +27,7 @@ import org.apache.pulsar.broker.service.Topic; import org.apache.pulsar.client.impl.MessageImpl; import org.apache.pulsar.common.naming.TopicDomain; -import org.junit.Assert; +import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; diff --git a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/SimpleIntegrationTest.java b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/SimpleIntegrationTest.java index 9416bb919..1e6eec042 100644 --- a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/SimpleIntegrationTest.java +++ b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/SimpleIntegrationTest.java @@ -26,7 +26,6 @@ import io.streamnative.pulsar.handlers.mqtt.common.utils.PulsarTopicUtils; import io.streamnative.pulsar.handlers.mqtt.mqtt3.fusesource.psk.PSKClient; import java.io.BufferedReader; -import java.io.EOFException; import java.io.InputStream; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; @@ -50,6 +49,7 @@ import org.awaitility.Awaitility; import org.fusesource.mqtt.client.BlockingConnection; import org.fusesource.mqtt.client.MQTT; +import org.fusesource.mqtt.client.MQTTException; import org.fusesource.mqtt.client.Message; import org.fusesource.mqtt.client.QoS; import org.fusesource.mqtt.client.Topic; @@ -374,14 +374,19 @@ public void testSubscribeWithTopicFilter() throws Exception { connection2.disconnect(); } - @Test(expectedExceptions = {EOFException.class, IllegalStateException.class}) + @Test(expectedExceptions = {MQTTException.class}, + expectedExceptionsMessageRegExp = ".*CONNECTION_REFUSED_SERVER_UNAVAILABLE.*") public void testInvalidClientId() throws Exception { MQTT mqtt = createMQTTClient(); - mqtt.setConnectAttemptsMax(1); // ClientId is invalid, for max length is 23 in mqtt 3.1 mqtt.setClientId(UUID.randomUUID().toString().replace("-", "")); BlockingConnection connection = Mockito.spy(mqtt.blockingConnection()); - connection.connect(); + try { + connection.connect(); + } catch (Exception ex) { + log.info("Expected exception: {}", ex.getMessage()); + throw ex; // rethrow to verify the exception + } verify(connection, Mockito.times(2)).connect(); } diff --git a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/proxy/ProxyTest.java b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/proxy/ProxyTest.java index 454d64f5b..cdfc5f318 100644 --- a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/proxy/ProxyTest.java +++ b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/proxy/ProxyTest.java @@ -29,7 +29,6 @@ import io.streamnative.pulsar.handlers.mqtt.common.TopicFilterImpl; import io.streamnative.pulsar.handlers.mqtt.mqtt3.fusesource.psk.PSKClient; import java.io.BufferedReader; -import java.io.EOFException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.URI; @@ -62,6 +61,7 @@ import org.awaitility.Awaitility; import org.fusesource.mqtt.client.BlockingConnection; import org.fusesource.mqtt.client.MQTT; +import org.fusesource.mqtt.client.MQTTException; import org.fusesource.mqtt.client.Message; import org.fusesource.mqtt.client.QoS; import org.fusesource.mqtt.client.Topic; @@ -167,14 +167,19 @@ public void testSendAndConsume(String topicName) throws Exception { connection.disconnect(); } - @Test(expectedExceptions = {EOFException.class, IllegalStateException.class}, priority = 3) + @Test(expectedExceptions = {MQTTException.class}, + expectedExceptionsMessageRegExp = ".*CONNECTION_REFUSED_SERVER_UNAVAILABLE.*") public void testInvalidClientId() throws Exception { MQTT mqtt = createMQTTProxyClient(); - mqtt.setConnectAttemptsMax(1); // ClientId is invalid, for max length is 23 in mqtt 3.1 mqtt.setClientId(UUID.randomUUID().toString().replace("-", "")); BlockingConnection connection = Mockito.spy(mqtt.blockingConnection()); - connection.connect(); + try { + connection.connect(); + } catch (Exception ex) { + log.info("Expected exception: {}", ex.getMessage()); + throw ex; // rethrow to verify the exception + } verify(connection, Mockito.times(2)).connect(); }