Skip to content

Commit f6fb6ad

Browse files
authored
feat: add async operations in a client extension (#289)
1 parent d2cda87 commit f6fb6ad

File tree

3 files changed

+196
-197
lines changed

3 files changed

+196
-197
lines changed

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ServiceGenerator.kt

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,8 @@ class ServiceGenerator(
5353
val inputShapeName = symbolProvider.toSymbol(inputShape).name
5454
val inputParam = "input: $inputShapeName"
5555

56-
val outputShape = opIndex.getOutput(op).get()
57-
val outputShapeName = symbolProvider.toSymbol(outputShape).name
58-
val errorTypeName = getOperationErrorShapeName(op)
59-
60-
val outputParam = "completion: @escaping (SdkResult<$outputShapeName, $errorTypeName>) -> Void"
56+
val outputType = createOutputType(opIndex, op, symbolProvider)
57+
val outputParam = "completion: @escaping ($outputType) -> Void"
6158

6259
val paramTerminator = ", "
6360

@@ -74,6 +71,30 @@ class ServiceGenerator(
7471
)
7572
}
7673

74+
fun renderAsyncOperationDefinition(model: Model, symbolProvider: SymbolProvider, writer: SwiftWriter, opIndex: OperationIndex, op: OperationShape) {
75+
if (!op.input.isPresent || !op.output.isPresent) throw CodegenException("model should have been preprocessed to ensure operations always have an input or output shape: $op.id")
76+
77+
val operationName = op.camelCaseName()
78+
val inputShape = opIndex.getInput(op).get()
79+
val inputSymbolName = symbolProvider.toSymbol(inputShape).name
80+
val inputParam = "input: $inputSymbolName"
81+
82+
val outputType = getOperationOutputShapeName(symbolProvider, opIndex, op)
83+
84+
writer.writeShapeDocs(op)
85+
writer.writeAvailableAttribute(model, op)
86+
87+
writer.write("func \$L(\$L) async throws -> \$L", operationName, inputParam, outputType)
88+
}
89+
90+
fun createOutputType(opIndex: OperationIndex, op: OperationShape, symbolProvider: SymbolProvider): String {
91+
val outputShape = opIndex.getOutput(op).get()
92+
val outputShapeName = symbolProvider.toSymbol(outputShape).name
93+
val errorTypeName = getOperationErrorShapeName(op)
94+
95+
return "SdkResult<$outputShapeName, $errorTypeName>"
96+
}
97+
7798
fun getOperationInputShapeName(symbolProvider: SymbolProvider, opIndex: OperationIndex, op: OperationShape): String {
7899
val inputShape = opIndex.getInput(op).get()
79100
return symbolProvider.toSymbol(inputShape).name
@@ -160,7 +181,7 @@ class ServiceGenerator(
160181

161182
delegator.useShapeWriter(operationErrorSymbol) { writer ->
162183
writer.addImport(unknownServiceErrorSymbol)
163-
writer.openBlock("public enum $operationErrorName: Equatable {", "}") {
184+
writer.openBlock("public enum $operationErrorName: Swift.Error, Equatable {", "}") {
164185
for (errorShape in errorShapes) {
165186
val errorShapeName = symbolProvider.toSymbol(errorShape).name
166187
writer.write("case \$L(\$L)", errorShapeName.decapitalize(), errorShapeName)

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolClientGenerator.kt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ open class HttpProtocolClientGenerator(
5252
httpProtocolServiceClient.render(serviceSymbol)
5353
writer.write("")
5454
renderOperationsInExtension(serviceSymbol)
55+
val rootNamespace = ctx.settings.moduleName
56+
ctx.delegator.useFileWriter("./$rootNamespace/${serviceSymbol.name}+Async.swift") {
57+
it.write("#if swift(>=5.5)")
58+
it.addImport(SwiftDependency.CLIENT_RUNTIME.target)
59+
renderAsyncOperationsInExtension(serviceSymbol, it)
60+
it.write("#endif")
61+
}
5562
}
5663

5764
private fun renderOperationsInExtension(serviceSymbol: Symbol) {
@@ -70,6 +77,38 @@ open class HttpProtocolClientGenerator(
7077
}
7178
}
7279

80+
private fun renderAsyncOperationsInExtension(serviceSymbol: Symbol, writer: SwiftWriter) {
81+
val topDownIndex = TopDownIndex.of(model)
82+
val operations = topDownIndex.getContainedOperations(serviceShape).sortedBy { it.capitalizedName() }
83+
val operationsIndex = OperationIndex.of(model)
84+
writer.write("@available(macOS 12.0, iOS 15.0, tvOS 15.0, watchOS 8.0, macCatalyst 15.0, *)")
85+
writer.openBlock("public extension ${serviceSymbol.name} {", "}") {
86+
operations.forEach {
87+
ServiceGenerator.renderAsyncOperationDefinition(model, symbolProvider, writer, operationsIndex, it)
88+
writer.openBlock("{", "}") {
89+
renderContinuation(operationsIndex, it, writer)
90+
}
91+
writer.write("")
92+
}
93+
}
94+
}
95+
96+
private fun renderContinuation(opIndex: OperationIndex, op: OperationShape, writer: SwiftWriter) {
97+
val operationName = op.camelCaseName()
98+
val continuationName = "${operationName}Continuation"
99+
writer.write("typealias $continuationName = CheckedContinuation<${ServiceGenerator.getOperationOutputShapeName(ctx.symbolProvider, opIndex, op)}, Swift.Error>")
100+
writer.openBlock("return try await withCheckedThrowingContinuation { (continuation: $continuationName) in", "}") {
101+
writer.openBlock("$operationName(input: input) { result in", "}") {
102+
writer.openBlock("switch result {", "}") {
103+
writer.write("case .success(let output):")
104+
writer.indent().write("continuation.resume(returning: output)").dedent()
105+
writer.write("case .failure(let error):")
106+
writer.indent().write("continuation.resume(throwing: error)").dedent()
107+
}
108+
}
109+
}
110+
}
111+
73112
// replace labels with any path bindings
74113
private fun renderUriPath(httpTrait: HttpTrait, pathBindings: List<HttpBindingDescriptor>, writer: SwiftWriter) {
75114
val resolvedURIComponents = mutableListOf<String>()

0 commit comments

Comments
 (0)