Skip to content

Commit 9d061e3

Browse files
Peng-LeiMaxGekk
authored andcommitted
[SPARK-35926][SQL] Add support YearMonthIntervalType for width_bucket
### What changes were proposed in this pull request? Support width_bucket(YearMonthIntervalType, YearMonthIntervalType, YearMonthIntervalType, Long), it return long result eg: ``` width_bucket(input_value, min_value, max_value, bucket_nums) width_bucket(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10) It will divides the range between the max_value and min_value into 10 buckets. [ INTERVAL '0' YEAR, INTERVAL '1' YEAR), [ INTERVAL '1' YEAR, INTERVAL '2' YEAR)...... [INTERVAL '9' YEAR, INTERVAL '10' YEAR) Then, calculates which bucket the given input_value locate. ``` The function `width_bucket` is introduced from [SPARK-21117](https://issues.apache.org/jira/browse/SPARK-21117) ### Why are the changes needed? [35926](https://issues.apache.org/jira/browse/SPARK-35926) 1. The `WIDTH_BUCKET` function assigns values to buckets (individual segments) in an equiwidth histogram. The ANSI SQL Standard Syntax is like follow: `WIDTH_BUCKET( expression, min, max, buckets)`. [Reference](https://www.oreilly.com/library/view/sql-in-a/9780596155322/re91.html). 2. `WIDTH_BUCKET` just support `Double` at now, Of course, we can cast `Int` to `Double` to use it. But we cloud not cast `YearMonthIntervayType` to `Double`. 3. I think it has a use scenario. eg: Histogram of employee years of service, the `years of service` is a column of `YearMonthIntervalType` dataType. ### Does this PR introduce _any_ user-facing change? Yes. The user can use `width_bucket` with YearMonthIntervalType. ### How was this patch tested? Add ut test Closes apache#33132 from Peng-Lei/SPARK-35926. Authored-by: PengLei <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 9865c09 commit 9d061e3

File tree

6 files changed

+96
-7
lines changed

6 files changed

+96
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
2525
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
28-
import org.apache.spark.sql.catalyst.util.NumberConverter
28+
import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
2929
import org.apache.spark.sql.types._
3030
import org.apache.spark.unsafe.types.UTF8String
3131

@@ -1613,6 +1613,10 @@ object WidthBucket {
16131613
5
16141614
> SELECT _FUNC_(-0.9, 5.2, 0.5, 2);
16151615
3
1616+
> SELECT _FUNC_(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
1617+
1
1618+
> SELECT _FUNC_(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
1619+
2
16161620
""",
16171621
since = "3.1.0",
16181622
group = "math_funcs")
@@ -1623,16 +1627,35 @@ case class WidthBucket(
16231627
numBucket: Expression)
16241628
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
16251629

1626-
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
1630+
override def inputTypes: Seq[AbstractDataType] = Seq(
1631+
TypeCollection(DoubleType, YearMonthIntervalType),
1632+
TypeCollection(DoubleType, YearMonthIntervalType),
1633+
TypeCollection(DoubleType, YearMonthIntervalType),
1634+
LongType)
1635+
1636+
override def checkInputDataTypes(): TypeCheckResult = {
1637+
super.checkInputDataTypes() match {
1638+
case TypeCheckSuccess =>
1639+
(value.dataType, minValue.dataType, maxValue.dataType) match {
1640+
case (_: YearMonthIntervalType, _: YearMonthIntervalType, _: YearMonthIntervalType) =>
1641+
TypeCheckSuccess
1642+
case _ =>
1643+
val types = Seq(value.dataType, minValue.dataType, maxValue.dataType)
1644+
TypeUtils.checkForSameTypeInputExpr(types, s"function $prettyName")
1645+
}
1646+
case f => f
1647+
}
1648+
}
1649+
16271650
override def dataType: DataType = LongType
16281651
override def nullable: Boolean = true
16291652
override def prettyName: String = "width_bucket"
16301653

16311654
override protected def nullSafeEval(input: Any, min: Any, max: Any, numBucket: Any): Any = {
16321655
WidthBucket.computeBucketNumber(
1633-
input.asInstanceOf[Double],
1634-
min.asInstanceOf[Double],
1635-
max.asInstanceOf[Double],
1656+
input.asInstanceOf[Number].doubleValue(),
1657+
min.asInstanceOf[Number].doubleValue(),
1658+
max.asInstanceOf[Number].doubleValue(),
16361659
numBucket.asInstanceOf[Long])
16371660
}
16381661

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,4 +725,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
725725
checkEvaluation(Signum(Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS))), 1.0)
726726
checkEvaluation(Signum(Literal(Duration.of(Long.MinValue, ChronoUnit.MICROS))), -1.0)
727727
}
728+
729+
test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") {
730+
Seq(
731+
(Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10L) -> 0L,
732+
(Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10L) -> 1L,
733+
(Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10L) -> 2L,
734+
(Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10L) -> 2L,
735+
(Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10L) -> 11L,
736+
(Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10L) -> 11L,
737+
(Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10L) -> 6L,
738+
(Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10L) -> 5L
739+
).foreach { case ((v, s, e, n), expected) =>
740+
checkEvaluation(WidthBucket(Literal(v), Literal(s), Literal(e), Literal(n)), expected)
741+
}
742+
}
728743
}

sql/core/src/test/resources/sql-tests/inputs/interval.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,5 @@ SELECT signum(INTERVAL '0-0' YEAR TO MONTH);
382382
SELECT signum(INTERVAL '-10' DAY);
383383
SELECT signum(INTERVAL '10' HOUR);
384384
SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND);
385+
SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
386+
SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10);

sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 282
2+
-- Number of queries: 284
33

44

55
-- !query
@@ -2657,3 +2657,19 @@ SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND)
26572657
struct<SIGNUM(INTERVAL '0 00:00:00' DAY TO SECOND):double>
26582658
-- !query output
26592659
0.0
2660+
2661+
2662+
-- !query
2663+
SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10)
2664+
-- !query schema
2665+
struct<width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10):bigint>
2666+
-- !query output
2667+
1
2668+
2669+
2670+
-- !query
2671+
SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10)
2672+
-- !query schema
2673+
struct<width_bucket(INTERVAL '-1' YEAR, INTERVAL '-1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10):bigint>
2674+
-- !query output
2675+
1

sql/core/src/test/resources/sql-tests/results/interval.sql.out

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 282
2+
-- Number of queries: 284
33

44

55
-- !query
@@ -2646,3 +2646,19 @@ SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND)
26462646
struct<SIGNUM(INTERVAL '0 00:00:00' DAY TO SECOND):double>
26472647
-- !query output
26482648
0.0
2649+
2650+
2651+
-- !query
2652+
SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10)
2653+
-- !query schema
2654+
struct<width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10):bigint>
2655+
-- !query output
2656+
1
2657+
2658+
2659+
-- !query
2660+
SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10)
2661+
-- !query schema
2662+
struct<width_bucket(INTERVAL '-1' YEAR, INTERVAL '-1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10):bigint>
2663+
-- !query output
2664+
1

sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import java.nio.charset.StandardCharsets
21+
import java.time.Period
2122

2223
import org.apache.spark.sql.functions._
2324
import org.apache.spark.sql.functions.{log => logarithm}
@@ -520,4 +521,20 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
520521
checkAnswer(df.selectExpr("positive(a)"), Row(1))
521522
checkAnswer(df.selectExpr("positive(b)"), Row(-1))
522523
}
524+
525+
test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") {
526+
Seq(
527+
(Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10) -> 0,
528+
(Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10) -> 1,
529+
(Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10) -> 2,
530+
(Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10) -> 2,
531+
(Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10) -> 11,
532+
(Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10) -> 11,
533+
(Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 6,
534+
(Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 5
535+
).foreach { case ((value, start, end, num), expected) =>
536+
val df = Seq((value, start, end, num)).toDF("v", "s", "e", "n")
537+
checkAnswer(df.selectExpr("width_bucket(v, s, e, n)"), Row(expected))
538+
}
539+
}
523540
}

0 commit comments

Comments
 (0)