Skip to content

Commit df5577f

Browse files
authored
Support gang scheduling with Apache YuniKorn (#2396)
1 parent 6786350 commit df5577f

File tree

5 files changed

+363
-16
lines changed

5 files changed

+363
-16
lines changed

ray-operator/controllers/ray/batchscheduler/yunikorn/yunikorn_scheduler.go

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@ import (
1212

1313
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
1414
schedulerinterface "github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler/interface"
15+
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
1516
)
1617

1718
const (
18-
SchedulerName string = "yunikorn"
19-
YuniKornPodApplicationIDLabelName string = "applicationId"
20-
YuniKornPodQueueLabelName string = "queue"
21-
RayClusterApplicationIDLabelName string = "yunikorn.apache.org/application-id"
22-
RayClusterQueueLabelName string = "yunikorn.apache.org/queue-name"
19+
SchedulerName string = "yunikorn"
20+
YuniKornPodApplicationIDLabelName string = "applicationId"
21+
YuniKornPodQueueLabelName string = "queue"
22+
RayClusterApplicationIDLabelName string = "yunikorn.apache.org/app-id"
23+
RayClusterQueueLabelName string = "yunikorn.apache.org/queue"
24+
YuniKornTaskGroupNameAnnotationName string = "yunikorn.apache.org/task-group-name"
25+
YuniKornTaskGroupsAnnotationName string = "yunikorn.apache.org/task-groups"
2326
)
2427

2528
type YuniKornScheduler struct {
@@ -42,19 +45,66 @@ func (y *YuniKornScheduler) DoBatchSchedulingOnSubmission(_ context.Context, _ *
4245
return nil
4346
}
4447

48+
// populatePodLabels is a helper function that copies RayCluster's label to the given pod based on the label key
49+
// TODO: remove the legacy labels, i.e "applicationId" and "queue", directly populate labels
50+
// RayClusterApplicationIDLabelName to RayClusterQueueLabelName to pod labels.
51+
// Currently we use this function to translate labels "yunikorn.apache.org/app-id" and "yunikorn.apache.org/queue"
52+
// to legacy labels "applicationId" and "queue", this is for the better compatibilities to support older yunikorn
53+
// versions.
4554
func (y *YuniKornScheduler) populatePodLabels(app *rayv1.RayCluster, pod *corev1.Pod, sourceKey string, targetKey string) {
4655
// check labels
4756
if value, exist := app.Labels[sourceKey]; exist {
48-
y.log.Info("Updating pod label based on RayCluster annotations",
57+
y.log.Info("Updating pod label based on RayCluster labels",
4958
"sourceKey", sourceKey, "targetKey", targetKey, "value", value)
5059
pod.Labels[targetKey] = value
5160
}
5261
}
5362

54-
func (y *YuniKornScheduler) AddMetadataToPod(app *rayv1.RayCluster, _ string, pod *corev1.Pod) {
63+
// AddMetadataToPod adds essential labels and annotations to the Ray pods
64+
// the yunikorn scheduler needs these labels and annotations in order to do the scheduling properly
65+
func (y *YuniKornScheduler) AddMetadataToPod(app *rayv1.RayCluster, groupName string, pod *corev1.Pod) {
66+
// the applicationID and queue name must be provided in the labels
5567
y.populatePodLabels(app, pod, RayClusterApplicationIDLabelName, YuniKornPodApplicationIDLabelName)
5668
y.populatePodLabels(app, pod, RayClusterQueueLabelName, YuniKornPodQueueLabelName)
5769
pod.Spec.SchedulerName = y.Name()
70+
71+
// when gang scheduling is enabled, extra annotations need to be added to all pods
72+
if y.isGangSchedulingEnabled(app) {
73+
// populate the taskGroups info to each pod
74+
y.populateTaskGroupsAnnotationToPod(app, pod)
75+
76+
// set the task group name based on the head or worker group name
77+
// the group name for the head and each of the worker group should be different
78+
pod.Annotations[YuniKornTaskGroupNameAnnotationName] = groupName
79+
}
80+
}
81+
82+
func (y *YuniKornScheduler) isGangSchedulingEnabled(app *rayv1.RayCluster) bool {
83+
_, exist := app.Labels[utils.RayClusterGangSchedulingEnabled]
84+
return exist
85+
}
86+
87+
func (y *YuniKornScheduler) populateTaskGroupsAnnotationToPod(app *rayv1.RayCluster, pod *corev1.Pod) {
88+
taskGroups := newTaskGroupsFromApp(app)
89+
taskGroupsAnnotationValue, err := taskGroups.marshal()
90+
if err != nil {
91+
y.log.Error(err, "failed to add gang scheduling related annotations to pod, "+
92+
"gang scheduling will not be enabled for this workload",
93+
"rayCluster", app.Name, "name", pod.Name, "namespace", pod.Namespace)
94+
return
95+
}
96+
97+
y.log.Info("add task groups info to pod's annotation",
98+
"key", YuniKornTaskGroupsAnnotationName,
99+
"value", taskGroupsAnnotationValue,
100+
"numOfTaskGroups", taskGroups.size())
101+
if pod.Annotations == nil {
102+
pod.Annotations = make(map[string]string)
103+
}
104+
pod.Annotations[YuniKornTaskGroupsAnnotationName] = taskGroupsAnnotationValue
105+
106+
y.log.Info("Gang Scheduling enabled for RayCluster",
107+
"RayCluster", app.Name, "Namespace", app.Namespace)
58108
}
59109

60110
func (yf *YuniKornSchedulerFactory) New(_ *rest.Config) (schedulerinterface.BatchScheduler, error) {

ray-operator/controllers/ray/batchscheduler/yunikorn/yunikorn_scheduler_test.go

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
package yunikorn
22

33
import (
4+
"encoding/json"
5+
"fmt"
46
"testing"
57

8+
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
9+
610
"github.com/stretchr/testify/assert"
711
v1 "k8s.io/api/core/v1"
12+
"k8s.io/apimachinery/pkg/api/resource"
813
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
914

1015
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
@@ -41,7 +46,7 @@ func TestPopulatePodLabels(t *testing.T) {
4146

4247
rayCluster2 := createRayClusterWithLabels(
4348
"ray-cluster-without-labels",
44-
"test",
49+
"test1",
4550
nil, // empty labels
4651
)
4752
rayPod3 := createPod("my-pod-2", "test")
@@ -51,6 +56,114 @@ func TestPopulatePodLabels(t *testing.T) {
5156
assert.Equal(t, podLabelsContains(rayPod3, YuniKornPodQueueLabelName, queue2), false)
5257
}
5358

59+
func TestIsGangSchedulingEnabled(t *testing.T) {
60+
yk := &YuniKornScheduler{}
61+
62+
job1 := "job-1-01234"
63+
queue1 := "root.default"
64+
rayCluster1 := createRayClusterWithLabels(
65+
"ray-cluster-with-gang-scheduling",
66+
"test1",
67+
map[string]string{
68+
RayClusterApplicationIDLabelName: job1,
69+
RayClusterQueueLabelName: queue1,
70+
utils.RayClusterGangSchedulingEnabled: "true",
71+
},
72+
)
73+
74+
assert.Equal(t, yk.isGangSchedulingEnabled(rayCluster1), true)
75+
76+
rayCluster2 := createRayClusterWithLabels(
77+
"ray-cluster-with-gang-scheduling",
78+
"test2",
79+
map[string]string{
80+
RayClusterApplicationIDLabelName: job1,
81+
RayClusterQueueLabelName: queue1,
82+
utils.RayClusterGangSchedulingEnabled: "",
83+
},
84+
)
85+
86+
assert.Equal(t, yk.isGangSchedulingEnabled(rayCluster2), true)
87+
88+
rayCluster3 := createRayClusterWithLabels(
89+
"ray-cluster-with-gang-scheduling",
90+
"test3",
91+
map[string]string{
92+
RayClusterApplicationIDLabelName: job1,
93+
RayClusterQueueLabelName: queue1,
94+
},
95+
)
96+
97+
assert.Equal(t, yk.isGangSchedulingEnabled(rayCluster3), false)
98+
}
99+
100+
func TestPopulateGangSchedulingAnnotations(t *testing.T) {
101+
yk := &YuniKornScheduler{}
102+
103+
job1 := "job-1-01234"
104+
queue1 := "root.default"
105+
106+
// test the case when gang-scheduling is enabled
107+
rayClusterWithGangScheduling := createRayClusterWithLabels(
108+
"ray-cluster-with-gang-scheduling",
109+
"test3",
110+
map[string]string{
111+
RayClusterApplicationIDLabelName: job1,
112+
RayClusterQueueLabelName: queue1,
113+
utils.RayClusterGangSchedulingEnabled: "true",
114+
},
115+
)
116+
117+
// head pod:
118+
// cpu: 5
119+
// memory: 5Gi
120+
addHeadPodSpec(rayClusterWithGangScheduling, v1.ResourceList{
121+
v1.ResourceCPU: resource.MustParse("5"),
122+
v1.ResourceMemory: resource.MustParse("5Gi"),
123+
})
124+
125+
// worker pod:
126+
// cpu: 2
127+
// memory: 10Gi
128+
// nvidia.com/gpu: 1
129+
addWorkerPodSpec(rayClusterWithGangScheduling,
130+
"worker-group-1", 1, 1, 2, v1.ResourceList{
131+
v1.ResourceCPU: resource.MustParse("2"),
132+
v1.ResourceMemory: resource.MustParse("10Gi"),
133+
"nvidia.com/gpu": resource.MustParse("1"),
134+
})
135+
136+
// gang-scheduling enabled case, the plugin should populate the taskGroup annotation to the app
137+
rayPod := createPod("ray-pod", "default")
138+
yk.populateTaskGroupsAnnotationToPod(rayClusterWithGangScheduling, rayPod)
139+
140+
kk, err := getTaskGroupsFromAnnotation(rayPod)
141+
assert.NoError(t, err)
142+
assert.Equal(t, len(kk), 2)
143+
// verify the annotation value
144+
taskGroupsSpec := rayPod.Annotations[YuniKornTaskGroupsAnnotationName]
145+
assert.Equal(t, true, len(taskGroupsSpec) > 0)
146+
taskGroups := newTaskGroups()
147+
err = taskGroups.unmarshalFrom(taskGroupsSpec)
148+
assert.NoError(t, err)
149+
assert.Equal(t, len(taskGroups.Groups), 2)
150+
151+
// verify the correctness of head group
152+
headGroup := taskGroups.getTaskGroup(utils.RayNodeHeadGroupLabelValue)
153+
assert.NotNil(t, headGroup)
154+
assert.Equal(t, int32(1), headGroup.MinMember)
155+
assert.Equal(t, resource.MustParse("5"), headGroup.MinResource[v1.ResourceCPU.String()])
156+
assert.Equal(t, resource.MustParse("5Gi"), headGroup.MinResource[v1.ResourceMemory.String()])
157+
158+
// verify the correctness of worker group
159+
workerGroup := taskGroups.getTaskGroup("worker-group-1")
160+
assert.NotNil(t, workerGroup)
161+
assert.Equal(t, int32(1), workerGroup.MinMember)
162+
assert.Equal(t, resource.MustParse("2"), workerGroup.MinResource[v1.ResourceCPU.String()])
163+
assert.Equal(t, resource.MustParse("10Gi"), workerGroup.MinResource[v1.ResourceMemory.String()])
164+
assert.Equal(t, resource.MustParse("1"), workerGroup.MinResource["nvidia.com/gpu"])
165+
}
166+
54167
func createRayClusterWithLabels(name string, namespace string, labels map[string]string) *rayv1.RayCluster {
55168
rayCluster := &rayv1.RayCluster{
56169
ObjectMeta: metav1.ObjectMeta{
@@ -63,6 +176,49 @@ func createRayClusterWithLabels(name string, namespace string, labels map[string
63176
return rayCluster
64177
}
65178

179+
func addHeadPodSpec(app *rayv1.RayCluster, resource v1.ResourceList) {
180+
// app.Spec.HeadGroupSpec.Template.Spec.Containers
181+
headContainers := []v1.Container{
182+
{
183+
Name: "head-pod",
184+
Image: "ray.io/ray-head:latest",
185+
Resources: v1.ResourceRequirements{
186+
Limits: nil,
187+
Requests: resource,
188+
},
189+
},
190+
}
191+
192+
app.Spec.HeadGroupSpec.Template.Spec.Containers = headContainers
193+
}
194+
195+
func addWorkerPodSpec(app *rayv1.RayCluster, workerGroupName string,
196+
replicas int32, minReplicas int32, maxReplicas int32, resources v1.ResourceList,
197+
) {
198+
workerContainers := []v1.Container{
199+
{
200+
Name: "worker-pod",
201+
Image: "ray.io/ray-head:latest",
202+
Resources: v1.ResourceRequirements{
203+
Limits: nil,
204+
Requests: resources,
205+
},
206+
},
207+
}
208+
209+
app.Spec.WorkerGroupSpecs = append(app.Spec.WorkerGroupSpecs, rayv1.WorkerGroupSpec{
210+
GroupName: workerGroupName,
211+
Replicas: &replicas,
212+
MinReplicas: &minReplicas,
213+
MaxReplicas: &maxReplicas,
214+
Template: v1.PodTemplateSpec{
215+
Spec: v1.PodSpec{
216+
Containers: workerContainers,
217+
},
218+
},
219+
})
220+
}
221+
66222
func createPod(name string, namespace string) *v1.Pod {
67223
return &v1.Pod{
68224
ObjectMeta: metav1.ObjectMeta{
@@ -90,3 +246,36 @@ func podLabelsContains(pod *v1.Pod, key string, value string) bool {
90246

91247
return false
92248
}
249+
250+
func getTaskGroupsFromAnnotation(pod *v1.Pod) ([]TaskGroup, error) {
251+
taskGroupInfo, exist := pod.Annotations[YuniKornTaskGroupsAnnotationName]
252+
if !exist {
253+
return nil, fmt.Errorf("not found")
254+
}
255+
256+
taskGroups := []TaskGroup{}
257+
err := json.Unmarshal([]byte(taskGroupInfo), &taskGroups)
258+
if err != nil {
259+
return nil, err
260+
}
261+
// json.Unmarshal won't return error if name or MinMember is empty, but will return error if MinResource is empty or error format.
262+
for _, taskGroup := range taskGroups {
263+
if taskGroup.Name == "" {
264+
return nil, fmt.Errorf("can't get taskGroup Name from pod annotation, %s",
265+
taskGroupInfo)
266+
}
267+
if taskGroup.MinResource == nil {
268+
return nil, fmt.Errorf("can't get taskGroup MinResource from pod annotation, %s",
269+
taskGroupInfo)
270+
}
271+
if taskGroup.MinMember == int32(0) {
272+
return nil, fmt.Errorf("can't get taskGroup MinMember from pod annotation, %s",
273+
taskGroupInfo)
274+
}
275+
if taskGroup.MinMember < int32(0) {
276+
return nil, fmt.Errorf("minMember cannot be negative, %s",
277+
taskGroupInfo)
278+
}
279+
}
280+
return taskGroups, nil
281+
}

0 commit comments

Comments
 (0)