Skip to content

Commit fe5d1b5

Browse files
authored
Async shutdown (#73)
1 parent 5380dfc commit fe5d1b5

File tree

9 files changed

+124
-26
lines changed

9 files changed

+124
-26
lines changed

Sources/MQTTNIO/AsyncAwaitSupport/MQTTClient+async.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,18 @@ import NIOCore
55

66
@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *)
77
extension MQTTClient {
8+
public func shutdown(queue: DispatchQueue = .global()) async throws {
9+
return try await withUnsafeThrowingContinuation { cont in
10+
shutdown(queue: queue) { error in
11+
if let error = error {
12+
cont.resume(throwing: error)
13+
} else {
14+
cont.resume()
15+
}
16+
}
17+
}
18+
}
19+
820
/// Connect to MQTT server
921
///
1022
/// Completes when CONNACK is received
@@ -111,7 +123,7 @@ public class MQTTPublishListener: AsyncSequence {
111123
}
112124

113125
public __consuming func makeAsyncIterator() -> AsyncStream<Element>.AsyncIterator {
114-
return stream.makeAsyncIterator()
126+
return self.stream.makeAsyncIterator()
115127
}
116128
}
117129

Sources/MQTTNIO/AsyncAwaitSupport/MQTTClientV5+async.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,8 @@ public class MQTTPublishIdListener: AsyncSequence {
130130
}
131131

132132
public __consuming func makeAsyncIterator() -> AsyncStream<Element>.AsyncIterator {
133-
return stream.makeAsyncIterator()
133+
return self.stream.makeAsyncIterator()
134134
}
135135
}
136136

137-
138137
#endif // compiler(>=5.5)

Sources/MQTTNIO/MQTTClient.swift

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ public final class MQTTClient {
6565
private static let loggingDisabled = Logger(label: "MQTT-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() })
6666
/// inflight messages
6767
private var inflight: MQTTInflight
68+
/// flag to tell is client is shutdown
69+
private let isShutdown = NIOAtomic<Bool>.makeAtomic(value: false)
6870

6971
/// Create MQTT client
7072
/// - Parameters:
@@ -121,15 +123,88 @@ public final class MQTTClient {
121123
self.inflight = .init()
122124
}
123125

124-
/// Close down client. Must be called before the client is destroyed
126+
deinit {
127+
guard isShutdown.load() else {
128+
preconditionFailure("Client not shut down before the deinit. Please call client.syncShutdown() when no longer needed.")
129+
}
130+
}
131+
132+
/// Shutdown client synchronously. Before an `MQTTClient` is deleted you need to call this function or the async version `shutdown`
133+
/// to do a clean shutdown of the client. It closes the connection, notifies everything listening for shutdown and shuts down the
134+
/// EventLoopGroup if the client created it
135+
///
136+
/// - Throws: MQTTError.alreadyShutdown: You have already shutdown the client
125137
public func syncShutdownGracefully() throws {
126-
try self.connection?.close().wait()
138+
if let eventLoop = MultiThreadedEventLoopGroup.currentEventLoop {
139+
preconditionFailure("""
140+
BUG DETECTED: syncShutdown() must not be called when on an EventLoop.
141+
Calling syncShutdown() on any EventLoop can lead to deadlocks.
142+
Current eventLoop: \(eventLoop)
143+
""")
144+
}
145+
let errorStorageLock = Lock()
146+
var errorStorage: Error?
147+
let continuation = DispatchWorkItem {}
148+
self.shutdown(queue: DispatchQueue(label: "mqtt-client.shutdown")) { error in
149+
if let error = error {
150+
errorStorageLock.withLock {
151+
errorStorage = error
152+
}
153+
}
154+
continuation.perform()
155+
}
156+
continuation.wait()
157+
try errorStorageLock.withLock {
158+
if let error = errorStorage {
159+
throw error
160+
}
161+
}
162+
}
163+
164+
/// Shutdown MQTTClient asynchronously. Before an `AWSClient` is deleted you need to call this function or the synchronous
165+
/// version `syncShutdownGracefully` to do a clean shutdown of the client. It closes the connection, notifies everything
166+
/// listening for shutdown and shuts down the EventLoopGroup if the client created it
167+
///
168+
/// - Parameters:
169+
/// - queue: Dispatch Queue to run shutdown on
170+
/// - callback: Callback called when shutdown is complete. If there was an error it will return with Error in callback
171+
public func shutdown(queue: DispatchQueue = .global(), _ callback: @escaping (Error?) -> Void) {
172+
guard self.isShutdown.compareAndExchange(expected: false, desired: true) else {
173+
callback(MQTTError.alreadyShutdown)
174+
return
175+
}
176+
let eventLoop = self.eventLoopGroup.next()
177+
let closeFuture: EventLoopFuture<Void>
127178
self.shutdownListeners.notify(.success(()))
179+
if let connection = self.connection {
180+
closeFuture = connection.close()
181+
} else {
182+
closeFuture = eventLoop.makeSucceededVoidFuture()
183+
}
184+
closeFuture.whenComplete { result in
185+
let closeError: Error?
186+
switch result {
187+
case .failure(let error):
188+
closeError = error
189+
case .success:
190+
closeError = nil
191+
}
192+
self.shutdownListeners.notify(.success(()))
193+
self.shutdownEventLoopGroup(queue: queue) { error in
194+
callback(closeError ?? error)
195+
}
196+
}
197+
}
198+
199+
/// shutdown EventLoopGroup
200+
private func shutdownEventLoopGroup(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
128201
switch self.eventLoopGroupProvider {
129-
case .createNew:
130-
try self.eventLoopGroup.syncShutdownGracefully()
131202
case .shared:
132-
break
203+
queue.async {
204+
callback(nil)
205+
}
206+
case .createNew:
207+
self.eventLoopGroup.shutdownGracefully(queue: queue, callback)
133208
}
134209
}
135210

Sources/MQTTNIO/MQTTClientV5.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ extension MQTTClient {
158158
if case .success(let info) = result {
159159
for property in info.properties {
160160
if case .subscriptionIdentifier(let id) = property,
161-
id == subscriptionId {
161+
id == subscriptionId
162+
{
162163
listener(info)
163164
break
164165
}

Sources/MQTTNIO/MQTTConnection.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ final class MQTTConnection {
181181
let task = MQTTTask(on: channel.eventLoop, timeout: self.timeout, checkInbound: checkInbound)
182182
let taskHandler = MQTTTaskHandler(task: task, channel: channel)
183183

184-
channel.pipeline.addHandler(taskHandler, position: .before(self.unhandledHandler))
184+
self.channel.pipeline.addHandler(taskHandler, position: .before(self.unhandledHandler))
185185
.flatMap {
186186
self.channel.writeAndFlush(message)
187187
}

Sources/MQTTNIO/MQTTError.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public enum MQTTError: Error {
2020

2121
/// You called connect on a client that is already connected to the broker
2222
case alreadyConnected
23+
/// Client has already been shutdown
24+
case alreadyShutdown
2325
/// We received an unexpected message while connecting
2426
case failedToConnect
2527
/// We received an unsuccessful connection return value

Sources/MQTTNIO/WebSocketHandler.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ final class WebSocketHandler: ChannelDuplexHandler {
2525
guard context.channel.isActive else { return }
2626

2727
let buffer = unwrapOutboundIn(data)
28-
send(context: context, buffer: buffer, opcode: .binary, fin: true, promise: promise)
28+
self.send(context: context, buffer: buffer, opcode: .binary, fin: true, promise: promise)
2929
}
3030

3131
/// Read WebSocket frame

Tests/MQTTNIOTests/MQTTNIOTests+async.swift

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ final class AsyncMQTTNIOTests: XCTestCase {
5151
self.XCTRunAsyncAndBlock {
5252
try await client.connect()
5353
try await client.disconnect()
54-
try client.syncShutdownGracefully()
54+
try await client.shutdown()
5555
}
5656
}
5757

5858
func testPublishSubscribe() {
59+
let expectation = XCTestExpectation(description: "testPublishSubscribe")
60+
expectation.expectedFulfillmentCount = 1
61+
5962
let client = self.createClient(identifier: "testPublish+async")
6063
let client2 = self.createClient(identifier: "testPublish+async2")
6164
let payloadString = "Hello"
@@ -69,15 +72,19 @@ final class AsyncMQTTNIOTests: XCTestCase {
6972
var buffer = publish.payload
7073
let string = buffer.readString(length: buffer.readableBytes)
7174
XCTAssertEqual(string, payloadString)
75+
expectation.fulfill()
7276
case .failure(let error):
7377
XCTFail("\(error)")
7478
}
7579
}
7680
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: payloadString), qos: .atLeastOnce)
81+
82+
self.wait(for: [expectation], timeout: 2)
83+
7784
try await client.disconnect()
78-
Thread.sleep(forTimeInterval: 2)
7985
try await client2.disconnect()
80-
try client.syncShutdownGracefully()
86+
try await client.shutdown()
87+
try await client2.shutdown()
8188
}
8289
}
8390

@@ -95,18 +102,19 @@ final class AsyncMQTTNIOTests: XCTestCase {
95102
try await client.connect()
96103
try await client.ping()
97104
try await client.disconnect()
98-
try client.syncShutdownGracefully()
105+
try await client.shutdown()
99106
}
100107
}
101108

102109
func testAsyncSequencePublishListener() {
103-
let client = createClient(identifier: "testAsyncSequencePublishListener+async", version: .v5_0)
104-
let client2 = createClient(identifier: "testAsyncSequencePublishListener+async2", version: .v5_0)
105-
let payloadString = "Hello"
106-
let expectation = XCTestExpectation(description: "publish listener")
110+
let expectation = XCTestExpectation(description: "testAsyncSequencePublishListener")
107111
expectation.expectedFulfillmentCount = 3
108-
109-
XCTRunAsyncAndBlock {
112+
113+
let client = self.createClient(identifier: "testAsyncSequencePublishListener+async", version: .v5_0)
114+
let client2 = self.createClient(identifier: "testAsyncSequencePublishListener+async2", version: .v5_0)
115+
let payloadString = "Hello"
116+
117+
self.XCTRunAsyncAndBlock {
110118
try await client.connect()
111119
try await client2.connect()
112120
_ = try await client2.v5.subscribe(to: [.init(topicFilter: "TestSubject", qos: .atLeastOnce)])
@@ -131,24 +139,24 @@ final class AsyncMQTTNIOTests: XCTestCase {
131139
Thread.sleep(forTimeInterval: 0.5)
132140
try await client2.disconnect()
133141
Thread.sleep(forTimeInterval: 0.5)
134-
try client.syncShutdownGracefully()
135-
try client2.syncShutdownGracefully()
142+
try await client.shutdown()
143+
try await client2.shutdown()
136144

137145
_ = await task.result
138146
}
139147
wait(for: [expectation], timeout: 5.0)
140148
}
141149

142150
func testAsyncSequencePublishSubscriptionIdListener() {
143-
let client = createClient(identifier: "testAsyncSequencePublishSubscriptionIdListener+async", version: .v5_0)
144-
let client2 = createClient(identifier: "testAsyncSequencePublishSubscriptionIdListener+async2", version: .v5_0)
151+
let client = self.createClient(identifier: "testAsyncSequencePublishSubscriptionIdListener+async", version: .v5_0)
152+
let client2 = self.createClient(identifier: "testAsyncSequencePublishSubscriptionIdListener+async2", version: .v5_0)
145153
let payloadString = "Hello"
146154
let expectation = XCTestExpectation(description: "publish listener")
147155
let expectation2 = XCTestExpectation(description: "publish listener2")
148156
expectation.expectedFulfillmentCount = 3
149157
expectation2.expectedFulfillmentCount = 2
150158

151-
XCTRunAsyncAndBlock {
159+
self.XCTRunAsyncAndBlock {
152160
try await client.connect()
153161
try await client2.connect()
154162
_ = try await client2.v5.subscribe(to: [.init(topicFilter: "TestSubject", qos: .atLeastOnce)], properties: [.subscriptionIdentifier(1)])

Tests/MQTTNIOTests/MQTTNIOv5Tests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ final class MQTTNIOv5Tests: XCTestCase {
234234
connack = try client.v5.connect(cleanStart: false).wait()
235235
XCTAssertEqual(connack.sessionPresent, true)
236236
try client.disconnect().wait()
237+
try client.syncShutdownGracefully()
237238
}
238239

239240
func testPersistentSession() throws {

0 commit comments

Comments
 (0)