Skip to content

Commit 2a1f87f

Browse files
committed
Fix: Replica validation logic in right place.
1. Remove validation logic from GetWorkerGroupDesiredReplicas (utils.go) and add this logic to ValidateRayClusterSpec (validation.go). 2. Remove unnecessary tests from TestGetWorkerGroupDesiredReplicas. 3. Remove the unused ctx.
1 parent 7ef6196 commit 2a1f87f

File tree

8 files changed

+49
-55
lines changed

8 files changed

+49
-55
lines changed

ray-operator/controllers/ray/batchscheduler/scheduler-plugins/scheduler_plugins.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func (k *KubeScheduler) Name() string {
4545
return schedulerInstanceName
4646
}
4747

48-
func createPodGroup(ctx context.Context, app *rayv1.RayCluster) *v1alpha1.PodGroup {
48+
func createPodGroup(app *rayv1.RayCluster) *v1alpha1.PodGroup {
4949
// TODO(troychiu): Consider the case when autoscaling is enabled.
5050

5151
podGroup := &v1alpha1.PodGroup{
@@ -62,7 +62,7 @@ func createPodGroup(ctx context.Context, app *rayv1.RayCluster) *v1alpha1.PodGro
6262
},
6363
},
6464
Spec: v1alpha1.PodGroupSpec{
65-
MinMember: utils.CalculateDesiredReplicas(ctx, app) + 1, // +1 for the head pod
65+
MinMember: utils.CalculateDesiredReplicas(app) + 1, // +1 for the head pod
6666
MinResources: utils.CalculateDesiredResources(app),
6767
},
6868
}
@@ -82,7 +82,7 @@ func (k *KubeScheduler) DoBatchSchedulingOnSubmission(ctx context.Context, objec
8282
if !errors.IsNotFound(err) {
8383
return err
8484
}
85-
podGroup = createPodGroup(ctx, app)
85+
podGroup = createPodGroup(app)
8686
if err := k.cli.Create(ctx, podGroup); err != nil {
8787
if errors.IsAlreadyExists(err) {
8888
return nil

ray-operator/controllers/ray/batchscheduler/scheduler-plugins/scheduler_plugins_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func TestCreatePodGroup(t *testing.T) {
8282

8383
cluster := createTestRayCluster(1)
8484

85-
podGroup := createPodGroup(context.TODO(), &cluster)
85+
podGroup := createPodGroup(&cluster)
8686

8787
// 256m * 3 (requests, not limits)
8888
a.Equal("768m", podGroup.Spec.MinResources.Cpu().String())
@@ -102,7 +102,7 @@ func TestCreatePodGroupWithMultipleHosts(t *testing.T) {
102102

103103
cluster := createTestRayCluster(2) // 2 hosts
104104

105-
podGroup := createPodGroup(context.TODO(), &cluster)
105+
podGroup := createPodGroup(&cluster)
106106

107107
// 256m * 5 (requests, not limits)
108108
a.Equal("1280m", podGroup.Spec.MinResources.Cpu().String())

ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (v *VolcanoBatchScheduler) handleRayCluster(ctx context.Context, raycluster
6262
return nil
6363
}
6464

65-
minMember, totalResource := v.calculatePodGroupParams(ctx, &raycluster.Spec)
65+
minMember, totalResource := v.calculatePodGroupParams(&raycluster.Spec)
6666

6767
return v.syncPodGroup(ctx, raycluster, minMember, totalResource)
6868
}
@@ -74,7 +74,7 @@ func (v *VolcanoBatchScheduler) handleRayJob(ctx context.Context, rayJob *rayv1.
7474
}
7575

7676
var totalResourceList []corev1.ResourceList
77-
minMember, totalResource := v.calculatePodGroupParams(ctx, rayJob.Spec.RayClusterSpec)
77+
minMember, totalResource := v.calculatePodGroupParams(rayJob.Spec.RayClusterSpec)
7878
totalResourceList = append(totalResourceList, totalResource)
7979

8080
// MinMember intentionally excludes the submitter pod to avoid a startup deadlock
@@ -186,11 +186,11 @@ func (v *VolcanoBatchScheduler) syncPodGroup(ctx context.Context, owner metav1.O
186186
return nil
187187
}
188188

189-
func (v *VolcanoBatchScheduler) calculatePodGroupParams(ctx context.Context, rayClusterSpec *rayv1.RayClusterSpec) (int32, corev1.ResourceList) {
189+
func (v *VolcanoBatchScheduler) calculatePodGroupParams(rayClusterSpec *rayv1.RayClusterSpec) (int32, corev1.ResourceList) {
190190
rayCluster := &rayv1.RayCluster{Spec: *rayClusterSpec}
191191

192192
if !utils.IsAutoscalingEnabled(rayClusterSpec) {
193-
return utils.CalculateDesiredReplicas(ctx, rayCluster) + 1, utils.CalculateDesiredResources(rayCluster)
193+
return utils.CalculateDesiredReplicas(rayCluster) + 1, utils.CalculateDesiredResources(rayCluster)
194194
}
195195
return utils.CalculateMinReplicas(rayCluster) + 1, utils.CalculateMinResources(rayCluster)
196196
}

ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func TestCreatePodGroupForRayCluster(t *testing.T) {
160160

161161
cluster := createTestRayCluster(1)
162162

163-
minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
163+
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
164164
totalResource := utils.CalculateDesiredResources(&cluster)
165165
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
166166
require.NoError(t, err)
@@ -185,7 +185,7 @@ func TestCreatePodGroupForRayCluster_NumOfHosts2(t *testing.T) {
185185

186186
cluster := createTestRayCluster(2)
187187

188-
minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
188+
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
189189
totalResource := utils.CalculateDesiredResources(&cluster)
190190
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
191191
require.NoError(t, err)
@@ -227,7 +227,7 @@ func TestCreatePodGroup_NetworkTopologyBothLabels(t *testing.T) {
227227
NetworkTopologyHighestTierAllowedLabelKey: "3",
228228
})
229229

230-
minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
230+
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
231231
totalResource := utils.CalculateDesiredResources(&cluster)
232232
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
233233
require.NoError(t, err)
@@ -246,7 +246,7 @@ func TestCreatePodGroup_NetworkTopologyOnlyModeLabel(t *testing.T) {
246246
NetworkTopologyModeLabelKey: "hard",
247247
})
248248

249-
minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
249+
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
250250
totalResource := utils.CalculateDesiredResources(&cluster)
251251
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
252252
require.NoError(t, err)
@@ -266,7 +266,7 @@ func TestCreatePodGroup_NetworkTopologyHighestTierAllowedNotInt(t *testing.T) {
266266
NetworkTopologyHighestTierAllowedLabelKey: "not-an-int",
267267
})
268268

269-
minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
269+
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
270270
totalResource := utils.CalculateDesiredResources(&cluster)
271271
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
272272

@@ -474,7 +474,7 @@ func TestCalculatePodGroupParams(t *testing.T) {
474474
t.Run("Autoscaling disabled", func(_ *testing.T) {
475475
cluster := createTestRayCluster(1)
476476

477-
minMember, totalResource := scheduler.calculatePodGroupParams(context.Background(), &cluster.Spec)
477+
minMember, totalResource := scheduler.calculatePodGroupParams(&cluster.Spec)
478478

479479
// 1 head + 2 workers (desired replicas)
480480
a.Equal(int32(3), minMember)
@@ -490,7 +490,7 @@ func TestCalculatePodGroupParams(t *testing.T) {
490490
cluster := createTestRayCluster(1)
491491
cluster.Spec.EnableInTreeAutoscaling = ptr.To(true)
492492

493-
minMember, totalResource := scheduler.calculatePodGroupParams(context.Background(), &cluster.Spec)
493+
minMember, totalResource := scheduler.calculatePodGroupParams(&cluster.Spec)
494494

495495
// 1 head + 1 worker (min replicas)
496496
a.Equal(int32(2), minMember)

ray-operator/controllers/ray/raycluster_controller.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
752752
continue
753753
}
754754
// workerReplicas will store the target number of pods for this worker group.
755-
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, worker))
755+
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(worker))
756756
logger.Info("reconcilePods", "desired workerReplicas (always adhering to minReplicas/maxReplica)", numExpectedWorkerPods, "worker group", worker.GroupName, "maxReplicas", worker.MaxReplicas, "minReplicas", worker.MinReplicas, "replicas", worker.Replicas)
757757

758758
workerPods := corev1.PodList{}
@@ -1051,7 +1051,7 @@ func (r *RayClusterReconciler) reconcileMultiHostWorkerGroup(ctx context.Context
10511051
}
10521052
}
10531053
numRunningReplicas := len(validReplicaGroups)
1054-
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, *worker))
1054+
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(*worker))
10551055

10561056
// Ensure that if numExpectedWorkerPods is not a multiple of NumOfHosts, we log an error.
10571057
if numExpectedWorkerPods%int(worker.NumOfHosts) != 0 {
@@ -1580,7 +1580,7 @@ func (r *RayClusterReconciler) calculateStatus(ctx context.Context, instance *ra
15801580

15811581
newInstance.Status.ReadyWorkerReplicas = utils.CalculateReadyReplicas(runtimePods)
15821582
newInstance.Status.AvailableWorkerReplicas = utils.CalculateAvailableReplicas(runtimePods)
1583-
newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(ctx, newInstance)
1583+
newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(newInstance)
15841584
newInstance.Status.MinWorkerReplicas = utils.CalculateMinReplicas(newInstance)
15851585
newInstance.Status.MaxWorkerReplicas = utils.CalculateMaxReplicas(newInstance)
15861586

ray-operator/controllers/ray/utils/util.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -384,18 +384,13 @@ func GenerateIdentifier(clusterName string, nodeType rayv1.RayNodeType) string {
384384
return fmt.Sprintf("%s-%s", clusterName, nodeType)
385385
}
386386

387-
func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.WorkerGroupSpec) int32 {
388-
log := ctrl.LoggerFrom(ctx)
387+
func GetWorkerGroupDesiredReplicas(workerGroupSpec rayv1.WorkerGroupSpec) int32 {
389388
// Always adhere to min/max replicas constraints.
390389
var workerReplicas int32
391390
if workerGroupSpec.Suspend != nil && *workerGroupSpec.Suspend {
392391
return 0
393392
}
394-
if *workerGroupSpec.MinReplicas > *workerGroupSpec.MaxReplicas {
395-
log.Info("minReplicas is greater than maxReplicas, using maxReplicas as desired replicas. "+
396-
"Please fix this to avoid any unexpected behaviors.", "minReplicas", *workerGroupSpec.MinReplicas, "maxReplicas", *workerGroupSpec.MaxReplicas)
397-
workerReplicas = *workerGroupSpec.MaxReplicas
398-
} else if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < *workerGroupSpec.MinReplicas {
393+
if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < *workerGroupSpec.MinReplicas {
399394
// Replicas is impossible to be nil as it has a default value assigned in the CRD.
400395
// Add this check to make testing easier.
401396
workerReplicas = *workerGroupSpec.MinReplicas
@@ -408,10 +403,10 @@ func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.Wo
408403
}
409404

410405
// CalculateDesiredReplicas calculate desired worker replicas at the cluster level
411-
func CalculateDesiredReplicas(ctx context.Context, cluster *rayv1.RayCluster) int32 {
406+
func CalculateDesiredReplicas(cluster *rayv1.RayCluster) int32 {
412407
count := int32(0)
413408
for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs {
414-
count += GetWorkerGroupDesiredReplicas(ctx, nodeGroup)
409+
count += GetWorkerGroupDesiredReplicas(nodeGroup)
415410
}
416411

417412
return count

ray-operator/controllers/ray/utils/util_test.go

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,6 @@ func TestGenerateHeadServiceName(t *testing.T) {
553553
}
554554

555555
func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
556-
ctx := context.Background()
557556
// Test 1: `WorkerGroupSpec.Replicas` is nil.
558557
// `Replicas` is impossible to be nil in a real RayCluster CR as it has a default value assigned in the CRD.
559558
numOfHosts := int32(1)
@@ -565,37 +564,21 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
565564
MinReplicas: &minReplicas,
566565
MaxReplicas: &maxReplicas,
567566
}
568-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)
567+
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), minReplicas)
569568

570569
// Test 2: `WorkerGroupSpec.Replicas` is not nil and is within the range.
571570
replicas := int32(3)
572571
workerGroupSpec.Replicas = &replicas
573-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas)
572+
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas)
574573

575-
// Test 3: `WorkerGroupSpec.Replicas` is not nil but is more than maxReplicas.
576-
replicas = int32(6)
577-
workerGroupSpec.Replicas = &replicas
578-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), maxReplicas)
579-
580-
// Test 4: `WorkerGroupSpec.Replicas` is not nil but is less than minReplicas.
581-
replicas = int32(0)
582-
workerGroupSpec.Replicas = &replicas
583-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)
584-
585-
// Test 5: `WorkerGroupSpec.Replicas` is nil and minReplicas is less than maxReplicas.
586-
workerGroupSpec.Replicas = nil
587-
workerGroupSpec.MinReplicas = &maxReplicas
588-
workerGroupSpec.MaxReplicas = &minReplicas
589-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), *workerGroupSpec.MaxReplicas)
590-
591-
// Test 6: `WorkerGroupSpec.Suspend` is true.
574+
// Test 3: `WorkerGroupSpec.Suspend` is true.
592575
suspend := true
593576
workerGroupSpec.MinReplicas = &maxReplicas
594577
workerGroupSpec.MaxReplicas = &minReplicas
595578
workerGroupSpec.Suspend = &suspend
596-
assert.Zero(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec))
579+
assert.Zero(t, GetWorkerGroupDesiredReplicas(workerGroupSpec))
597580

598-
// Test 7: `WorkerGroupSpec.NumOfHosts` is 4.
581+
// Test 4: `WorkerGroupSpec.NumOfHosts` is 4.
599582
numOfHosts = int32(4)
600583
replicas = int32(5)
601584
suspend = false
@@ -604,7 +587,7 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
604587
workerGroupSpec.Suspend = &suspend
605588
workerGroupSpec.MinReplicas = &minReplicas
606589
workerGroupSpec.MaxReplicas = &maxReplicas
607-
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas*numOfHosts)
590+
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas*numOfHosts)
608591
}
609592

610593
func TestCalculateMinAndMaxReplicas(t *testing.T) {
@@ -801,7 +784,7 @@ func TestCalculateDesiredReplicas(t *testing.T) {
801784
},
802785
},
803786
}
804-
assert.Equal(t, CalculateDesiredReplicas(context.Background(), &cluster), tc.answer)
787+
assert.Equal(t, CalculateDesiredReplicas(&cluster), tc.answer)
805788
})
806789
}
807790
}

ray-operator/controllers/ray/utils/validation.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,30 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s
111111
return err
112112
}
113113

114+
// Check if autoscaling is enabled once to avoid repeated calls
115+
isAutoscalingEnabled := IsAutoscalingEnabled(spec)
116+
114117
for _, workerGroup := range spec.WorkerGroupSpecs {
115118
if len(workerGroup.Template.Spec.Containers) == 0 {
116119
return fmt.Errorf("workerGroupSpec should have at least one container")
117120
}
118121

122+
// When autoscaling is enabled, MinReplicas and MaxReplicas are optional
123+
// as users can manually update them and the autoscaler will handle the adjustment.
124+
if !isAutoscalingEnabled && (workerGroup.MinReplicas == nil || workerGroup.MaxReplicas == nil) {
125+
return fmt.Errorf("worker group %s must set both minReplicas and maxReplicas when autoscaling is disabled", workerGroup.GroupName)
126+
}
127+
if workerGroup.MinReplicas != nil && *workerGroup.MinReplicas < 0 {
128+
return fmt.Errorf("worker group %s has negative minReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas)
129+
}
130+
if workerGroup.MaxReplicas != nil && *workerGroup.MaxReplicas < 0 {
131+
return fmt.Errorf("worker group %s has negative maxReplicas %d", workerGroup.GroupName, *workerGroup.MaxReplicas)
132+
}
133+
if workerGroup.MinReplicas != nil && workerGroup.MaxReplicas != nil {
134+
if *workerGroup.MinReplicas > *workerGroup.MaxReplicas {
135+
return fmt.Errorf("worker group %s has minReplicas %d greater than maxReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas, *workerGroup.MaxReplicas)
136+
}
137+
}
119138
if err := validateRayGroupResources(workerGroup.GroupName, workerGroup.RayStartParams, workerGroup.Resources); err != nil {
120139
return err
121140
}
@@ -175,9 +194,6 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s
175194
}
176195
}
177196

178-
// Check if autoscaling is enabled once to avoid repeated calls
179-
isAutoscalingEnabled := IsAutoscalingEnabled(spec)
180-
181197
// Validate that RAY_enable_autoscaler_v2 environment variable is not set to "1" or "true" when autoscaler is disabled
182198
if !isAutoscalingEnabled {
183199
if envVar, exists := EnvVarByName(RAY_ENABLE_AUTOSCALER_V2, spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env); exists {

0 commit comments

Comments
 (0)