Skip to content

Commit 6eb7479

Browse files
authored
Merge pull request #341 from bittcrafter/dev/0.18.0
fix: improve MQTT protocol validation and WebSocket subprotocol support #340
2 parents a9748e3 + a90d040 commit 6eb7479

File tree

5 files changed

+40
-17
lines changed

5 files changed

+40
-17
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ rust-version = "1.85.0"
6464

6565
[workspace.dependencies]
6666
rmqtt = "0.18.0"
67-
rmqtt-codec = "0.2"
68-
rmqtt-net = "0.3"
67+
rmqtt-codec = "0.2.1"
68+
rmqtt-net = "0.3.2"
6969
rmqtt-conf = "0.3"
7070
rmqtt-macros = "0.1"
7171
rmqtt-utils = "0.1"

rmqtt-codec/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "rmqtt-codec"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "MQTT protocol codec implementation with multi-version support and version negotiation"
55
repository = "https://github.com/rmqtt/rmqtt/tree/master/rmqtt-codec"
66
edition.workspace = true

rmqtt-codec/src/version.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,22 @@ impl Decoder for VersionCodec {
5555
);
5656

5757
// Validate protocol name matches MQTT spec
58-
ensure!(
59-
(protocol_len == 4 && &src[consumed + 2..consumed + 6] == MQTT)
60-
|| (protocol_len == 6 && &src[consumed + 2..consumed + 8] == MQISDP),
61-
DecodeError::InvalidProtocol
62-
);
58+
if protocol_len == 4 {
59+
//for mqtt 3.1.1 or 5.0
60+
if &src[consumed + 2..consumed + 6] != MQTT {
61+
return Err(DecodeError::InvalidProtocol);
62+
}
63+
} else if protocol_len == 6 {
64+
//for mqtt 3.1
65+
if len <= consumed + 8 {
66+
return Ok(None);
67+
}
68+
if &src[consumed + 2..consumed + 8] != MQISDP {
69+
return Err(DecodeError::InvalidProtocol);
70+
}
71+
} else {
72+
return Err(DecodeError::InvalidProtocol);
73+
}
6374

6475
// Extract protocol level byte (position after protocol name)
6576
match src[consumed + 2 + protocol_len as usize] {

rmqtt-net/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "rmqtt-net"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
description = "Basic Implementation of MQTT Server"
55
repository = "https://github.com/rmqtt/rmqtt/tree/master/rmqtt-net"
66
edition.workspace = true

rmqtt-net/src/builder.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -904,14 +904,26 @@ fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Re
904904
let mqtt_protocol = req
905905
.headers()
906906
.get("Sec-WebSocket-Protocol")
907-
.ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
908-
if mqtt_protocol != "mqtt" {
909-
return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
910-
}
911-
response.headers_mut().append(
912-
"Sec-WebSocket-Protocol",
913-
"mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
914-
);
907+
.ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?
908+
.to_str()
909+
.map_err(|_| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
910+
match mqtt_protocol {
911+
"mqtt" => {
912+
response.headers_mut().append(
913+
"Sec-WebSocket-Protocol",
914+
"mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
915+
);
916+
}
917+
"mqttv3.1" => {
918+
response.headers_mut().append(
919+
"Sec-WebSocket-Protocol",
920+
"mqttv3.1".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
921+
);
922+
}
923+
_ => {
924+
return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
925+
}
926+
}
915927
Ok(response)
916928
}
917929

0 commit comments

Comments
 (0)