Skip to content

Commit 44e4a16

Browse files
authored
fixes data race in CRTClientEngine (#424)
* fixes data race in CRTClientEngine * creates nested actor named SerialExecutor in CRTClientEngine * public functions were directed to the actor to manage connection pools
1 parent 902cbcd commit 44e4a16

File tree

3 files changed

+78
-55
lines changed

3 files changed

+78
-55
lines changed

Packages/ClientRuntime/Sources/Networking/Http/CRT/CRTClientEngine.swift

Lines changed: 74 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,67 @@ import Darwin
1111
#endif
1212

1313
public class CRTClientEngine: HttpClientEngine {
14+
actor SerialExecutor {
15+
private var logger: LogAgent
16+
17+
private let windowSize: Int
18+
private let maxConnectionsPerEndpoint: Int
19+
private var connectionPools: [Endpoint: HttpClientConnectionManager] = [:]
20+
21+
init(config: CRTClientEngineConfig) {
22+
self.windowSize = config.windowSize
23+
self.maxConnectionsPerEndpoint = config.maxConnectionsPerEndpoint
24+
self.logger = SwiftLogger(label: "SerialExecutor")
25+
}
26+
27+
func getOrCreateConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {
28+
guard let connectionPool = connectionPools[endpoint] else {
29+
let newConnectionPool = createConnectionPool(endpoint: endpoint)
30+
connectionPools[endpoint] = newConnectionPool // save in dictionary
31+
return newConnectionPool
32+
}
33+
34+
return connectionPool
35+
}
36+
37+
func closeAllPendingConnections() {
38+
for (endpoint, value) in connectionPools {
39+
logger.debug("Connection to endpoint: \(String(describing: endpoint.url?.absoluteString)) is closing")
40+
value.closePendingConnections()
41+
}
42+
}
43+
44+
private func createConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {
45+
let tlsConnectionOptions = SDKDefaultIO.shared.tlsContext.newConnectionOptions()
46+
do {
47+
try tlsConnectionOptions.setServerName(endpoint.host)
48+
} catch let err {
49+
logger.error("Server name was not able to be set in TLS Connection Options. TLS Negotiation will fail.")
50+
logger.error("Error: \(err.localizedDescription)")
51+
}
52+
let socketOptions = SocketOptions(socketType: .stream)
53+
#if os(iOS) || os(watchOS)
54+
socketOptions.connectTimeoutMs = 30_000
55+
#endif
56+
let options = HttpClientConnectionOptions(clientBootstrap: SDKDefaultIO.shared.clientBootstrap,
57+
hostName: endpoint.host,
58+
initialWindowSize: windowSize,
59+
port: UInt16(endpoint.port),
60+
proxyOptions: nil,
61+
socketOptions: socketOptions,
62+
tlsOptions: tlsConnectionOptions,
63+
monitoringOptions: nil,
64+
maxConnections: maxConnectionsPerEndpoint,
65+
enableManualWindowManagement: false) // not using backpressure yet
66+
logger.debug("Creating connection pool for \(String(describing: endpoint.url?.absoluteString))" +
67+
"with max connections: \(maxConnectionsPerEndpoint)")
68+
return HttpClientConnectionManager(options: options)
69+
}
70+
}
71+
1472
public typealias StreamContinuation = CheckedContinuation<HttpResponse, Error>
1573
private var logger: LogAgent
16-
private var connectionPools: [Endpoint: HttpClientConnectionManager] = [:]
74+
private let serialExecutor: SerialExecutor
1775
private let CONTENT_LENGTH_HEADER = "Content-Length"
1876
private let AWS_COMMON_RUNTIME = "AwsCommonRuntime"
1977
private let DEFAULT_STREAM_WINDOW_SIZE = 16 * 1024 * 1024 // 16 MB
@@ -26,70 +84,33 @@ public class CRTClientEngine: HttpClientEngine {
2684
self.maxConnectionsPerEndpoint = config.maxConnectionsPerEndpoint
2785
self.windowSize = config.windowSize
2886
self.logger = SwiftLogger(label: "CRTClientEngine")
29-
}
30-
31-
private func createConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {
32-
let tlsConnectionOptions = SDKDefaultIO.shared.tlsContext.newConnectionOptions()
33-
do {
34-
try tlsConnectionOptions.setServerName(endpoint.host)
35-
} catch let err {
36-
logger.error("Server name was not able to be set in TLS Connection Options. TLS Negotiation will fail.")
37-
logger.error("Error: \(err.localizedDescription)")
38-
}
39-
let socketOptions = SocketOptions(socketType: .stream)
40-
#if os(iOS) || os(watchOS)
41-
socketOptions.connectTimeoutMs = 30_000
42-
#endif
43-
let options = HttpClientConnectionOptions(clientBootstrap: SDKDefaultIO.shared.clientBootstrap,
44-
hostName: endpoint.host,
45-
initialWindowSize: windowSize,
46-
port: UInt16(endpoint.port),
47-
proxyOptions: nil,
48-
socketOptions: socketOptions,
49-
tlsOptions: tlsConnectionOptions,
50-
monitoringOptions: nil,
51-
maxConnections: maxConnectionsPerEndpoint,
52-
enableManualWindowManagement: false) // not using backpressure yet
53-
logger.debug("Creating connection pool for \(String(describing: endpoint.url?.absoluteString))" +
54-
"with max connections: \(maxConnectionsPerEndpoint)")
55-
return HttpClientConnectionManager(options: options)
56-
}
57-
58-
private func getOrCreateConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {
59-
60-
guard let connectionPool = connectionPools[endpoint] else {
61-
let newConnectionPool = createConnectionPool(endpoint: endpoint)
62-
connectionPools[endpoint] = newConnectionPool // save in dictionary
63-
return newConnectionPool
64-
}
65-
66-
return connectionPool
87+
self.serialExecutor = SerialExecutor(config: config)
6788
}
6889

6990
public func execute(request: SdkHttpRequest) async throws -> HttpResponse {
70-
let connectionMgr = getOrCreateConnectionPool(endpoint: request.endpoint)
91+
let connectionMgr = await serialExecutor.getOrCreateConnectionPool(endpoint: request.endpoint)
7192
let connection = try await connectionMgr.acquireConnection()
7293
self.logger.debug("Connection was acquired to: \(String(describing: request.endpoint.url?.absoluteString))")
7394
return try await withCheckedThrowingContinuation({ (continuation: StreamContinuation) in
74-
let requestOptions = makeHttpRequestStreamOptions(request, continuation)
75-
let stream = connection.makeRequest(requestOptions: requestOptions)
76-
stream.activate()
95+
do {
96+
let requestOptions = makeHttpRequestStreamOptions(request, continuation)
97+
let stream = try connection.makeRequest(requestOptions: requestOptions)
98+
try stream.activate()
99+
} catch {
100+
continuation.resume(throwing: error)
101+
}
77102
})
78-
79103
}
80104

81-
public func close() {
82-
for (endpoint, value) in connectionPools {
83-
logger.debug("Connection to endpoint: \(String(describing: endpoint.url?.absoluteString)) is closing")
84-
value.closePendingConnections()
85-
}
105+
public func close() async {
106+
await serialExecutor.closeAllPendingConnections()
86107
}
87108

88109
public func makeHttpRequestStreamOptions(_ request: SdkHttpRequest, _ continuation: StreamContinuation) -> HttpRequestOptions {
89110
let response = HttpResponse()
90111
let crtRequest = request.toHttpRequest(bufferSize: windowSize)
91112
let streamReader: StreamReader = DataStreamReader()
92-
113+
93114
let requestOptions = HttpRequestOptions(request: crtRequest) { [self] (stream, _, httpHeaders) in
94115
logger.debug("headers were received")
95116
response.statusCode = HttpStatusCode(rawValue: Int(stream.statusCode)) ?? HttpStatusCode.notFound
@@ -113,11 +134,11 @@ public class CRTClientEngine: HttpClientEngine {
113134
return
114135
}
115136
}
116-
137+
117138
response.body = .stream(.reader(streamReader))
118-
139+
119140
response.statusCode = HttpStatusCode(rawValue: Int(stream.statusCode)) ?? HttpStatusCode.notFound
120-
141+
121142
continuation.resume(returning: response)
122143
}
123144
return requestOptions

Packages/ClientRuntime/Sources/Networking/Http/HttpClientEngine.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ import AwsCommonRuntimeKit
66

77
public protocol HttpClientEngine {
88
func execute(request: SdkHttpRequest) async throws -> HttpResponse
9-
func close()
9+
func close() async
1010
}

Packages/ClientRuntime/Sources/Networking/Http/SdkHttpClient.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ public class SdkHttpClient {
2424
}
2525

2626
public func close() {
27-
engine.close()
27+
Task {
28+
await self.engine.close()
29+
}
2830
}
2931

3032
}

0 commit comments

Comments
 (0)