Skip to content

Commit f5ec52e

Browse files
authored
fix: Adds a Hashable trait to structs that need to be Hashable (#268)
1 parent bd2c5db commit f5ec52e

File tree

6 files changed

+268
-1
lines changed

6 files changed

+268
-1
lines changed

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/CodegenVisitor.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Void>() {
5151
// Add operation input/output shapes if not provided for future evolution of sdk
5252
resolvedModel = AddOperationShapes.execute(resolvedModel, settings.getService(resolvedModel), settings.moduleName)
5353
resolvedModel = RecursiveShapeBoxer.transform(resolvedModel)
54+
resolvedModel = HashableShapeTransformer.transform(resolvedModel)
5455
model = resolvedModel
5556

5657
service = settings.getService(model)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package software.amazon.smithy.swift.codegen
2+
3+
import software.amazon.smithy.model.Model
4+
import software.amazon.smithy.model.loader.Prelude
5+
import software.amazon.smithy.model.neighbor.RelationshipType
6+
import software.amazon.smithy.model.neighbor.Walker
7+
import software.amazon.smithy.model.shapes.Shape
8+
import software.amazon.smithy.model.shapes.StructureShape
9+
import software.amazon.smithy.model.transform.ModelTransformer
10+
import software.amazon.smithy.swift.codegen.model.hasTrait
11+
12+
object HashableShapeTransformer {
13+
14+
fun transform(model: Model): Model {
15+
val next = transformInner(model)
16+
return if (next == null) {
17+
model
18+
} else {
19+
transform(next)
20+
}
21+
}
22+
23+
private fun transformInner(model: Model): Model? {
24+
// find all the shapes in this models shapes that have a struct shape contained within a set and don't already have the trait
25+
val allShapesNeedingHashable = mutableSetOf<Shape>()
26+
model.shapes().filter { needsHashableTrait(model, it) }
27+
.forEach { allShapesNeedingHashable.add(it) }
28+
// find all the other shapes referencing that shape and mark with hashable.
29+
allShapesNeedingHashable.addAll(getNestedTypesNeedingHashable(model, allShapesNeedingHashable))
30+
31+
if (allShapesNeedingHashable.isEmpty()) {
32+
return null
33+
}
34+
35+
return ModelTransformer.create().mapShapes(model) { shape ->
36+
if (allShapesNeedingHashable.contains(shape)) {
37+
shape.asStructureShape().get().toBuilder().addTrait(HashableTrait()).build()
38+
} else {
39+
shape
40+
}
41+
}
42+
}
43+
44+
private fun getNestedTypesNeedingHashable(model: Model, shapes: Set<Shape>): Set<Shape> {
45+
val nestedTypes = mutableSetOf<Shape>()
46+
val walker = Walker(model)
47+
// walk all the shapes in the set and find all other
48+
// structs in the graph from that shape that are nested
49+
shapes.forEach { shape ->
50+
walker.iterateShapes(shape) { relationship ->
51+
when (relationship.relationshipType) {
52+
RelationshipType.STRUCTURE_MEMBER,
53+
RelationshipType.MEMBER_TARGET -> true
54+
else -> false
55+
}
56+
}.forEach {
57+
if (it is StructureShape && !it.hasTrait<HashableTrait>()) {
58+
nestedTypes.add(it)
59+
}
60+
}
61+
}
62+
return nestedTypes
63+
}
64+
65+
private fun needsHashableTrait(model: Model, shape: Shape): Boolean {
66+
return if (shape is StructureShape && !Prelude.isPreludeShape(shape)) {
67+
val allCollectionShapes = model.setShapes.filter { !Prelude.isPreludeShape(it) }
68+
allCollectionShapes.any { it.member.target == shape.toShapeId() } && !shape.hasTrait<HashableTrait>()
69+
} else {
70+
false
71+
}
72+
}
73+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package software.amazon.smithy.swift.codegen
2+
3+
import software.amazon.smithy.model.node.Node
4+
import software.amazon.smithy.model.shapes.ShapeId
5+
import software.amazon.smithy.model.traits.Trait
6+
7+
class HashableTrait : Trait {
8+
val ID = ShapeId.from("software.amazon.smithy.swift.codegen.swift.synthetic#hashable")
9+
override fun toNode(): Node = Node.objectNode()
10+
11+
override fun toShapeId(): ShapeId = ID
12+
}

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/StructureGenerator.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import software.amazon.smithy.model.traits.IdempotencyTokenTrait
1717
import software.amazon.smithy.model.traits.RetryableTrait
1818
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
1919
import software.amazon.smithy.swift.codegen.model.getTrait
20+
import software.amazon.smithy.swift.codegen.model.hasTrait
2021
import software.amazon.smithy.swift.codegen.model.isError
2122

2223
fun MemberShape.isRecursiveMember(index: TopologicalIndex): Boolean {
@@ -97,7 +98,8 @@ class StructureGenerator(
9798
private fun renderNonErrorStructure() {
9899
writer.writeShapeDocs(shape)
99100
writer.writeAvailableAttribute(model, shape)
100-
writer.openBlock("public struct \$struct.name:L: Equatable {")
101+
val needsHashable = if (shape.hasTrait<HashableTrait>()) ", Hashable" else ""
102+
writer.openBlock("public struct \$struct.name:L: Equatable$needsHashable {")
101103
.call { generateStructMembers() }
102104
.write("")
103105
.call { generateInitializerForStructure() }
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import io.kotest.matchers.string.shouldContain
2+
import org.junit.jupiter.api.Assertions
3+
import org.junit.jupiter.api.Test
4+
import software.amazon.smithy.build.MockManifest
5+
import software.amazon.smithy.model.shapes.ShapeId
6+
import software.amazon.smithy.swift.codegen.HashableShapeTransformer
7+
import software.amazon.smithy.swift.codegen.HashableTrait
8+
import software.amazon.smithy.swift.codegen.SwiftCodegenPlugin
9+
import software.amazon.smithy.swift.codegen.model.hasTrait
10+
import kotlin.streams.toList
11+
12+
class HashableShapeTransformerTests {
13+
14+
@Test
15+
fun `leave non-hashable models unchanged`() {
16+
val model = javaClass.getResource("simple-service-with-operation-and-dependency.smithy").asSmithy()
17+
val transformed = HashableShapeTransformer.transform(model)
18+
transformed.shapes().toList().forEach {
19+
Assertions.assertFalse(transformed.getShape(it.id).get().hasTrait<HashableTrait>())
20+
}
21+
}
22+
23+
@Test
24+
fun `add the hashable trait to hashable shapes`() {
25+
val model = javaClass.getResource("hashable-trait-test.smithy").asSmithy()
26+
val transformed = HashableShapeTransformer.transform(model)
27+
28+
val traitedMember = "smithy.example#HashableStructure"
29+
val traitedMemberShape = transformed.getShape(ShapeId.from(traitedMember)).get()
30+
31+
Assertions.assertTrue(traitedMemberShape.hasTrait<HashableTrait>())
32+
}
33+
34+
@Test
35+
fun `test for nested types with hashable trait`() {
36+
val model = javaClass.getResource("hashable-trait-test.smithy").asSmithy()
37+
val transformed = HashableShapeTransformer.transform(model)
38+
39+
val traitedMember2 = "smithy.example#NestedHashableStructure"
40+
val traitedMemberShape2 = transformed.getShape(ShapeId.from(traitedMember2)).get()
41+
42+
Assertions.assertTrue(traitedMemberShape2.hasTrait<HashableTrait>())
43+
}
44+
45+
@Test
46+
fun `test that certain types do not receive the trait`() {
47+
val model = javaClass.getResource("hashable-trait-test.smithy").asSmithy()
48+
val transformed = HashableShapeTransformer.transform(model)
49+
50+
val untraitedMember = "smithy.example#HashableInput"
51+
val untraitedMemberShape = transformed.getShape(ShapeId.from(untraitedMember)).get()
52+
53+
Assertions.assertFalse(untraitedMemberShape.hasTrait<HashableTrait>())
54+
}
55+
56+
@Test
57+
fun `add the hashable trait to hashable shapes during integration with SwiftCodegenPlugin`() {
58+
val model = javaClass.getResource("hashable-trait-test.smithy").asSmithy()
59+
val manifest = MockManifest()
60+
val context = buildMockPluginContext(model, manifest, "smithy.example#Example")
61+
SwiftCodegenPlugin().execute(context)
62+
63+
val hashableShapeInput = manifest
64+
.getFileString("example/models/HashableShapesInput.swift").get()
65+
Assertions.assertNotNull(hashableShapeInput)
66+
val expected = """
67+
public struct HashableShapesInput: Equatable {
68+
public let `set`: Set<HashableStructure>?
69+
public let bar: String?
70+
71+
public init (
72+
`set`: Set<HashableStructure>? = nil,
73+
bar: String? = nil
74+
)
75+
{
76+
self.`set` = `set`
77+
self.bar = bar
78+
}
79+
}
80+
""".trimIndent()
81+
hashableShapeInput.shouldContain(expected)
82+
83+
val hashableShapeOutput = manifest
84+
.getFileString("example/models/HashableShapesOutputResponse.swift").get()
85+
Assertions.assertNotNull(hashableShapeOutput)
86+
val expectedOutput = """
87+
public struct HashableShapesOutputResponse: Equatable {
88+
public let quz: String?
89+
90+
public init (
91+
quz: String? = nil
92+
)
93+
{
94+
self.quz = quz
95+
}
96+
}
97+
""".trimIndent()
98+
hashableShapeOutput.shouldContain(expectedOutput)
99+
100+
val hashableSetShape = manifest
101+
.getFileString("example/models/HashableStructure.swift").get()
102+
Assertions.assertNotNull(hashableSetShape)
103+
val expectedStructureShape = """
104+
public struct HashableStructure: Equatable, Hashable {
105+
public let baz: NestedHashableStructure?
106+
public let foo: String?
107+
108+
public init (
109+
baz: NestedHashableStructure? = nil,
110+
foo: String? = nil
111+
)
112+
{
113+
self.baz = baz
114+
self.foo = foo
115+
}
116+
}
117+
""".trimIndent()
118+
hashableSetShape.shouldContain(expectedStructureShape)
119+
120+
val hashableNestedStructure = manifest
121+
.getFileString("example/models/NestedHashableStructure.swift").get()
122+
Assertions.assertNotNull(hashableNestedStructure)
123+
val expectedNestedStructureShape = """
124+
public struct NestedHashableStructure: Equatable, Hashable {
125+
public let bar: String?
126+
public let quz: Int?
127+
128+
public init (
129+
bar: String? = nil,
130+
quz: Int? = nil
131+
)
132+
{
133+
self.bar = bar
134+
self.quz = quz
135+
}
136+
}
137+
""".trimIndent()
138+
hashableNestedStructure.shouldContain(expectedNestedStructureShape)
139+
}
140+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
$version: "1.0"
2+
namespace smithy.example
3+
4+
use aws.protocols#restJson1
5+
6+
service Example {
7+
version: "1.0.0",
8+
operations: [
9+
HashableShapes
10+
]
11+
}
12+
13+
operation HashableShapes {
14+
input: HashableInput,
15+
output: HashableOutput
16+
}
17+
18+
structure HashableInput {
19+
bar: String,
20+
set: HashableSet
21+
}
22+
23+
set HashableSet {
24+
member: HashableStructure
25+
}
26+
27+
structure HashableStructure {
28+
foo: String,
29+
baz: NestedHashableStructure
30+
}
31+
32+
structure NestedHashableStructure {
33+
bar: String,
34+
quz: Integer
35+
}
36+
37+
structure HashableOutput {
38+
quz: String
39+
}

0 commit comments

Comments
 (0)