Skip to content

Commit 75f32f5

Browse files
authored
Single channel handler to manage tasks (#76)
* Move timeout to MQTTTask * Combine all task handlers into one with a list of tasks * remove tasks on eventLoop * Add testMultipleTasks * swift format * Remove task from list, before completing it * Fixed docker-compose file
1 parent 483a171 commit 75f32f5

File tree

8 files changed

+158
-132
lines changed

8 files changed

+158
-132
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import Logging
2+
import NIO
3+
4+
/// Handler encoding MQTT Messages into ByteBuffers
5+
final class MQTTEncodeHandler: ChannelOutboundHandler {
6+
public typealias OutboundIn = MQTTPacket
7+
public typealias OutboundOut = ByteBuffer
8+
9+
let client: MQTTClient
10+
11+
init(client: MQTTClient) {
12+
self.client = client
13+
}
14+
15+
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
16+
let message = unwrapOutboundIn(data)
17+
self.client.logger.trace("MQTT Out", metadata: ["mqtt_message": .string("\(message)"), "mqtt_packet_id": .string("\(message.packetId)")])
18+
var bb = context.channel.allocator.buffer(capacity: 0)
19+
do {
20+
try message.write(version: self.client.configuration.version, to: &bb)
21+
context.write(wrapOutboundOut(bb), promise: promise)
22+
} catch {
23+
promise?.fail(error)
24+
}
25+
}
26+
}

Sources/MQTTNIO/MQTTChannelHandlers.swift renamed to Sources/MQTTNIO/ChannelHandlers/MQTTMessageDecoder.swift

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,6 @@
11
import Logging
22
import NIO
33

4-
/// Handler encoding MQTT Messages into ByteBuffers
5-
final class MQTTEncodeHandler: ChannelOutboundHandler {
6-
public typealias OutboundIn = MQTTPacket
7-
public typealias OutboundOut = ByteBuffer
8-
9-
let client: MQTTClient
10-
11-
init(client: MQTTClient) {
12-
self.client = client
13-
}
14-
15-
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
16-
let message = unwrapOutboundIn(data)
17-
self.client.logger.trace("MQTT Out", metadata: ["mqtt_message": .string("\(message)"), "mqtt_packet_id": .string("\(message.packetId)")])
18-
var bb = context.channel.allocator.buffer(capacity: 0)
19-
do {
20-
try message.write(version: self.client.configuration.version, to: &bb)
21-
context.write(wrapOutboundOut(bb), promise: promise)
22-
} catch {
23-
promise?.fail(error)
24-
}
25-
}
26-
}
27-
284
/// Decode ByteBuffers into MQTT Messages
295
struct ByteToMQTTMessageDecoder: ByteToMessageDecoder {
306
typealias InboundOut = MQTTPacket
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import NIO
2+
3+
final class MQTTTaskHandler: ChannelInboundHandler, RemovableChannelHandler {
4+
typealias InboundIn = MQTTPacket
5+
6+
var eventLoop: EventLoop!
7+
8+
init() {
9+
self.eventLoop = nil
10+
self.tasks = []
11+
}
12+
13+
func addTask(_ task: MQTTTask) -> EventLoopFuture<Void> {
14+
return self.eventLoop.submit {
15+
self.tasks.append(task)
16+
}
17+
}
18+
19+
func _removeTask(_ task: MQTTTask) {
20+
self.tasks.removeAll { $0 === task }
21+
}
22+
23+
func removeTask(_ task: MQTTTask) {
24+
if self.eventLoop.inEventLoop {
25+
self._removeTask(task)
26+
} else {
27+
self.eventLoop.execute {
28+
self._removeTask(task)
29+
}
30+
}
31+
}
32+
33+
func handlerAdded(context: ChannelHandlerContext) {
34+
self.eventLoop = context.eventLoop
35+
}
36+
37+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
38+
let response = self.unwrapInboundIn(data)
39+
for task in self.tasks {
40+
do {
41+
if try task.checkInbound(response) {
42+
self.removeTask(task)
43+
task.succeed(response)
44+
return
45+
}
46+
} catch {
47+
self.removeTask(task)
48+
task.fail(error)
49+
return
50+
}
51+
}
52+
}
53+
54+
func channelInactive(context: ChannelHandlerContext) {
55+
self.tasks.forEach { $0.fail(MQTTError.serverClosedConnection) }
56+
self.tasks.removeAll()
57+
}
58+
59+
func errorCaught(context: ChannelHandlerContext, error: Error) {
60+
self.tasks.forEach { $0.fail(error) }
61+
self.tasks.removeAll()
62+
}
63+
64+
var tasks: [MQTTTask]
65+
}
66+
67+
/// If packet reaches this handler then it was never dealt with by a task
68+
final class MQTTUnhandledPacketHandler: ChannelInboundHandler {
69+
typealias InboundIn = MQTTPacket
70+
let client: MQTTClient
71+
72+
init(client: MQTTClient) {
73+
self.client = client
74+
}
75+
76+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
77+
// we only send response to v5 server
78+
guard self.client.configuration.version == .v5_0 else { return }
79+
guard let connection = client.connection else { return }
80+
let response = self.unwrapInboundIn(data)
81+
switch response.type {
82+
case .PUBREC:
83+
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBREL, packetId: response.packetId, reason: .packetIdentifierNotFound))
84+
case .PUBREL:
85+
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBCOMP, packetId: response.packetId, reason: .packetIdentifierNotFound))
86+
default:
87+
break
88+
}
89+
}
90+
}

Sources/MQTTNIO/MQTTConnection.swift

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@ import NIOWebSocket
1313
final class MQTTConnection {
1414
let channel: Channel
1515
let timeout: TimeAmount?
16-
let unhandledHandler: MQTTUnhandledPacketHandler
16+
let taskHandler: MQTTTaskHandler
1717

18-
private init(channel: Channel, timeout: TimeAmount?, unhandledHandler: MQTTUnhandledPacketHandler) {
18+
private init(channel: Channel, timeout: TimeAmount?, taskHandler: MQTTTaskHandler) {
1919
self.channel = channel
2020
self.timeout = timeout
21-
self.unhandledHandler = unhandledHandler
21+
self.taskHandler = taskHandler
2222
}
2323

2424
static func create(client: MQTTClient, pingInterval: TimeAmount) -> EventLoopFuture<MQTTConnection> {
25-
let unhandledHandler = MQTTUnhandledPacketHandler(client: client)
26-
return self.createBootstrap(client: client, pingInterval: pingInterval, unhandledHandler: unhandledHandler)
27-
.map { MQTTConnection(channel: $0, timeout: client.configuration.timeout, unhandledHandler: unhandledHandler) }
25+
let taskHandler = MQTTTaskHandler()
26+
return self.createBootstrap(client: client, pingInterval: pingInterval, taskHandler: taskHandler)
27+
.map { MQTTConnection(channel: $0, timeout: client.configuration.timeout, taskHandler: taskHandler) }
2828
}
2929

30-
static func createBootstrap(client: MQTTClient, pingInterval: TimeAmount, unhandledHandler: MQTTUnhandledPacketHandler) -> EventLoopFuture<Channel> {
30+
static func createBootstrap(client: MQTTClient, pingInterval: TimeAmount, taskHandler: MQTTTaskHandler) -> EventLoopFuture<Channel> {
3131
let eventLoop = client.eventLoopGroup.next()
3232
let channelPromise = eventLoop.makePromise(of: Channel.self)
3333
do {
@@ -41,7 +41,7 @@ final class MQTTConnection {
4141
// Work out what handlers to add
4242
var handlers: [ChannelHandler] = [
4343
ByteToMessageHandler(ByteToMQTTMessageDecoder(client: client)),
44-
unhandledHandler,
44+
taskHandler,
4545
MQTTEncodeHandler(client: client),
4646
]
4747
if !client.configuration.disablePing {
@@ -179,9 +179,8 @@ final class MQTTConnection {
179179

180180
func sendMessage(_ message: MQTTPacket, checkInbound: @escaping (MQTTPacket) throws -> Bool) -> EventLoopFuture<MQTTPacket> {
181181
let task = MQTTTask(on: channel.eventLoop, timeout: self.timeout, checkInbound: checkInbound)
182-
let taskHandler = MQTTTaskHandler(task: task, channel: channel)
183182

184-
self.channel.pipeline.addHandler(taskHandler, position: .before(self.unhandledHandler))
183+
self.taskHandler.addTask(task)
185184
.flatMap {
186185
self.channel.writeAndFlush(message)
187186
}

Sources/MQTTNIO/MQTTTask.swift

Lines changed: 13 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,34 @@
11

22
import NIO
33

4+
/// Class encapsulating a single task
45
final class MQTTTask {
56
let promise: EventLoopPromise<MQTTPacket>
67
let checkInbound: (MQTTPacket) throws -> Bool
78
let timeout: TimeAmount?
9+
let timeoutTask: Scheduled<Void>?
810

911
init(on eventLoop: EventLoop, timeout: TimeAmount?, checkInbound: @escaping (MQTTPacket) throws -> Bool) {
10-
self.promise = eventLoop.makePromise(of: MQTTPacket.self)
12+
let promise = eventLoop.makePromise(of: MQTTPacket.self)
13+
self.promise = promise
1114
self.checkInbound = checkInbound
1215
self.timeout = timeout
13-
}
14-
15-
func succeed(_ response: MQTTPacket) {
16-
self.promise.succeed(response)
17-
}
18-
19-
func fail(_ error: Error) {
20-
self.promise.fail(error)
21-
}
22-
}
23-
24-
final class MQTTTaskHandler: ChannelInboundHandler, RemovableChannelHandler {
25-
typealias InboundIn = MQTTPacket
26-
27-
let task: MQTTTask
28-
let channel: Channel
29-
var timeoutTask: Scheduled<Void>?
30-
31-
init(task: MQTTTask, channel: Channel) {
32-
self.task = task
33-
self.channel = channel
34-
self.timeoutTask = nil
35-
}
36-
37-
public func handlerAdded(context: ChannelHandlerContext) {
38-
self.addTimeoutTask()
39-
}
40-
41-
public func handlerRemoved(context: ChannelHandlerContext) {
42-
self.timeoutTask?.cancel()
43-
}
44-
45-
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
46-
let response = self.unwrapInboundIn(data)
47-
do {
48-
if try self.task.checkInbound(response) {
49-
self.channel.pipeline.removeHandler(self).whenSuccess { _ in
50-
self.timeoutTask?.cancel()
51-
self.task.succeed(response)
52-
}
53-
} else {
54-
context.fireChannelRead(data)
55-
}
56-
} catch {
57-
self.errorCaught(context: context, error: error)
58-
}
59-
}
60-
61-
func channelInactive(context: ChannelHandlerContext) {
62-
self.task.fail(MQTTError.serverClosedConnection)
63-
}
64-
65-
func errorCaught(context: ChannelHandlerContext, error: Error) {
66-
self.timeoutTask?.cancel()
67-
self.channel.pipeline.removeHandler(self).whenSuccess { _ in
68-
self.task.fail(error)
69-
}
70-
}
71-
72-
func addTimeoutTask() {
73-
if let timeout = task.timeout {
74-
self.timeoutTask = self.channel.eventLoop.scheduleTask(in: timeout) {
75-
self.channel.pipeline.removeHandler(self).whenSuccess { _ in
76-
self.task.fail(MQTTError.timeout)
77-
}
16+
if let timeout = timeout {
17+
self.timeoutTask = eventLoop.scheduleTask(in: timeout) {
18+
promise.fail(MQTTError.timeout)
7819
}
7920
} else {
8021
self.timeoutTask = nil
8122
}
8223
}
83-
}
84-
85-
/// If packet reaches this handler then it was never dealt with by a task
86-
final class MQTTUnhandledPacketHandler: ChannelInboundHandler {
87-
typealias InboundIn = MQTTPacket
88-
let client: MQTTClient
8924

90-
init(client: MQTTClient) {
91-
self.client = client
25+
func succeed(_ response: MQTTPacket) {
26+
self.timeoutTask?.cancel()
27+
self.promise.succeed(response)
9228
}
9329

94-
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
95-
// we only send response to v5 server
96-
guard self.client.configuration.version == .v5_0 else { return }
97-
guard let connection = client.connection else { return }
98-
let response = self.unwrapInboundIn(data)
99-
switch response.type {
100-
case .PUBREC:
101-
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBREL, packetId: response.packetId, reason: .packetIdentifierNotFound))
102-
case .PUBREL:
103-
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBCOMP, packetId: response.packetId, reason: .packetIdentifierNotFound))
104-
default:
105-
break
106-
}
30+
func fail(_ error: Error) {
31+
self.timeoutTask?.cancel()
32+
self.promise.fail(error)
10733
}
10834
}

Tests/MQTTNIOTests/MQTTNIOTests+async.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ final class AsyncMQTTNIOTests: XCTestCase {
114114

115115
let client = self.createClient(identifier: "testAsyncSequencePublishListener+async", version: .v5_0)
116116
let client2 = self.createClient(identifier: "testAsyncSequencePublishListener+async2", version: .v5_0)
117-
let payloadString = "Hello"
118117

119118
self.XCTRunAsyncAndBlock {
120119
try await client.connect()
@@ -127,16 +126,17 @@ final class AsyncMQTTNIOTests: XCTestCase {
127126
case .success(let publish):
128127
var buffer = publish.payload
129128
let string = buffer.readString(length: buffer.readableBytes)
130-
XCTAssertEqual(string, payloadString)
129+
print("Received: \(string ?? "nothing")")
131130
expectation.fulfill()
131+
132132
case .failure(let error):
133133
XCTFail("\(error)")
134134
}
135135
}
136136
finishExpectation.fulfill()
137137
}
138-
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: payloadString), qos: .atLeastOnce)
139-
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: payloadString), qos: .atLeastOnce)
138+
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: "Hello"), qos: .atLeastOnce)
139+
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: "Goodbye"), qos: .atLeastOnce)
140140
try await client.disconnect()
141141

142142
self.wait(for: [expectation], timeout: 5.0)

Tests/MQTTNIOTests/MQTTNIOTests.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,17 @@ final class MQTTNIOTests: XCTestCase {
8989
try client.disconnect().wait()
9090
}
9191

92+
func testMultipleTasks() throws {
93+
let client = self.createClient(identifier: "testMultipleTasks")
94+
defer { XCTAssertNoThrow(try client.syncShutdownGracefully()) }
95+
_ = try client.connect().wait()
96+
let publishFutures = (0..<16).map { client.publish(to: "test/multiple", payload: ByteBuffer(integer: $0), qos: .exactlyOnce) }
97+
_ = client.ping()
98+
try EventLoopFuture.andAllComplete(publishFutures, on: client.eventLoopGroup.next()).wait()
99+
XCTAssertEqual(client.connection?.taskHandler.tasks.count, 0)
100+
try client.disconnect().wait()
101+
}
102+
92103
func testMQTTSubscribe() throws {
93104
let client = self.createClient(identifier: "testMQTTSubscribe")
94105
defer { XCTAssertNoThrow(try client.syncShutdownGracefully()) }
@@ -385,11 +396,12 @@ final class MQTTNIOTests: XCTestCase {
385396
eventLoopGroupProvider: .createNew,
386397
logger: self.logger
387398
)
399+
defer { XCTAssertNoThrow(try client.syncShutdownGracefully()) }
400+
388401
_ = try client.connect().wait()
389402
_ = try client.subscribe(to: [.init(topicFilter: "#", qos: .exactlyOnce)]).wait()
390403
Thread.sleep(forTimeInterval: 5)
391404
try client.disconnect().wait()
392-
try client.syncShutdownGracefully()
393405
}
394406

395407
func testRawIPConnect() throws {

0 commit comments

Comments
 (0)