Skip to content

Commit 55aa8c3

Browse files
authored
chore: refactor middleware handler generation (#371)
1 parent 5fea7db commit 55aa8c3

22 files changed

+261
-205
lines changed

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ServiceGenerator.kt

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ import software.amazon.smithy.model.shapes.OperationShape
1515
import software.amazon.smithy.model.shapes.StructureShape
1616
import software.amazon.smithy.model.traits.StreamingTrait
1717
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
18+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
1819
import software.amazon.smithy.swift.codegen.model.camelCaseName
19-
import software.amazon.smithy.swift.codegen.model.capitalizedName
2020

2121
/*
2222
* Generates a Swift protocol for the service
@@ -76,49 +76,24 @@ class ServiceGenerator(
7676
fun renderAsyncOperationDefinition(model: Model, symbolProvider: SymbolProvider, writer: SwiftWriter, opIndex: OperationIndex, op: OperationShape) {
7777
if (!op.input.isPresent || !op.output.isPresent) throw CodegenException("model should have been preprocessed to ensure operations always have an input or output shape: $op.id")
7878

79-
val operationName = op.camelCaseName()
80-
val inputShape = opIndex.getInput(op).get()
81-
val inputSymbolName = symbolProvider.toSymbol(inputShape).name
79+
val inputSymbolName = MiddlewareShapeUtils.inputSymbol(symbolProvider, model, op).name
8280
val inputParam = "input: $inputSymbolName"
83-
84-
val outputType = getOperationOutputShapeName(symbolProvider, opIndex, op)
81+
val outputType = MiddlewareShapeUtils.outputSymbol(symbolProvider, model, op).name
8582

8683
writer.writeShapeDocs(op)
8784
writer.writeAvailableAttribute(model, op)
8885

86+
val operationName = op.camelCaseName()
8987
writer.write("func \$L(\$L) async throws -> \$L", operationName, inputParam, outputType)
9088
}
9189

9290
fun createOutputType(opIndex: OperationIndex, op: OperationShape, symbolProvider: SymbolProvider): String {
9391
val outputShape = opIndex.getOutput(op).get()
9492
val outputShapeName = symbolProvider.toSymbol(outputShape).name
95-
val errorTypeName = getOperationErrorShapeName(op)
93+
val errorTypeName = MiddlewareShapeUtils.outputErrorSymbolName(op)
9694

9795
return "${ClientRuntimeTypes.Core.SdkResult}<$outputShapeName, $errorTypeName>"
9896
}
99-
100-
fun getOperationInputShapeName(symbolProvider: SymbolProvider, opIndex: OperationIndex, op: OperationShape): String {
101-
val inputShape = opIndex.getInput(op).get()
102-
return symbolProvider.toSymbol(inputShape).name
103-
}
104-
fun getOperationInputShapeName(symbolProvider: SymbolProvider, model: Model, op: OperationShape): String {
105-
val inputShape = model.expectShape(op.input.get())
106-
return symbolProvider.toSymbol(inputShape).name
107-
}
108-
109-
fun getOperationOutputShapeName(symbolProvider: SymbolProvider, opIndex: OperationIndex, op: OperationShape): String {
110-
val outputShape = opIndex.getOutput(op).get()
111-
return symbolProvider.toSymbol(outputShape).name
112-
}
113-
114-
fun getOperationOutputShapeName(symbolProvider: SymbolProvider, model: Model, op: OperationShape): String {
115-
val outputShape = model.expectShape(op.output.get())
116-
return symbolProvider.toSymbol(outputShape).name
117-
}
118-
119-
fun getOperationErrorShapeName(op: OperationShape): String {
120-
return "${op.capitalizedName()}OutputError"
121-
}
12297
}
12398

12499
fun render() {
@@ -183,7 +158,7 @@ class ServiceGenerator(
183158
op: OperationShape
184159
) {
185160
val errorShapes = op.errors.map { model.expectShape(it) as StructureShape }.toSet().sorted()
186-
val operationErrorName = getOperationErrorShapeName(op)
161+
val operationErrorName = MiddlewareShapeUtils.outputErrorSymbolName(op)
187162
val operationErrorSymbol = Symbol.builder()
188163
.definitionFile("./$rootNamespace/models/$operationErrorName.swift")
189164
.name(operationErrorName)

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt

Lines changed: 11 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package software.amazon.smithy.swift.codegen.integration
77
import software.amazon.smithy.codegen.core.Symbol
88
import software.amazon.smithy.model.knowledge.HttpBinding
99
import software.amazon.smithy.model.knowledge.HttpBindingIndex
10-
import software.amazon.smithy.model.knowledge.OperationIndex
1110
import software.amazon.smithy.model.knowledge.TopDownIndex
1211
import software.amazon.smithy.model.neighbor.RelationshipType
1312
import software.amazon.smithy.model.neighbor.Walker
@@ -31,9 +30,6 @@ import software.amazon.smithy.model.traits.HttpQueryTrait
3130
import software.amazon.smithy.model.traits.MediaTypeTrait
3231
import software.amazon.smithy.model.traits.TimestampFormatTrait
3332
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
34-
import software.amazon.smithy.swift.codegen.Middleware
35-
import software.amazon.smithy.swift.codegen.MiddlewareGenerator
36-
import software.amazon.smithy.swift.codegen.ServiceGenerator
3733
import software.amazon.smithy.swift.codegen.SwiftDependency
3834
import software.amazon.smithy.swift.codegen.SwiftTypes
3935
import software.amazon.smithy.swift.codegen.SwiftWriter
@@ -49,13 +45,16 @@ import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInp
4945
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputHeadersMiddleware
5046
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputQueryItemMiddleware
5147
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlPathMiddleware
48+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpBodyMiddleware
49+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpHeaderMiddleware
50+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpQueryItemMiddleware
51+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpUrlPathMiddleware
5252
import software.amazon.smithy.swift.codegen.integration.serde.DynamicNodeDecodingGeneratorStrategy
5353
import software.amazon.smithy.swift.codegen.integration.serde.UnionDecodeGeneratorStrategy
5454
import software.amazon.smithy.swift.codegen.integration.serde.UnionEncodeGeneratorStrategy
5555
import software.amazon.smithy.swift.codegen.middleware.OperationMiddlewareGenerator
5656
import software.amazon.smithy.swift.codegen.model.ShapeMetadata
5757
import software.amazon.smithy.swift.codegen.model.bodySymbol
58-
import software.amazon.smithy.swift.codegen.model.capitalizedName
5958
import software.amazon.smithy.utils.OptionalUtils
6059
import java.util.Optional
6160
import java.util.logging.Logger
@@ -128,11 +127,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
128127
// The input shape is referenced by more than one operation
129128
continue
130129
}
131-
renderUrlPathMiddleware(ctx, operation)
132-
renderHeaderMiddleware(ctx, operation)
133-
renderQueryMiddleware(ctx, operation)
134-
renderBodyMiddleware(ctx, operation)
135-
130+
val httpBindingResolver = getProtocolHttpBindingResolver(ctx, defaultContentType)
131+
HttpUrlPathMiddleware.renderUrlPathMiddleware(ctx, operation, httpBindingResolver)
132+
HttpHeaderMiddleware.renderHeaderMiddleware(ctx, operation, httpBindingResolver, defaultTimestampFormat)
133+
HttpQueryItemMiddleware.renderQueryMiddleware(ctx, operation, httpBindingResolver, defaultTimestampFormat)
134+
HttpBodyMiddleware.renderBodyMiddleware(ctx, operation, httpBindingResolver, httpProtocolBodyMiddleware())
136135
inputShapesWithHttpBindings.add(inputShapeId)
137136
}
138137
}
@@ -359,115 +358,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
359358
return resolved
360359
}
361360

362-
private fun renderHeaderMiddleware(
363-
ctx: ProtocolGenerator.GenerationContext,
364-
op: OperationShape
365-
) {
366-
val opIndex = OperationIndex.of(ctx.model)
367-
val httpBindingResolver = getProtocolHttpBindingResolver(ctx, defaultContentType)
368-
val requestBindings = httpBindingResolver.requestBindings(op)
369-
val inputShape = opIndex.getInput(op).get()
370-
val outputShape = opIndex.getOutput(op).get()
371-
val operationErrorName = "${op.capitalizedName()}OutputError"
372-
val inputSymbol = ctx.symbolProvider.toSymbol(inputShape)
373-
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)
374-
val outputErrorSymbol = Symbol.builder().name(operationErrorName).build()
375-
376-
val headerBindings = requestBindings
377-
.filter { it.location == HttpBinding.Location.HEADER }
378-
.sortedBy { it.memberName }
379-
val prefixHeaderBindings = requestBindings
380-
.filter { it.location == HttpBinding.Location.PREFIX_HEADERS }
381-
382-
val rootNamespace = ctx.settings.moduleName
383-
val headerMiddlewareSymbol = Symbol.builder()
384-
.definitionFile("./$rootNamespace/models/${inputSymbol.name}+HeaderMiddleware.swift")
385-
.name(inputSymbol.name)
386-
.build()
387-
ctx.delegator.useShapeWriter(headerMiddlewareSymbol) { writer ->
388-
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
389-
val headerMiddleware = HttpHeaderMiddleware(writer, ctx, inputSymbol, outputSymbol, outputErrorSymbol, headerBindings, prefixHeaderBindings, defaultTimestampFormat)
390-
MiddlewareGenerator(writer, headerMiddleware).generate()
391-
}
392-
}
393-
394-
private fun renderQueryMiddleware(ctx: ProtocolGenerator.GenerationContext, op: OperationShape) {
395-
val opIndex = OperationIndex.of(ctx.model)
396-
val httpBindingResolver = getProtocolHttpBindingResolver(ctx, defaultContentType)
397-
val httpTrait = httpBindingResolver.httpTrait(op)
398-
val requestBindings = httpBindingResolver.requestBindings(op)
399-
val inputShape = opIndex.getInput(op).get()
400-
val outputShape = opIndex.getOutput(op).get()
401-
val operationErrorName = "${op.capitalizedName()}OutputError"
402-
val inputSymbol = ctx.symbolProvider.toSymbol(inputShape)
403-
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)
404-
val outputErrorSymbol = Symbol.builder().name(operationErrorName).build()
405-
val queryBindings = requestBindings.filter { it.location == HttpBinding.Location.QUERY || it.location == HttpBinding.Location.QUERY_PARAMS }
406-
val queryLiterals = httpTrait.uri.queryLiterals
407-
408-
val rootNamespace = ctx.settings.moduleName
409-
val headerMiddlewareSymbol = Symbol.builder()
410-
.definitionFile("./$rootNamespace/models/${inputSymbol.name}+QueryItemMiddleware.swift")
411-
.name(inputSymbol.name)
412-
.build()
413-
ctx.delegator.useShapeWriter(headerMiddlewareSymbol) { writer ->
414-
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
415-
val queryItemMiddleware = HttpQueryItemMiddleware(ctx, inputSymbol, outputSymbol, outputErrorSymbol, queryLiterals, queryBindings, defaultTimestampFormat, writer)
416-
MiddlewareGenerator(writer, queryItemMiddleware).generate()
417-
}
418-
}
419-
420-
private fun renderUrlPathMiddleware(ctx: ProtocolGenerator.GenerationContext, op: OperationShape) {
421-
val opIndex = OperationIndex.of(ctx.model)
422-
val httpBindingResolver = getProtocolHttpBindingResolver(ctx, defaultContentType)
423-
val httpTrait = httpBindingResolver.httpTrait(op)
424-
val requestBindings = httpBindingResolver.requestBindings(op)
425-
val pathBindings = requestBindings.filter { it.location == HttpBinding.Location.LABEL }
426-
val inputShape = opIndex.getInput(op).get()
427-
val outputShape = opIndex.getOutput(op).get()
428-
val operationErrorName = ServiceGenerator.getOperationErrorShapeName(op)
429-
val inputSymbol = ctx.symbolProvider.toSymbol(inputShape)
430-
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)
431-
val outputErrorSymbol = Symbol.builder().name(operationErrorName).build()
432-
433-
val rootNamespace = ctx.settings.moduleName
434-
val urlPathMiddlewareSymbol = Symbol.builder()
435-
.definitionFile("./$rootNamespace/models/${inputSymbol.name}+UrlPathMiddleware.swift")
436-
.name(inputSymbol.name)
437-
.build()
438-
ctx.delegator.useShapeWriter(urlPathMiddlewareSymbol) { writer ->
439-
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
440-
val urlPathMiddleware = HttpUrlPathMiddleware(ctx, inputSymbol, outputSymbol, outputErrorSymbol, httpTrait, pathBindings, writer)
441-
MiddlewareGenerator(writer, urlPathMiddleware).generate()
442-
}
443-
}
444-
445-
private fun renderBodyMiddleware(ctx: ProtocolGenerator.GenerationContext, op: OperationShape) {
446-
val opIndex = OperationIndex.of(ctx.model)
447-
val inputShape = opIndex.getInput(op).get()
448-
449-
if (shouldRenderHttpBodyMiddleware(inputShape)) {
450-
val rootNamespace = ctx.settings.moduleName
451-
val inputSymbol = ctx.symbolProvider.toSymbol(inputShape)
452-
val headerMiddlewareSymbol = Symbol.builder()
453-
.definitionFile("./$rootNamespace/models/${inputSymbol.name}+BodyMiddleware.swift")
454-
.name(inputSymbol.name)
455-
.build()
456-
ctx.delegator.useShapeWriter(headerMiddlewareSymbol) { writer ->
457-
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
458-
val outputShape = opIndex.getOutput(op).get()
459-
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)
460-
val operationErrorName = "${op.capitalizedName()}OutputError"
461-
val outputErrorSymbol = Symbol.builder().name(operationErrorName).build()
462-
val httpBindingResolver = getProtocolHttpBindingResolver(ctx, defaultContentType)
463-
val requestBindings = httpBindingResolver.requestBindings(op)
464-
val bodyMiddleware = httpBodyMiddleware(writer, ctx, inputSymbol, outputSymbol, outputErrorSymbol, requestBindings)
465-
466-
MiddlewareGenerator(writer, bodyMiddleware).generate()
467-
}
468-
}
469-
}
470-
471361
override fun generateProtocolClient(ctx: ProtocolGenerator.GenerationContext) {
472362
val symbol = ctx.symbolProvider.toSymbol(ctx.service)
473363
ctx.delegator.useFileWriter("./${ctx.settings.moduleName}/${symbol.name}.swift") { writer ->
@@ -536,19 +426,8 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
536426
)
537427
protected abstract fun addProtocolSpecificMiddleware(ctx: ProtocolGenerator.GenerationContext, operation: OperationShape)
538428

539-
open fun shouldRenderHttpBodyMiddleware(shape: Shape): Boolean {
540-
return shape.members().filter { it.isInHttpBody() }.count() > 0
541-
}
542-
543-
open fun httpBodyMiddleware(
544-
writer: SwiftWriter,
545-
ctx: ProtocolGenerator.GenerationContext,
546-
inputSymbol: Symbol,
547-
outputSymbol: Symbol,
548-
outputErrorSymbol: Symbol,
549-
requestBindings: List<HttpBindingDescriptor>
550-
): Middleware {
551-
return HttpBodyMiddleware(writer, ctx, inputSymbol, outputSymbol, outputErrorSymbol, requestBindings)
429+
open fun httpProtocolBodyMiddleware(): HttpProtocolBodyMiddlewareGeneratorFactory {
430+
return DefaultHttpProtocolBodyMiddlewareGeneratorFactory()
552431
}
553432

554433
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package software.amazon.smithy.swift.codegen.integration
2+
3+
import software.amazon.smithy.codegen.core.Symbol
4+
import software.amazon.smithy.model.shapes.Shape
5+
import software.amazon.smithy.swift.codegen.Middleware
6+
import software.amazon.smithy.swift.codegen.SwiftWriter
7+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpBodyMiddleware
8+
9+
interface HttpProtocolBodyMiddlewareGeneratorFactory {
10+
fun shouldRenderHttpBodyMiddleware(shape: Shape): Boolean
11+
12+
fun httpBodyMiddleware(
13+
writer: SwiftWriter,
14+
ctx: ProtocolGenerator.GenerationContext,
15+
inputSymbol: Symbol,
16+
outputSymbol: Symbol,
17+
outputErrorSymbol: Symbol,
18+
requestBindings: List<HttpBindingDescriptor>
19+
): Middleware
20+
}
21+
22+
class DefaultHttpProtocolBodyMiddlewareGeneratorFactory : HttpProtocolBodyMiddlewareGeneratorFactory {
23+
override fun shouldRenderHttpBodyMiddleware(shape: Shape): Boolean {
24+
return shape.members().filter { it.isInHttpBody() }.count() > 0
25+
}
26+
27+
override fun httpBodyMiddleware(
28+
writer: SwiftWriter,
29+
ctx: ProtocolGenerator.GenerationContext,
30+
inputSymbol: Symbol,
31+
outputSymbol: Symbol,
32+
outputErrorSymbol: Symbol,
33+
requestBindings: List<HttpBindingDescriptor>
34+
): Middleware {
35+
return HttpBodyMiddleware(writer, ctx, inputSymbol, outputSymbol, outputErrorSymbol, requestBindings)
36+
}
37+
}

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolClientGenerator.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import software.amazon.smithy.swift.codegen.ServiceGenerator
1414
import software.amazon.smithy.swift.codegen.SwiftDependency
1515
import software.amazon.smithy.swift.codegen.SwiftTypes
1616
import software.amazon.smithy.swift.codegen.SwiftWriter
17+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
1718
import software.amazon.smithy.swift.codegen.middleware.MiddlewareExecutionGenerator
1819
import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware
1920
import software.amazon.smithy.swift.codegen.model.camelCaseName
@@ -95,7 +96,7 @@ open class HttpProtocolClientGenerator(
9596
private fun renderContinuation(opIndex: OperationIndex, op: OperationShape, writer: SwiftWriter) {
9697
val operationName = op.camelCaseName()
9798
val continuationName = "${operationName}Continuation"
98-
writer.write("typealias $continuationName = CheckedContinuation<${ServiceGenerator.getOperationOutputShapeName(ctx.symbolProvider, opIndex, op)}, \$N>", SwiftTypes.Error)
99+
writer.write("typealias $continuationName = CheckedContinuation<${MiddlewareShapeUtils.outputSymbol(ctx.symbolProvider, opIndex, op).name}, \$N>", SwiftTypes.Error)
99100
writer.openBlock("return try await withCheckedThrowingContinuation { (continuation: $continuationName) in", "}") {
100101
writer.openBlock("$operationName(input: input) { result in", "}") {
101102
writer.openBlock("switch result {", "}") {

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/httpResponse/HttpResponseBindingErrorNarrowGenerator.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ import software.amazon.smithy.codegen.core.Symbol
1010
import software.amazon.smithy.model.shapes.OperationShape
1111
import software.amazon.smithy.model.shapes.StructureShape
1212
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
13-
import software.amazon.smithy.swift.codegen.ServiceGenerator
1413
import software.amazon.smithy.swift.codegen.SwiftDependency
1514
import software.amazon.smithy.swift.codegen.SwiftTypes
1615
import software.amazon.smithy.swift.codegen.declareSection
1716
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
1817
import software.amazon.smithy.swift.codegen.integration.SectionId
18+
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
1919
import software.amazon.smithy.swift.codegen.model.getTrait
2020

2121
class HttpResponseBindingErrorNarrowGenerator(
@@ -27,7 +27,7 @@ class HttpResponseBindingErrorNarrowGenerator(
2727

2828
fun render() {
2929
val errorShapes = op.errors.map { ctx.model.expectShape(it) as StructureShape }.toSet().sorted()
30-
val operationErrorName = ServiceGenerator.getOperationErrorShapeName(op)
30+
val operationErrorName = MiddlewareShapeUtils.outputErrorSymbolName(op)
3131
val rootNamespace = ctx.settings.moduleName
3232
val httpBindingSymbol = Symbol.builder()
3333
.definitionFile("./$rootNamespace/models/$operationErrorName+HttpResponseBinding.swift")

0 commit comments

Comments
 (0)