Skip to content

Commit 936f11f

Browse files
authored
Use swift-atomics (#122)
* Use swift-atomics * Add tests to ensure swift-atomics change works * Use ManagedAtomic in tests * Remove @preconcurrency from NIO imports * Remove NIOConcurrencyHelpers where not needed
1 parent bc68c7c commit 936f11f

File tree

8 files changed

+59
-37
lines changed

8 files changed

+59
-37
lines changed

Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ let package = Package(
99
.library(name: "MQTTNIO", targets: ["MQTTNIO"]),
1010
],
1111
dependencies: [
12+
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"),
1213
.package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"),
1314
.package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"),
1415
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"),
@@ -19,6 +20,7 @@ let package = Package(
1920
name: "MQTTNIO",
2021
dependencies:
2122
[
23+
.product(name: "Atomics", package: "swift-atomics"),
2224
.product(name: "Logging", package: "swift-log"),
2325
.product(name: "NIO", package: "swift-nio"),
2426
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),

Sources/MQTTNIO/MQTTClient.swift

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
import Atomics
1415
import Dispatch
1516
import Logging
1617
#if canImport(Network)
@@ -73,13 +74,13 @@ public final class MQTTClient {
7374
return self.host
7475
}
7576

76-
private let globalPacketId = NIOAtomic<UInt16>.makeAtomic(value: 1)
77+
internal let globalPacketId = ManagedAtomic<UInt16>(1)
7778
/// default logger that logs nothing
7879
private static let loggingDisabled = Logger(label: "MQTT-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() })
7980
/// inflight messages
8081
private var inflight: MQTTInflight
8182
/// flag to tell is client is shutdown
82-
private let isShutdown = NIOAtomic<Bool>.makeAtomic(value: false)
83+
private let isShutdown = ManagedAtomic(false)
8384

8485
/// Create MQTT client
8586
/// - Parameters:
@@ -139,7 +140,7 @@ public final class MQTTClient {
139140
}
140141

141142
deinit {
142-
guard isShutdown.load() else {
143+
guard isShutdown.load(ordering: .relaxed) else {
143144
preconditionFailure("Client not shut down before the deinit. Please call client.syncShutdownGracefully() when no longer needed.")
144145
}
145146
}
@@ -188,7 +189,7 @@ public final class MQTTClient {
188189
/// - queue: Dispatch Queue to run shutdown on
189190
/// - callback: Callback called when shutdown is complete. If there was an error it will return with Error in callback
190191
public func shutdown(queue: DispatchQueue = .global(), _ callback: @escaping (Error?) -> Void) {
191-
guard self.isShutdown.compareAndExchange(expected: false, desired: true) else {
192+
guard self.isShutdown.compareExchange(expected: false, desired: true, ordering: .relaxed).exchanged else {
192193
callback(MQTTError.alreadyShutdown)
193194
return
194195
}
@@ -379,11 +380,13 @@ public final class MQTTClient {
379380
}
380381

381382
internal func updatePacketId() -> UInt16 {
383+
let id = self.globalPacketId.wrappingIncrementThenLoad(by: 1, ordering: .relaxed)
384+
382385
// packet id must be non-zero
383-
if self.globalPacketId.compareAndExchange(expected: 0, desired: 1) {
384-
return 1
386+
if id == 0 {
387+
return self.globalPacketId.wrappingIncrementThenLoad(by: 1, ordering: .relaxed)
385388
} else {
386-
return self.globalPacketId.add(1)
389+
return id
387390
}
388391
}
389392

@@ -490,7 +493,7 @@ internal extension MQTTClient {
490493

491494
func processConnack(_ connack: MQTTConnAckPacket) throws -> MQTTConnAckPacket {
492495
// connack doesn't return a packet id so this is alway 32767. Need a better way to choose first packet id
493-
_ = self.globalPacketId.exchange(with: connack.packetId + 32767)
496+
self.globalPacketId.store(connack.packetId + 32767, ordering: .relaxed)
494497
switch self.configuration.version {
495498
case .v3_1_1:
496499
if connack.returnCode != 0 {

Sources/MQTTNIO/MQTTCoreTypes.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#if compiler(>=5.6)
15-
@preconcurrency import NIOCore
16-
#else
1714
import NIOCore
18-
#endif
1915

2016
/// Indicates the level of assurance for delivery of a packet.
2117
public enum MQTTQoS: UInt8, _MQTTSendable {

Sources/MQTTNIO/MQTTPacket.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#if compiler(>=5.6)
15-
@preconcurrency import NIOCore
16-
#else
1714
import NIOCore
18-
#endif
1915

2016
internal enum InternalError: Swift.Error {
2117
case incompletePacket

Sources/MQTTNIO/MQTTProperties.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#if compiler(>=5.6)
15-
@preconcurrency import NIOCore
16-
#else
1714
import NIOCore
18-
#endif
1915

2016
/// MQTT v5.0 properties. A property consists of a identifier and a value
2117
public struct MQTTProperties: _MQTTSendable {

Tests/MQTTNIOTests/MQTTNIOTests+async.swift

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
#if compiler(>=5.5) && canImport(_Concurrency)
1515

16+
import Atomics
1617
import Logging
1718
import NIO
18-
import NIOConcurrencyHelpers
1919
import NIOFoundationCompat
2020
import NIOHTTP1
2121
import XCTest
@@ -99,8 +99,8 @@ final class AsyncMQTTNIOTests: XCTestCase {
9999
}
100100

101101
func testAsyncSequencePublishListener() async throws {
102-
let expectation = NIOAtomic.makeAtomic(value: 0)
103-
let finishExpectation = NIOAtomic.makeAtomic(value: 0)
102+
let expectation = ManagedAtomic(0)
103+
let finishExpectation = ManagedAtomic(0)
104104

105105
let client = self.createClient(identifier: "testAsyncSequencePublishListener+async", version: .v5_0)
106106
let client2 = self.createClient(identifier: "testAsyncSequencePublishListener+async2", version: .v5_0)
@@ -116,13 +116,13 @@ final class AsyncMQTTNIOTests: XCTestCase {
116116
var buffer = publish.payload
117117
let string = buffer.readString(length: buffer.readableBytes)
118118
print("Received: \(string ?? "nothing")")
119-
expectation.add(1)
119+
expectation.wrappingIncrement(ordering: .relaxed)
120120

121121
case .failure(let error):
122122
XCTFail("\(error)")
123123
}
124124
}
125-
finishExpectation.add(1)
125+
finishExpectation.wrappingIncrement(ordering: .relaxed)
126126
}
127127
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: "Hello"), qos: .atLeastOnce)
128128
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: "Goodbye"), qos: .atLeastOnce)
@@ -136,13 +136,13 @@ final class AsyncMQTTNIOTests: XCTestCase {
136136

137137
_ = await task.result
138138

139-
XCTAssertEqual(expectation.load(), 2)
140-
XCTAssertEqual(finishExpectation.load(), 1)
139+
XCTAssertEqual(expectation.load(ordering: .relaxed), 2)
140+
XCTAssertEqual(finishExpectation.load(ordering: .relaxed), 1)
141141
}
142142

143143
func testAsyncSequencePublishSubscriptionIdListener() async throws {
144-
let expectation = NIOAtomic.makeAtomic(value: 0)
145-
let expectation2 = NIOAtomic.makeAtomic(value: 0)
144+
let expectation = ManagedAtomic(0)
145+
let expectation2 = ManagedAtomic(0)
146146

147147
let client = self.createClient(identifier: "testAsyncSequencePublishSubscriptionIdListener+async", version: .v5_0)
148148
let client2 = self.createClient(identifier: "testAsyncSequencePublishSubscriptionIdListener+async2", version: .v5_0)
@@ -155,16 +155,16 @@ final class AsyncMQTTNIOTests: XCTestCase {
155155
let task = Task {
156156
let publishListener = client2.v5.createPublishListener(subscriptionId: 1)
157157
for await _ in publishListener {
158-
expectation.add(1)
158+
expectation.wrappingIncrement(ordering: .relaxed)
159159
}
160-
expectation.add(1)
160+
expectation.wrappingIncrement(ordering: .relaxed)
161161
}
162162
let task2 = Task {
163163
let publishListener = client2.v5.createPublishListener(subscriptionId: 2)
164164
for await _ in publishListener {
165-
expectation2.add(1)
165+
expectation2.wrappingIncrement(ordering: .relaxed)
166166
}
167-
expectation2.add(1)
167+
expectation2.wrappingIncrement(ordering: .relaxed)
168168
}
169169
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: payloadString), qos: .atLeastOnce)
170170
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: payloadString), qos: .atLeastOnce)
@@ -181,8 +181,8 @@ final class AsyncMQTTNIOTests: XCTestCase {
181181
_ = await task.result
182182
_ = await task2.result
183183

184-
XCTAssertEqual(expectation.load(), 3)
185-
XCTAssertEqual(expectation2.load(), 2)
184+
XCTAssertEqual(expectation.load(ordering: .relaxed), 3)
185+
XCTAssertEqual(expectation2.load(ordering: .relaxed), 2)
186186
}
187187
}
188188

Tests/MQTTNIOTests/MQTTNIOTests.swift

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import Foundation
1515
import Logging
1616
import NIO
17-
import NIOConcurrencyHelpers
1817
import NIOFoundationCompat
1918
import NIOHTTP1
2019
import XCTest
@@ -546,6 +545,37 @@ final class MQTTNIOTests: XCTestCase {
546545
#endif
547546
}
548547

548+
func testDoubleShutdown() throws {
549+
let client = self.createClient(identifier: "DoubleShutdown")
550+
try client.syncShutdownGracefully()
551+
do {
552+
try client.syncShutdownGracefully()
553+
XCTFail("testDoubleShutdown: Should fail after second shutdown")
554+
} catch MQTTError.alreadyShutdown {}
555+
}
556+
557+
func testPacketId() throws {
558+
var logger = Logger(label: "MQTTTests")
559+
logger.logLevel = .info
560+
let client = MQTTClient(
561+
host: Self.hostname,
562+
port: 1883,
563+
identifier: "testPacketId",
564+
eventLoopGroupProvider: .createNew,
565+
logger: logger
566+
)
567+
defer { XCTAssertNoThrow(try client.syncShutdownGracefully()) }
568+
569+
_ = try client.connect().wait()
570+
let packetId = client.globalPacketId.load(ordering: .relaxed)
571+
try client.publish(to: "testPersistentAtLeastOnce", payload: ByteBufferAllocator().buffer(capacity: 0), qos: .atLeastOnce).wait()
572+
XCTAssertEqual(packetId + 1, client.globalPacketId.load(ordering: .relaxed))
573+
try client.publish(to: "testPersistentAtLeastOnce", payload: ByteBufferAllocator().buffer(capacity: 0), qos: .atLeastOnce).wait()
574+
XCTAssertEqual(packetId + 2, client.globalPacketId.load(ordering: .relaxed))
575+
576+
try client.disconnect().wait()
577+
}
578+
549579
// MARK: Helper variables and functions
550580

551581
func createClient(identifier: String, configuration: MQTTClient.Configuration = .init()) -> MQTTClient {

Tests/MQTTNIOTests/MQTTNIOv5Tests.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import Foundation
1515
import Logging
1616
import NIO
17-
import NIOConcurrencyHelpers
1817
import NIOFoundationCompat
1918
import NIOHTTP1
2019
import XCTest

0 commit comments

Comments
 (0)