Skip to content

Commit 818a5c7

Browse files
authored
feat: add support for requiresLength trait and Transfer-Encoding: Chunked (#604)
1 parent 62ad3a2 commit 818a5c7

File tree

8 files changed

+294
-26
lines changed

8 files changed

+294
-26
lines changed

Sources/ClientRuntime/Networking/Http/Middlewares/ContentLengthMiddleware.swift

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@ public struct ContentLengthMiddleware<OperationStackOutput: HttpResponseBinding>
66

77
private let contentLengthHeaderName = "Content-Length"
88

9-
public init() {}
9+
private var requiresLength: Bool = false
10+
11+
private var unsignedPayload: Bool = false
12+
13+
public init(requiresLength: Bool = false, unsignedPayload: Bool = false) {
14+
self.requiresLength = requiresLength
15+
self.unsignedPayload = unsignedPayload
16+
}
1017

1118
public func handle<H>(context: Context,
1219
input: MInput,
@@ -22,8 +29,16 @@ public struct ContentLengthMiddleware<OperationStackOutput: HttpResponseBinding>
2229
case .stream(let stream):
2330
if let length = stream.length {
2431
input.headers.update(name: "Content-Length", value: String(length))
32+
} else if !requiresLength && unsignedPayload {
33+
// only for HTTP/1.1 requests, will be removed in all HTTP/2 requests
34+
input.headers.update(name: "Transfer-Encoding", value: "Chunked")
2535
} else {
26-
input.headers.update(name: "Transfer-Encoded", value: "Chunked")
36+
let operation = context.attributes.get(key: AttributeKey<String>(name: "Operation"))
37+
?? "Error getting operation name"
38+
let errorMessage = unsignedPayload ?
39+
"Missing content-length for operation: \(operation)" :
40+
"Missing content-length for SigV4 signing on operation: \(operation)"
41+
throw StreamError.notSupported(errorMessage)
2742
}
2843
default:
2944
input.headers.update(name: "Content-Length", value: "0")

Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ extension SdkHttpRequest {
7575
httpRequest.path = [endpoint.path, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?")
7676
httpRequest.addHeaders(headers: headers.toHttpHeaders())
7777

78+
// Remove the "Transfer-Encoding" header if it exists since h2 does not support it
79+
httpRequest.removeHeader(name: "Transfer-Encoding")
80+
7881
// HTTP2Request used with manual writes hence we need to set the body to nil
7982
// so that CRT does not write the body for us (we will write it manually)
8083
httpRequest.body = nil
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0.
3+
4+
import XCTest
5+
import SmithyTestUtil
6+
@testable import ClientRuntime
7+
8+
class ContentLengthMiddlewareTests: XCTestCase {
9+
private var builtContext: HttpContext!
10+
private var stack: OperationStack<MockInput, MockOutput, MockMiddlewareError>!
11+
12+
override func setUpWithError() throws {
13+
try super.setUpWithError()
14+
builtContext = HttpContextBuilder()
15+
.withMethod(value: .get)
16+
.withPath(value: "/")
17+
.withEncoder(value: JSONEncoder())
18+
.withDecoder(value: JSONDecoder())
19+
.withOperation(value: "Test Operation")
20+
.build()
21+
stack = OperationStack<MockInput, MockOutput, MockMiddlewareError>(id: "Test Operation")
22+
}
23+
24+
func testTransferEncodingChunkedSetWhenStreamLengthIsNil() async throws {
25+
addContentLengthMiddlewareWith(requiresLength: false, unsignedPayload: true)
26+
forceEmptyStream()
27+
try await AssertHeadersArePresent(expectedHeaders: ["Transfer-Encoding": "Chunked"])
28+
}
29+
30+
func testContentLengthSetWhenStreamLengthAvailableAndRequiresLengthSet() async throws {
31+
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: false)
32+
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
33+
}
34+
35+
func testContentLengthSetWhenRequiresLengthAndUnsignedPayload() async throws {
36+
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: true)
37+
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
38+
}
39+
40+
func testRequiresLengthSetWithNilStreamShouldThrowError() async throws {
41+
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: false)
42+
forceEmptyStream()
43+
do {
44+
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
45+
XCTFail("Should throw error")
46+
} catch let error as StreamError {
47+
switch error {
48+
case .notSupported("Missing content-length for SigV4 signing on operation: Test Operation"), .notSupported("Missing content-length for operation: Test Operation"):
49+
// The error matches one of the expected cases, test passes
50+
break
51+
default:
52+
XCTFail("Error is not StreamError.notSupported with expected message")
53+
}
54+
}
55+
}
56+
57+
private func addContentLengthMiddlewareWith(requiresLength: Bool, unsignedPayload: Bool) {
58+
stack.finalizeStep.intercept(
59+
position: .before,
60+
middleware: ContentLengthMiddleware(requiresLength: requiresLength, unsignedPayload: unsignedPayload)
61+
)
62+
}
63+
64+
private func forceEmptyStream() {
65+
// Force stream length to be nil
66+
stack.finalizeStep.intercept(position: .before, id: "set nil stream length") { (context, input, next) -> OperationOutput<MockOutput> in
67+
input.body = .stream(BufferedStream()) // Set the stream length to nil
68+
return try await next.handle(context: context, input: input)
69+
}
70+
}
71+
72+
private func AssertHeadersArePresent(expectedHeaders: [String: String], file: StaticString = #file, line: UInt = #line) async throws -> Void {
73+
let mockHandler = MockHandler { (_, input) in
74+
for (key, value) in expectedHeaders {
75+
XCTAssert(input.headers.value(for: key) == value, file: file, line: line)
76+
}
77+
let httpResponse = HttpResponse(body: HttpBody.none, statusCode: HttpStatusCode.ok)
78+
let mockOutput = try! MockOutput(httpResponse: httpResponse, decoder: nil)
79+
let output = OperationOutput<MockOutput>(httpResponse: httpResponse, output: mockOutput)
80+
return output
81+
}
82+
83+
_ = try await stack.handleMiddleware(context: builtContext, input: MockInput(), next: mockHandler)
84+
}
85+
}

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
package software.amazon.smithy.swift.codegen.integration
66

7+
import software.amazon.smithy.aws.traits.auth.UnsignedPayloadTrait
78
import software.amazon.smithy.codegen.core.Symbol
89
import software.amazon.smithy.model.knowledge.HttpBinding
910
import software.amazon.smithy.model.knowledge.HttpBindingIndex
@@ -30,6 +31,7 @@ import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait
3031
import software.amazon.smithy.model.traits.HttpQueryParamsTrait
3132
import software.amazon.smithy.model.traits.HttpQueryTrait
3233
import software.amazon.smithy.model.traits.MediaTypeTrait
34+
import software.amazon.smithy.model.traits.RequiresLengthTrait
3335
import software.amazon.smithy.model.traits.StreamingTrait
3436
import software.amazon.smithy.model.traits.TimestampFormatTrait
3537
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
@@ -60,6 +62,7 @@ import software.amazon.smithy.swift.codegen.integration.serde.UnionEncodeGenerat
6062
import software.amazon.smithy.swift.codegen.middleware.OperationMiddlewareGenerator
6163
import software.amazon.smithy.swift.codegen.model.ShapeMetadata
6264
import software.amazon.smithy.swift.codegen.model.bodySymbol
65+
import software.amazon.smithy.swift.codegen.model.findStreamingMember
6366
import software.amazon.smithy.swift.codegen.model.hasEventStreamMember
6467
import software.amazon.smithy.swift.codegen.model.hasTrait
6568
import software.amazon.smithy.utils.OptionalUtils
@@ -91,9 +94,8 @@ fun formatHeaderOrQueryValue(
9194
memberShape: MemberShape,
9295
location: HttpBinding.Location,
9396
bindingIndex: HttpBindingIndex,
94-
defaultTimestampFormat: TimestampFormatTrait.Format
97+
defaultTimestampFormat: TimestampFormatTrait.Format,
9598
): Pair<String, Boolean> {
96-
9799
return when (val shape = ctx.model.expectShape(memberShape.target)) {
98100
is TimestampShape -> {
99101
val timestampFormat = bindingIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat)
@@ -165,7 +167,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
165167
writer.openBlock(
166168
"extension $symbolName: \$N {",
167169
"}",
168-
SwiftTypes.Protocols.Encodable
170+
SwiftTypes.Protocols.Encodable,
169171
) {
170172
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
171173

@@ -286,7 +288,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
286288
private fun generateCodingKeysForMembers(
287289
ctx: ProtocolGenerator.GenerationContext,
288290
writer: SwiftWriter,
289-
members: List<MemberShape>
291+
members: List<MemberShape>,
290292
) {
291293
codingKeysGenerator.generateCodingKeysForMembers(ctx, writer, members)
292294
}
@@ -298,7 +300,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
298300
val inputType = ctx.model.expectShape(operation.input.get())
299301
var metadata = mapOf<ShapeMetadata, Any>(
300302
Pair(ShapeMetadata.OPERATION_SHAPE, operation),
301-
Pair(ShapeMetadata.SERVICE_VERSION, ctx.service.version)
303+
Pair(ShapeMetadata.SERVICE_VERSION, ctx.service.version),
302304
)
303305
shapesInfo.put(inputType, metadata)
304306
}
@@ -336,7 +338,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
336338
}
337339

338340
private fun resolveShapesNeedingCodableConformance(ctx: ProtocolGenerator.GenerationContext): Set<Shape> {
339-
340341
val topLevelOutputMembers = getHttpBindingOperations(ctx).flatMap {
341342
val outputShape = ctx.model.expectShape(it.output.get())
342343
outputShape.members()
@@ -390,7 +391,8 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
390391
RelationshipType.LIST_MEMBER,
391392
RelationshipType.SET_MEMBER,
392393
RelationshipType.MAP_VALUE,
393-
RelationshipType.UNION_MEMBER -> true
394+
RelationshipType.UNION_MEMBER,
395+
-> true
394396
else -> false
395397
}
396398
}.forEach {
@@ -403,6 +405,29 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
403405
return resolved
404406
}
405407

408+
// Checks for @requiresLength trait
409+
// Returns true if the operation:
410+
// - has a streaming member with @httpPayload trait
411+
// - target is a blob shape with @requiresLength trait
412+
private fun hasRequiresLengthTrait(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean {
413+
if (op.input.isPresent) {
414+
val inputShape = ctx.model.expectShape(op.input.get())
415+
val streamingMember = inputShape.findStreamingMember(ctx.model)
416+
if (streamingMember != null) {
417+
val targetShape = ctx.model.expectShape(streamingMember.target)
418+
if (targetShape != null) {
419+
return streamingMember.hasTrait<HttpPayloadTrait>() &&
420+
targetShape.isBlobShape &&
421+
targetShape.hasTrait<RequiresLengthTrait>()
422+
}
423+
}
424+
}
425+
return false
426+
}
427+
428+
// Checks for @unsignedPayload trait on an operation
429+
private fun hasUnsignedPayloadTrait(op: OperationShape): Boolean = op.hasTrait<UnsignedPayloadTrait>()
430+
406431
override fun generateProtocolClient(ctx: ProtocolGenerator.GenerationContext) {
407432
val symbol = ctx.symbolProvider.toSymbol(ctx.service)
408433
ctx.delegator.useFileWriter("./${ctx.settings.moduleName}/${symbol.name}.swift") { writer ->
@@ -414,7 +439,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
414439
serviceSymbol.name,
415440
defaultContentType,
416441
httpProtocolCustomizable,
417-
operationMiddleware
442+
operationMiddleware,
418443
)
419444
clientGenerator.render()
420445
}
@@ -433,7 +458,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
433458
operationMiddleware.appendMiddleware(operation, ContentTypeMiddleware(ctx.model, ctx.symbolProvider, resolver.determineRequestContentType(operation)))
434459
operationMiddleware.appendMiddleware(operation, OperationInputBodyMiddleware(ctx.model, ctx.symbolProvider))
435460

436-
operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware(ctx.model, shouldRenderEncodableConformance))
461+
operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware(ctx.model, shouldRenderEncodableConformance, hasRequiresLengthTrait(ctx, operation), hasUnsignedPayloadTrait(operation)))
437462

438463
operationMiddleware.appendMiddleware(operation, DeserializeMiddleware(ctx.model, ctx.symbolProvider))
439464
operationMiddleware.appendMiddleware(operation, LoggingMiddleware(ctx.model, ctx.symbolProvider))
@@ -463,15 +488,15 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
463488
members: List<MemberShape>,
464489
writer: SwiftWriter,
465490
defaultTimestampFormat: TimestampFormatTrait.Format,
466-
path: String? = null
491+
path: String? = null,
467492
)
468493
protected abstract fun renderStructDecode(
469494
ctx: ProtocolGenerator.GenerationContext,
470495
shapeMetaData: Map<ShapeMetadata, Any>,
471496
members: List<MemberShape>,
472497
writer: SwiftWriter,
473498
defaultTimestampFormat: TimestampFormatTrait.Format,
474-
path: String
499+
path: String,
475500
)
476501
protected abstract fun addProtocolSpecificMiddleware(ctx: ProtocolGenerator.GenerationContext, operation: OperationShape)
477502

@@ -487,11 +512,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
487512
for (operation in topDownIndex.getContainedOperations(ctx.service)) {
488513
OptionalUtils.ifPresentOrElse(
489514
Optional.of(getProtocolHttpBindingResolver(ctx, defaultContentType).httpTrait(operation)::class.java),
490-
{ containedOperations.add(operation) }
515+
{ containedOperations.add(operation) },
491516
) {
492517
LOGGER.warning(
493518
"Unable to fetch $protocolName protocol request bindings for ${operation.id} because " +
494-
"it does not have an http binding trait"
519+
"it does not have an http binding trait",
495520
)
496521
}
497522
}

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/ContentLengthMiddleware.kt

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
99
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
1010
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep
1111

12-
class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boolean) : MiddlewareRenderable {
12+
class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boolean, private val requiresLength: Boolean, private val unsignedPayload: Boolean) : MiddlewareRenderable {
1313

1414
override val name = "ContentLengthMiddleware"
1515

@@ -20,17 +20,17 @@ class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boo
2020
override fun render(
2121
writer: SwiftWriter,
2222
op: OperationShape,
23-
operationStackName: String
23+
operationStackName: String,
2424
) {
2525
val hasHttpBody = MiddlewareShapeUtils.hasHttpBody(model, op)
2626
if (hasHttpBody || alwaysIntercept) {
27-
writer.write(
28-
"\$L.\$L.intercept(position: \$L, middleware: \$N())",
29-
operationStackName,
30-
middlewareStep.stringValue(),
31-
position.stringValue(),
32-
ClientRuntimeTypes.Middleware.ContentLengthMiddleware
33-
)
27+
val str = "requiresLength: $requiresLength, unsignedPayload: $unsignedPayload"
28+
val middlewareArgs = str.takeIf { requiresLength || unsignedPayload } ?: ""
29+
30+
val interceptStatement = "$operationStackName.${middlewareStep.stringValue()}.intercept(" +
31+
"position: ${position.stringValue()}, middleware: ${ClientRuntimeTypes.Middleware.ContentLengthMiddleware}($middlewareArgs))"
32+
33+
writer.write(interceptStatement)
3434
}
3535
}
3636
}

smithy-swift-codegen/src/test/kotlin/HttpBindingProtocolGeneratorTests.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class TestHttpProtocolClientGeneratorFactory : HttpProtocolClientGeneratorFactor
3838
private fun getClientProperties(ctx: ProtocolGenerator.GenerationContext): List<ClientProperty> {
3939
return mutableListOf(
4040
DefaultRequestEncoder(),
41-
DefaultResponseDecoder()
41+
DefaultResponseDecoder(),
4242
)
4343
}
4444

@@ -125,6 +125,7 @@ extension InlineDocumentAsPayloadOutput: ClientRuntime.HttpResponseBinding {
125125
""".trimIndent()
126126
contents.shouldContainOnlyOnce(expectedContents)
127127
}
128+
128129
@Test
129130
fun `default fooMap to an empty map if keysForFooMap is empty`() {
130131
val contents = getModelFileContents("example", "HttpPrefixHeadersOutput+HttpResponseBinding.swift", newTestContext.manifest)

0 commit comments

Comments
 (0)