44 */
55package software.amazon.smithy.swift.codegen.integration
66
7+ import software.amazon.smithy.aws.traits.auth.UnsignedPayloadTrait
78import software.amazon.smithy.codegen.core.Symbol
89import software.amazon.smithy.model.knowledge.HttpBinding
910import software.amazon.smithy.model.knowledge.HttpBindingIndex
@@ -30,6 +31,7 @@ import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait
3031import software.amazon.smithy.model.traits.HttpQueryParamsTrait
3132import software.amazon.smithy.model.traits.HttpQueryTrait
3233import software.amazon.smithy.model.traits.MediaTypeTrait
34+ import software.amazon.smithy.model.traits.RequiresLengthTrait
3335import software.amazon.smithy.model.traits.StreamingTrait
3436import software.amazon.smithy.model.traits.TimestampFormatTrait
3537import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
@@ -60,6 +62,7 @@ import software.amazon.smithy.swift.codegen.integration.serde.UnionEncodeGenerat
6062import software.amazon.smithy.swift.codegen.middleware.OperationMiddlewareGenerator
6163import software.amazon.smithy.swift.codegen.model.ShapeMetadata
6264import software.amazon.smithy.swift.codegen.model.bodySymbol
65+ import software.amazon.smithy.swift.codegen.model.findStreamingMember
6366import software.amazon.smithy.swift.codegen.model.hasEventStreamMember
6467import software.amazon.smithy.swift.codegen.model.hasTrait
6568import software.amazon.smithy.utils.OptionalUtils
@@ -91,9 +94,8 @@ fun formatHeaderOrQueryValue(
9194 memberShape : MemberShape ,
9295 location : HttpBinding .Location ,
9396 bindingIndex : HttpBindingIndex ,
94- defaultTimestampFormat : TimestampFormatTrait .Format
97+ defaultTimestampFormat : TimestampFormatTrait .Format ,
9598): Pair <String , Boolean > {
96-
9799 return when (val shape = ctx.model.expectShape(memberShape.target)) {
98100 is TimestampShape -> {
99101 val timestampFormat = bindingIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat)
@@ -165,7 +167,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
165167 writer.openBlock(
166168 " extension $symbolName : \$ N {" ,
167169 " }" ,
168- SwiftTypes .Protocols .Encodable
170+ SwiftTypes .Protocols .Encodable ,
169171 ) {
170172 writer.addImport(SwiftDependency .CLIENT_RUNTIME .target)
171173
@@ -286,7 +288,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
286288 private fun generateCodingKeysForMembers (
287289 ctx : ProtocolGenerator .GenerationContext ,
288290 writer : SwiftWriter ,
289- members : List <MemberShape >
291+ members : List <MemberShape >,
290292 ) {
291293 codingKeysGenerator.generateCodingKeysForMembers(ctx, writer, members)
292294 }
@@ -298,7 +300,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
298300 val inputType = ctx.model.expectShape(operation.input.get())
299301 var metadata = mapOf<ShapeMetadata , Any >(
300302 Pair (ShapeMetadata .OPERATION_SHAPE , operation),
301- Pair (ShapeMetadata .SERVICE_VERSION , ctx.service.version)
303+ Pair (ShapeMetadata .SERVICE_VERSION , ctx.service.version),
302304 )
303305 shapesInfo.put(inputType, metadata)
304306 }
@@ -336,7 +338,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
336338 }
337339
338340 private fun resolveShapesNeedingCodableConformance (ctx : ProtocolGenerator .GenerationContext ): Set <Shape > {
339-
340341 val topLevelOutputMembers = getHttpBindingOperations(ctx).flatMap {
341342 val outputShape = ctx.model.expectShape(it.output.get())
342343 outputShape.members()
@@ -390,7 +391,8 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
390391 RelationshipType .LIST_MEMBER ,
391392 RelationshipType .SET_MEMBER ,
392393 RelationshipType .MAP_VALUE ,
393- RelationshipType .UNION_MEMBER -> true
394+ RelationshipType .UNION_MEMBER ,
395+ -> true
394396 else -> false
395397 }
396398 }.forEach {
@@ -403,6 +405,29 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
403405 return resolved
404406 }
405407
408+ // Checks for @requiresLength trait
409+ // Returns true if the operation:
410+ // - has a streaming member with @httpPayload trait
411+ // - target is a blob shape with @requiresLength trait
412+ private fun hasRequiresLengthTrait (ctx : ProtocolGenerator .GenerationContext , op : OperationShape ): Boolean {
413+ if (op.input.isPresent) {
414+ val inputShape = ctx.model.expectShape(op.input.get())
415+ val streamingMember = inputShape.findStreamingMember(ctx.model)
416+ if (streamingMember != null ) {
417+ val targetShape = ctx.model.expectShape(streamingMember.target)
418+ if (targetShape != null ) {
419+ return streamingMember.hasTrait<HttpPayloadTrait >() &&
420+ targetShape.isBlobShape &&
421+ targetShape.hasTrait<RequiresLengthTrait >()
422+ }
423+ }
424+ }
425+ return false
426+ }
427+
428+ // Checks for @unsignedPayload trait on an operation
429+ private fun hasUnsignedPayloadTrait (op : OperationShape ): Boolean = op.hasTrait<UnsignedPayloadTrait >()
430+
406431 override fun generateProtocolClient (ctx : ProtocolGenerator .GenerationContext ) {
407432 val symbol = ctx.symbolProvider.toSymbol(ctx.service)
408433 ctx.delegator.useFileWriter(" ./${ctx.settings.moduleName} /${symbol.name} .swift" ) { writer ->
@@ -414,7 +439,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
414439 serviceSymbol.name,
415440 defaultContentType,
416441 httpProtocolCustomizable,
417- operationMiddleware
442+ operationMiddleware,
418443 )
419444 clientGenerator.render()
420445 }
@@ -433,7 +458,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
433458 operationMiddleware.appendMiddleware(operation, ContentTypeMiddleware (ctx.model, ctx.symbolProvider, resolver.determineRequestContentType(operation)))
434459 operationMiddleware.appendMiddleware(operation, OperationInputBodyMiddleware (ctx.model, ctx.symbolProvider))
435460
436- operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware (ctx.model, shouldRenderEncodableConformance))
461+ operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware (ctx.model, shouldRenderEncodableConformance, hasRequiresLengthTrait(ctx, operation), hasUnsignedPayloadTrait(operation) ))
437462
438463 operationMiddleware.appendMiddleware(operation, DeserializeMiddleware (ctx.model, ctx.symbolProvider))
439464 operationMiddleware.appendMiddleware(operation, LoggingMiddleware (ctx.model, ctx.symbolProvider))
@@ -463,15 +488,15 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
463488 members : List <MemberShape >,
464489 writer : SwiftWriter ,
465490 defaultTimestampFormat : TimestampFormatTrait .Format ,
466- path : String? = null
491+ path : String? = null,
467492 )
468493 protected abstract fun renderStructDecode (
469494 ctx : ProtocolGenerator .GenerationContext ,
470495 shapeMetaData : Map <ShapeMetadata , Any >,
471496 members : List <MemberShape >,
472497 writer : SwiftWriter ,
473498 defaultTimestampFormat : TimestampFormatTrait .Format ,
474- path : String
499+ path : String ,
475500 )
476501 protected abstract fun addProtocolSpecificMiddleware (ctx : ProtocolGenerator .GenerationContext , operation : OperationShape )
477502
@@ -487,11 +512,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
487512 for (operation in topDownIndex.getContainedOperations(ctx.service)) {
488513 OptionalUtils .ifPresentOrElse(
489514 Optional .of(getProtocolHttpBindingResolver(ctx, defaultContentType).httpTrait(operation)::class .java),
490- { containedOperations.add(operation) }
515+ { containedOperations.add(operation) },
491516 ) {
492517 LOGGER .warning(
493518 " Unable to fetch $protocolName protocol request bindings for ${operation.id} because " +
494- " it does not have an http binding trait"
519+ " it does not have an http binding trait" ,
495520 )
496521 }
497522 }
0 commit comments