Skip to content

Commit a9b3758

Browse files
authored
fix: Don't change HTTP request components after they have been signed (#557)
1 parent 0feeb94 commit a9b3758

File tree

16 files changed

+176
-161
lines changed

16 files changed

+176
-161
lines changed

Sources/ClientRuntime/Networking/Endpoint.swift

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import Foundation
77

8-
public struct Endpoint {
8+
public struct Endpoint: Hashable {
99
public let path: String
1010
public let queryItems: [URLQueryItem]?
1111
public let protocolType: ProtocolType?
@@ -58,55 +58,27 @@ public struct Endpoint {
5858
}
5959
}
6060

61-
public extension Endpoint {
61+
extension Endpoint {
6262
// We still have to keep 'url' as an optional, since we're
6363
// dealing with dynamic components that could be invalid.
64-
var url: URL? {
64+
public var url: URL? {
6565
var components = URLComponents()
6666
components.scheme = protocolType?.rawValue
6767
components.host = host
68-
components.path = path
69-
components.percentEncodedQueryItems = queryItems
68+
components.percentEncodedPath = path
69+
components.percentEncodedQuery = queryItemString
7070

7171
return components.url
7272
}
7373

74-
var queryItemString: String {
75-
guard let queryItems = queryItems, !queryItems.isEmpty else {
76-
return ""
77-
}
78-
let queryString = queryItems.map { "\($0.name)=\($0.value ?? "")" }.joined(separator: "&")
79-
return "?\(queryString)"
80-
}
81-
}
82-
83-
// It was discovered that in Swift 5.8 and earlier versions, the URLQueryItem type does not correctly implement
84-
// Hashable: namely, multiple URLQueryItems with the same name & value and that are equal by the == operator will have
85-
// different hash values.
86-
//
87-
// Github issue filed against open-source Foundation:
88-
// https://github.com/apple/swift-corelibs-foundation/issues/4737
89-
//
90-
// This extension is intended to correct this problem for the Endpoint type by substituting a
91-
// different structure with the same properties as URLQueryItem when the Endpoint is hashed.
92-
//
93-
// This extension may be removed, and the compiler-generated Hashable compliance may be used instead, once the
94-
// URLQueryItem's Hashable implementation is fixed in open-source Foundation.
95-
extension Endpoint: Hashable {
96-
97-
private struct QueryItem: Hashable {
98-
let name: String
99-
let value: String?
100-
}
101-
102-
public func hash(into hasher: inout Hasher) {
103-
hasher.combine(path)
104-
let queryItemElements = queryItems?.map { QueryItem(name: $0.name, value: $0.value) }
105-
hasher.combine(queryItemElements)
106-
hasher.combine(protocolType)
107-
hasher.combine(host)
108-
hasher.combine(port)
109-
hasher.combine(headers)
110-
hasher.combine(properties)
74+
var queryItemString: String? {
75+
guard let queryItems = queryItems else { return nil }
76+
return queryItems.map { queryItem in
77+
if let value = queryItem.value {
78+
return "\(queryItem.name)=\(value)"
79+
} else {
80+
return queryItem.name
81+
}
82+
}.joined(separator: "&")
11183
}
11284
}

Sources/ClientRuntime/Networking/Http/CRT/CRTClientEngine.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public class CRTClientEngine: HttpClientEngine {
4545

4646
return connectionPool
4747
}
48-
48+
4949
private func createConnectionPool(endpoint: Endpoint) throws -> HTTPClientConnectionManager {
5050
let tlsConnectionOptions = TLSConnectionOptions(
5151
context: sharedDefaultIO.tlsContext,
@@ -69,7 +69,7 @@ public class CRTClientEngine: HttpClientEngine {
6969
enableManualWindowManagement: false
7070
) // not using backpressure yet
7171
logger.debug("""
72-
Creating connection pool for \(String(describing: endpoint.url?.absoluteString)) \
72+
Creating connection pool for \(String(describing: endpoint.host)) \
7373
with max connections: \(maxConnectionsPerEndpoint)
7474
""")
7575
return try HTTPClientConnectionManager(options: options)
@@ -96,7 +96,7 @@ public class CRTClientEngine: HttpClientEngine {
9696
enableStreamManualWindowManagement: false
9797
)
9898
logger.debug("""
99-
Creating connection pool for \(String(describing: endpoint.url?.absoluteString)) \
99+
Creating connection pool for \(String(describing: endpoint.host)) \
100100
with max connections: \(maxConnectionsPerEndpoint)
101101
""")
102102

@@ -274,7 +274,7 @@ public class CRTClientEngine: HttpClientEngine {
274274
}
275275

276276
requestOptions.http2ManualDataWrites = http2ManualDataWrites
277-
277+
278278
response.body = .stream(stream)
279279
return requestOptions
280280
}

Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,56 +10,44 @@ import AwsCommonRuntimeKit
1010
// we need to maintain a reference to this same request while we add headers
1111
// in the CRT engine so that is why it's a class
1212
public class SdkHttpRequest {
13-
public var body: HttpBody
14-
public var headers: Headers
15-
public let queryItems: [URLQueryItem]?
13+
public let body: HttpBody
1614
public let endpoint: Endpoint
1715
public let method: HttpMethodType
16+
public var headers: Headers { endpoint.headers ?? Headers() }
17+
public var path: String { endpoint.path }
18+
public var host: String { endpoint.host }
19+
public var queryItems: [URLQueryItem]? { endpoint.queryItems }
1820

1921
public init(method: HttpMethodType,
2022
endpoint: Endpoint,
21-
headers: Headers,
22-
queryItems: [URLQueryItem]? = nil,
2323
body: HttpBody = HttpBody.none) {
2424
self.method = method
2525
self.endpoint = endpoint
26-
self.headers = headers
2726
self.body = body
28-
self.queryItems = queryItems
2927
}
3028
}
3129

32-
// Create a `CharacterSet` of the characters that need not be percent encoded in the
33-
// resulting URL. This set consists of alphanumerics plus underscore, dash, tilde, and
34-
// period. Any other character should be percent-encoded when used in a path segment.
35-
// Forward-slash is added as well because the segments have already been joined into a path.
36-
//
37-
// See, for URL-allowed characters:
38-
// https://www.rfc-editor.org/rfc/rfc3986#section-2.3
39-
private let allowed = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "/_-.~"))
40-
4130
extension SdkHttpRequest {
42-
public func toHttpRequest() throws -> HTTPRequest {
43-
let httpHeaders = headers.toHttpHeaders()
31+
32+
public func toHttpRequest(escaping: Bool = false) throws -> HTTPRequest {
4433
let httpRequest = try HTTPRequest()
4534
httpRequest.method = method.rawValue
46-
let encodedPath = endpoint.path.addingPercentEncoding(withAllowedCharacters: allowed) ?? endpoint.path
47-
httpRequest.path = "\(encodedPath)\(endpoint.queryItemString)"
48-
httpRequest.addHeaders(headers: httpHeaders)
35+
let encodedPath = escaping ? endpoint.path.urlPercentEncodedForPath : endpoint.path
36+
httpRequest.path = [encodedPath, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?")
37+
httpRequest.addHeaders(headers: headers.toHttpHeaders())
4938
httpRequest.body = StreamableHttpBody(body: body)
5039
return httpRequest
5140
}
5241

5342
/// Convert the SDK request to a CRT HTTPRequestBase
5443
/// CRT converts the HTTPRequestBase to HTTP2Request internally if the protocol is HTTP/2
5544
/// - Returns: the CRT request
56-
public func toHttp2Request() throws -> HTTPRequestBase {
57-
let httpHeaders = headers.toHttpHeaders()
45+
public func toHttp2Request(escaping: Bool = false) throws -> HTTPRequestBase {
5846
let httpRequest = try HTTPRequest()
5947
httpRequest.method = method.rawValue
60-
let encodedPath = endpoint.path.addingPercentEncoding(withAllowedCharacters: allowed) ?? endpoint.path
61-
httpRequest.path = "\(encodedPath)\(endpoint.queryItemString)"
62-
httpRequest.addHeaders(headers: httpHeaders)
48+
let encodedPath = escaping ? endpoint.path.urlPercentEncodedForPath : endpoint.path
49+
httpRequest.path = [encodedPath, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?")
50+
httpRequest.addHeaders(headers: headers.toHttpHeaders())
6351

6452
// HTTP2Request used with manual writes hence we need to set the body to nil
6553
// so that CRT does not write the body for us (we will write it manually)
@@ -96,11 +84,11 @@ extension SdkHttpRequestBuilder {
9684
public func update(from crtRequest: HTTPRequestBase, originalRequest: SdkHttpRequest) -> SdkHttpRequestBuilder {
9785
headers = convertSignedHeadersToHeaders(crtRequest: crtRequest)
9886
methodType = originalRequest.method
99-
host = originalRequest.endpoint.host
100-
if let crtRequest = crtRequest as? HTTPRequest {
101-
let pathAndQueryItems = URLComponents(string: crtRequest.path)
102-
path = pathAndQueryItems?.path ?? "/"
103-
queryItems = pathAndQueryItems?.percentEncodedQueryItems ?? [URLQueryItem]()
87+
host = originalRequest.host
88+
if let crtRequest = crtRequest as? HTTPRequest, let components = URLComponents(string: crtRequest.path) {
89+
path = components.percentEncodedPath
90+
queryItems = components.percentEncodedQueryItems?.map { URLQueryItem(name: $0.name, value: $0.value) }
91+
?? [URLQueryItem]()
10492
} else if crtRequest as? HTTP2Request != nil {
10593
assertionFailure("HTTP2Request not supported")
10694
} else {
@@ -123,11 +111,11 @@ public class SdkHttpRequestBuilder {
123111
var host: String = ""
124112
var path: String = "/"
125113
var body: HttpBody = .none
126-
var queryItems = [URLQueryItem]()
114+
var queryItems: [URLQueryItem]? = nil
127115
var port: Int16 = 443
128116
var protocolType: ProtocolType = .https
129117

130-
public var currentQueryItems: [URLQueryItem] {
118+
public var currentQueryItems: [URLQueryItem]? {
131119
return queryItems
132120
}
133121

@@ -179,14 +167,14 @@ public class SdkHttpRequestBuilder {
179167

180168
@discardableResult
181169
public func withQueryItems(_ value: [URLQueryItem]) -> SdkHttpRequestBuilder {
182-
self.queryItems = value
170+
self.queryItems = self.queryItems ?? []
171+
self.queryItems?.append(contentsOf: value)
183172
return self
184173
}
185174

186175
@discardableResult
187176
public func withQueryItem(_ value: URLQueryItem) -> SdkHttpRequestBuilder {
188-
self.queryItems.append(value)
189-
return self
177+
withQueryItems([value])
190178
}
191179

192180
@discardableResult
@@ -206,11 +194,10 @@ public class SdkHttpRequestBuilder {
206194
path: path,
207195
port: port,
208196
queryItems: queryItems,
209-
protocolType: protocolType)
197+
protocolType: protocolType,
198+
headers: headers)
210199
return SdkHttpRequest(method: methodType,
211200
endpoint: endpoint,
212-
headers: headers,
213-
queryItems: queryItems,
214201
body: body)
215202
}
216203
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
import Foundation
9+
10+
// Creates a `CharacterSet` of the characters that need not be percent encoded in the
11+
// resulting URL. This set consists of alphanumerics plus underscore, dash, tilde, and
12+
// period. Any other character should be percent-encoded when used in a path segment.
13+
// Forward-slash is added as well because the segments have already been joined into a path.
14+
//
15+
// See, for URL-allowed characters:
16+
// https://www.rfc-editor.org/rfc/rfc3986#section-2.3
17+
private let allowedForPath = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "/_-.~"))
18+
private let allowedForQuery = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "_-.~"))
19+
20+
extension String {
21+
22+
/// Encodes a URL component for inclusion in the path or query items, using percent-escaping.
23+
///
24+
/// All characters except alphanumerics plus forward slash, underscore, dash, tilde, and period will be escaped.
25+
var urlPercentEncodedForPath: String {
26+
addingPercentEncoding(withAllowedCharacters: allowedForPath) ?? self
27+
}
28+
29+
/// Encodes a URL component for inclusion in query item name or value, using percent-escaping.
30+
///
31+
/// All characters except alphanumerics plus forward slash, underscore, dash, tilde, and period will be escaped.
32+
var urlPercentEncodedForQuery: String {
33+
addingPercentEncoding(withAllowedCharacters: allowedForQuery) ?? self
34+
}
35+
}

Sources/ClientRuntime/PrimitiveTypeExtensions/String+Extensions.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,7 @@ extension String {
101101

102102
extension String {
103103
public func urlPercentEncoding() -> String {
104-
if let encodedString = self.addingPercentEncoding(withAllowedCharacters: .singleUrlQueryAllowed) {
105-
return encodedString
106-
}
107-
return self
104+
self.urlPercentEncodedForQuery
108105
}
109106
}
110107

Sources/ClientRuntime/PrimitiveTypeExtensions/URL+Extension.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44
*/
55

66
import Foundation
7+
78
public typealias URL = Foundation.URL
9+
810
extension URL {
9-
func toQueryItems() -> [URLQueryItem]? { return URLComponents(url: self,
10-
resolvingAgainstBaseURL: false)?.queryItems }
11+
12+
func toQueryItems() -> [URLQueryItem]? {
13+
URLComponents(url: self, resolvingAgainstBaseURL: false)?
14+
.queryItems?
15+
.map { URLQueryItem(name: $0.name, value: $0.value) }
16+
}
1117
}

Sources/ClientRuntime/PrimitiveTypeExtensions/URLQueryItem+Extensions.swift

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33
* SPDX-License-Identifier: Apache-2.0.
44
*/
55

6-
import struct Foundation.URLQueryItem
7-
public typealias URLQueryItem = Foundation.URLQueryItem
6+
public typealias URLQueryItem = MyURLQueryItem
87

9-
extension URLQueryItem: Comparable {
10-
/// Compares two `URLQueryItem` instances by their `name` property.
11-
/// - Parameters:
12-
/// - lhs: The first `URLQueryItem` to compare.
13-
/// - rhs: The second `URLQueryItem` to compare.
14-
/// - Returns: `true` if the `name` property of `lhs` is less than the `name` property of `rhs`.
15-
public static func < (lhs: URLQueryItem, rhs: URLQueryItem) -> Bool {
16-
lhs.name < rhs.name
8+
public struct MyURLQueryItem: Hashable {
9+
public var name: String
10+
public var value: String?
11+
12+
public init(name: String, value: String?) {
13+
self.name = name
14+
self.value = value
1715
}
1816
}

Sources/SmithyTestUtil/RequestTestUtil/HttpRequestTestBase+FormURL.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//
55
// SPDX-License-Identifier: Apache-2.0
66
//
7+
78
import XCTest
89
import ClientRuntime
910

@@ -19,12 +20,12 @@ extension HttpRequestTestBase {
1920
assertQueryItems(expectedQueryItems, actualQueryItems, file: file, line: line)
2021
}
2122

22-
private func convertToQueryItems(data: Data) -> [URLQueryItem] {
23+
private func convertToQueryItems(data: Data) -> [ClientRuntime.URLQueryItem] {
2324
guard let queryString = String(data: data, encoding: .utf8) else {
2425
XCTFail("Failed to decode data")
2526
return []
2627
}
27-
var queryItems: [URLQueryItem] = []
28+
var queryItems: [ClientRuntime.URLQueryItem] = []
2829
let sanitizedQueryString = queryString.replacingOccurrences(of: "\n", with: "")
2930
let keyValuePairs = sanitizedQueryString.components(separatedBy: "&")
3031
for keyValue in keyValuePairs {

0 commit comments

Comments
 (0)