@@ -57,10 +57,14 @@ class KotlinJmespathExpressionVisitor(
5757 private val currentShape: Shape
5858 get() = shapeCursor.last()
5959
60+ // traverses an independent expression (one whose resolved scope does not persist in the outer evaluation)
6061 private fun acceptSubexpression (expr : JmespathExpression ): VisitedExpression {
61- shapeCursor.addLast(currentShape)
62+ val pos = shapeCursor.size
6263 val out = expr.accept(this )
63- shapeCursor.removeLast()
64+
65+ val diff = shapeCursor.size - pos
66+ repeat(diff) { shapeCursor.removeLast() } // reset the shape cursor
67+
6468 return out
6569 }
6670
@@ -73,15 +77,6 @@ class KotlinJmespathExpressionVisitor(
7377 private fun bestTempVarName (preferredName : String ): String =
7478 suffixSequence.map { " $preferredName$it " }.first(tempVars::add)
7579
76- private fun childBlock (forExpression : JmespathExpression , shape : Shape ): VisitedExpression {
77- val childShape = when (val target = shape.targetOrSelf(ctx.model)) {
78- is ListShape -> target.member
79- is MapShape -> target.value
80- else -> shape
81- }
82- return forExpression.accept(KotlinJmespathExpressionVisitor (ctx, writer, childShape))
83- }
84-
8580 @OptIn(ExperimentalContracts ::class )
8681 private fun codegenReq (condition : Boolean , lazyMessage : () -> String ) {
8782 contract {
@@ -90,21 +85,25 @@ class KotlinJmespathExpressionVisitor(
9085 if (! condition) throw CodegenException (lazyMessage())
9186 }
9287
93- private fun flatMappingBlock (right : JmespathExpression , leftName : String , leftShape : Shape ): VisitedExpression {
88+ private fun flatMappingBlock (right : JmespathExpression , leftName : String , leftShape : Shape , innerShape : Shape ? ): VisitedExpression {
9489 if (right is CurrentExpression ) return VisitedExpression (leftName, leftShape) // nothing to map
9590
9691 val outerName = bestTempVarName(" projection" )
97- writer.openBlock(" val #L = #L.flatMap {" , outerName, leftName)
92+ val flatMapExpr = ensureNullGuard(leftShape, " flatMap" )
93+ writer.openBlock(" val #L = #L#L {" , outerName, leftName, flatMapExpr)
94+
95+ shapeCursor.addLast(innerShape?.targetMemberOrSelf ? : leftShape.targetMemberOrSelf)
96+ val innerResult = acceptSubexpression(right)
97+ shapeCursor.removeLast()
9898
99- val innerResult = childBlock(right, leftShape)
10099 val innerCollector = when (right) {
101100 is MultiSelectListExpression -> innerResult.identifier // Already a list
102101 else -> " listOfNotNull(${innerResult.identifier} )"
103102 }
104103 writer.write(innerCollector)
105104
106105 writer.closeBlock(" }" )
107- return VisitedExpression (outerName, innerResult.shape)
106+ return VisitedExpression (outerName, leftShape, innerResult.shape)
108107 }
109108
110109 private fun subfield (expression : FieldExpression , parentName : String ): VisitedExpression {
@@ -151,7 +150,7 @@ class KotlinJmespathExpressionVisitor(
151150 val codegen = buildString {
152151 val nullables = buildList {
153152 if (left.shape?.isNullable == true ) add(" ${left.identifier} == null" )
154- if (right.shape?.isNullable == true ) add(" ${left .identifier} == null" )
153+ if (right.shape?.isNullable == true ) add(" ${right .identifier} == null" )
155154 }
156155 if (nullables.isNotEmpty()) {
157156 val isNullExpr = nullables.joinToString(" || " )
@@ -176,27 +175,31 @@ class KotlinJmespathExpressionVisitor(
176175 override fun visitField (expression : FieldExpression ): VisitedExpression = subfield(expression, " it" )
177176
178177 override fun visitFilterProjection (expression : FilterProjectionExpression ): VisitedExpression {
179- val left = acceptSubexpression( expression.left)
178+ val left = expression.left.accept( this )
180179 requireNotNull(left.shape) { " filter projection is operating on nothing?" }
181180
182181 val filteredName = bestTempVarName(" ${left.identifier} Filtered" )
183182
184183 val filterExpr = ensureNullGuard(left.shape, " filter" )
185184 writer.withBlock(" val #L = #L#L {" , " }" , filteredName, left.identifier, filterExpr) {
186- val comparison = childBlock(expression.comparison, left.shape)
185+ shapeCursor.addLast(left.shape.targetMemberOrSelf)
186+ val comparison = acceptSubexpression(expression.comparison)
187+ shapeCursor.removeLast()
187188 write(" #L == true" , comparison.identifier)
188189 }
189190
190- return flatMappingBlock(expression.right, filteredName, left.shape)
191+ return flatMappingBlock(expression.right, filteredName, left.shape, left.projected )
191192 }
192193
193194 override fun visitFlatten (expression : FlattenExpression ): VisitedExpression {
194195 writer.addImport(RuntimeTypes .Core .Utils .flattenIfPossible)
195196
196- val inner = acceptSubexpression(expression.expression)
197- val flattenExpr = ensureNullGuard(inner.shape, " flattenIfPossible()" , " listOf()" )
197+ val inner = expression.expression.accept(this )
198+
199+ val flattenExpr = ensureNullGuard(currentShape, " flattenIfPossible()" )
198200 val ident = addTempVar(" ${inner.identifier} OrEmpty" , " ${inner.identifier}$flattenExpr " )
199- return VisitedExpression (ident, inner.shape)
201+
202+ return VisitedExpression (ident, currentShape, inner.projected)
200203 }
201204
202205 override fun visitFunction (expression : FunctionExpression ): VisitedExpression = when (expression.name) {
@@ -278,7 +281,7 @@ class KotlinJmespathExpressionVisitor(
278281 val valuesExpr = ensureNullGuard(left.shape, " values" )
279282 val valuesName = addTempVar(" ${left.identifier} Values" , " ${left.identifier}$valuesExpr " )
280283
281- return flatMappingBlock(expression.right, valuesName, left.shape)
284+ return flatMappingBlock(expression.right, valuesName, left.shape, left.projected )
282285 }
283286
284287 override fun visitOr (expression : OrExpression ): VisitedExpression {
@@ -294,26 +297,23 @@ class KotlinJmespathExpressionVisitor(
294297 }
295298
296299 override fun visitProjection (expression : ProjectionExpression ): VisitedExpression {
297- val left = acceptSubexpression( expression.left)
300+ val left = expression.left.accept( this )
298301 requireNotNull(left.shape) { " projection is operating on nothing?" }
299302
300- return flatMappingBlock(expression.right, left.identifier, left.shape)
303+ return flatMappingBlock(expression.right, left.identifier, left.shape, left.projected )
301304 }
302305
303306 override fun visitSlice (expression : SliceExpression ): VisitedExpression {
304307 throw CodegenException (" SliceExpression is unsupported" )
305308 }
306309
307310 override fun visitSubexpression (expression : Subexpression ): VisitedExpression {
308- val left = acceptSubexpression(expression.left)
309- requireNotNull(left.shape)
311+ val left = expression.left.accept(this )
310312
311- shapeCursor.addLast(left.shape)
312313 val ret = when (val right = expression.right) {
313314 is FieldExpression -> subfield(right, left.identifier)
314315 else -> throw CodegenException (" Subexpression type $right is unsupported" )
315316 }
316- shapeCursor.removeLast()
317317
318318 return ret
319319 }
@@ -338,6 +338,13 @@ class KotlinJmespathExpressionVisitor(
338338 get() = this is MemberShape &&
339339 ctx.model.expectShape(target).let { ! it.hasTrait<OperationInput >() && ! it.hasTrait<OperationOutput >() } &&
340340 nullableIndex.isMemberNullable(this , NullableIndex .CheckMode .CLIENT_ZERO_VALUE_V1_NO_INPUT )
341+
342+ private val Shape .targetMemberOrSelf: Shape
343+ get() = when (val target = targetOrSelf(ctx.model)) {
344+ is ListShape -> target.member
345+ is MapShape -> target.value
346+ else -> this
347+ }
341348}
342349
343350/* *
@@ -346,5 +353,9 @@ class KotlinJmespathExpressionVisitor(
346353 * @param shape The underlying shape (if any) that the identifier represents. Not all expressions reference a modeled
347354 * shape, e.g. [LiteralExpression] (the value is just a literal) or [FunctionExpression]s where the
348355 * returned value is scalar.
356+ * @param projected For projections, the context of the inner shape. For example, given the expression
357+ * `foo[].bar[].baz.qux`, the shape that backs the identifier (and therefore determines overall nullability)
358+ * is `foo`, but the shape that needs carried through to subfield expressions in the following projection
359+ * is the target of `bar`, such that its subfields `baz` and `qux` can be recognized.
349360 */
350- data class VisitedExpression (val identifier : String , val shape : Shape ? = null )
361+ data class VisitedExpression (val identifier : String , val shape : Shape ? = null , val projected : Shape ? = null )
0 commit comments