Skip to content

Commit 6d780ca

Browse files
authored
fix: correct bad waiter codegen caused by dropped projection scope (#848)
1 parent 2b760b2 commit 6d780ca

File tree

4 files changed

+108
-30
lines changed

4 files changed

+108
-30
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"id": "a0f70d1d-ab72-4d18-812d-fe72059ffb2c",
3+
"type": "bugfix",
4+
"description": "Fix incorrect waiter codegen due to dropped projection scope"
5+
}

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/KotlinJmespathExpressionVisitor.kt

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,14 @@ class KotlinJmespathExpressionVisitor(
5757
private val currentShape: Shape
5858
get() = shapeCursor.last()
5959

60+
// traverses an independent expression (one whose resolved scope does not persist in the outer evaluation)
6061
private fun acceptSubexpression(expr: JmespathExpression): VisitedExpression {
61-
shapeCursor.addLast(currentShape)
62+
val pos = shapeCursor.size
6263
val out = expr.accept(this)
63-
shapeCursor.removeLast()
64+
65+
val diff = shapeCursor.size - pos
66+
repeat(diff) { shapeCursor.removeLast() } // reset the shape cursor
67+
6468
return out
6569
}
6670

@@ -73,15 +77,6 @@ class KotlinJmespathExpressionVisitor(
7377
private fun bestTempVarName(preferredName: String): String =
7478
suffixSequence.map { "$preferredName$it" }.first(tempVars::add)
7579

76-
private fun childBlock(forExpression: JmespathExpression, shape: Shape): VisitedExpression {
77-
val childShape = when (val target = shape.targetOrSelf(ctx.model)) {
78-
is ListShape -> target.member
79-
is MapShape -> target.value
80-
else -> shape
81-
}
82-
return forExpression.accept(KotlinJmespathExpressionVisitor(ctx, writer, childShape))
83-
}
84-
8580
@OptIn(ExperimentalContracts::class)
8681
private fun codegenReq(condition: Boolean, lazyMessage: () -> String) {
8782
contract {
@@ -90,21 +85,25 @@ class KotlinJmespathExpressionVisitor(
9085
if (!condition) throw CodegenException(lazyMessage())
9186
}
9287

93-
private fun flatMappingBlock(right: JmespathExpression, leftName: String, leftShape: Shape): VisitedExpression {
88+
private fun flatMappingBlock(right: JmespathExpression, leftName: String, leftShape: Shape, innerShape: Shape?): VisitedExpression {
9489
if (right is CurrentExpression) return VisitedExpression(leftName, leftShape) // nothing to map
9590

9691
val outerName = bestTempVarName("projection")
97-
writer.openBlock("val #L = #L.flatMap {", outerName, leftName)
92+
val flatMapExpr = ensureNullGuard(leftShape, "flatMap")
93+
writer.openBlock("val #L = #L#L {", outerName, leftName, flatMapExpr)
94+
95+
shapeCursor.addLast(innerShape?.targetMemberOrSelf ?: leftShape.targetMemberOrSelf)
96+
val innerResult = acceptSubexpression(right)
97+
shapeCursor.removeLast()
9898

99-
val innerResult = childBlock(right, leftShape)
10099
val innerCollector = when (right) {
101100
is MultiSelectListExpression -> innerResult.identifier // Already a list
102101
else -> "listOfNotNull(${innerResult.identifier})"
103102
}
104103
writer.write(innerCollector)
105104

106105
writer.closeBlock("}")
107-
return VisitedExpression(outerName, innerResult.shape)
106+
return VisitedExpression(outerName, leftShape, innerResult.shape)
108107
}
109108

110109
private fun subfield(expression: FieldExpression, parentName: String): VisitedExpression {
@@ -151,7 +150,7 @@ class KotlinJmespathExpressionVisitor(
151150
val codegen = buildString {
152151
val nullables = buildList {
153152
if (left.shape?.isNullable == true) add("${left.identifier} == null")
154-
if (right.shape?.isNullable == true) add("${left.identifier} == null")
153+
if (right.shape?.isNullable == true) add("${right.identifier} == null")
155154
}
156155
if (nullables.isNotEmpty()) {
157156
val isNullExpr = nullables.joinToString(" || ")
@@ -176,27 +175,31 @@ class KotlinJmespathExpressionVisitor(
176175
override fun visitField(expression: FieldExpression): VisitedExpression = subfield(expression, "it")
177176

178177
override fun visitFilterProjection(expression: FilterProjectionExpression): VisitedExpression {
179-
val left = acceptSubexpression(expression.left)
178+
val left = expression.left.accept(this)
180179
requireNotNull(left.shape) { "filter projection is operating on nothing?" }
181180

182181
val filteredName = bestTempVarName("${left.identifier}Filtered")
183182

184183
val filterExpr = ensureNullGuard(left.shape, "filter")
185184
writer.withBlock("val #L = #L#L {", "}", filteredName, left.identifier, filterExpr) {
186-
val comparison = childBlock(expression.comparison, left.shape)
185+
shapeCursor.addLast(left.shape.targetMemberOrSelf)
186+
val comparison = acceptSubexpression(expression.comparison)
187+
shapeCursor.removeLast()
187188
write("#L == true", comparison.identifier)
188189
}
189190

190-
return flatMappingBlock(expression.right, filteredName, left.shape)
191+
return flatMappingBlock(expression.right, filteredName, left.shape, left.projected)
191192
}
192193

193194
override fun visitFlatten(expression: FlattenExpression): VisitedExpression {
194195
writer.addImport(RuntimeTypes.Core.Utils.flattenIfPossible)
195196

196-
val inner = acceptSubexpression(expression.expression)
197-
val flattenExpr = ensureNullGuard(inner.shape, "flattenIfPossible()", "listOf()")
197+
val inner = expression.expression.accept(this)
198+
199+
val flattenExpr = ensureNullGuard(currentShape, "flattenIfPossible()")
198200
val ident = addTempVar("${inner.identifier}OrEmpty", "${inner.identifier}$flattenExpr")
199-
return VisitedExpression(ident, inner.shape)
201+
202+
return VisitedExpression(ident, currentShape, inner.projected)
200203
}
201204

202205
override fun visitFunction(expression: FunctionExpression): VisitedExpression = when (expression.name) {
@@ -278,7 +281,7 @@ class KotlinJmespathExpressionVisitor(
278281
val valuesExpr = ensureNullGuard(left.shape, "values")
279282
val valuesName = addTempVar("${left.identifier}Values", "${left.identifier}$valuesExpr")
280283

281-
return flatMappingBlock(expression.right, valuesName, left.shape)
284+
return flatMappingBlock(expression.right, valuesName, left.shape, left.projected)
282285
}
283286

284287
override fun visitOr(expression: OrExpression): VisitedExpression {
@@ -294,26 +297,23 @@ class KotlinJmespathExpressionVisitor(
294297
}
295298

296299
override fun visitProjection(expression: ProjectionExpression): VisitedExpression {
297-
val left = acceptSubexpression(expression.left)
300+
val left = expression.left.accept(this)
298301
requireNotNull(left.shape) { "projection is operating on nothing?" }
299302

300-
return flatMappingBlock(expression.right, left.identifier, left.shape)
303+
return flatMappingBlock(expression.right, left.identifier, left.shape, left.projected)
301304
}
302305

303306
override fun visitSlice(expression: SliceExpression): VisitedExpression {
304307
throw CodegenException("SliceExpression is unsupported")
305308
}
306309

307310
override fun visitSubexpression(expression: Subexpression): VisitedExpression {
308-
val left = acceptSubexpression(expression.left)
309-
requireNotNull(left.shape)
311+
val left = expression.left.accept(this)
310312

311-
shapeCursor.addLast(left.shape)
312313
val ret = when (val right = expression.right) {
313314
is FieldExpression -> subfield(right, left.identifier)
314315
else -> throw CodegenException("Subexpression type $right is unsupported")
315316
}
316-
shapeCursor.removeLast()
317317

318318
return ret
319319
}
@@ -338,6 +338,13 @@ class KotlinJmespathExpressionVisitor(
338338
get() = this is MemberShape &&
339339
ctx.model.expectShape(target).let { !it.hasTrait<OperationInput>() && !it.hasTrait<OperationOutput>() } &&
340340
nullableIndex.isMemberNullable(this, NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1_NO_INPUT)
341+
342+
private val Shape.targetMemberOrSelf: Shape
343+
get() = when (val target = targetOrSelf(ctx.model)) {
344+
is ListShape -> target.member
345+
is MapShape -> target.value
346+
else -> this
347+
}
341348
}
342349

343350
/**
@@ -346,5 +353,9 @@ class KotlinJmespathExpressionVisitor(
346353
* @param shape The underlying shape (if any) that the identifier represents. Not all expressions reference a modeled
347354
* shape, e.g. [LiteralExpression] (the value is just a literal) or [FunctionExpression]s where the
348355
* returned value is scalar.
356+
* @param projected For projections, the context of the inner shape. For example, given the expression
357+
* `foo[].bar[].baz.qux`, the shape that backs the identifier (and therefore determines overall nullability)
358+
* is `foo`, but the shape that needs carried through to subfield expressions in the following projection
359+
* is the target of `bar`, such that its subfields `baz` and `qux` can be recognized.
349360
*/
350-
data class VisitedExpression(val identifier: String, val shape: Shape? = null)
361+
data class VisitedExpression(val identifier: String, val shape: Shape? = null, val projected: Shape? = null)

tests/codegen/waiter-tests/model/waiter-operations.smithy

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,20 @@ service WaitersTestService {
729729
}
730730
]
731731
},
732+
HasFilteredSubStruct: {
733+
acceptors: [
734+
{
735+
state: "success",
736+
matcher: {
737+
output: {
738+
path: "lists.structs[].subStructs[?subStructPrimitives.integer > `0`][].subStructPrimitives.string"
739+
expected: "foo",
740+
comparator: "anyStringEquals"
741+
}
742+
}
743+
}
744+
]
745+
},
732746
)
733747
@readonly
734748
@http(method: "GET", uri: "/entities/{name}", code: 200)

tests/codegen/waiter-tests/src/test/kotlin/com/test/WaiterTest.kt

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,4 +785,52 @@ class WaiterTest {
785785
}
786786
},
787787
)
788+
789+
@Test fun testHasFilteredSubStruct() = successTest(
790+
WaitersTestClient::waitUntilHasFilteredSubStruct,
791+
GetEntityResponse { },
792+
GetEntityResponse { lists = EntityLists { } },
793+
GetEntityResponse {
794+
lists = EntityLists { structs = listOf() }
795+
},
796+
GetEntityResponse {
797+
lists = EntityLists {
798+
structs = listOf(
799+
Struct { subStructs = listOf(SubStruct { }) },
800+
)
801+
}
802+
},
803+
GetEntityResponse {
804+
lists = EntityLists {
805+
structs = listOf(
806+
Struct { subStructs = listOf(SubStruct { subStructPrimitives = EntityPrimitives { string = "foo" } }) },
807+
)
808+
}
809+
},
810+
GetEntityResponse {
811+
lists = EntityLists {
812+
structs = listOf(
813+
Struct { subStructs = listOf(SubStruct { subStructPrimitives = EntityPrimitives { string = "foo"; integer = -1 } }) },
814+
Struct { subStructs = listOf(SubStruct { subStructPrimitives = EntityPrimitives { string = "bar"; integer = 2 } }) },
815+
)
816+
}
817+
},
818+
GetEntityResponse {
819+
lists = EntityLists {
820+
structs = listOf(
821+
Struct {
822+
subStructs = listOf(
823+
SubStruct { subStructPrimitives = EntityPrimitives { string = "foo"; integer = -1 } },
824+
SubStruct { subStructPrimitives = EntityPrimitives { string = "bar"; integer = 2 } },
825+
)
826+
},
827+
Struct {
828+
subStructs = listOf(
829+
SubStruct { subStructPrimitives = EntityPrimitives { string = "foo"; integer = 2 } },
830+
)
831+
},
832+
)
833+
}
834+
},
835+
)
788836
}

0 commit comments

Comments
 (0)