Skip to content

Commit c94ce2c

Browse files
committed
[SPARK-55302][SQL] Fix custom metrics in case of KeyGroupedPartitioning
### What changes were proposed in this pull request? This PR adds a new `initMetricsValues()` method to `PartitionReader` so as to initialize custom metrics returned by `currentMetricsValues()`. In case of `KeyGroupedPartitioning` multiple input partitions are grouped and so multiple `PartitionReader` belong to one output partition. A `PartitionReader` needs to be initialized with metrics calculated by the previous `PartitionReader` of the same partition group to calculate the right value. ### Why are the changes needed? To calculate custom metrics correctly. ### Does this PR introduce _any_ user-facing change? It fixes metrics calculation. ### How was this patch tested? New UT is added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#54081 from peter-toth/SPARK-55302-fix-kgp-custom-metrics. Authored-by: Peter Toth <peter.toth@gmail.com> Signed-off-by: Peter Toth <peter.toth@gmail.com>
1 parent 15c6849 commit c94ce2c

File tree

6 files changed

+95
-28
lines changed

6 files changed

+95
-28
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,13 @@ default CustomTaskMetric[] currentMetricsValues() {
5858
CustomTaskMetric[] NO_METRICS = {};
5959
return NO_METRICS;
6060
}
61+
62+
/**
63+
* Sets the initial value of metrics before fetching any data from the reader. This is called
64+
* when multiple {@link PartitionReader}s are grouped into one partition in case of
65+
* {@link org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and the reader
66+
* is initialized with the metrics returned by the previous reader that belongs to the same
67+
* partition. By default, this method does nothing.
68+
*/
69+
default void initMetricsValues(CustomTaskMetric[] metrics) {}
6170
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ abstract class InMemoryBaseTable(
543543
}
544544
new BufferedRowsReaderFactory(metadataColumns.toSeq, nonMetadataColumns, tableSchema)
545545
}
546+
547+
override def supportedCustomMetrics(): Array[CustomMetric] = {
548+
Array(new RowsReadCustomMetric)
549+
}
546550
}
547551

548552
case class InMemoryBatchScan(
@@ -830,10 +834,13 @@ private class BufferedRowsReader(
830834
}
831835

832836
private var index: Int = -1
837+
private var rowsRead: Long = 0
833838

834839
override def next(): Boolean = {
835840
index += 1
836-
index < partition.rows.length
841+
val hasNext = index < partition.rows.length
842+
if (hasNext) rowsRead += 1
843+
hasNext
837844
}
838845

839846
override def get(): InternalRow = {
@@ -976,6 +983,22 @@ private class BufferedRowsReader(
976983

977984
private def castElement(elem: Any, toType: DataType, fromType: DataType): Any =
978985
Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null)
986+
987+
override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = {
988+
metrics.foreach { m =>
989+
m.name match {
990+
case "rows_read" => rowsRead = m.value()
991+
}
992+
}
993+
}
994+
995+
override def currentMetricsValues(): Array[CustomTaskMetric] = {
996+
val metric = new CustomTaskMetric {
997+
override def name(): String = "rows_read"
998+
override def value(): Long = rowsRead
999+
}
1000+
Array(metric)
1001+
}
9791002
}
9801003

9811004
private class BufferedRowsWriterFactory(schema: StructType)
@@ -1044,6 +1067,11 @@ class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric {
10441067
override def value(): Long = value
10451068
}
10461069

1070+
class RowsReadCustomMetric extends CustomSumMetric {
1071+
override def name(): String = "rows_read"
1072+
override def description(): String = "number of rows read"
1073+
}
1074+
10471075
case class Commit(id: Long, writeSummary: Option[WriteSummary] = None)
10481076

10491077
sealed trait Operation

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
2424
import org.apache.spark.internal.Logging
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.catalyst.InternalRow
27+
import org.apache.spark.sql.connector.metric.CustomTaskMetric
2728
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2829
import org.apache.spark.sql.errors.QueryExecutionErrors
2930
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -97,7 +98,8 @@ class DataSourceRDD(
9798
}
9899

99100
// Once we advance to the next partition, update the metric callback for early finish
100-
partitionMetricCallback.advancePartition(iter, reader)
101+
val previousMetrics = partitionMetricCallback.advancePartition(iter, reader)
102+
previousMetrics.foreach(reader.initMetricsValues)
101103

102104
currentIter = Some(iter)
103105
hasNext
@@ -118,19 +120,26 @@ private class PartitionMetricCallback
118120
private var iter: MetricsIterator[_] = null
119121
private var reader: PartitionReader[_] = null
120122

121-
def advancePartition(iter: MetricsIterator[_], reader: PartitionReader[_]): Unit = {
122-
execute()
123+
def advancePartition(
124+
iter: MetricsIterator[_],
125+
reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
126+
val metrics = execute()
123127

124128
this.iter = iter
125129
this.reader = reader
130+
131+
metrics
126132
}
127133

128-
def execute(): Unit = {
134+
def execute(): Option[Array[CustomTaskMetric]] = {
129135
if (iter != null && reader != null) {
130-
CustomMetrics
131-
.updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
136+
val metrics = reader.currentMetricsValues
137+
CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
132138
iter.forceUpdateMetrics()
133139
reader.close()
140+
Some(metrics)
141+
} else {
142+
None
134143
}
135144
}
136145
}

sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,4 +2823,22 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
28232823
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
28242824
}
28252825
}
2826+
2827+
test("SPARK-55302: Custom metrics of grouped partitions") {
2828+
val items_partitions = Array(identity("id"))
2829+
createTable(items, itemsColumns, items_partitions)
2830+
2831+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2832+
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
2833+
"(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
2834+
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
2835+
2836+
val metrics = runAndFetchMetrics {
2837+
val df = sql(s"SELECT * FROM testcat.ns.$items")
2838+
val scans = collectScans(df.queryExecution.executedPlan)
2839+
assert(scans(0).inputRDD.partitions.length === 2, "items scan should have 2 partition groups")
2840+
df.collect()
2841+
}
2842+
assert(metrics("number of rows read") == "3")
2843+
}
28262844
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources
1919
import java.util.Collections
2020

2121
import org.scalatest.BeforeAndAfter
22-
import org.scalatest.time.SpanSugar._
2322

2423
import org.apache.spark.sql.QueryTest
2524
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTable, InMemoryTableCatalog}
@@ -54,27 +53,8 @@ class InMemoryTableMetricSuite
5453
Array(Column.create("i", IntegerType)),
5554
Array.empty[Transform], Collections.emptyMap[String, String])
5655

57-
func("testcat.table_name")
56+
val metrics = runAndFetchMetrics(func("testcat.table_name"))
5857

59-
// Wait until the new execution is started and being tracked.
60-
eventually(timeout(10.seconds), interval(10.milliseconds)) {
61-
assert(statusStore.executionsCount() >= oldCount)
62-
}
63-
64-
// Wait for listener to finish computing the metrics for the execution.
65-
eventually(timeout(10.seconds), interval(10.milliseconds)) {
66-
assert(statusStore.executionsList().nonEmpty &&
67-
statusStore.executionsList().last.metricValues != null)
68-
}
69-
70-
val exec = statusStore.executionsList().last
71-
val execId = exec.executionId
72-
val sqlMetrics = exec.metrics.map { metric =>
73-
metric.accumulatorId -> metric.name
74-
}.toMap
75-
val metrics = statusStore.executionMetrics(execId).map { case (k, v) =>
76-
sqlMetrics(k) -> v
77-
}
7858
checker(metrics)
7959
}
8060
}

sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,29 @@ trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase {
5454
doThreadPostAudit()
5555
}
5656
}
57+
58+
def runAndFetchMetrics(func: => Unit): Map[String, String] = {
59+
val statusStore = spark.sharedState.statusStore
60+
val oldCount = statusStore.executionsList().size
61+
62+
func
63+
64+
// Wait until the new execution is started and being tracked.
65+
eventually(timeout(10.seconds), interval(10.milliseconds)) {
66+
assert(statusStore.executionsCount() >= oldCount)
67+
}
68+
69+
// Wait for listener to finish computing the metrics for the execution.
70+
eventually(timeout(10.seconds), interval(10.milliseconds)) {
71+
assert(statusStore.executionsList().nonEmpty &&
72+
statusStore.executionsList().last.metricValues != null)
73+
}
74+
75+
val exec = statusStore.executionsList().last
76+
val execId = exec.executionId
77+
val sqlMetrics = exec.metrics.map { metric => metric.accumulatorId -> metric.name }.toMap
78+
statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v }
79+
}
5780
}
5881

5982
/**

0 commit comments

Comments
 (0)