Skip to content

Commit 191c577

Browse files
Fix constraint-related errors in Rpcv2CBOR server implementation (#3794)
2 parents 684c15f + 27ca7f1 commit 191c577

File tree

15 files changed

+437
-43
lines changed

15 files changed

+437
-43
lines changed

.changelog/2155171.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
applies_to: ["server","client"]
3+
authors: ["drganjoo"]
4+
references: [smithy-rs#3573]
5+
breaking: false
6+
new_feature: true
7+
bug_fix: false
8+
---
9+
Support for the [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) protocol has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission.

codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource
1010
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
1111
import software.amazon.smithy.rust.codegen.core.rustlang.rust
1212
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
13+
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.generateRustPayloadInitializer
1314
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
1415
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
1516
import software.amazon.smithy.rust.codegen.core.testutil.testModule
@@ -46,7 +47,7 @@ class ClientEventStreamUnmarshallerGeneratorTest {
4647
"exception",
4748
"UnmodeledError",
4849
"${testCase.responseContentType}",
49-
br#"${testCase.validUnmodeledError}"#
50+
${testCase.generateRustPayloadInitializer(testCase.validUnmodeledError)}
5051
);
5152
let result = $generator::new().unmarshall(&message);
5253
assert!(result.is_ok(), "expected ok, got: {:?}", result);

codegen-core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies {
2525
implementation("org.jsoup:jsoup:1.16.2")
2626
api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
2727
api("com.moandjiezana.toml:toml4j:0.7.2")
28+
implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
2829
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
2930
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
3031
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
1414
import software.amazon.smithy.model.shapes.ToShapeId
1515
import software.amazon.smithy.model.traits.HttpTrait
1616
import software.amazon.smithy.model.traits.TimestampFormatTrait
17+
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
1718
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
1819
import software.amazon.smithy.rust.codegen.core.rustlang.writable
1920
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
@@ -140,9 +141,23 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
140141
override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
141142
RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")
142143

143-
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
144144
override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
145-
TODO("rpcv2Cbor event streams have not yet been implemented")
145+
ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName ->
146+
rustTemplate(
147+
"""
148+
pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> {
149+
#{cbor_errors}::parse_error_metadata(0, &#{Headers}::new(), payload)
150+
}
151+
""",
152+
"cbor_errors" to RuntimeType.cborErrors(runtimeConfig),
153+
"Bytes" to RuntimeType.Bytes,
154+
"ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
155+
"DeserializeError" to
156+
CargoDependency.smithyCbor(runtimeConfig).toType()
157+
.resolve("decode::DeserializeError"),
158+
"Headers" to RuntimeType.headers(runtimeConfig),
159+
)
160+
}
146161

147162
// Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set
148163
// unless there is no input or if the operation is an event stream, see

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso
4848
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
4949
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
5050
import software.amazon.smithy.rust.codegen.core.util.PANIC
51-
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
5251
import software.amazon.smithy.rust.codegen.core.util.dq
5352
import software.amazon.smithy.rust.codegen.core.util.hasTrait
5453
import software.amazon.smithy.rust.codegen.core.util.inputShape
@@ -447,7 +446,24 @@ class CborParserGenerator(
447446
}
448447

449448
override fun payloadParser(member: MemberShape): RuntimeType {
450-
UNREACHABLE("No protocol using CBOR serialization supports payload binding")
449+
val shape = model.expectShape(member.target)
450+
val returnSymbol = returnSymbolToParse(shape)
451+
check(shape is UnionShape || shape is StructureShape) {
452+
"Payload parser should only be used on structure and union shapes."
453+
}
454+
return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName ->
455+
rustTemplate(
456+
"""
457+
pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> {
458+
let decoder = &mut #{Decoder}::new(value);
459+
#{DeserializeMember}
460+
}
461+
""",
462+
"ReturnType" to returnSymbol.symbol,
463+
"DeserializeMember" to deserializeMember(member),
464+
*codegenScope,
465+
)
466+
}
451467
}
452468

453469
override fun operationParser(operationShape: OperationShape): RuntimeType? {

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) {
6666
/** Manipulate the serializer context for a map prior to it being serialized. **/
6767
data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) :
6868
CborSerializerSection("BeforeIteratingOverMapOrCollection")
69+
70+
/** Manipulate the serializer context for a non-null member prior to it being serialized. **/
71+
data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) :
72+
CborSerializerSection("BeforeSerializingNonNullMember")
6973
}
7074

7175
/**
@@ -200,9 +204,26 @@ class CborSerializerGenerator(
200204
}
201205
}
202206

203-
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
204207
override fun payloadSerializer(member: MemberShape): RuntimeType {
205-
TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573")
208+
val target = model.expectShape(member.target)
209+
return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName ->
210+
rustBlockTemplate(
211+
"pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>",
212+
*codegenScope,
213+
"target" to symbolProvider.toSymbol(target),
214+
) {
215+
rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope)
216+
rustBlock("") {
217+
rust("let encoder = &mut encoder;")
218+
when (target) {
219+
is StructureShape -> serializeStructure(StructContext("input", target))
220+
is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target))
221+
else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions")
222+
}
223+
}
224+
rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope)
225+
}
226+
}
206227
}
207228

208229
override fun unsetStructure(structure: StructureShape): RuntimeType =
@@ -311,6 +332,7 @@ class CborSerializerGenerator(
311332
safeName().also { local ->
312333
rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") {
313334
context.valueExpression = ValueExpression.Reference(local)
335+
resolveValueExpressionForConstrainedType(targetShape, context)
314336
serializeMemberValue(context, targetShape)
315337
}
316338
if (context.writeNulls) {
@@ -320,6 +342,7 @@ class CborSerializerGenerator(
320342
}
321343
}
322344
} else {
345+
resolveValueExpressionForConstrainedType(targetShape, context)
323346
with(serializerUtil) {
324347
ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) {
325348
serializeMemberValue(context, targetShape)
@@ -328,6 +351,20 @@ class CborSerializerGenerator(
328351
}
329352
}
330353

354+
private fun RustWriter.resolveValueExpressionForConstrainedType(
355+
targetShape: Shape,
356+
context: MemberContext,
357+
) {
358+
for (customization in customizations) {
359+
customization.section(
360+
CborSerializerSection.BeforeSerializingNonNullMember(
361+
targetShape,
362+
context,
363+
),
364+
)(this)
365+
}
366+
}
367+
331368
private fun RustWriter.serializeMemberValue(
332369
context: MemberContext,
333370
target: Shape,
@@ -362,7 +399,7 @@ class CborSerializerGenerator(
362399
rust("$encoder;") // Encode the member key.
363400
}
364401
when (target) {
365-
is StructureShape -> serializeStructure(StructContext(value.name, target))
402+
is StructureShape -> serializeStructure(StructContext(value.asRef(), target))
366403
is CollectionShape -> serializeCollection(Context(value, target))
367404
is MapShape -> serializeMap(Context(value, target))
368405
is UnionShape -> serializeUnion(Context(value, target))

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,27 @@
55

66
package software.amazon.smithy.rust.codegen.core.testutil
77

8+
import com.fasterxml.jackson.databind.ObjectMapper
9+
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
810
import software.amazon.smithy.model.Model
911
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
1012
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
1113
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
1214
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
1315
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
1416
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
17+
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor
18+
import java.util.Base64
1519

1620
private fun fillInBaseModel(
17-
protocolName: String,
21+
namespacedProtocolName: String,
1822
extraServiceAnnotations: String = "",
1923
): String =
2024
"""
2125
namespace test
2226
2327
use smithy.framework#ValidationException
24-
use aws.protocols#$protocolName
28+
use $namespacedProtocolName
2529
2630
union TestUnion {
2731
Foo: String,
@@ -86,22 +90,24 @@ private fun fillInBaseModel(
8690
}
8791
8892
$extraServiceAnnotations
89-
@$protocolName
93+
@${namespacedProtocolName.substringAfter("#")}
9094
service TestService { version: "123", operations: [TestStreamOp] }
9195
"""
9296

9397
object EventStreamTestModels {
94-
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
98+
private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel()
9599

96-
private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
100+
private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel()
97101

98-
private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
102+
private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel()
103+
104+
private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel()
99105

100106
private fun awsQuery(): Model =
101-
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
107+
fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
102108

103109
private fun ec2Query(): Model =
104-
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
110+
fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
105111

106112
data class TestCase(
107113
val protocolShapeId: String,
@@ -120,39 +126,67 @@ object EventStreamTestModels {
120126
override fun toString(): String = protocolShapeId
121127
}
122128

129+
private fun base64Encode(input: ByteArray): String {
130+
val encodedBytes = Base64.getEncoder().encode(input)
131+
return String(encodedBytes)
132+
}
133+
134+
private fun createCborFromJson(jsonString: String): ByteArray {
135+
val jsonMapper = ObjectMapper()
136+
val cborMapper = ObjectMapper(CBORFactory())
137+
// Parse JSON string to a generic type.
138+
val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
139+
// Convert the parsed data to CBOR.
140+
return cborMapper.writeValueAsBytes(jsonData)
141+
}
142+
143+
private val restJsonTestCase =
144+
TestCase(
145+
protocolShapeId = "aws.protocols#restJson1",
146+
model = restJson1(),
147+
mediaType = "application/json",
148+
requestContentType = "application/vnd.amazon.eventstream",
149+
responseContentType = "application/json",
150+
eventStreamMessageContentType = "application/json",
151+
validTestStruct = """{"someString":"hello","someInt":5}""",
152+
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
153+
validTestUnion = """{"Foo":"hello"}""",
154+
validSomeError = """{"Message":"some error"}""",
155+
validUnmodeledError = """{"Message":"unmodeled error"}""",
156+
) { RestJson(it) }
157+
123158
val TEST_CASES =
124159
listOf(
125160
//
126161
// restJson1
127162
//
128-
TestCase(
129-
protocolShapeId = "aws.protocols#restJson1",
130-
model = restJson1(),
131-
mediaType = "application/json",
132-
requestContentType = "application/vnd.amazon.eventstream",
133-
responseContentType = "application/json",
134-
eventStreamMessageContentType = "application/json",
135-
validTestStruct = """{"someString":"hello","someInt":5}""",
136-
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
137-
validTestUnion = """{"Foo":"hello"}""",
138-
validSomeError = """{"Message":"some error"}""",
139-
validUnmodeledError = """{"Message":"unmodeled error"}""",
140-
) { RestJson(it) },
163+
restJsonTestCase,
164+
//
165+
// rpcV2Cbor
166+
//
167+
restJsonTestCase.copy(
168+
protocolShapeId = "smithy.protocols#rpcv2Cbor",
169+
model = rpcv2Cbor(),
170+
mediaType = "application/cbor",
171+
responseContentType = "application/cbor",
172+
eventStreamMessageContentType = "application/cbor",
173+
validTestStruct = base64Encode(createCborFromJson(restJsonTestCase.validTestStruct)),
174+
validMessageWithNoHeaderPayloadTraits = base64Encode(createCborFromJson(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)),
175+
validTestUnion = base64Encode(createCborFromJson(restJsonTestCase.validTestUnion)),
176+
validSomeError = base64Encode(createCborFromJson(restJsonTestCase.validSomeError)),
177+
validUnmodeledError = base64Encode(createCborFromJson(restJsonTestCase.validUnmodeledError)),
178+
protocolBuilder = { RpcV2Cbor(it) },
179+
),
141180
//
142181
// awsJson1_1
143182
//
144-
TestCase(
183+
restJsonTestCase.copy(
145184
protocolShapeId = "aws.protocols#awsJson1_1",
146185
model = awsJson11(),
147186
mediaType = "application/x-amz-json-1.1",
148187
requestContentType = "application/x-amz-json-1.1",
149188
responseContentType = "application/x-amz-json-1.1",
150189
eventStreamMessageContentType = "application/json",
151-
validTestStruct = """{"someString":"hello","someInt":5}""",
152-
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
153-
validTestUnion = """{"Foo":"hello"}""",
154-
validSomeError = """{"Message":"some error"}""",
155-
validUnmodeledError = """{"Message":"unmodeled error"}""",
156190
) { AwsJson(it, AwsJsonVersion.Json11) },
157191
//
158192
// restXml

0 commit comments

Comments
 (0)