Skip to content

Commit d2c9b00

Browse files
wForgetJunfan Zhang
authored andcommitted
[apache#1824] feat(spark): Support map side combine of shuffle writer (apache#1825)
Support map side combine of shuffle write Fix: apache#1824 Yes, support new shuffle writer behavior. Added integration test
1 parent c77caea commit d2c9b00

File tree

6 files changed

+227
-195
lines changed

6 files changed

+227
-195
lines changed

client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ public class RssSparkConfig {
9494
.defaultValue(1)
9595
.withDescription("The block retry max times when partition reassign is enabled.");
9696

97+
public static final ConfigOption<Boolean> RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED =
98+
ConfigOptions.key("rss.client.mapSideCombine.enabled")
99+
.booleanType()
100+
.defaultValue(false)
101+
.withDescription("Whether to enable map side combine of shuffle writer.");
102+
97103
public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";
98104

99105
public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE =

client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import scala.Function1;
4242
import scala.Option;
4343
import scala.Product2;
44+
import scala.Tuple2;
4445
import scala.collection.Iterator;
4546

4647
import com.google.common.annotations.VisibleForTesting;
@@ -95,6 +96,7 @@
9596
import org.apache.uniffle.common.rpc.StatusCode;
9697
import org.apache.uniffle.storage.util.StorageType;
9798

99+
import static org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED;
98100
import static org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES;
99101
import static org.apache.spark.shuffle.RssSparkConfig.RSS_TASK_FAILED_CALLBACK_ENABLED;
100102

@@ -337,25 +339,27 @@ private void reportTaskFailure(Exception exception) {
337339
protected void writeImpl(Iterator<Product2<K, V>> records) {
338340
List<ShuffleBlockInfo> shuffleBlockInfos;
339341
boolean isCombine = shuffleDependency.mapSideCombine();
340-
Function1<V, C> createCombiner = null;
342+
343+
Iterator<? extends Product2<K, ?>> iterator = records;
341344
if (isCombine) {
342-
createCombiner = shuffleDependency.aggregator().get().createCombiner();
345+
if (RssSparkConfig.toRssConf(sparkConf).get(RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED)) {
346+
iterator = shuffleDependency.aggregator().get().combineValuesByKey(records, taskContext);
347+
} else {
348+
Function1<V, C> combiner = shuffleDependency.aggregator().get().createCombiner();
349+
iterator =
350+
records.map(
351+
(Function1<Product2<K, V>, Product2<K, C>>)
352+
x -> new Tuple2<>(x._1(), combiner.apply(x._2())));
353+
}
343354
}
344355
long recordCount = 0;
345-
while (records.hasNext()) {
356+
while (iterator.hasNext()) {
346357
recordCount++;
347-
348358
checkDataIfAnyFailure();
349-
350-
Product2<K, V> record = records.next();
359+
Product2<K, ?> record = iterator.next();
351360
K key = record._1();
352361
int partition = getPartition(key);
353-
if (isCombine) {
354-
Object c = createCombiner.apply(record._2());
355-
shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), c);
356-
} else {
357-
shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), record._2());
358-
}
362+
shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), record._2());
359363
if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
360364
processShuffleBlockInfos(shuffleBlockInfos);
361365
}

docs/client_guide/spark_client_guide.md

Lines changed: 0 additions & 152 deletions
This file was deleted.

integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121
import java.util.List;
2222
import java.util.Map;
2323

24-
import org.apache.spark.executor.TaskMetrics;
25-
import org.apache.spark.scheduler.SparkListener;
26-
import org.apache.spark.scheduler.SparkListenerTaskEnd;
2724
import org.apache.spark.sql.Dataset;
2825
import org.apache.spark.sql.Row;
2926
import org.apache.spark.sql.SparkSession;
3027
import org.apache.spark.sql.functions;
3128
import org.junit.jupiter.api.Test;
3229

30+
import org.apache.uniffle.test.listener.WriteAndReadMetricsSparkListener;
31+
3332
public class WriteAndReadMetricsTest extends SimpleTestBase {
3433

3534
@Test
@@ -63,6 +62,7 @@ public Map<String, Long> runTest(SparkSession spark, String fileName) throws Exc
6362

6463
// take a rest to make sure all task metrics are updated before read stageData
6564
Thread.sleep(100);
65+
6666
for (int stageId : spark.sparkContext().statusTracker().getJobInfo(0).get().stageIds()) {
6767
long writeRecords = listener.getWriteRecords(stageId);
6868
long readRecords = listener.getReadRecords(stageId);
@@ -72,32 +72,4 @@ public Map<String, Long> runTest(SparkSession spark, String fileName) throws Exc
7272

7373
return result;
7474
}
75-
76-
private static class WriteAndReadMetricsSparkListener extends SparkListener {
77-
private HashMap<Integer, Long> stageIdToWriteRecords = new HashMap<>();
78-
private HashMap<Integer, Long> stageIdToReadRecords = new HashMap<>();
79-
80-
@Override
81-
public void onTaskEnd(SparkListenerTaskEnd event) {
82-
int stageId = event.stageId();
83-
TaskMetrics taskMetrics = event.taskMetrics();
84-
if (taskMetrics != null) {
85-
long writeRecords = taskMetrics.shuffleWriteMetrics().recordsWritten();
86-
long readRecords = taskMetrics.shuffleReadMetrics().recordsRead();
87-
// Accumulate writeRecords and readRecords for the given stageId
88-
stageIdToWriteRecords.put(
89-
stageId, stageIdToWriteRecords.getOrDefault(stageId, 0L) + writeRecords);
90-
stageIdToReadRecords.put(
91-
stageId, stageIdToReadRecords.getOrDefault(stageId, 0L) + readRecords);
92-
}
93-
}
94-
95-
public long getWriteRecords(int stageId) {
96-
return stageIdToWriteRecords.getOrDefault(stageId, 0L);
97-
}
98-
99-
public long getReadRecords(int stageId) {
100-
return stageIdToReadRecords.getOrDefault(stageId, 0L);
101-
}
102-
}
10375
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.uniffle.test.listener;
19+
20+
import java.util.HashMap;
21+
22+
import org.apache.spark.executor.TaskMetrics;
23+
import org.apache.spark.scheduler.SparkListener;
24+
import org.apache.spark.scheduler.SparkListenerTaskEnd;
25+
26+
public class WriteAndReadMetricsSparkListener extends SparkListener {
27+
private HashMap<Integer, Long> stageIdToWriteRecords = new HashMap<>();
28+
private HashMap<Integer, Long> stageIdToReadRecords = new HashMap<>();
29+
30+
@Override
31+
public void onTaskEnd(SparkListenerTaskEnd event) {
32+
int stageId = event.stageId();
33+
TaskMetrics taskMetrics = event.taskMetrics();
34+
if (taskMetrics != null) {
35+
long writeRecords = taskMetrics.shuffleWriteMetrics().recordsWritten();
36+
long readRecords = taskMetrics.shuffleReadMetrics().recordsRead();
37+
// Accumulate writeRecords and readRecords for the given stageId
38+
stageIdToWriteRecords.put(
39+
stageId, stageIdToWriteRecords.getOrDefault(stageId, 0L) + writeRecords);
40+
stageIdToReadRecords.put(
41+
stageId, stageIdToReadRecords.getOrDefault(stageId, 0L) + readRecords);
42+
}
43+
}
44+
45+
public long getWriteRecords(int stageId) {
46+
return stageIdToWriteRecords.getOrDefault(stageId, 0L);
47+
}
48+
49+
public long getReadRecords(int stageId) {
50+
return stageIdToReadRecords.getOrDefault(stageId, 0L);
51+
}
52+
}

0 commit comments

Comments
 (0)