Skip to content

Commit f4d10b8

Browse files
[Data] Sample finalized partitions randomly to avoid lensing finalization on a single node (#58456)
> Thank you for contributing to Ray! 🚀 > Please review the [Ray Contribution Guide](https://docs.ray.io/en/master/ray-contribute/getting-involved.html) before opening a pull request. > ⚠️ Remove these instructions before submitting your PR. > 💡 Tip: Mark as draft if you want early feedback, or ready for review when it's complete. ## Description Currently, finalization is scheduled in batches sequentially -- ie batch of N adjacent partitions is finalized at once (in a sliding window). This creates a lensing effect since: 1. Adjacent partitions i and i+1 get scheduled onto adjacent aggregators j and j+i (since membership is determined as j = i % num_aggregators) 2. Adjacent aggregators have high likelihood of getting scheduled on the same node (due to similarly being scheduled at about the same time in sequence) To address that this change applies random sampling when choosing next partitions to finalize to make sure partitions are chosen uniformly reducing concurrent finalization of the adjacent partitions. ## Related issues > Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Alexey Kudinkin <[email protected]>
1 parent a70a1b1 commit f4d10b8

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

python/ray/data/_internal/execution/operators/hash_shuffle.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import logging
55
import math
6+
import random
67
import threading
78
import time
89
from collections import defaultdict, deque
@@ -16,6 +17,7 @@
1617
Dict,
1718
List,
1819
Optional,
20+
Set,
1921
Tuple,
2022
Union,
2123
)
@@ -601,8 +603,10 @@ def __init__(
601603
# aggregators (keeps track which input sequences have already broadcasted
602604
# their schemas)
603605
self._has_schemas_broadcasted: DefaultDict[int, bool] = defaultdict(bool)
604-
# Id of the last partition finalization of which had already been scheduled
605-
self._last_finalized_partition_id: int = -1
606+
# Set of partitions still pending finalization
607+
self._pending_finalization_partition_ids: Set[int] = set(
608+
range(target_num_partitions)
609+
)
606610

607611
self._output_queue: Deque[RefBundle] = deque()
608612

@@ -823,11 +827,6 @@ def _try_finalize(self):
823827
if not self._is_shuffling_done():
824828
return
825829

826-
logger.debug(
827-
f"Scheduling next shuffling finalization batch (last finalized "
828-
f"partition id is {self._last_finalized_partition_id})"
829-
)
830-
831830
def _on_bundle_ready(partition_id: int, bundle: RefBundle):
832831
# Add finalized block to the output queue
833832
self._output_queue.append(bundle)
@@ -872,10 +871,8 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
872871
or self._aggregator_pool.num_aggregators
873872
)
874873

875-
num_remaining_partitions = (
876-
self._num_partitions - 1 - self._last_finalized_partition_id
877-
)
878874
num_running_finalizing_tasks = len(self._finalizing_tasks)
875+
num_remaining_partitions = len(self._pending_finalization_partition_ids)
879876

880877
# Finalization is executed in batches of no more than
881878
# `DataContext.max_hash_shuffle_finalization_batch_size` tasks at a time.
@@ -899,12 +896,21 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
899896
if next_batch_size == 0:
900897
return
901898

902-
# Next partition to be scheduled for finalization is the one right
903-
# after the last one scheduled
904-
next_partition_id = self._last_finalized_partition_id + 1
905-
906-
target_partition_ids = list(
907-
range(next_partition_id, next_partition_id + next_batch_size)
899+
# We're sampling randomly next set of partitions to be finalized
900+
# to distribute finalization window uniformly across the nodes of the cluster
901+
# and avoid effect of "sliding lense" effect where we finalize the batch of
902+
# N *adjacent* partitions that may be co-located on the same node:
903+
#
904+
# - Adjacent partitions i and i+1 are handled by adjacent
905+
# aggregators (since membership is determined as i % num_aggregators)
906+
#
907+
# - Adjacent aggregators have high likelihood of running on the
908+
# same node (when num aggregators > num nodes)
909+
#
910+
# NOTE: This doesn't affect determinism, since this only impacts order
911+
# of finalization (hence not required to be seeded)
912+
target_partition_ids = random.sample(
913+
list(self._pending_finalization_partition_ids), next_batch_size
908914
)
909915

910916
logger.debug(
@@ -941,15 +947,15 @@ def _on_aggregation_done(partition_id: int, exc: Optional[Exception]):
941947
),
942948
)
943949

950+
# Pop partition id from remaining set
951+
self._pending_finalization_partition_ids.remove(partition_id)
952+
944953
# Update Finalize Metrics on task submission
945954
# NOTE: This is empty because the input is directly forwarded from the
946955
# output of the shuffling stage, which we don't return.
947956
empty_bundle = RefBundle([], schema=None, owns_blocks=False)
948957
self.reduce_metrics.on_task_submitted(partition_id, empty_bundle)
949958

950-
# Update last finalized partition id
951-
self._last_finalized_partition_id = max(target_partition_ids)
952-
953959
def _do_shutdown(self, force: bool = False) -> None:
954960
self._aggregator_pool.shutdown(force=True)
955961
# NOTE: It's critical for Actor Pool to release actors before calling into
@@ -1021,7 +1027,7 @@ def implements_accurate_memory_accounting(self) -> bool:
10211027
return True
10221028

10231029
def _is_finalized(self):
1024-
return self._last_finalized_partition_id == self._num_partitions - 1
1030+
return len(self._pending_finalization_partition_ids) == 0
10251031

10261032
def _handle_shuffled_block_metadata(
10271033
self,

python/ray/data/_internal/execution/operators/join.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def _preprocess(
166166
left_seq_partition: pa.Table = self._get_partition_builder(
167167
input_seq_id=0, partition_id=partition_id
168168
).build()
169+
169170
right_seq_partition: pa.Table = self._get_partition_builder(
170171
input_seq_id=1, partition_id=partition_id
171172
).build()
@@ -198,7 +199,6 @@ def _preprocess(
198199
should_index_r = self._should_index_side("right", supported_r, unsupported_r)
199200

200201
# Add index columns for back-referencing if we have unsupported columns
201-
# TODO: what are the chances of a collision with the index column?
202202
if should_index_l:
203203
supported_l = self._append_index_column(
204204
table=supported_l, col_name=self._index_name("left")
@@ -246,7 +246,7 @@ def _postprocess(
246246
return supported
247247

248248
def _index_name(self, suffix: str) -> str:
249-
return f"__ray_data_index_level_{suffix}__"
249+
return f"__rd_index_level_{suffix}__"
250250

251251
def clear(self, partition_id: int):
252252
self._left_input_seq_partition_builders.pop(partition_id)
@@ -263,9 +263,6 @@ def _get_partition_builder(self, *, input_seq_id: int, partition_id: int):
263263
)
264264
return partition_builder
265265

266-
def _get_index_col_name(self, index: int) -> str:
267-
return f"__index_level_{index}__"
268-
269266
def _should_index_side(
270267
self, side: str, supported_table: "pa.Table", unsupported_table: "pa.Table"
271268
) -> bool:
@@ -318,9 +315,8 @@ def _split_unsupported_columns(
318315
"""
319316
supported, unsupported = [], []
320317
for idx in range(len(table.columns)):
321-
column: "pa.ChunkedArray" = table.column(idx)
322-
323-
col_type = column.type
318+
col: "pa.ChunkedArray" = table.column(idx)
319+
col_type: "pa.DataType" = col.type
324320

325321
if _is_pa_extension_type(col_type) or self._is_pa_join_not_supported(
326322
col_type
@@ -329,7 +325,7 @@ def _split_unsupported_columns(
329325
else:
330326
supported.append(idx)
331327

332-
return (table.select(supported), table.select(unsupported))
328+
return table.select(supported), table.select(unsupported)
333329

334330
def _add_back_unsupported_columns(
335331
self,

0 commit comments

Comments
 (0)