Skip to content

Commit d224c1d

Browse files
authored
feat: Move RPCv2CBOR generation to smithy-swift (#900)
1 parent 524c8ef commit d224c1d

16 files changed

+992
-11
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
import SmithyIdentity
9+
import SmithyIdentityAPI
10+
import protocol SmithyHTTPAPI.HTTPClient
11+
import struct SmithyRetries.DefaultRetryStrategy
12+
import struct SmithyRetries.ExponentialBackoffStrategy
13+
import struct SmithyRetriesAPI.RetryStrategyOptions
14+
15+
public typealias RuntimeConfigType
16+
= DefaultSDKRuntimeConfiguration<DefaultRetryStrategy, DefaultRetryErrorInfoProvider>
17+
18+
open class ClientConfigDefaultsProvider {
19+
/// Returns a default `HTTPClient` engine.
20+
open class func httpClientEngine() -> HTTPClient {
21+
return RuntimeConfigType.makeClient(
22+
httpClientConfiguration: RuntimeConfigType.defaultHttpClientConfiguration
23+
)
24+
}
25+
26+
/// Returns default `HttpClientConfiguration`.
27+
open class func httpClientConfiguration() -> HttpClientConfiguration {
28+
return RuntimeConfigType.defaultHttpClientConfiguration
29+
}
30+
31+
/// Returns a default idempotency token generator.
32+
open class func idempotencyTokenGenerator() -> IdempotencyTokenGenerator {
33+
return RuntimeConfigType.defaultIdempotencyTokenGenerator
34+
}
35+
36+
/// Returns a default client logging mode.
37+
open class func clientLogMode() -> ClientLogMode {
38+
return RuntimeConfigType.defaultClientLogMode
39+
}
40+
41+
/// Returns default retry strategy options *without* referencing AWS-specific config.
42+
open class func retryStrategyOptions(maxAttempts: Int? = nil) -> RetryStrategyOptions {
43+
// Provide some simple fallback for non-AWS usage, e.g. a standard exponential backoff.
44+
let attempts = maxAttempts ?? 3
45+
return RetryStrategyOptions(
46+
backoffStrategy: ExponentialBackoffStrategy(),
47+
maxRetriesBase: attempts - 1,
48+
rateLimitingMode: .standard
49+
)
50+
}
51+
}

Sources/ClientRuntime/Endpoints/EndpointResolverMiddleware.swift

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,17 @@ import enum SmithyHTTPAuthAPI.SigningPropertyKeys
1616
public struct EndpointResolverMiddleware<OperationStackOutput, Params: EndpointsRequestContextProviding> {
1717
public let id: Swift.String = "EndpointResolverMiddleware"
1818

19-
let endpointResolverBlock: (Params) throws -> Endpoint
20-
21-
let endpointParams: Params
22-
19+
let paramsBlock: (Context) throws -> Params
20+
let resolverBlock: (Params) throws -> Endpoint
2321
let authSchemeResolver: EndpointsAuthSchemeResolver
2422

2523
public init(
26-
endpointResolverBlock: @escaping (Params) throws -> Endpoint,
27-
endpointParams: Params,
24+
paramsBlock: @escaping (Context) throws -> Params,
25+
resolverBlock: @escaping (Params) throws -> Endpoint,
2826
authSchemeResolver: EndpointsAuthSchemeResolver = DefaultEndpointsAuthSchemeResolver()
2927
) {
30-
self.endpointResolverBlock = endpointResolverBlock
31-
self.endpointParams = endpointParams
28+
self.paramsBlock = paramsBlock
29+
self.resolverBlock = resolverBlock
3230
self.authSchemeResolver = authSchemeResolver
3331
}
3432
}
@@ -42,7 +40,7 @@ extension EndpointResolverMiddleware: ApplyEndpoint {
4240
) async throws -> HTTPRequest {
4341
let builder = request.toBuilder()
4442

45-
let endpoint = try endpointResolverBlock(endpointParams)
43+
let endpoint = try resolverBlock(paramsBlock(attributes))
4644

4745
var signingName: String?
4846
var signingAlgorithm: String?
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
import class SmithyHTTPAPI.HTTPResponse
9+
@_spi(SmithyReadWrite) import class SmithyCBOR.Reader
10+
11+
public struct RpcV2CborError: BaseError {
12+
public let code: String
13+
public let message: String?
14+
public let requestID: String?
15+
@_spi(SmithyReadWrite) public var errorBodyReader: Reader { responseReader }
16+
17+
public let httpResponse: HTTPResponse
18+
private let responseReader: Reader
19+
20+
@_spi(SmithyReadWrite)
21+
public init(httpResponse: HTTPResponse, responseReader: Reader, noErrorWrapping: Bool, code: String? = nil) throws {
22+
switch responseReader.cborValue {
23+
case .map(let errorDetails):
24+
if case let .text(errorCode) = errorDetails["__type"] {
25+
self.code = sanitizeErrorType(errorCode)
26+
} else {
27+
self.code = "UnknownError"
28+
}
29+
30+
if case let .text(errorMessage) = errorDetails["Message"] {
31+
self.message = errorMessage
32+
} else {
33+
self.message = nil
34+
}
35+
default:
36+
self.code = "UnknownError"
37+
self.message = nil
38+
}
39+
40+
self.httpResponse = httpResponse
41+
self.responseReader = responseReader
42+
self.requestID = nil
43+
}
44+
}
45+
46+
/// Filter additional information from error name and sanitize it
47+
/// Reference: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html#operation-error-serialization
48+
func sanitizeErrorType(_ type: String) -> String {
49+
return type.substringAfter("#").substringBefore(":").trim()
50+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
import SmithyHTTPAPI
9+
10+
public struct CborValidateResponseHeaderMiddleware<Input, Output> {
11+
public let id: Swift.String = "CborValidateResponseHeaderMiddleware"
12+
13+
public init() {}
14+
}
15+
16+
public enum ServiceResponseError: Error {
17+
case missingHeader(String)
18+
case badHeaderValue(String)
19+
}
20+
21+
extension CborValidateResponseHeaderMiddleware: Interceptor {
22+
23+
public typealias InputType = Input
24+
public typealias OutputType = Output
25+
public typealias RequestType = HTTPRequest
26+
public typealias ResponseType = HTTPResponse
27+
28+
public func readBeforeDeserialization(
29+
context: some BeforeDeserialization<InputType, RequestType, ResponseType>
30+
) async throws {
31+
let response = context.getResponse()
32+
let smithyProtocolHeader = response.headers.value(for: "smithy-protocol")
33+
34+
guard let smithyProtocolHeader else {
35+
throw ServiceResponseError.missingHeader(
36+
"smithy-protocol header is missing from a response over RpcV2 Cbor!"
37+
)
38+
}
39+
40+
guard smithyProtocolHeader == "rpc-v2-cbor" else {
41+
throw ServiceResponseError.badHeaderValue(
42+
"smithy-protocol header is set to \(smithyProtocolHeader) instead of expected value rpc-v2-cbor"
43+
)
44+
}
45+
}
46+
}
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package software.amazon.smithy.swift.codegen
7+
8+
import software.amazon.smithy.codegen.core.CodegenException
9+
import software.amazon.smithy.model.node.Node
10+
import software.amazon.smithy.rulesengine.language.EndpointRuleSet
11+
import software.amazon.smithy.rulesengine.language.evaluation.value.ArrayValue
12+
import software.amazon.smithy.rulesengine.language.evaluation.value.BooleanValue
13+
import software.amazon.smithy.rulesengine.language.evaluation.value.EmptyValue
14+
import software.amazon.smithy.rulesengine.language.evaluation.value.IntegerValue
15+
import software.amazon.smithy.rulesengine.language.evaluation.value.RecordValue
16+
import software.amazon.smithy.rulesengine.language.evaluation.value.StringValue
17+
import software.amazon.smithy.rulesengine.language.evaluation.value.Value
18+
import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait
19+
import software.amazon.smithy.swift.codegen.endpoints.EndpointTypes
20+
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
21+
import software.amazon.smithy.swift.codegen.swiftmodules.ClientRuntimeTypes
22+
import software.amazon.smithy.swift.codegen.swiftmodules.SmithyHTTPAPITypes
23+
import software.amazon.smithy.swift.codegen.swiftmodules.SmithyTestUtilTypes
24+
import software.amazon.smithy.swift.codegen.swiftmodules.XCTestTypes
25+
import software.amazon.smithy.swift.codegen.utils.toLowerCamelCase
26+
27+
/**
28+
* Generates code for EndpointResolver tests.
29+
*/
30+
class EndpointTestGenerator(
31+
private val endpointTest: EndpointTestsTrait,
32+
private val endpointRuleSet: EndpointRuleSet?,
33+
private val ctx: ProtocolGenerator.GenerationContext
34+
) {
35+
fun render(writer: SwiftWriter): Int {
36+
if (endpointTest.testCases.isEmpty()) { return 0 }
37+
38+
writer.addImport(ctx.settings.moduleName, isTestable = true)
39+
writer.addImport(SwiftDependency.XCTest.target)
40+
41+
// used to filter out test params that are not valid
42+
val endpointParamsMembers = endpointRuleSet?.parameters?.toList()?.map { it.name.name.value }?.toSet() ?: emptySet()
43+
44+
var count = 0
45+
writer.openBlock("class EndpointResolverTest: \$N {", "}", XCTestTypes.XCTestCase) {
46+
writer.write("")
47+
writer.openBlock("override class func setUp() {", "}") {
48+
writer.write("\$N.initialize()", SmithyTestUtilTypes.TestInitializer)
49+
}
50+
writer.write("")
51+
52+
endpointTest.testCases.forEach { testCase ->
53+
writer.write("/// \$L", testCase.documentation)
54+
writer.openBlock("func testResolve${++count}() throws {", "}") {
55+
writer.openBlock("let endpointParams = \$N(", ")", EndpointTypes.EndpointParams) {
56+
val applicableParams =
57+
testCase.params.members.filter { endpointParamsMembers.contains(it.key.value) }
58+
.toSortedMap(compareBy { it.value }).map { (key, value) ->
59+
key to value
60+
}
61+
62+
applicableParams.forEachIndexed { idx, pair ->
63+
writer.writeInline("${pair.first.value.toLowerCamelCase()}: ")
64+
val value = Value.fromNode(pair.second)
65+
writer.call {
66+
generateValue(
67+
writer, value, if (idx < applicableParams.count() - 1) "," else "", false
68+
)
69+
}
70+
}
71+
}
72+
writer.write("let resolver = try \$N()", EndpointTypes.DefaultEndpointResolver).write("")
73+
74+
testCase.expect.error.ifPresent { error ->
75+
writer.openBlock(
76+
"XCTAssertThrowsError(try resolver.resolve(params: endpointParams)) { error in", "}"
77+
) {
78+
writer.openBlock("switch error {", "}") {
79+
writer.dedent().write("case \$N.unresolved(let message):", ClientRuntimeTypes.Core.EndpointError)
80+
writer.indent().write("XCTAssertEqual(\$S, message)", error)
81+
writer.dedent().write("default:")
82+
writer.indent().write("XCTFail()")
83+
}
84+
}
85+
}
86+
testCase.expect.endpoint.ifPresent { endpoint ->
87+
writer.write("let actual = try resolver.resolve(params: endpointParams)").write("")
88+
89+
// [String: AnyHashable] can't be constructed from a dictionary literal
90+
// first create a string JSON string literal
91+
// then convert to [String: AnyHashable] using JSONSerialization.jsonObject(with:)
92+
writer.openBlock("let properties: [String: AnyHashable] = ", "") {
93+
generateProperties(writer, endpoint.properties)
94+
}
95+
96+
val reference = if (endpoint.headers.isNotEmpty()) "var" else "let"
97+
writer.write("$reference headers = \$N()", SmithyHTTPAPITypes.Headers)
98+
endpoint.headers.forEach { (name, values) ->
99+
writer.write("headers.add(name: \$S, values: [\$S])", name, values.sorted().joinToString(","))
100+
}
101+
writer.write(
102+
"let expected = try \$N(urlString: \$S, headers: headers, properties: properties)",
103+
SmithyHTTPAPITypes.Endpoint,
104+
endpoint.url
105+
).write("")
106+
writer.write("XCTAssertEqual(expected, actual)")
107+
}
108+
}
109+
writer.write("")
110+
}
111+
}
112+
113+
return count
114+
}
115+
116+
/**
117+
* Recursively traverse map of properties and generate JSON string literal.
118+
*/
119+
private fun generateProperties(writer: SwiftWriter, properties: Map<String, Node>) {
120+
if (properties.isEmpty()) {
121+
writer.write("[:]")
122+
} else {
123+
writer.openBlock("[", "]") {
124+
properties.map { it.key to it.value }.forEachIndexed { idx, (first, second) ->
125+
val value = Value.fromNode(second)
126+
writer.writeInline("\$S: ", first)
127+
writer.call {
128+
generateValue(writer, value, if (idx < properties.values.count() - 1) "," else "", true)
129+
}
130+
}
131+
}
132+
}
133+
}
134+
135+
/**
136+
* Recursively traverse the value and render a JSON string literal.
137+
*/
138+
private fun generateValue(writer: SwiftWriter, value: Value, delimeter: String, castToAnyHashable: Boolean) {
139+
when (value) {
140+
is StringValue -> {
141+
writer.write("\$S$delimeter", value.toString())
142+
}
143+
144+
is IntegerValue -> {
145+
writer.write("\$L$delimeter", value.toString())
146+
}
147+
148+
is BooleanValue -> {
149+
writer.write("\$L$delimeter", value.toString())
150+
}
151+
152+
is EmptyValue -> {
153+
writer.write("nil$delimeter")
154+
}
155+
156+
is ArrayValue -> {
157+
val castStmt = if (castToAnyHashable) " as [AnyHashable]$delimeter" else delimeter
158+
writer.openBlock("[", "]$castStmt") {
159+
value.values.forEachIndexed { idx, item ->
160+
writer.call {
161+
generateValue(writer, item, if (idx < value.values.count() - 1) "," else "", castToAnyHashable)
162+
}
163+
}
164+
}
165+
}
166+
167+
is RecordValue -> {
168+
if (value.value.isEmpty()) {
169+
writer.writeInline("[:]")
170+
} else {
171+
writer.openBlock("[", "] as [String: AnyHashable]$delimeter") {
172+
value.value.map { it.key to it.value }.forEachIndexed { idx, (first, second) ->
173+
writer.writeInline("\$S: ", first.name)
174+
writer.call {
175+
generateValue(writer, second, if (idx < value.value.count() - 1) "," else "", castToAnyHashable)
176+
}
177+
}
178+
}
179+
}
180+
}
181+
182+
else -> {
183+
throw CodegenException("Unsupported value type: $value")
184+
}
185+
}
186+
}
187+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package software.amazon.smithy.swift.codegen.integration
2+
3+
import software.amazon.smithy.swift.codegen.SwiftWriter
4+
import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware
5+
6+
class SmithyHttpProtocolClientGeneratorFactory : HttpProtocolClientGeneratorFactory {
7+
override fun createHttpProtocolClientGenerator(
8+
ctx: ProtocolGenerator.GenerationContext,
9+
httpBindingResolver: HttpBindingResolver,
10+
writer: SwiftWriter,
11+
serviceName: String,
12+
defaultContentType: String,
13+
httpProtocolCustomizable: HTTPProtocolCustomizable,
14+
operationMiddleware: OperationMiddleware
15+
): HttpProtocolClientGenerator {
16+
val config = SmithyServiceConfig(writer, ctx)
17+
return HttpProtocolClientGenerator(ctx, writer, config, httpBindingResolver, defaultContentType, httpProtocolCustomizable, operationMiddleware)
18+
}
19+
}

0 commit comments

Comments
 (0)