1616import com .facebook .airlift .log .Logger ;
1717import com .facebook .presto .execution .NodeTaskMap ;
1818import com .facebook .presto .execution .RemoteTask ;
19+ import com .facebook .presto .execution .TaskStatus ;
1920import com .facebook .presto .execution .scheduler .BucketNodeMap ;
2021import com .facebook .presto .execution .scheduler .InternalNodeInfo ;
2122import com .facebook .presto .execution .scheduler .NodeAssignmentStats ;
3536import com .google .common .collect .Multimap ;
3637import com .google .common .util .concurrent .ListenableFuture ;
3738
39+ import java .util .HashMap ;
3840import java .util .HashSet ;
3941import java .util .List ;
42+ import java .util .Map ;
4043import java .util .Objects ;
4144import java .util .Optional ;
4245import java .util .OptionalInt ;
@@ -71,9 +74,11 @@ public class SimpleNodeSelector
7174 private final NodeSelectionStats nodeSelectionStats ;
7275 private final NodeTaskMap nodeTaskMap ;
7376 private final boolean includeCoordinator ;
77+ private final boolean scheduleSplitsBasedOnTaskLoad ;
7478 private final AtomicReference <Supplier <NodeMap >> nodeMap ;
7579 private final int minCandidates ;
7680 private final long maxSplitsWeightPerNode ;
81+ private final long maxSplitsWeightPerTask ;
7782 private final long maxPendingSplitsWeightPerTask ;
7883 private final int maxUnacknowledgedSplitsPerTask ;
7984 private final int maxTasksPerStage ;
@@ -84,9 +89,11 @@ public SimpleNodeSelector(
8489 NodeSelectionStats nodeSelectionStats ,
8590 NodeTaskMap nodeTaskMap ,
8691 boolean includeCoordinator ,
92+ boolean scheduleSplitsBasedOnTaskLoad ,
8793 Supplier <NodeMap > nodeMap ,
8894 int minCandidates ,
8995 long maxSplitsWeightPerNode ,
96+ long maxSplitsWeightPerTask ,
9097 long maxPendingSplitsWeightPerTask ,
9198 int maxUnacknowledgedSplitsPerTask ,
9299 int maxTasksPerStage ,
@@ -96,9 +103,11 @@ public SimpleNodeSelector(
96103 this .nodeSelectionStats = requireNonNull (nodeSelectionStats , "nodeSelectionStats is null" );
97104 this .nodeTaskMap = requireNonNull (nodeTaskMap , "nodeTaskMap is null" );
98105 this .includeCoordinator = includeCoordinator ;
106+ this .scheduleSplitsBasedOnTaskLoad = scheduleSplitsBasedOnTaskLoad ;
99107 this .nodeMap = new AtomicReference <>(nodeMap );
100108 this .minCandidates = minCandidates ;
101109 this .maxSplitsWeightPerNode = maxSplitsWeightPerNode ;
110+ this .maxSplitsWeightPerTask = maxSplitsWeightPerTask ;
102111 this .maxPendingSplitsWeightPerTask = maxPendingSplitsWeightPerTask ;
103112 this .maxUnacknowledgedSplitsPerTask = maxUnacknowledgedSplitsPerTask ;
104113 checkArgument (maxUnacknowledgedSplitsPerTask > 0 , "maxUnacknowledgedSplitsPerTask must be > 0, found: %s" , maxUnacknowledgedSplitsPerTask );
@@ -149,6 +158,11 @@ public SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTas
149158 Set <InternalNode > blockedExactNodes = new HashSet <>();
150159 boolean splitWaitingForAnyNode = false ;
151160
161+ Optional <ToLongFunction <InternalNode >> taskLoadSplitWeightProvider = Optional .empty ();
162+ if (this .scheduleSplitsBasedOnTaskLoad ) {
163+ taskLoadSplitWeightProvider = Optional .of (createTaskLoadSplitWeightProvider (existingTasks , assignmentStats ));
164+ }
165+
152166 NodeProvider nodeProvider = nodeMap .getNodeProvider (maxPreferredNodes );
153167 OptionalInt preferredNodeCount = OptionalInt .empty ();
154168 for (Split split : splits ) {
@@ -179,9 +193,16 @@ public SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTas
179193 }
180194
181195 SplitWeight splitWeight = split .getSplitWeight ();
182- Optional <InternalNodeInfo > chosenNodeInfo = chooseLeastBusyNode (splitWeight , candidateNodes , assignmentStats ::getTotalSplitsWeight , preferredNodeCount , maxSplitsWeightPerNode , assignmentStats );
183- if (!chosenNodeInfo .isPresent ()) {
184- chosenNodeInfo = chooseLeastBusyNode (splitWeight , candidateNodes , assignmentStats ::getQueuedSplitsWeightForStage , preferredNodeCount , maxPendingSplitsWeightPerTask , assignmentStats );
196+ Optional <InternalNodeInfo > chosenNodeInfo = Optional .empty ();
197+
198+ if (taskLoadSplitWeightProvider .isPresent ()) {
199+ chosenNodeInfo = chooseLeastBusyNode (splitWeight , candidateNodes , taskLoadSplitWeightProvider .get (), preferredNodeCount , maxSplitsWeightPerTask , assignmentStats );
200+ }
201+ else {
202+ chosenNodeInfo = chooseLeastBusyNode (splitWeight , candidateNodes , assignmentStats ::getTotalSplitsWeight , preferredNodeCount , maxSplitsWeightPerNode , assignmentStats );
203+ if (!chosenNodeInfo .isPresent ()) {
204+ chosenNodeInfo = chooseLeastBusyNode (splitWeight , candidateNodes , assignmentStats ::getQueuedSplitsWeightForStage , preferredNodeCount , maxPendingSplitsWeightPerTask , assignmentStats );
205+ }
185206 }
186207
187208 if (chosenNodeInfo .isPresent ()) {
@@ -223,6 +244,28 @@ public SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTas
223244 return selectDistributionNodes (nodeMap .get ().get (), nodeTaskMap , maxSplitsWeightPerNode , maxPendingSplitsWeightPerTask , maxUnacknowledgedSplitsPerTask , splits , existingTasks , bucketNodeMap , nodeSelectionStats );
224245 }
225246
247+ private ToLongFunction <InternalNode > createTaskLoadSplitWeightProvider (List <RemoteTask > existingTasks , NodeAssignmentStats assignmentStats )
248+ {
249+ // Create a map from nodeId to RemoteTask for efficient lookup
250+ Map <String , RemoteTask > tasksByNodeId = new HashMap <>();
251+ for (RemoteTask task : existingTasks ) {
252+ tasksByNodeId .put (task .getNodeId (), task );
253+ }
254+
255+ return node -> {
256+ RemoteTask remoteTask = tasksByNodeId .get (node .getNodeIdentifier ());
257+ if (remoteTask == null ) {
258+ // No task for this node, return only the queued splits weight for the stage
259+ return assignmentStats .getAssignedSplitsWeightForStage (node );
260+ }
261+
262+ TaskStatus taskStatus = remoteTask .getTaskStatus ();
263+ return taskStatus .getQueuedPartitionedSplitsWeight () +
264+ taskStatus .getRunningPartitionedSplitsWeight () +
265+ assignmentStats .getAssignedSplitsWeightForStage (node );
266+ };
267+ }
268+
226269 protected Optional <InternalNodeInfo > chooseLeastBusyNode (SplitWeight splitWeight , List <InternalNode > candidateNodes , ToLongFunction <InternalNode > splitWeightProvider , OptionalInt preferredNodeCount , long maxSplitsWeight , NodeAssignmentStats assignmentStats )
227270 {
228271 long minWeight = Long .MAX_VALUE ;
0 commit comments