Skip to content

Commit 0fa94c2

Browse files
authored
fix: locate event stream member more carefully (#1623)
1 parent 90d03e4 commit 0fa94c2

File tree

3 files changed

+150
-10
lines changed

3 files changed

+150
-10
lines changed

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/EventStreamGenerator.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import software.amazon.smithy.model.traits.ErrorTrait;
4040
import software.amazon.smithy.model.traits.EventHeaderTrait;
4141
import software.amazon.smithy.model.traits.EventPayloadTrait;
42+
import software.amazon.smithy.model.traits.HttpPayloadTrait;
4243
import software.amazon.smithy.model.traits.StreamingTrait;
4344
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
4445
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
@@ -82,6 +83,26 @@ public static UnionShape getEventStreamOutputShape(GenerationContext context, Op
8283
return eventStreamInfo.getEventStreamTarget().asUnionShape().get();
8384
}
8485

86+
public static MemberShape getEventStreamMember(GenerationContext context, StructureShape struct) {
87+
List<MemberShape> eventStreamMembers = struct.members()
88+
.stream()
89+
.filter(shape -> {
90+
Shape target = context.getModel().expectShape(shape.getTarget());
91+
boolean targetStreaming = target.hasTrait(StreamingTrait.class);
92+
boolean targetUnion = target.isUnionShape();
93+
boolean memberStreaming = shape.hasTrait(StreamingTrait.class);
94+
boolean memberPayload = shape.hasTrait(HttpPayloadTrait.class);
95+
return memberPayload && targetUnion && (targetStreaming || memberStreaming);
96+
}).toList();
97+
98+
if (eventStreamMembers.isEmpty()) {
99+
throw new CodegenException("No event stream member found in " + struct.getId().toString());
100+
} else if (eventStreamMembers.size() > 1) {
101+
throw new CodegenException("More than one event stream member in " + struct.getId().toString());
102+
}
103+
return eventStreamMembers.get(0);
104+
}
105+
85106
/**
86107
* Generate eventstream serializers, and related serializers for events.
87108
* @param context Code generation context instance.

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpRpcProtocolGenerator.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import java.util.Set;
1919
import java.util.TreeSet;
2020
import java.util.logging.Logger;
21-
import java.util.stream.Collectors;
2221
import software.amazon.smithy.codegen.core.Symbol;
2322
import software.amazon.smithy.codegen.core.SymbolProvider;
2423
import software.amazon.smithy.codegen.core.SymbolReference;
@@ -359,12 +358,13 @@ protected boolean writeRequestBody(GenerationContext context, OperationShape ope
359358
// Write the default `body` property.
360359
writer.write("let body: any;");
361360
if (EventStreamGenerator.hasEventStreamInput(context, operation)) {
362-
// There must only one eventstream member in request structure.
363-
MemberShape member = inputShape.members().stream().collect(Collectors.toList()).get(0);
364-
Shape target = context.getModel().expectShape(member.getTarget());
361+
MemberShape eventStreamMember = EventStreamGenerator.getEventStreamMember(
362+
context, inputShape
363+
);
364+
Shape target = context.getModel().expectShape(eventStreamMember.getTarget());
365365
Symbol targetSymbol = context.getSymbolProvider().toSymbol(target);
366366
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(targetSymbol);
367-
String memberName = member.getMemberName();
367+
String memberName = eventStreamMember.getMemberName();
368368
writer.write("body = $L(input.$L, context);", serFunctionName, memberName);
369369
} else {
370370
// Track input shapes so their serializers may be generated.
@@ -540,12 +540,16 @@ protected void readResponseBody(GenerationContext context, OperationShape operat
540540
// If there's an output present, we know it's a structure.
541541
StructureShape outputShape = context.getModel().expectShape(outputId).asStructureShape().get();
542542
if (EventStreamGenerator.hasEventStreamOutput(context, operation)) {
543-
// There must only one eventstream member in response structure.
544-
MemberShape member = outputShape.members().stream().collect(Collectors.toList()).get(0);
545-
Shape target = context.getModel().expectShape(member.getTarget());
543+
MemberShape eventStreamMember = EventStreamGenerator.getEventStreamMember(
544+
context, outputShape
545+
);
546+
Shape target = context.getModel().expectShape(eventStreamMember.getTarget());
546547
Symbol targetSymbol = context.getSymbolProvider().toSymbol(target);
547-
writer.write("const contents = { $L: $L(output.body, context) };", member.getMemberName(),
548-
ProtocolGenerator.getDeserFunctionShortName(targetSymbol));
548+
writer.write(
549+
"const contents = { $L: $L(output.body, context) };",
550+
eventStreamMember.getMemberName(),
551+
ProtocolGenerator.getDeserFunctionShortName(targetSymbol)
552+
);
549553
} else {
550554
// We only need to load the body and prepare a contents object if there is a response.
551555
writer.write("const data: any = await parseBody(output.body, context)");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package software.amazon.smithy.typescript.codegen.integration;
2+
3+
import java.util.List;
4+
import org.junit.jupiter.api.Test;
5+
import org.junit.jupiter.api.extension.ExtendWith;
6+
import org.mockito.Mock;
7+
import org.mockito.junit.jupiter.MockitoExtension;
8+
import software.amazon.smithy.codegen.core.CodegenException;
9+
import software.amazon.smithy.model.Model;
10+
import software.amazon.smithy.model.shapes.MemberShape;
11+
import software.amazon.smithy.model.shapes.ShapeId;
12+
import software.amazon.smithy.model.shapes.StructureShape;
13+
import software.amazon.smithy.model.shapes.UnionShape;
14+
import software.amazon.smithy.model.traits.HttpPayloadTrait;
15+
import software.amazon.smithy.model.traits.StreamingTrait;
16+
17+
import static org.junit.jupiter.api.Assertions.*;
18+
import static org.mockito.Mockito.when;
19+
20+
@ExtendWith(MockitoExtension.class)
21+
class EventStreamGeneratorTest {
22+
@Test
23+
void getEventStreamMember(
24+
@Mock ProtocolGenerator.GenerationContext context,
25+
@Mock Model model,
26+
@Mock StructureShape struct,
27+
@Mock MemberShape eventStreamMember1,
28+
@Mock ShapeId streamingMember1ShapeId,
29+
@Mock UnionShape streamingTarget1
30+
) {
31+
when(struct.members()).thenReturn(List.of(eventStreamMember1));
32+
when(eventStreamMember1.getTarget()).thenReturn(streamingMember1ShapeId);
33+
when(context.getModel()).thenReturn(model);
34+
when(model.expectShape(streamingMember1ShapeId)).thenReturn(streamingTarget1);
35+
36+
when(streamingTarget1.hasTrait(StreamingTrait.class)).thenReturn(true);
37+
when(streamingTarget1.isUnionShape()).thenReturn(true);
38+
when(eventStreamMember1.hasTrait(StreamingTrait.class)).thenReturn(false);
39+
when(eventStreamMember1.hasTrait(HttpPayloadTrait.class)).thenReturn(true);
40+
41+
MemberShape eventStreamMember = EventStreamGenerator.getEventStreamMember(
42+
context,
43+
struct
44+
);
45+
46+
assertEquals(eventStreamMember1, eventStreamMember);
47+
}
48+
49+
@Test
50+
void getEventStreamMemberTooFew(
51+
@Mock ProtocolGenerator.GenerationContext context,
52+
@Mock StructureShape struct
53+
) {
54+
when(struct.members()).thenReturn(List.of());
55+
when(struct.getId()).thenReturn(ShapeId.from("namespace#Shape"));
56+
57+
try {
58+
MemberShape eventStreamMember = EventStreamGenerator.getEventStreamMember(
59+
context,
60+
struct
61+
);
62+
} catch (CodegenException e) {
63+
assertEquals(
64+
"No event stream member found in " + struct.getId().toString(),
65+
e.getMessage()
66+
);
67+
}
68+
}
69+
70+
@Test
71+
void getEventStreamMemberTooMany(
72+
@Mock ProtocolGenerator.GenerationContext context,
73+
@Mock Model model,
74+
@Mock StructureShape struct,
75+
@Mock MemberShape eventStreamMember1,
76+
@Mock ShapeId streamingMember1ShapeId,
77+
@Mock UnionShape streamingTarget1,
78+
@Mock MemberShape eventStreamMember2,
79+
@Mock ShapeId streamingMember2ShapeId,
80+
@Mock UnionShape streamingTarget2
81+
) {
82+
when(struct.members()).thenReturn(List.of(
83+
eventStreamMember1,
84+
eventStreamMember2
85+
));
86+
when(context.getModel()).thenReturn(model);
87+
when(struct.getId()).thenReturn(ShapeId.from("namespace#Shape"));
88+
89+
when(eventStreamMember1.getTarget()).thenReturn(streamingMember1ShapeId);
90+
when(model.expectShape(streamingMember1ShapeId)).thenReturn(streamingTarget1);
91+
when(streamingTarget1.hasTrait(StreamingTrait.class)).thenReturn(true);
92+
when(streamingTarget1.isUnionShape()).thenReturn(true);
93+
when(eventStreamMember1.hasTrait(StreamingTrait.class)).thenReturn(false);
94+
when(eventStreamMember1.hasTrait(HttpPayloadTrait.class)).thenReturn(true);
95+
96+
when(eventStreamMember2.getTarget()).thenReturn(streamingMember2ShapeId);
97+
when(model.expectShape(streamingMember2ShapeId)).thenReturn(streamingTarget2);
98+
when(streamingTarget2.hasTrait(StreamingTrait.class)).thenReturn(true);
99+
when(streamingTarget2.isUnionShape()).thenReturn(true);
100+
when(eventStreamMember2.hasTrait(StreamingTrait.class)).thenReturn(false);
101+
when(eventStreamMember2.hasTrait(HttpPayloadTrait.class)).thenReturn(true);
102+
103+
try {
104+
MemberShape eventStreamMember = EventStreamGenerator.getEventStreamMember(
105+
context,
106+
struct
107+
);
108+
} catch (CodegenException e) {
109+
assertEquals(
110+
"More than one event stream member in " + struct.getId().toString(),
111+
e.getMessage()
112+
);
113+
}
114+
}
115+
}

0 commit comments

Comments
 (0)