Skip to content

Commit cacd325

Browse files
authored
feat: add functions (sum, avg, join, starts_with, ends_with) to JMESPath visitor (#940)
1 parent 6e221cc commit cacd325

File tree

14 files changed

+778
-215
lines changed

14 files changed

+778
-215
lines changed

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/KotlinJmespathExpressionVisitor.kt

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,10 @@ class KotlinJmespathExpressionVisitor(
158158
append("if ($isNullExpr) null else ")
159159
}
160160

161-
append("${left.identifier}.compareTo(${right.identifier}) ${expression.comparator} 0")
161+
val unSafeComparatorExpr = "compareTo(${right.identifier}) ${expression.comparator} 0"
162+
val comparatorExpr = if (left.nullable) "?.$unSafeComparatorExpr" else ".$unSafeComparatorExpr"
163+
164+
append("${left.identifier}$comparatorExpr")
162165
}
163166

164167
val ident = addTempVar("comparison", codegen)
@@ -213,33 +216,55 @@ class KotlinJmespathExpressionVisitor(
213216
return acceptSubexpression(this.arguments[0]) to acceptSubexpression(this.arguments[1])
214217
}
215218

219+
private fun VisitedExpression.dotFunction(expression: FunctionExpression, expr: String, elvisExpr: String? = null): VisitedExpression {
220+
val dotFunctionExpr = ensureNullGuard(shape, expr, elvisExpr)
221+
val ident = addTempVar(expression.name, "$identifier$dotFunctionExpr")
222+
223+
return VisitedExpression(ident, shape)
224+
}
225+
216226
override fun visitFunction(expression: FunctionExpression): VisitedExpression = when (expression.name) {
217227
"contains" -> {
218228
val (subject, search) = expression.twoArgs()
219-
220-
val containsExpr = ensureNullGuard(subject.shape, "contains(${search.identifier})", "false")
221-
val ident = addTempVar("contains", "${subject.identifier}$containsExpr")
222-
223-
VisitedExpression(ident)
229+
subject.dotFunction(expression, "contains(${search.identifier})", "false")
224230
}
225231

226232
"length" -> {
227233
writer.addImport(RuntimeTypes.Core.Utils.length)
228234
val subject = expression.singleArg()
229-
230-
val lengthExpr = ensureNullGuard(subject.shape, "length", "0")
231-
val ident = addTempVar("length", "${subject.identifier}$lengthExpr")
232-
233-
VisitedExpression(ident)
235+
subject.dotFunction(expression, "length", "0")
234236
}
235237

236238
"abs", "floor", "ceil" -> {
237239
val number = expression.singleArg()
240+
number.dotFunction(expression, "let { kotlin.math.${expression.name}(it.toDouble()) }")
241+
}
242+
243+
"sum" -> {
244+
val numbers = expression.singleArg()
245+
numbers.dotFunction(expression, "sum()")
246+
}
247+
248+
"avg" -> {
249+
val numbers = expression.singleArg()
250+
val average = numbers.dotFunction(expression, "average()").identifier
251+
val isNaN = ensureNullGuard(numbers.shape, "isNaN() == true")
252+
VisitedExpression(addTempVar("averageOrNull", "if($average$isNaN) null else $average"), nullable = true)
253+
}
254+
255+
"join" -> {
256+
val (glue, list) = expression.twoArgs()
257+
list.dotFunction(expression, "joinToString(${glue.identifier})")
258+
}
238259

239-
val functionExpr = ensureNullGuard(number.shape, "let { kotlin.math.${expression.name}(it.toDouble()) }")
240-
val ident = addTempVar(expression.name, "${number.identifier}$functionExpr")
260+
"starts_with" -> {
261+
val (subject, prefix) = expression.twoArgs()
262+
subject.dotFunction(expression, "startsWith(${prefix.identifier})")
263+
}
241264

242-
VisitedExpression(ident, number.shape)
265+
"ends_with" -> {
266+
val (subject, suffix) = expression.twoArgs()
267+
subject.dotFunction(expression, "endsWith(${suffix.identifier})")
243268
}
244269

245270
else -> throw CodegenException("Unknown function type in $expression")
@@ -437,4 +462,4 @@ class KotlinJmespathExpressionVisitor(
437462
* is `foo`, but the shape that needs carried through to subfield expressions in the following projection
438463
* is the target of `bar`, such that its subfields `baz` and `qux` can be recognized.
439464
*/
440-
data class VisitedExpression(val identifier: String, val shape: Shape? = null, val projected: Shape? = null)
465+
data class VisitedExpression(val identifier: String, val shape: Shape? = null, val projected: Shape? = null, val nullable: Boolean = false)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@waitable(
7+
EmptyIntegerListAvgNotEquals: {
8+
acceptors: [
9+
{
10+
state: "success",
11+
matcher: {
12+
output: {
13+
path: "avg(lists.integers) == `10`",
14+
expected: "false",
15+
comparator: "booleanEquals"
16+
}
17+
}
18+
}
19+
]
20+
},
21+
ShortListAvgNotEquals: {
22+
acceptors: [
23+
{
24+
state: "success",
25+
matcher: {
26+
output: {
27+
path: "avg(lists.shorts) == `10`",
28+
expected: "false",
29+
comparator: "booleanEquals"
30+
}
31+
}
32+
}
33+
]
34+
},
35+
IntegerListAvgNotEquals: {
36+
acceptors: [
37+
{
38+
state: "success",
39+
matcher: {
40+
output: {
41+
path: "avg(lists.integers) == `10`",
42+
expected: "false",
43+
comparator: "booleanEquals"
44+
}
45+
}
46+
}
47+
]
48+
},
49+
LongListAvgNotEquals: {
50+
acceptors: [
51+
{
52+
state: "success",
53+
matcher: {
54+
output: {
55+
path: "avg(lists.longs) == `10`",
56+
expected: "false",
57+
comparator: "booleanEquals"
58+
}
59+
}
60+
}
61+
]
62+
},
63+
FloatListAvgNotEquals: {
64+
acceptors: [
65+
{
66+
state: "success",
67+
matcher: {
68+
output: {
69+
path: "avg(lists.floats) == `10`",
70+
expected: "false",
71+
comparator: "booleanEquals"
72+
}
73+
}
74+
}
75+
]
76+
},
77+
DoubleListAvgNotEquals: {
78+
acceptors: [
79+
{
80+
state: "success",
81+
matcher: {
82+
output: {
83+
path: "avg(lists.doubles) == `10`",
84+
expected: "false",
85+
comparator: "booleanEquals"
86+
}
87+
}
88+
}
89+
]
90+
},
91+
)
92+
@readonly
93+
@http(method: "GET", uri: "/avg/{name}", code: 200)
94+
operation GetFunctionAvgEquals {
95+
input: GetEntityRequest,
96+
output: GetEntityResponse,
97+
errors: [NotFound],
98+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@suppress(["WaitableTraitJmespathProblem"])
7+
@waitable(
8+
StringEndsWithEquals: {
9+
acceptors: [
10+
{
11+
state: "success",
12+
matcher: {
13+
output: {
14+
path: "ends_with(primitives.string, 'baz')",
15+
expected: "true",
16+
comparator: "booleanEquals"
17+
}
18+
}
19+
}
20+
]
21+
},
22+
)
23+
@readonly
24+
@http(method: "GET", uri: "/ends/{name}", code: 200)
25+
operation GetFunctionEndsWithEquals {
26+
input: GetEntityRequest,
27+
output: GetEntityResponse,
28+
errors: [NotFound],
29+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@waitable(
7+
StringListJoinEquals: {
8+
acceptors: [
9+
{
10+
state: "success",
11+
matcher: {
12+
output: {
13+
path: "join('', lists.strings)",
14+
expected: "foo",
15+
comparator: "stringEquals"
16+
}
17+
}
18+
}
19+
]
20+
},
21+
StringListSeparatorJoinEquals: {
22+
acceptors: [
23+
{
24+
state: "success",
25+
matcher: {
26+
output: {
27+
path: "join(', ', lists.strings)",
28+
expected: "foo, bar",
29+
comparator: "stringEquals"
30+
}
31+
}
32+
}
33+
]
34+
},
35+
)
36+
@readonly
37+
@http(method: "GET", uri: "/join/{name}", code: 200)
38+
operation GetFunctionJoinEquals {
39+
input: GetEntityRequest,
40+
output: GetEntityResponse,
41+
errors: [NotFound],
42+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@waitable(
7+
StringStartsWithEquals: {
8+
acceptors: [
9+
{
10+
state: "success",
11+
matcher: {
12+
output: {
13+
path: "starts_with(primitives.string, 'foo')",
14+
expected: "true",
15+
comparator: "booleanEquals"
16+
}
17+
}
18+
}
19+
]
20+
},
21+
)
22+
@readonly
23+
@http(method: "GET", uri: "/starts/{name}", code: 200)
24+
operation GetFunctionStartsWithEquals {
25+
input: GetEntityRequest,
26+
output: GetEntityResponse,
27+
errors: [NotFound],
28+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
$version: "2"
2+
namespace com.test
3+
4+
use smithy.waiters#waitable
5+
6+
@waitable(
7+
ShortListSumNotEquals: {
8+
acceptors: [
9+
{
10+
state: "success",
11+
matcher: {
12+
output: {
13+
path: "sum(lists.shorts) == `10`",
14+
expected: "false",
15+
comparator: "booleanEquals"
16+
}
17+
}
18+
}
19+
]
20+
},
21+
IntegerListSumNotEquals: {
22+
acceptors: [
23+
{
24+
state: "success",
25+
matcher: {
26+
output: {
27+
path: "sum(lists.integers) == `10`",
28+
expected: "false",
29+
comparator: "booleanEquals"
30+
}
31+
}
32+
}
33+
]
34+
},
35+
LongListSumNotEquals: {
36+
acceptors: [
37+
{
38+
state: "success",
39+
matcher: {
40+
output: {
41+
path: "sum(lists.longs) == `10`",
42+
expected: "false",
43+
comparator: "booleanEquals"
44+
}
45+
}
46+
}
47+
]
48+
},
49+
FloatListSumNotEquals: {
50+
acceptors: [
51+
{
52+
state: "success",
53+
matcher: {
54+
output: {
55+
path: "sum(lists.floats) == `10`",
56+
expected: "false",
57+
comparator: "booleanEquals"
58+
}
59+
}
60+
}
61+
]
62+
},
63+
DoubleListSumNotEquals: {
64+
acceptors: [
65+
{
66+
state: "success",
67+
matcher: {
68+
output: {
69+
path: "sum(lists.doubles) == `10`",
70+
expected: "false",
71+
comparator: "booleanEquals"
72+
}
73+
}
74+
}
75+
]
76+
},
77+
)
78+
@readonly
79+
@http(method: "GET", uri: "/sum/{name}", code: 200)
80+
operation GetFunctionSumEquals {
81+
input: GetEntityRequest,
82+
output: GetEntityResponse,
83+
errors: [NotFound],
84+
}

0 commit comments

Comments
 (0)