Skip to content

Commit fcda935

Browse files
mihailomilosevic2001MaxGekk
authored andcommitted
[SPARK-49864][SQL] Improve message of BINARY_ARITHMETIC_OVERFLOW
### What changes were proposed in this pull request? BINARY_ARITHMETIC_OVERFLOW did not have a suggestion on bypassing the error. This PR improves on that. ### Why are the changes needed? All errors should suggest a way to overcome an issue, so that customers can fix problems easier. ### Does this PR introduce _any_ user-facing change? Yes, change in error message. ### How was this patch tested? Tests added for all paths for bytes. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#48335 from mihailom-db/binary_arithmetic_overflow. Authored-by: Mihailo Milosevic <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 4cf9d14 commit fcda935

File tree

5 files changed

+58
-15
lines changed

5 files changed

+58
-15
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@
127127
},
128128
"BINARY_ARITHMETIC_OVERFLOW" : {
129129
"message" : [
130-
"<value1> <symbol> <value2> caused overflow."
130+
"<value1> <symbol> <value2> caused overflow. Use <functionName> to ignore overflow problem and return NULL."
131131
],
132132
"sqlState" : "22003"
133133
},

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,18 @@ abstract class BinaryArithmetic extends BinaryOperator
294294
case ByteType | ShortType =>
295295
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
296296
val tmpResult = ctx.freshName("tmpResult")
297+
val try_suggestion = symbol match {
298+
case "+" => "try_add"
299+
case "-" => "try_subtract"
300+
case "*" => "try_multiply"
301+
case _ => ""
302+
}
297303
val overflowCheck = if (failOnError) {
298304
val javaType = CodeGenerator.boxedType(dataType)
299305
s"""
300306
|if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) {
301307
| throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(
302-
| $eval1, "$symbol", $eval2);
308+
| $eval1, "$symbol", $eval2, "$try_suggestion");
303309
|}
304310
""".stripMargin
305311
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,13 +610,17 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
610610
}
611611

612612
def binaryArithmeticCauseOverflowError(
613-
eval1: Short, symbol: String, eval2: Short): SparkArithmeticException = {
613+
eval1: Short,
614+
symbol: String,
615+
eval2: Short,
616+
suggestedFunc: String): SparkArithmeticException = {
614617
new SparkArithmeticException(
615618
errorClass = "BINARY_ARITHMETIC_OVERFLOW",
616619
messageParameters = Map(
617620
"value1" -> toSQLValue(eval1, ShortType),
618621
"symbol" -> symbol,
619-
"value2" -> toSQLValue(eval2, ShortType)),
622+
"value2" -> toSQLValue(eval2, ShortType),
623+
"functionName" -> toSQLId(suggestedFunc)),
620624
context = Array.empty,
621625
summary = "")
622626
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,27 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
2424
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
2525

2626
private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering {
27-
private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = {
27+
private def checkOverflow(res: Int, x: Byte, y: Byte, op: String, hint: String): Unit = {
2828
if (res > Byte.MaxValue || res < Byte.MinValue) {
29-
throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y)
29+
throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y, hint)
3030
}
3131
}
3232

3333
override def plus(x: Byte, y: Byte): Byte = {
3434
val tmp = x + y
35-
checkOverflow(tmp, x, y, "+")
35+
checkOverflow(tmp, x, y, "+", "try_add")
3636
tmp.toByte
3737
}
3838

3939
override def minus(x: Byte, y: Byte): Byte = {
4040
val tmp = x - y
41-
checkOverflow(tmp, x, y, "-")
41+
checkOverflow(tmp, x, y, "-", "try_subtract")
4242
tmp.toByte
4343
}
4444

4545
override def times(x: Byte, y: Byte): Byte = {
4646
val tmp = x * y
47-
checkOverflow(tmp, x, y, "*")
47+
checkOverflow(tmp, x, y, "*", "try_multiply")
4848
tmp.toByte
4949
}
5050

@@ -55,27 +55,27 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr
5555

5656

5757
private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering {
58-
private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = {
58+
private def checkOverflow(res: Int, x: Short, y: Short, op: String, hint: String): Unit = {
5959
if (res > Short.MaxValue || res < Short.MinValue) {
60-
throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y)
60+
throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y, hint)
6161
}
6262
}
6363

6464
override def plus(x: Short, y: Short): Short = {
6565
val tmp = x + y
66-
checkOverflow(tmp, x, y, "+")
66+
checkOverflow(tmp, x, y, "+", "try_add")
6767
tmp.toShort
6868
}
6969

7070
override def minus(x: Short, y: Short): Short = {
7171
val tmp = x - y
72-
checkOverflow(tmp, x, y, "-")
72+
checkOverflow(tmp, x, y, "-", "try_subtract")
7373
tmp.toShort
7474
}
7575

7676
override def times(x: Short, y: Short): Short = {
7777
val tmp = x * y
78-
checkOverflow(tmp, x, y, "*")
78+
checkOverflow(tmp, x, y, "*", "try_multiply")
7979
tmp.toShort
8080
}
8181

sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,40 @@ class QueryExecutionErrorsSuite
767767
parameters = Map(
768768
"value1" -> "127S",
769769
"symbol" -> "+",
770-
"value2" -> "5S"),
770+
"value2" -> "5S",
771+
"functionName" -> "`try_add`"),
772+
sqlState = "22003")
773+
}
774+
}
775+
776+
test("BINARY_ARITHMETIC_OVERFLOW: byte minus byte result overflow") {
777+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
778+
checkError(
779+
exception = intercept[SparkArithmeticException] {
780+
sql(s"select -2Y - 127Y").collect()
781+
},
782+
condition = "BINARY_ARITHMETIC_OVERFLOW",
783+
parameters = Map(
784+
"value1" -> "-2S",
785+
"symbol" -> "-",
786+
"value2" -> "127S",
787+
"functionName" -> "`try_subtract`"),
788+
sqlState = "22003")
789+
}
790+
}
791+
792+
test("BINARY_ARITHMETIC_OVERFLOW: byte multiply byte result overflow") {
793+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
794+
checkError(
795+
exception = intercept[SparkArithmeticException] {
796+
sql(s"select 127Y * 5Y").collect()
797+
},
798+
condition = "BINARY_ARITHMETIC_OVERFLOW",
799+
parameters = Map(
800+
"value1" -> "127S",
801+
"symbol" -> "*",
802+
"value2" -> "5S",
803+
"functionName" -> "`try_multiply`"),
771804
sqlState = "22003")
772805
}
773806
}

0 commit comments

Comments
 (0)