Skip to content

Commit 03ec067

Browse files
authored
feat: Move retry middleware from SDK (#502)
1 parent 56aa830 commit 03ec067

File tree

12 files changed

+194
-1
lines changed

12 files changed

+194
-1
lines changed

Sources/ClientRuntime/Config/DefaultSDKRuntimeConfiguration.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@ public struct DefaultSDKRuntimeConfiguration: SDKRuntimeConfiguration {
1515
public let retryer: SDKRetryer
1616
public var clientLogMode: ClientLogMode
1717
public var endpoint: String?
18+
19+
/// The partition ID to be used for this configuration.
20+
///
21+
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
22+
/// If no partition ID is provided, requests will be partitioned based on the hostname.
23+
public var partitionID: String?
1824

1925
public init(
2026
_ clientName: String,
21-
clientLogMode: ClientLogMode = .request
27+
clientLogMode: ClientLogMode = .request,
28+
partitionID: String? = nil
2229
) throws {
2330
self.encoder = nil
2431
self.decoder = nil
@@ -28,5 +35,6 @@ public struct DefaultSDKRuntimeConfiguration: SDKRuntimeConfiguration {
2835
self.retryer = try SDKRetryer()
2936
self.logger = SwiftLogger(label: clientName)
3037
self.clientLogMode = clientLogMode
38+
self.partitionID = partitionID
3139
}
3240
}

Sources/ClientRuntime/Config/SDKRuntimeConfiguration.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,10 @@ public protocol SDKRuntimeConfiguration {
1818
var clientLogMode: ClientLogMode {get}
1919
var retryer: SDKRetryer {get}
2020
var endpoint: String? {get set}
21+
22+
/// The partition ID to be used for this configuration.
23+
///
24+
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
25+
/// If no partition ID is provided, requests will be partitioned based on the hostname.
26+
var partitionID: String? { get }
2127
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
public struct RetryerMiddleware<Output: HttpResponseBinding,
9+
OutputError: HttpResponseBinding>: Middleware {
10+
11+
public var id: String = "Retryer"
12+
13+
let retryer: SDKRetryer
14+
15+
public init(retryer: SDKRetryer) {
16+
self.retryer = retryer
17+
}
18+
19+
public func handle<H>(
20+
context: Context,
21+
input: SdkHttpRequestBuilder,
22+
next: H
23+
) async throws -> OperationOutput<Output> where
24+
H: Handler,
25+
Self.MInput == H.Input,
26+
Self.MOutput == H.Output,
27+
Self.Context == H.Context {
28+
29+
// Select a partition ID to be used for throttling retry requests. Requests with the
30+
// same partition ID will be "pooled" together for throttling purposes.
31+
let partitionID: String
32+
if let customPartitionID = context.getPartitionID(), !customPartitionID.isEmpty {
33+
// use custom partition ID provided by context
34+
partitionID = customPartitionID
35+
} else if !input.host.isEmpty {
36+
// fall back to the hostname for partition ID, which is a "commonsense" default
37+
partitionID = input.host
38+
} else {
39+
throw SdkError<OutputError>.client(ClientError.unknownError("Partition ID could not be determined"))
40+
}
41+
42+
do {
43+
let token = try await retryer.acquireToken(partitionId: partitionID)
44+
return try await tryRequest(
45+
token: token,
46+
partitionID: partitionID,
47+
context: context,
48+
input: input,
49+
next: next
50+
)
51+
} catch {
52+
throw SdkError<OutputError>.client(ClientError.retryError(error))
53+
}
54+
}
55+
56+
func tryRequest<H>(
57+
token: RetryToken,
58+
errorType: RetryError? = nil,
59+
partitionID: String,
60+
context: Context,
61+
input: SdkHttpRequestBuilder,
62+
next: H
63+
) async throws -> OperationOutput<Output> where
64+
H: Handler,
65+
Self.MInput == H.Input,
66+
Self.MOutput == H.Output,
67+
Self.Context == H.Context {
68+
69+
do {
70+
let serviceResponse = try await next.handle(context: context, input: input)
71+
retryer.recordSuccess(token: token)
72+
return serviceResponse
73+
} catch let error as SdkError<OutputError> where retryer.isErrorRetryable(error: error) {
74+
let errorType = retryer.getErrorType(error: error)
75+
let newToken = try await retryer.scheduleRetry(token: token, error: errorType)
76+
// TODO: rewind the stream once streaming is properly implemented
77+
return try await tryRequest(
78+
token: newToken,
79+
partitionID: partitionID,
80+
context: context,
81+
input: input,
82+
next: next
83+
)
84+
}
85+
}
86+
87+
public typealias MInput = SdkHttpRequestBuilder
88+
public typealias MOutput = OperationOutput<Output>
89+
public typealias Context = HttpContext
90+
}

Sources/ClientRuntime/Networking/Http/HttpContext.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ public struct HttpContext: MiddlewareContext {
4545
public func getLogger() -> LogAgent? {
4646
return attributes.get(key: AttributeKey<LogAgent>(name: "Logger"))
4747
}
48+
49+
/// The partition ID to be used for this context.
50+
///
51+
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
52+
/// If no partition ID is provided, requests will be partitioned based on the hostname.
53+
public func getPartitionID() -> String? {
54+
return attributes.get(key: AttributeKey<String>(name: "PartitionID"))
55+
}
4856
}
4957

5058
public class HttpContextBuilder {
@@ -63,6 +71,7 @@ public class HttpContextBuilder {
6371
let idempotencyTokenGenerator = AttributeKey<IdempotencyTokenGenerator>(name: "IdempotencyTokenGenerator")
6472
let hostPrefix = AttributeKey<String>(name: "HostPrefix")
6573
let logger = AttributeKey<LogAgent>(name: "Logger")
74+
let partitionID = AttributeKey<String>(name: "PartitionID")
6675

6776
// We follow the convention of returning the builder object
6877
// itself from any configuration methods, and by adding the
@@ -140,6 +149,18 @@ public class HttpContextBuilder {
140149
self.attributes.set(key: logger, value: value)
141150
return self
142151
}
152+
153+
/// Sets the partition ID on the context builder.
154+
///
155+
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
156+
/// If no partition ID is provided, requests will be partitioned based on the hostname.
157+
/// - Parameter value: The partition ID to be set on this builder, or `nil`.
158+
/// - Returns: `self`, after the partition ID is set as specified.
159+
@discardableResult
160+
public func withPartitionID(value: String?) -> HttpContextBuilder {
161+
self.attributes.set(key: partitionID, value: value)
162+
return self
163+
}
143164

144165
public func build() -> HttpContext {
145166
return HttpContext(attributes: attributes)

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ object ClientRuntimeTypes {
6262
val QueryItemMiddleware = runtimeSymbol("QueryItemMiddleware")
6363
val HeaderMiddleware = runtimeSymbol("HeaderMiddleware")
6464
val SerializableBodyMiddleware = runtimeSymbol("SerializableBodyMiddleware")
65+
val RetryerMiddleware = runtimeSymbol("RetryerMiddleware")
6566
val NoopHandler = runtimeSymbol("NoopHandler")
6667

6768
object Providers {

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInp
4747
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputQueryItemMiddleware
4848
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlHostMiddleware
4949
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlPathMiddleware
50+
import software.amazon.smithy.swift.codegen.integration.middlewares.RetryMiddleware
5051
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpBodyMiddleware
5152
import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpHeaderProvider
5253
import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpQueryItemProvider
@@ -404,6 +405,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
404405

405406
operationMiddleware.appendMiddleware(operation, LoggingMiddleware(ctx.model, ctx.symbolProvider))
406407
operationMiddleware.appendMiddleware(operation, DeserializeMiddleware(ctx.model, ctx.symbolProvider))
408+
operationMiddleware.appendMiddleware(operation, RetryMiddleware(ctx.model, ctx.symbolProvider))
407409

408410
addProtocolSpecificMiddleware(ctx, operation)
409411

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package software.amazon.smithy.swift.codegen.integration.middlewares
7+
8+
import software.amazon.smithy.codegen.core.SymbolProvider
9+
import software.amazon.smithy.model.Model
10+
import software.amazon.smithy.model.shapes.OperationShape
11+
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
12+
import software.amazon.smithy.swift.codegen.SwiftWriter
13+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
14+
import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
15+
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
16+
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep
17+
18+
class RetryMiddleware(
19+
val model: Model,
20+
val symbolProvider: SymbolProvider
21+
) : MiddlewareRenderable {
22+
23+
override val name = "RetryMiddleware"
24+
25+
override val middlewareStep = MiddlewareStep.FINALIZESTEP
26+
27+
override val position = MiddlewarePosition.AFTER
28+
29+
override fun render(writer: SwiftWriter, op: OperationShape, operationStackName: String) {
30+
val output = MiddlewareShapeUtils.outputSymbol(symbolProvider, model, op)
31+
val outputError = MiddlewareShapeUtils.outputErrorSymbol(op)
32+
writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<\$N, \$N>(retryer: config.retryer))", ClientRuntimeTypes.Middleware.RetryerMiddleware, output, outputError)
33+
}
34+
}

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class MiddlewareExecutionGenerator(
4949
writer.write(" .withOperation(value: \"${op.toLowerCamelCase()}\")")
5050
writer.write(" .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)")
5151
writer.write(" .withLogger(value: config.logger)")
52+
writer.write(" .withPartitionID(value: config.partitionID)")
5253

5354
val serviceShape = ctx.service
5455
httpProtocolCustomizable.renderContextAttributes(ctx, writer, serviceShape, op)

smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ContentMd5MiddlewareTests {
1919
.withOperation(value: "idempotencyTokenWithStructure")
2020
.withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)
2121
.withLogger(value: config.logger)
22+
.withPartitionID(value: config.partitionID)
2223
var operation = ClientRuntime.OperationStack<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(id: "idempotencyTokenWithStructure")
2324
operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput<IdempotencyTokenWithStructureOutputResponse> in
2425
let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator()
@@ -34,6 +35,7 @@ class ContentMd5MiddlewareTests {
3435
operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse>(contentType: "application/xml"))
3536
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse>(xmlName: "IdempotencyToken"))
3637
operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware())
38+
operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(retryer: config.retryer))
3739
operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(clientLogMode: config.clientLogMode))
3840
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>())
3941
let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler())

smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class HttpProtocolClientGeneratorTests {
128128
.withOperation(value: "allocateWidget")
129129
.withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)
130130
.withLogger(value: config.logger)
131+
.withPartitionID(value: config.partitionID)
131132
var operation = ClientRuntime.OperationStack<AllocateWidgetInput, AllocateWidgetOutputResponse, AllocateWidgetOutputError>(id: "allocateWidget")
132133
operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput<AllocateWidgetOutputResponse> in
133134
let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator()
@@ -142,6 +143,7 @@ class HttpProtocolClientGeneratorTests {
142143
operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware<AllocateWidgetInput, AllocateWidgetOutputResponse>(contentType: "application/json"))
143144
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware<AllocateWidgetInput, AllocateWidgetOutputResponse>(xmlName: "AllocateWidgetInput"))
144145
operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware())
146+
operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware<AllocateWidgetOutputResponse, AllocateWidgetOutputError>(retryer: config.retryer))
145147
operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware<AllocateWidgetOutputResponse, AllocateWidgetOutputError>(clientLogMode: config.clientLogMode))
146148
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware<AllocateWidgetOutputResponse, AllocateWidgetOutputError>())
147149
let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler())

0 commit comments

Comments
 (0)