Skip to content

Commit 031b50c

Browse files
authored
feat: add functions (sort, sort_by, max_by, min_by, map) to JMESPath visitor (#969)
1 parent bc522c6 commit 031b50c

File tree

14 files changed

+509
-30
lines changed

14 files changed

+509
-30
lines changed

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/KotlinJmespathExpressionVisitor.kt

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ import software.amazon.smithy.model.shapes.ListShape
2626
import software.amazon.smithy.model.shapes.MapShape
2727
import software.amazon.smithy.model.shapes.MemberShape
2828
import software.amazon.smithy.model.shapes.Shape
29-
import kotlin.contracts.ExperimentalContracts
30-
import kotlin.contracts.contract
3129

3230
private val suffixSequence = sequenceOf("") + generateSequence(2) { it + 1 }.map(Int::toString) // "", "2", "3", etc.
3331

@@ -76,14 +74,6 @@ class KotlinJmespathExpressionVisitor(
7674
private fun bestTempVarName(preferredName: String): String =
7775
suffixSequence.map { "$preferredName$it" }.first(tempVars::add)
7876

79-
@OptIn(ExperimentalContracts::class)
80-
private fun codegenReq(condition: Boolean, lazyMessage: () -> String) {
81-
contract {
82-
returns() implies condition
83-
}
84-
if (!condition) throw CodegenException(lazyMessage())
85-
}
86-
8777
private fun flatMappingBlock(right: JmespathExpression, leftName: String, leftShape: Shape, innerShape: Shape?): VisitedExpression {
8878
if (right is CurrentExpression) return VisitedExpression(leftName, leftShape) // nothing to map
8979

@@ -105,7 +95,9 @@ class KotlinJmespathExpressionVisitor(
10595
return VisitedExpression(outerName, leftShape, innerResult.shape)
10696
}
10797

108-
private fun subfield(expression: FieldExpression, parentName: String, isObject: Boolean = false): VisitedExpression {
98+
private data class SubFieldData(val name: String, val codegen: String, val member: Shape?)
99+
100+
private fun subfieldLogic(expression: FieldExpression, parentName: String, isObject: Boolean = false): SubFieldData {
109101
val member = currentShape.targetOrSelf(ctx.model).getMember(expression.name).getOrNull()
110102

111103
val name = expression.name.toCamelCase()
@@ -130,9 +122,17 @@ class KotlinJmespathExpressionVisitor(
130122
}
131123

132124
member?.let { shapeCursor.addLast(it) }
133-
return VisitedExpression(addTempVar(name, codegen), member)
125+
return SubFieldData(name, codegen, member)
126+
}
127+
128+
private fun subfield(expression: FieldExpression, parentName: String, isObject: Boolean = false): VisitedExpression {
129+
val (name, codegen, member) = subfieldLogic(expression, parentName, isObject)
130+
return VisitedExpression(addTempVar(name, codegen), member, nullable = currentShape.isNullable)
134131
}
135132

133+
private fun subfieldCodegen(expression: FieldExpression, parentName: String, isObject: Boolean = false): String =
134+
subfieldLogic(expression, parentName, isObject).codegen
135+
136136
override fun visitAnd(expression: AndExpression): VisitedExpression {
137137
writer.addImport(RuntimeTypes.Core.Utils.truthiness)
138138

@@ -151,22 +151,21 @@ class KotlinJmespathExpressionVisitor(
151151

152152
val codegen = buildString {
153153
val nullables = buildList {
154-
if (left.shape?.isNullable == true) add("${left.identifier} == null")
155-
if (right.shape?.isNullable == true) add("${right.identifier} == null")
154+
if (left.shape?.isNullable == true || left.nullable) add("${left.identifier} == null")
155+
if (right.shape?.isNullable == true || right.nullable) add("${right.identifier} == null")
156156
}
157+
157158
if (nullables.isNotEmpty()) {
158159
val isNullExpr = nullables.joinToString(" || ")
159160
append("if ($isNullExpr) null else ")
160161
}
161162

162-
val unSafeComparatorExpr = "compareTo(${right.identifier}) ${expression.comparator} 0"
163-
val comparatorExpr = if (left.nullable) "?.$unSafeComparatorExpr" else ".$unSafeComparatorExpr"
164-
163+
val comparatorExpr = ".compareTo(${right.identifier}) ${expression.comparator} 0"
165164
append("${left.identifier}$comparatorExpr")
166165
}
167166

168-
val ident = addTempVar("comparison", codegen)
169-
return VisitedExpression(ident)
167+
val identifier = addTempVar("comparison", codegen)
168+
return VisitedExpression(identifier)
170169
}
171170

172171
override fun visitCurrentNode(expression: CurrentExpression): VisitedExpression {
@@ -207,14 +206,11 @@ class KotlinJmespathExpressionVisitor(
207206
return VisitedExpression(ident, currentShape, inner.projected)
208207
}
209208

210-
private fun FunctionExpression.singleArg(): VisitedExpression {
211-
codegenReq(arguments.size == 1) { "Unexpected number of arguments to $this" }
212-
return acceptSubexpression(this.arguments[0])
213-
}
214-
private fun FunctionExpression.twoArgs(): Pair<VisitedExpression, VisitedExpression> {
215-
codegenReq(arguments.size == 2) { "Unexpected number of arguments to $this" }
216-
return acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1])
217-
}
209+
private fun FunctionExpression.singleArg(): VisitedExpression =
210+
acceptSubexpression(this.arguments[0])
211+
212+
private fun FunctionExpression.twoArgs(): Pair<VisitedExpression, VisitedExpression> =
213+
acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1])
218214

219215
private fun FunctionExpression.args(): List<VisitedExpression> =
220216
this.arguments.map { acceptSubexpression(it) }
@@ -257,9 +253,7 @@ class KotlinJmespathExpressionVisitor(
257253

258254
"avg" -> {
259255
val numbers = expression.singleArg()
260-
val average = numbers.dotFunction(expression, "average()").identifier
261-
val isNaN = ensureNullGuard(numbers.shape, "isNaN() == true")
262-
VisitedExpression(addTempVar("averageOrNull", "if($average$isNaN) null else $average"), nullable = true)
256+
numbers.dotFunction(expression, "average()")
263257
}
264258

265259
"join" -> {
@@ -334,6 +328,35 @@ class KotlinJmespathExpressionVisitor(
334328
arg.dotFunction(expression, "type()", ensureNullGuard = false)
335329
}
336330

331+
"sort" -> {
332+
val arg = expression.singleArg()
333+
arg.dotFunction(expression, "sorted()")
334+
}
335+
336+
"sort_by" -> {
337+
val list = expression.arguments[0].accept(this)
338+
val expressionValue = expression.arguments[1]
339+
list.applyFunction(expression.name.toCamelCase(), "sortedBy", expressionValue)
340+
}
341+
342+
"max_by" -> {
343+
val list = expression.arguments[0].accept(this)
344+
val expressionValue = expression.arguments[1]
345+
list.applyFunction(expression.name.toCamelCase(), "maxBy", expressionValue)
346+
}
347+
348+
"min_by" -> {
349+
val list = expression.arguments[0].accept(this)
350+
val expressionValue = expression.arguments[1]
351+
list.applyFunction(expression.name.toCamelCase(), "minBy", expressionValue)
352+
}
353+
354+
"map" -> {
355+
val list = expression.arguments[1].accept(this)
356+
val expressionValue = expression.arguments[0]
357+
list.applyFunction(expression.name.toCamelCase(), "map", expressionValue)
358+
}
359+
337360
else -> throw CodegenException("Unknown function type in $expression")
338361
}
339362

@@ -548,6 +571,21 @@ class KotlinJmespathExpressionVisitor(
548571
return notNull
549572
}
550573

574+
private fun VisitedExpression.applyFunction(
575+
name: String,
576+
operation: String,
577+
expression: JmespathExpression,
578+
): VisitedExpression {
579+
val result = bestTempVarName(name)
580+
581+
writer.withBlock("val $result = ${this.identifier}?.$operation {", "}") {
582+
val expressionValue = subfieldCodegen((expression as ExpressionTypeExpression).expression as FieldExpression, "it")
583+
write("$expressionValue!!")
584+
}
585+
586+
return VisitedExpression(result)
587+
}
588+
551589
private val Shape.isNullable: Boolean
552590
get() = this is MemberShape &&
553591
ctx.model.expectShape(target).let { !it.hasTrait<OperationInput>() && !it.hasTrait<OperationOutput>() }
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@suppress(["WaitableTraitJmespathProblem"])
7+
@waitable(
8+
MapStructEquals: {
9+
acceptors: [
10+
{
11+
state: "success",
12+
matcher: {
13+
output: {
14+
path: "map(&string, lists.structs)",
15+
expected: "foo",
16+
comparator: "allStringEquals"
17+
}
18+
}
19+
}
20+
]
21+
},
22+
)
23+
@readonly
24+
@http(method: "GET", uri: "/map/{name}", code: 200)
25+
operation GetFunctionMapEquals {
26+
input: GetEntityRequest,
27+
output: GetEntityResponse,
28+
errors: [NotFound],
29+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@suppress(["WaitableTraitJmespathProblem"])
7+
@waitable(
8+
MaxByNumberEquals: {
9+
acceptors: [
10+
{
11+
state: "success",
12+
matcher: {
13+
output: {
14+
path: "max_by(lists.structs, &integer).integer == `100`",
15+
expected: "true",
16+
comparator: "booleanEquals"
17+
}
18+
}
19+
}
20+
]
21+
},
22+
MaxByStringEquals: {
23+
acceptors: [
24+
{
25+
state: "success",
26+
matcher: {
27+
output: {
28+
path: "max_by(lists.structs, &string).string",
29+
expected: "foo",
30+
comparator: "stringEquals"
31+
}
32+
}
33+
}
34+
]
35+
},
36+
)
37+
@readonly
38+
@http(method: "GET", uri: "/maxBy/{name}", code: 200)
39+
operation GetFunctionMaxByEquals {
40+
input: GetEntityRequest,
41+
output: GetEntityResponse,
42+
errors: [NotFound],
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@suppress(["WaitableTraitJmespathProblem"])
7+
@waitable(
8+
MinByNumberEquals: {
9+
acceptors: [
10+
{
11+
state: "success",
12+
matcher: {
13+
output: {
14+
path: "min_by(lists.structs, &integer).integer == `100`",
15+
expected: "true",
16+
comparator: "booleanEquals"
17+
}
18+
}
19+
}
20+
]
21+
},
22+
MinByStringEquals: {
23+
acceptors: [
24+
{
25+
state: "success",
26+
matcher: {
27+
output: {
28+
path: "min_by(lists.structs, &string).string",
29+
expected: "foo",
30+
comparator: "stringEquals"
31+
}
32+
}
33+
}
34+
]
35+
},
36+
)
37+
@readonly
38+
@http(method: "GET", uri: "/minBy/{name}", code: 200)
39+
operation GetFunctionMinByEquals {
40+
input: GetEntityRequest,
41+
output: GetEntityResponse,
42+
errors: [NotFound],
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@suppress(["WaitableTraitJmespathProblem"])
7+
@waitable(
8+
SortByNumberEquals: {
9+
acceptors: [
10+
{
11+
state: "success",
12+
matcher: {
13+
output: {
14+
path: "sort_by(lists.structs, &integer)[0].integer == `1`",
15+
expected: "true",
16+
comparator: "booleanEquals"
17+
}
18+
}
19+
}
20+
]
21+
},
22+
SortByStringEquals: {
23+
acceptors: [
24+
{
25+
state: "success",
26+
matcher: {
27+
output: {
28+
path: "sort_by(lists.structs, &string)[2].string",
29+
expected: "foo",
30+
comparator: "stringEquals"
31+
}
32+
}
33+
}
34+
]
35+
},
36+
)
37+
@readonly
38+
@http(method: "GET", uri: "/sortBy/{name}", code: 200)
39+
operation GetFunctionSortByEquals {
40+
input: GetEntityRequest,
41+
output: GetEntityResponse,
42+
errors: [NotFound],
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@suppress(["WaitableTraitJmespathProblem"])
7+
@waitable(
8+
SortNumberEquals: {
9+
acceptors: [
10+
{
11+
state: "success",
12+
matcher: {
13+
output: {
14+
path: "sort(lists.integers)[2] == `2`",
15+
expected: "true",
16+
comparator: "booleanEquals"
17+
}
18+
}
19+
}
20+
]
21+
},
22+
SortStringEquals: {
23+
acceptors: [
24+
{
25+
state: "success",
26+
matcher: {
27+
output: {
28+
path: "sort(lists.strings)[2]",
29+
expected: "foo",
30+
comparator: "stringEquals"
31+
}
32+
}
33+
}
34+
]
35+
},
36+
)
37+
@readonly
38+
@http(method: "GET", uri: "/sort/{name}", code: 200)
39+
operation GetFunctionSortEquals {
40+
input: GetEntityRequest,
41+
output: GetEntityResponse,
42+
errors: [NotFound],
43+
}

tests/codegen/waiter-tests/model/utils/structures.smithy

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ map StructMap {
155155
structure Struct {
156156
primitives: EntityPrimitives,
157157
strings: StringList,
158+
integer: Integer,
159+
string: String,
158160
enums: EnumList,
159161
subStructs: SubStructList,
160162
}

tests/codegen/waiter-tests/model/utils/waiter-operations.smithy

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,10 @@ service WaitersTestService {
3131
GetFunctionToStringEquals,
3232
GetFunctionToNumberEquals,
3333
GetFunctionTypeEquals,
34+
GetFunctionSortByEquals,
35+
GetFunctionSortEquals,
36+
GetFunctionMaxByEquals,
37+
GetFunctionMinByEquals,
38+
GetFunctionMapEquals,
3439
]
3540
}

0 commit comments

Comments
 (0)