@@ -26,8 +26,6 @@ import software.amazon.smithy.model.shapes.ListShape
2626import software.amazon.smithy.model.shapes.MapShape
2727import software.amazon.smithy.model.shapes.MemberShape
2828import software.amazon.smithy.model.shapes.Shape
29- import kotlin.contracts.ExperimentalContracts
30- import kotlin.contracts.contract
3129
3230private val suffixSequence = sequenceOf(" " ) + generateSequence(2 ) { it + 1 }.map(Int ::toString) // "", "2", "3", etc.
3331
@@ -76,14 +74,6 @@ class KotlinJmespathExpressionVisitor(
7674 private fun bestTempVarName (preferredName : String ): String =
7775 suffixSequence.map { " $preferredName$it " }.first(tempVars::add)
7876
79- @OptIn(ExperimentalContracts ::class )
80- private fun codegenReq (condition : Boolean , lazyMessage : () -> String ) {
81- contract {
82- returns() implies condition
83- }
84- if (! condition) throw CodegenException (lazyMessage())
85- }
86-
8777 private fun flatMappingBlock (right : JmespathExpression , leftName : String , leftShape : Shape , innerShape : Shape ? ): VisitedExpression {
8878 if (right is CurrentExpression ) return VisitedExpression (leftName, leftShape) // nothing to map
8979
@@ -105,7 +95,9 @@ class KotlinJmespathExpressionVisitor(
10595 return VisitedExpression (outerName, leftShape, innerResult.shape)
10696 }
10797
108- private fun subfield (expression : FieldExpression , parentName : String , isObject : Boolean = false): VisitedExpression {
98+ private data class SubFieldData (val name : String , val codegen : String , val member : Shape ? )
99+
100+ private fun subfieldLogic (expression : FieldExpression , parentName : String , isObject : Boolean = false): SubFieldData {
109101 val member = currentShape.targetOrSelf(ctx.model).getMember(expression.name).getOrNull()
110102
111103 val name = expression.name.toCamelCase()
@@ -130,9 +122,17 @@ class KotlinJmespathExpressionVisitor(
130122 }
131123
132124 member?.let { shapeCursor.addLast(it) }
133- return VisitedExpression (addTempVar(name, codegen), member)
125+ return SubFieldData (name, codegen, member)
126+ }
127+
128+ private fun subfield (expression : FieldExpression , parentName : String , isObject : Boolean = false): VisitedExpression {
129+ val (name, codegen, member) = subfieldLogic(expression, parentName, isObject)
130+ return VisitedExpression (addTempVar(name, codegen), member, nullable = currentShape.isNullable)
134131 }
135132
133+ private fun subfieldCodegen (expression : FieldExpression , parentName : String , isObject : Boolean = false): String =
134+ subfieldLogic(expression, parentName, isObject).codegen
135+
136136 override fun visitAnd (expression : AndExpression ): VisitedExpression {
137137 writer.addImport(RuntimeTypes .Core .Utils .truthiness)
138138
@@ -151,22 +151,21 @@ class KotlinJmespathExpressionVisitor(
151151
152152 val codegen = buildString {
153153 val nullables = buildList {
154- if (left.shape?.isNullable == true ) add(" ${left.identifier} == null" )
155- if (right.shape?.isNullable == true ) add(" ${right.identifier} == null" )
154+ if (left.shape?.isNullable == true || left.nullable ) add(" ${left.identifier} == null" )
155+ if (right.shape?.isNullable == true || right.nullable ) add(" ${right.identifier} == null" )
156156 }
157+
157158 if (nullables.isNotEmpty()) {
158159 val isNullExpr = nullables.joinToString(" || " )
159160 append(" if ($isNullExpr ) null else " )
160161 }
161162
162- val unSafeComparatorExpr = " compareTo(${right.identifier} ) ${expression.comparator} 0"
163- val comparatorExpr = if (left.nullable) " ?.$unSafeComparatorExpr " else " .$unSafeComparatorExpr "
164-
163+ val comparatorExpr = " .compareTo(${right.identifier} ) ${expression.comparator} 0"
165164 append(" ${left.identifier}$comparatorExpr " )
166165 }
167166
168- val ident = addTempVar(" comparison" , codegen)
169- return VisitedExpression (ident )
167+ val identifier = addTempVar(" comparison" , codegen)
168+ return VisitedExpression (identifier )
170169 }
171170
172171 override fun visitCurrentNode (expression : CurrentExpression ): VisitedExpression {
@@ -207,14 +206,11 @@ class KotlinJmespathExpressionVisitor(
207206 return VisitedExpression (ident, currentShape, inner.projected)
208207 }
209208
210- private fun FunctionExpression.singleArg (): VisitedExpression {
211- codegenReq(arguments.size == 1 ) { " Unexpected number of arguments to $this " }
212- return acceptSubexpression(this .arguments[0 ])
213- }
214- private fun FunctionExpression.twoArgs (): Pair <VisitedExpression , VisitedExpression > {
215- codegenReq(arguments.size == 2 ) { " Unexpected number of arguments to $this " }
216- return acceptSubexpression(this .arguments[0 ]) to acceptSubexpression(this .arguments[1 ])
217- }
209+ private fun FunctionExpression.singleArg (): VisitedExpression =
210+ acceptSubexpression(this .arguments[0 ])
211+
212+ private fun FunctionExpression.twoArgs (): Pair <VisitedExpression , VisitedExpression > =
213+ acceptSubexpression(this .arguments[0 ]) to acceptSubexpression(this .arguments[1 ])
218214
219215 private fun FunctionExpression.args (): List <VisitedExpression > =
220216 this .arguments.map { acceptSubexpression(it) }
@@ -257,9 +253,7 @@ class KotlinJmespathExpressionVisitor(
257253
258254 " avg" -> {
259255 val numbers = expression.singleArg()
260- val average = numbers.dotFunction(expression, " average()" ).identifier
261- val isNaN = ensureNullGuard(numbers.shape, " isNaN() == true" )
262- VisitedExpression (addTempVar(" averageOrNull" , " if($average$isNaN ) null else $average " ), nullable = true )
256+ numbers.dotFunction(expression, " average()" )
263257 }
264258
265259 " join" -> {
@@ -334,6 +328,35 @@ class KotlinJmespathExpressionVisitor(
334328 arg.dotFunction(expression, " type()" , ensureNullGuard = false )
335329 }
336330
331+ " sort" -> {
332+ val arg = expression.singleArg()
333+ arg.dotFunction(expression, " sorted()" )
334+ }
335+
336+ " sort_by" -> {
337+ val list = expression.arguments[0 ].accept(this )
338+ val expressionValue = expression.arguments[1 ]
339+ list.applyFunction(expression.name.toCamelCase(), " sortedBy" , expressionValue)
340+ }
341+
342+ " max_by" -> {
343+ val list = expression.arguments[0 ].accept(this )
344+ val expressionValue = expression.arguments[1 ]
345+ list.applyFunction(expression.name.toCamelCase(), " maxBy" , expressionValue)
346+ }
347+
348+ " min_by" -> {
349+ val list = expression.arguments[0 ].accept(this )
350+ val expressionValue = expression.arguments[1 ]
351+ list.applyFunction(expression.name.toCamelCase(), " minBy" , expressionValue)
352+ }
353+
354+ " map" -> {
355+ val list = expression.arguments[1 ].accept(this )
356+ val expressionValue = expression.arguments[0 ]
357+ list.applyFunction(expression.name.toCamelCase(), " map" , expressionValue)
358+ }
359+
337360 else -> throw CodegenException (" Unknown function type in $expression " )
338361 }
339362
@@ -548,6 +571,21 @@ class KotlinJmespathExpressionVisitor(
548571 return notNull
549572 }
550573
574+ private fun VisitedExpression.applyFunction (
575+ name : String ,
576+ operation : String ,
577+ expression : JmespathExpression ,
578+ ): VisitedExpression {
579+ val result = bestTempVarName(name)
580+
581+ writer.withBlock(" val $result = ${this .identifier} ?.$operation {" , " }" ) {
582+ val expressionValue = subfieldCodegen((expression as ExpressionTypeExpression ).expression as FieldExpression , " it" )
583+ write(" $expressionValue !!" )
584+ }
585+
586+ return VisitedExpression (result)
587+ }
588+
551589 private val Shape .isNullable: Boolean
552590 get() = this is MemberShape &&
553591 ctx.model.expectShape(target).let { ! it.hasTrait<OperationInput >() && ! it.hasTrait<OperationOutput >() }
0 commit comments