Skip to content

Commit 05fbcc5

Browse files
authored
Add WebSocket configuration struct (#139)
* Add WebSocket configuration struct Add additionalHeaders parameter to this struct * Sec-WebSocket-Protocol should be mqtt * Update documentation/tests * additionalHeaders -> initialRequestHeaders * Change Configuration.init webSocketConfiguration signature * Add test for WebSocketInitialRequestHandler
1 parent 85fb8f6 commit 05fbcc5

File tree

7 files changed

+169
-33
lines changed

7 files changed

+169
-33
lines changed

Sources/MQTTNIO/ChannelHandlers/WebSocketInitialRequest.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,23 @@ final class WebSocketInitialRequestHandler: ChannelInboundHandler, RemovableChan
2323

2424
let host: String
2525
let urlPath: String
26+
let additionalHeaders: HTTPHeaders
2627
let upgradePromise: EventLoopPromise<Void>
2728

28-
init(host: String, urlPath: String, upgradePromise: EventLoopPromise<Void>) {
29+
init(host: String, urlPath: String, additionalHeaders: HTTPHeaders, upgradePromise: EventLoopPromise<Void>) {
2930
self.host = host
30-
self.upgradePromise = upgradePromise
3131
self.urlPath = urlPath
32+
self.additionalHeaders = additionalHeaders
33+
self.upgradePromise = upgradePromise
3234
}
3335

3436
public func channelActive(context: ChannelHandlerContext) {
3537
// We are connected. It's time to send the message to the server to initialize the upgrade dance.
3638
var headers = HTTPHeaders()
3739
headers.add(name: "Content-Length", value: "0")
3840
headers.add(name: "host", value: self.host)
39-
headers.add(name: "Sec-WebSocket-Protocol", value: "mqttv3.1")
41+
headers.add(name: "Sec-WebSocket-Protocol", value: "mqtt")
42+
headers.add(contentsOf: self.additionalHeaders)
4043

4144
let requestHead = HTTPRequestHead(
4245
version: HTTPVersion(major: 1, minor: 1),

Sources/MQTTNIO/MQTTClient.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ public final class MQTTClient {
5757
var connection: MQTTConnection? {
5858
get {
5959
self.lock.withLock {
60-
_connection
60+
self._connection
6161
}
6262
}
6363
set {
6464
self.lock.withLock {
65-
_connection = newValue
65+
self._connection = newValue
6666
}
6767
}
6868
}
@@ -480,7 +480,7 @@ internal extension MQTTClient {
480480
func resendOnRestart() {
481481
let inflight = self.inflight.packets
482482
self.inflight.clear()
483-
inflight.forEach { packet -> Void in
483+
inflight.forEach { packet in
484484
switch packet {
485485
case let publish as MQTTPublishPacket:
486486
let newPacket = MQTTPublishPacket(

Sources/MQTTNIO/MQTTConfiguration.swift

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
import NIO
15+
import NIOHTTP1
1516
#if canImport(NIOSSL)
1617
import NIOSSL
1718
#endif
@@ -40,8 +41,74 @@ extension MQTTClient {
4041
#endif
4142
}
4243

44+
public struct WebSocketConfiguration {
45+
/// Initialize MQTTClient WebSocket configuration struct
46+
/// - Parameters:
47+
/// - urlPath: WebSocket URL, defaults to "/mqtt"
48+
/// - maxFrameSize: Max frame size WebSocket client will allow
49+
/// - initialRequestHeaders: Additional headers to add to initial HTTP request
50+
public init(
51+
urlPath: String,
52+
maxFrameSize: Int = 1 << 14,
53+
initialRequestHeaders: HTTPHeaders = [:]
54+
) {
55+
self.urlPath = urlPath
56+
self.maxFrameSize = maxFrameSize
57+
self.initialRequestHeaders = initialRequestHeaders
58+
}
59+
60+
/// WebSocket URL, defaults to "/mqtt"
61+
public let urlPath: String
62+
/// Max frame size WebSocket client will allow
63+
public let maxFrameSize: Int
64+
/// Additional headers to add to initial HTTP request
65+
public let initialRequestHeaders: HTTPHeaders
66+
}
67+
4368
/// Configuration for MQTTClient
4469
public struct Configuration {
70+
/// Initialize MQTTClient configuration struct
71+
/// - Parameters:
72+
/// - version: Version of MQTT server client is connecting to
73+
/// - disablePing: Disable the automatic sending of pingreq messages
74+
/// - keepAliveInterval: MQTT keep alive period.
75+
/// - pingInterval: Override calculated interval between each pingreq message
76+
/// - connectTimeout: Timeout for connecting to server
77+
/// - timeout: Timeout for server ACK responses
78+
/// - userName: MQTT user name
79+
/// - password: MQTT password
80+
/// - useSSL: Use encrypted connection to server
81+
/// - tlsConfiguration: TLS configuration, for SSL connection
82+
/// - sniServerName: Server name used by TLS. This will default to host name if not set
83+
/// - webSocketConfiguration: Set this if you want to use WebSockets
84+
public init(
85+
version: Version = .v3_1_1,
86+
disablePing: Bool = false,
87+
keepAliveInterval: TimeAmount = .seconds(90),
88+
pingInterval: TimeAmount? = nil,
89+
connectTimeout: TimeAmount = .seconds(10),
90+
timeout: TimeAmount? = nil,
91+
userName: String? = nil,
92+
password: String? = nil,
93+
useSSL: Bool = false,
94+
tlsConfiguration: TLSConfigurationType? = nil,
95+
sniServerName: String? = nil,
96+
webSocketConfiguration: WebSocketConfiguration
97+
) {
98+
self.version = version
99+
self.disablePing = disablePing
100+
self.keepAliveInterval = keepAliveInterval
101+
self.pingInterval = pingInterval
102+
self.connectTimeout = connectTimeout
103+
self.timeout = timeout
104+
self.userName = userName
105+
self.password = password
106+
self.useSSL = useSSL
107+
self.tlsConfiguration = tlsConfiguration
108+
self.sniServerName = sniServerName
109+
self.webSocketConfiguration = webSocketConfiguration
110+
}
111+
45112
/// Initialize MQTTClient configuration struct
46113
/// - Parameters:
47114
/// - version: Version of MQTT server client is connecting to
@@ -83,11 +150,13 @@ extension MQTTClient {
83150
self.userName = userName
84151
self.password = password
85152
self.useSSL = useSSL
86-
self.useWebSockets = useWebSockets
87153
self.tlsConfiguration = tlsConfiguration
88154
self.sniServerName = sniServerName
89-
self.webSocketURLPath = webSocketURLPath
90-
self.webSocketMaxFrameSize = webSocketMaxFrameSize
155+
if useWebSockets {
156+
self.webSocketConfiguration = .init(urlPath: webSocketURLPath ?? "/mqtt", maxFrameSize: webSocketMaxFrameSize)
157+
} else {
158+
self.webSocketConfiguration = nil
159+
}
91160
}
92161

93162
/// Initialize MQTTClient configuration struct
@@ -107,7 +176,6 @@ extension MQTTClient {
107176
/// - sniServerName: Server name used by TLS. This will default to host name if not set
108177
/// - webSocketURLPath: URL Path for web socket. Defaults to "/mqtt"
109178
/// - webSocketMaxFrameSize: Maximum frame size for a web socket connection
110-
///
111179
@available(*, deprecated, message: "maxRetryAttempts is no longer used")
112180
public init(
113181
version: Version = .v3_1_1,
@@ -135,11 +203,28 @@ extension MQTTClient {
135203
self.userName = userName
136204
self.password = password
137205
self.useSSL = useSSL
138-
self.useWebSockets = useWebSockets
139206
self.tlsConfiguration = tlsConfiguration
140207
self.sniServerName = sniServerName
141-
self.webSocketURLPath = webSocketURLPath
142-
self.webSocketMaxFrameSize = webSocketMaxFrameSize
208+
if useWebSockets {
209+
self.webSocketConfiguration = .init(urlPath: webSocketURLPath ?? "/mqtt", maxFrameSize: webSocketMaxFrameSize)
210+
} else {
211+
self.webSocketConfiguration = nil
212+
}
213+
}
214+
215+
/// use a websocket connection to server
216+
public var useWebSockets: Bool {
217+
self.webSocketConfiguration != nil
218+
}
219+
220+
/// URL Path for web socket. Defaults to "/mqtt"
221+
public var webSocketURLPath: String? {
222+
self.webSocketConfiguration?.urlPath
223+
}
224+
225+
/// Maximum frame size for a web socket connection
226+
public var webSocketMaxFrameSize: Int {
227+
self.webSocketConfiguration?.maxFrameSize ?? 1 << 14
143228
}
144229

145230
/// Version of MQTT server client is connecting to
@@ -160,15 +245,11 @@ extension MQTTClient {
160245
public let password: String?
161246
/// use encrypted connection to server
162247
public let useSSL: Bool
163-
/// use a websocket connection to server
164-
public let useWebSockets: Bool
165248
/// TLS configuration
166249
public let tlsConfiguration: TLSConfigurationType?
167250
/// server name used by TLS
168251
public let sniServerName: String?
169-
/// URL Path for web socket. Defaults to "/mqtt"
170-
public let webSocketURLPath: String?
171-
/// Maximum frame size for a web socket connection
172-
public let webSocketMaxFrameSize: Int
252+
/// WebSocket configuration
253+
public let webSocketConfiguration: WebSocketConfiguration?
173254
}
174255
}

Sources/MQTTNIO/MQTTConnection.swift

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,18 @@ final class MQTTConnection {
5757
taskHandler,
5858
]
5959
// are we using websockets
60-
if client.configuration.useWebSockets {
60+
if let webSocketConfiguration = client.configuration.webSocketConfiguration {
6161
// prepare for websockets and on upgrade add handlers
6262
let promise = eventLoop.makePromise(of: Void.self)
6363
promise.futureResult.map { _ in channel }
6464
.cascade(to: channelPromise)
6565

66-
return Self.setupChannelForWebsockets(client: client, channel: channel, upgradePromise: promise) {
66+
return Self.setupChannelForWebsockets(
67+
client: client,
68+
channel: channel,
69+
webSocketConfiguration: webSocketConfiguration,
70+
upgradePromise: promise
71+
) {
6772
return channel.pipeline.addHandlers(handlers)
6873
}
6974
} else {
@@ -140,13 +145,15 @@ final class MQTTConnection {
140145
static func setupChannelForWebsockets(
141146
client: MQTTClient,
142147
channel: Channel,
148+
webSocketConfiguration: MQTTClient.WebSocketConfiguration,
143149
upgradePromise promise: EventLoopPromise<Void>,
144150
afterHandlerAdded: @escaping () -> EventLoopFuture<Void>
145151
) -> EventLoopFuture<Void> {
146152
// initial HTTP request handler, before upgrade
147153
let httpHandler = WebSocketInitialRequestHandler(
148154
host: client.configuration.sniServerName ?? client.hostHeader,
149-
urlPath: client.configuration.webSocketURLPath ?? "/mqtt",
155+
urlPath: webSocketConfiguration.urlPath,
156+
additionalHeaders: webSocketConfiguration.initialRequestHeaders,
150157
upgradePromise: promise
151158
)
152159
// create random key for request key

Sources/MQTTNIO/MQTTNIO.docc/mqttnio-aws.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let client = MQTTClient(
2727
host: host,
2828
identifier: "MyAWSClient",
2929
eventLoopGroupProvider: .createNew,
30-
configuration: .init(useSSL: true, useWebSockets: true, webSocketURLPath: requestUri)
30+
configuration: .init(useSSL: true, webSocketConfiguration: .init(urlPath: requestUri))
3131
)
3232
```
3333
You can find out more about connecting to AWS brokers [here](https://docs.aws.amazon.com/iot/latest/developerguide/protocols.html)

Sources/MQTTNIO/MQTTNIO.docc/mqttnio-connections.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ let client = MQTTClient(
2626

2727
## WebSockets
2828

29-
MQTT also supports Web Socket connections. Set `Configuration.useWebSockets` to `true` and set the URL path in `Configuration.webSocketsURLPath` to enable these.
29+
MQTT also supports Web Socket connections. Provide a `WebSocketConfiguration` when initializing `MQTTClient.Configuration` to enable this.
3030

3131
## NIO Transport Services
3232

Tests/MQTTNIOTests/MQTTNIOTests.swift

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -590,11 +590,52 @@ final class MQTTNIOTests: XCTestCase {
590590
try client.disconnect().wait()
591591
}
592592

593+
func testWebSocketInitialRequest() throws {
594+
let el = EmbeddedEventLoop()
595+
defer { XCTAssertNoThrow(try el.syncShutdownGracefully()) }
596+
let promise = el.makePromise(of: Void.self)
597+
let initialRequestHandler = WebSocketInitialRequestHandler(
598+
host: "test.mosquitto.org",
599+
urlPath: "/mqtt",
600+
additionalHeaders: ["Test": "Value"],
601+
upgradePromise: promise
602+
)
603+
let channel = EmbeddedChannel(handler: initialRequestHandler, loop: el)
604+
try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait()
605+
let requestHead = try channel.readOutbound(as: HTTPClientRequestPart.self)
606+
let requestBody = try channel.readOutbound(as: HTTPClientRequestPart.self)
607+
let requestEnd = try channel.readOutbound(as: HTTPClientRequestPart.self)
608+
switch requestHead {
609+
case .head(let head):
610+
XCTAssertEqual(head.uri, "/mqtt")
611+
XCTAssertEqual(head.headers["host"].first, "test.mosquitto.org")
612+
XCTAssertEqual(head.headers["Sec-WebSocket-Protocol"].first, "mqtt")
613+
XCTAssertEqual(head.headers["Test"].first, "Value")
614+
default:
615+
XCTFail("Did not expect \(String(describing: requestHead))")
616+
}
617+
switch requestBody {
618+
case .body(let data):
619+
XCTAssertEqual(data, .byteBuffer(ByteBuffer()))
620+
default:
621+
XCTFail("Did not expect \(String(describing: requestBody))")
622+
}
623+
switch requestEnd {
624+
case .end(nil):
625+
break
626+
default:
627+
XCTFail("Did not expect \(String(describing: requestEnd))")
628+
}
629+
_ = try channel.finish()
630+
promise.succeed()
631+
}
632+
593633
// MARK: Helper variables and functions
594634

595635
func createClient(identifier: String, configuration: MQTTClient.Configuration = .init()) -> MQTTClient {
596636
MQTTClient(
597637
host: Self.hostname,
638+
port: 1883,
598639
identifier: identifier,
599640
eventLoopGroupProvider: .createNew,
600641
logger: self.logger,
@@ -609,7 +650,7 @@ final class MQTTNIOTests: XCTestCase {
609650
identifier: identifier,
610651
eventLoopGroupProvider: .createNew,
611652
logger: self.logger,
612-
configuration: .init(useWebSockets: true, webSocketURLPath: "/mqtt")
653+
configuration: .init(webSocketConfiguration: .init(urlPath: "/mqtt"))
613654
)
614655
}
615656

@@ -630,7 +671,13 @@ final class MQTTNIOTests: XCTestCase {
630671
identifier: identifier,
631672
eventLoopGroupProvider: .createNew,
632673
logger: self.logger,
633-
configuration: .init(timeout: .seconds(5), useSSL: true, useWebSockets: true, tlsConfiguration: Self.getTLSConfiguration(), sniServerName: "soto.codes", webSocketURLPath: "/mqtt")
674+
configuration: .init(
675+
timeout: .seconds(5),
676+
useSSL: true,
677+
tlsConfiguration: Self.getTLSConfiguration(),
678+
sniServerName: "soto.codes",
679+
webSocketConfiguration: .init(urlPath: "/mqtt")
680+
)
634681
)
635682
}
636683

@@ -640,13 +687,11 @@ final class MQTTNIOTests: XCTestCase {
640687
return logger
641688
}()
642689

643-
static var rootPath: String = {
644-
return #file
645-
.split(separator: "/", omittingEmptySubsequences: false)
646-
.dropLast(3)
647-
.map { String(describing: $0) }
648-
.joined(separator: "/")
649-
}()
690+
static var rootPath: String = #file
691+
.split(separator: "/", omittingEmptySubsequences: false)
692+
.dropLast(3)
693+
.map { String(describing: $0) }
694+
.joined(separator: "/")
650695

651696
static var _tlsConfiguration: Result<MQTTClient.TLSConfigurationType, Error> = {
652697
do {

0 commit comments

Comments
 (0)