Skip to content

Commit 48fc0fb

Browse files
dusantism-dbcloud-fan
authored andcommitted
[SPARK-49912] Refactor simple CASE statement to evaluate the case variable only once
### What changes were proposed in this pull request? In this PR, CASE statement is refactored. Existing `CaseStatement` is split into two - `SimpleCaseStatement` and `SearchedCaseStatement`. `SearchedCaseStatement` retains the old behavior, while for `SimpleCaseStatement` a new logical and execution node are added - `SimpleCaseStatement` and `SimpleCaseStatementExec`. Previously, a simple case statement would evaluate the case variable again for every WHEN clause in the CASE. This is both inefficient, and could produce unexpected behavior if the evaluation has a side effect. `SimpleCaseStatementExec` now caches the result of the evaluation, and compares the WHEN conditions to the cached `Literal`. ### Why are the changes needed? Previous iteration of simple CASE evaluated the expression multiple times. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tests were added to `SqlScriptingParserSuite`, `SqlScriptingInterpreterSuite` and `SqlScriptingExecutionNodeSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50027 from dusantism-db/scripting-case-improvements-v2. Authored-by: Dušan Tišma <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 7e5bf72 commit 48fc0fb

File tree

7 files changed

+526
-106
lines changed

7 files changed

+526
-106
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ class AstBuilder extends DataTypeAstBuilder
445445

446446
private def visitSearchedCaseStatementImpl(
447447
ctx: SearchedCaseStatementContext,
448-
labelCtx: SqlScriptingLabelContext): CaseStatement = {
448+
labelCtx: SqlScriptingLabelContext): SearchedCaseStatement = {
449449
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
450450
SingleStatement(
451451
Project(
@@ -464,7 +464,7 @@ class AstBuilder extends DataTypeAstBuilder
464464
s" ${conditionalBodies.length} in case statement")
465465
}
466466

467-
CaseStatement(
467+
SearchedCaseStatement(
468468
conditions = conditions,
469469
conditionalBodies = conditionalBodies,
470470
elseBody = Option(ctx.elseBody).map(
@@ -475,30 +475,31 @@ class AstBuilder extends DataTypeAstBuilder
475475

476476
private def visitSimpleCaseStatementImpl(
477477
ctx: SimpleCaseStatementContext,
478-
labelCtx: SqlScriptingLabelContext): CaseStatement = {
479-
// uses EqualTo to compare the case variable(the main case expression)
480-
// to the WHEN clause expressions
481-
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
482-
SingleStatement(
483-
Project(
484-
Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()),
485-
OneRowRelation()))
486-
})
478+
labelCtx: SqlScriptingLabelContext): SimpleCaseStatement = {
479+
val caseVariableExpr = withOrigin(ctx.caseVariable) {
480+
expression(ctx.caseVariable)
481+
}
482+
val conditionExpressions =
483+
ctx.conditionExpressions.asScala.toList
484+
.map(exprCtx => withOrigin(exprCtx) {
485+
expression(exprCtx)
486+
})
487487
val conditionalBodies =
488488
ctx.conditionalBodies.asScala.toList.map(
489489
body =>
490490
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
491491
)
492492

493-
if (conditions.length != conditionalBodies.length) {
493+
if (conditionExpressions.length != conditionalBodies.length) {
494494
throw SparkException.internalError(
495-
s"Mismatched number of conditions ${conditions.length} and condition bodies" +
495+
s"Mismatched number of conditions ${conditionExpressions.length} and condition bodies" +
496496
s" ${conditionalBodies.length} in case statement")
497497
}
498498

499-
CaseStatement(
500-
conditions = conditions,
501-
conditionalBodies = conditionalBodies,
499+
SimpleCaseStatement(
500+
caseVariableExpr,
501+
conditionExpressions,
502+
conditionalBodies,
502503
elseBody = Option(ctx.elseBody).map(
503504
body =>
504505
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.Locale
2121

2222
import scala.collection.mutable
2323

24-
import org.apache.spark.sql.catalyst.expressions.Attribute
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
2525
import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType
2626
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
2727
import org.apache.spark.sql.errors.SqlScriptingErrors
@@ -220,13 +220,24 @@ case class IterateStatement(label: String) extends CompoundPlanStatement {
220220
}
221221

222222
/**
223-
* Logical operator for CASE statement.
223+
* Logical operator for CASE statement, SEARCHED variant.<br>
224+
* Example:
225+
* {{{
226+
* CASE
227+
* WHEN x = 1 THEN
228+
* SELECT 1;
229+
* WHEN x = 2 THEN
230+
* SELECT 2;
231+
* ELSE
232+
* SELECT 3;
233+
* END CASE;
234+
* }}}
224235
* @param conditions Collection of conditions which correspond to WHEN clauses.
225236
* @param conditionalBodies Collection of bodies that have a corresponding condition,
226237
* in WHEN branches.
227238
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
228239
*/
229-
case class CaseStatement(
240+
case class SearchedCaseStatement(
230241
conditions: Seq[SingleStatement],
231242
conditionalBodies: Seq[CompoundBody],
232243
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
@@ -253,7 +264,44 @@ case class CaseStatement(
253264
conditionalBodies = conditionalBodies.dropRight(1)
254265
elseBody = Some(conditionalBodies.last)
255266
}
256-
CaseStatement(conditions, conditionalBodies, elseBody)
267+
SearchedCaseStatement(conditions, conditionalBodies, elseBody)
268+
}
269+
}
270+
271+
/**
272+
* Logical operator for CASE statement, SIMPLE variant.<br>
273+
* Example:
274+
* {{{
275+
* CASE x
276+
* WHEN 1 THEN
277+
* SELECT 1;
278+
* WHEN 2 THEN
279+
* SELECT 2;
280+
* ELSE
281+
* SELECT 3;
282+
* END CASE;
283+
* }}}
284+
* @param caseVariableExpression Expression with which all conditionExpressions will be compared to.
285+
* @param conditionExpressions Collection of expressions which correspond to WHEN clauses.
286+
* @param conditionalBodies Collection of bodies that have a corresponding condition,
287+
* in WHEN branches.
288+
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
289+
*/
290+
case class SimpleCaseStatement(
291+
caseVariableExpression: Expression,
292+
conditionExpressions: Seq[Expression],
293+
conditionalBodies: Seq[CompoundBody],
294+
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
295+
assert(conditionExpressions.length == conditionalBodies.length)
296+
297+
override def output: Seq[Attribute] = Seq.empty
298+
299+
override def children: Seq[LogicalPlan] = conditionalBodies
300+
301+
override protected def withNewChildrenInternal(
302+
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
303+
val conditionalBodies = newChildren.map(_.asInstanceOf[CompoundBody])
304+
SimpleCaseStatement(caseVariableExpression, conditionExpressions, conditionalBodies, elseBody)
257305
}
258306
}
259307

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala

Lines changed: 49 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
package org.apache.spark.sql.catalyst.parser
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery}
21+
import org.apache.spark.sql.catalyst.expressions.{In, Literal, ScalarSubquery}
2222
import org.apache.spark.sql.catalyst.plans.SQLHelper
23-
import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SetVariable, SingleStatement, WhileStatement}
23+
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SearchedCaseStatement, SetVariable, SimpleCaseStatement, SingleStatement, WhileStatement}
2424
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
2525
import org.apache.spark.sql.exceptions.SqlScriptingException
2626
import org.apache.spark.sql.internal.SQLConf
@@ -1462,8 +1462,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
14621462
|""".stripMargin
14631463
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
14641464
assert(tree.collection.length == 1)
1465-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1466-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1465+
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
1466+
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
14671467
assert(caseStmt.conditions.length == 1)
14681468
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
14691469
assert(caseStmt.conditions.head.getText == "1 = 1")
@@ -1502,9 +1502,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
15021502
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
15031503

15041504
assert(tree.collection.length == 1)
1505-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1505+
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
15061506

1507-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1507+
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
15081508
assert(caseStmt.conditions.length == 3)
15091509
assert(caseStmt.conditionalBodies.length == 3)
15101510
assert(caseStmt.elseBody.isEmpty)
@@ -1545,8 +1545,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
15451545
|""".stripMargin
15461546
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
15471547
assert(tree.collection.length == 1)
1548-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1549-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1548+
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
1549+
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
15501550
assert(caseStmt.elseBody.isDefined)
15511551
assert(caseStmt.conditions.length == 1)
15521552
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
@@ -1574,19 +1574,19 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
15741574
|""".stripMargin
15751575
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
15761576
assert(tree.collection.length == 1)
1577-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1577+
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
15781578

1579-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1579+
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
15801580
assert(caseStmt.conditions.length == 1)
15811581
assert(caseStmt.conditionalBodies.length == 1)
15821582
assert(caseStmt.elseBody.isEmpty)
15831583

15841584
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
15851585
assert(caseStmt.conditions.head.getText == "1 = 1")
15861586

1587-
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
1587+
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SearchedCaseStatement])
15881588
val nestedCaseStmt =
1589-
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
1589+
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SearchedCaseStatement]
15901590

15911591
assert(nestedCaseStmt.conditions.length == 1)
15921592
assert(nestedCaseStmt.conditionalBodies.length == 1)
@@ -1616,11 +1616,15 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
16161616
|""".stripMargin
16171617
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
16181618
assert(tree.collection.length == 1)
1619-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1620-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1621-
assert(caseStmt.conditions.length == 1)
1622-
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
1623-
checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
1619+
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
1620+
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
1621+
assert(caseStmt.caseVariableExpression == Literal(1))
1622+
assert(caseStmt.conditionExpressions.length == 1)
1623+
assert(caseStmt.conditionExpressions.head == Literal(1))
1624+
1625+
assert(caseStmt.conditionalBodies.length == 1)
1626+
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
1627+
.getText == "SELECT 1")
16241628
}
16251629

16261630
test("simple case statement with empty body") {
@@ -1656,31 +1660,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
16561660
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
16571661

16581662
assert(tree.collection.length == 1)
1659-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1663+
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
16601664

1661-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1662-
assert(caseStmt.conditions.length == 3)
1665+
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
1666+
assert(caseStmt.caseVariableExpression == Literal(1))
1667+
assert(caseStmt.conditionExpressions.length == 3)
16631668
assert(caseStmt.conditionalBodies.length == 3)
16641669
assert(caseStmt.elseBody.isEmpty)
16651670

1666-
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
1667-
checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
1671+
assert(caseStmt.conditionExpressions.head == Literal(1))
16681672

16691673
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
16701674
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
16711675
.getText == "SELECT 1")
16721676

1673-
assert(caseStmt.conditions(1).isInstanceOf[SingleStatement])
1674-
checkSimpleCaseStatementCondition(
1675-
caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery])
1677+
assert(caseStmt.conditionExpressions(1).isInstanceOf[ScalarSubquery])
16761678

16771679
assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
16781680
assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
16791681
.getText == "SELECT * FROM b")
16801682

1681-
assert(caseStmt.conditions(2).isInstanceOf[SingleStatement])
1682-
checkSimpleCaseStatementCondition(
1683-
caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In])
1683+
assert(caseStmt.conditionExpressions(2).isInstanceOf[In])
16841684

16851685
assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
16861686
assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
@@ -1701,12 +1701,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
17011701
|""".stripMargin
17021702
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
17031703
assert(tree.collection.length == 1)
1704-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1705-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1704+
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
1705+
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
1706+
1707+
assert(caseStmt.caseVariableExpression == Literal(1))
17061708
assert(caseStmt.elseBody.isDefined)
1707-
assert(caseStmt.conditions.length == 1)
1708-
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
1709-
checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
1709+
assert(caseStmt.conditionExpressions.length == 1)
1710+
assert(caseStmt.conditionExpressions.head == Literal(1))
1711+
1712+
assert(caseStmt.conditionalBodies.length == 1)
1713+
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
1714+
.getText == "SELECT 42")
17101715

17111716
assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
17121717
assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
@@ -1730,28 +1735,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
17301735
|""".stripMargin
17311736
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
17321737
assert(tree.collection.length == 1)
1733-
assert(tree.collection.head.isInstanceOf[CaseStatement])
1738+
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
17341739

1735-
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
1736-
assert(caseStmt.conditions.length == 1)
1740+
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
1741+
1742+
assert(caseStmt.caseVariableExpression.isInstanceOf[ScalarSubquery])
1743+
assert(caseStmt.conditionExpressions.length == 1)
17371744
assert(caseStmt.conditionalBodies.length == 1)
17381745
assert(caseStmt.elseBody.isEmpty)
17391746

1740-
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
1741-
checkSimpleCaseStatementCondition(
1742-
caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ == Literal(1))
1747+
assert(caseStmt.conditionExpressions.head == Literal(1))
17431748

1744-
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
1749+
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SimpleCaseStatement])
17451750
val nestedCaseStmt =
1746-
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
1751+
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SimpleCaseStatement]
17471752

1748-
assert(nestedCaseStmt.conditions.length == 1)
1753+
assert(nestedCaseStmt.caseVariableExpression == Literal(2))
1754+
assert(nestedCaseStmt.conditionExpressions.length == 1)
17491755
assert(nestedCaseStmt.conditionalBodies.length == 1)
17501756
assert(nestedCaseStmt.elseBody.isDefined)
17511757

1752-
assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement])
1753-
checkSimpleCaseStatementCondition(
1754-
nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2))
1758+
assert(nestedCaseStmt.conditionExpressions.head == Literal(2))
17551759

17561760
assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
17571761
assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
@@ -2910,17 +2914,4 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
29102914
.replace("END", "")
29112915
.trim
29122916
}
2913-
2914-
private def checkSimpleCaseStatementCondition(
2915-
conditionStatement: SingleStatement,
2916-
predicateLeft: Expression => Boolean,
2917-
predicateRight: Expression => Boolean): Unit = {
2918-
assert(conditionStatement.parsedPlan.isInstanceOf[Project])
2919-
val project = conditionStatement.parsedPlan.asInstanceOf[Project]
2920-
assert(project.projectList.head.isInstanceOf[Alias])
2921-
assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo])
2922-
val equalTo = project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo]
2923-
assert(predicateLeft(equalTo.left))
2924-
assert(predicateRight(equalTo.right))
2925-
}
29262917
}

0 commit comments

Comments
 (0)