Skip to content

Commit 7d73dd1

Browse files
committed
Test: Add unit test for replica validation.
This is added since we moved the validation logic.
1 parent 17f7099 commit 7d73dd1

File tree

2 files changed

+107
-18
lines changed

2 files changed

+107
-18
lines changed

ray-operator/controllers/ray/utils/util.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -387,15 +387,17 @@ func GenerateIdentifier(clusterName string, nodeType rayv1.RayNodeType) string {
387387
func GetWorkerGroupDesiredReplicas(workerGroupSpec rayv1.WorkerGroupSpec) int32 {
388388
// Always adhere to min/max replicas constraints.
389389
var workerReplicas int32
390+
minReplicas := ptr.Deref(workerGroupSpec.MinReplicas, int32(0))
391+
maxReplicas := ptr.Deref(workerGroupSpec.MaxReplicas, int32(math.MaxInt32))
390392
if workerGroupSpec.Suspend != nil && *workerGroupSpec.Suspend {
391393
return 0
392394
}
393-
if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < *workerGroupSpec.MinReplicas {
395+
if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < minReplicas {
394396
// Replicas is impossible to be nil as it has a default value assigned in the CRD.
395397
// Add this check to make testing easier.
396-
workerReplicas = *workerGroupSpec.MinReplicas
397-
} else if *workerGroupSpec.Replicas > *workerGroupSpec.MaxReplicas {
398-
workerReplicas = *workerGroupSpec.MaxReplicas
398+
workerReplicas = minReplicas
399+
} else if *workerGroupSpec.Replicas > maxReplicas {
400+
workerReplicas = maxReplicas
399401
} else {
400402
workerReplicas = *workerGroupSpec.Replicas
401403
}
@@ -419,7 +421,8 @@ func CalculateMinReplicas(cluster *rayv1.RayCluster) int32 {
419421
if nodeGroup.Suspend != nil && *nodeGroup.Suspend {
420422
continue
421423
}
422-
count += (*nodeGroup.MinReplicas * nodeGroup.NumOfHosts)
424+
minReplicas := ptr.Deref(nodeGroup.MinReplicas, int32(0))
425+
count += (minReplicas * nodeGroup.NumOfHosts)
423426
}
424427

425428
return count
@@ -432,7 +435,8 @@ func CalculateMaxReplicas(cluster *rayv1.RayCluster) int32 {
432435
if nodeGroup.Suspend != nil && *nodeGroup.Suspend {
433436
continue
434437
}
435-
count += int64(*nodeGroup.MaxReplicas) * int64(nodeGroup.NumOfHosts)
438+
maxReplicas := ptr.Deref(nodeGroup.MaxReplicas, int32(math.MaxInt32))
439+
count += int64(maxReplicas) * int64(nodeGroup.NumOfHosts)
436440
}
437441

438442
return SafeInt64ToInt32(count)
@@ -494,11 +498,8 @@ func CalculateMinResources(cluster *rayv1.RayCluster) corev1.ResourceList {
494498
for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs {
495499
podResource := CalculatePodResource(nodeGroup.Template.Spec)
496500
calculateReplicaResource(&podResource, nodeGroup.NumOfHosts)
497-
minReplicas := int32(0)
498-
if nodeGroup.MinReplicas != nil {
499-
minReplicas = *nodeGroup.MinReplicas
500-
}
501-
for i := int32(0); i < minReplicas; i++ {
501+
minReplicas := ptr.Deref(nodeGroup.MinReplicas, int32(0))
502+
for range minReplicas {
502503
minResourcesList = append(minResourcesList, podResource)
503504
}
504505
}

ray-operator/controllers/ray/utils/validation_test.go

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ func TestValidateRayClusterSpecEmptyContainers(t *testing.T) {
395395
Template: podTemplateSpec(nil, nil),
396396
}
397397
workerGroupSpecWithOneContainer := rayv1.WorkerGroupSpec{
398-
Template: podTemplateSpec(nil, nil),
398+
Template: podTemplateSpec(nil, nil),
399+
MinReplicas: ptr.To(int32(0)),
400+
MaxReplicas: ptr.To(int32(5)),
399401
}
400402
headGroupSpecWithNoContainers := *headGroupSpecWithOneContainer.DeepCopy()
401403
headGroupSpecWithNoContainers.Template.Spec.Containers = []corev1.Container{}
@@ -459,8 +461,10 @@ func TestValidateRayClusterSpecSuspendingWorkerGroup(t *testing.T) {
459461
Template: podTemplateSpec(nil, nil),
460462
}
461463
workerGroupSpecSuspended := rayv1.WorkerGroupSpec{
462-
GroupName: "worker-group-1",
463-
Template: podTemplateSpec(nil, nil),
464+
GroupName: "worker-group-1",
465+
Template: podTemplateSpec(nil, nil),
466+
MinReplicas: ptr.To(int32(0)),
467+
MaxReplicas: ptr.To(int32(5)),
464468
}
465469
workerGroupSpecSuspended.Suspend = ptr.To(true)
466470

@@ -692,8 +696,10 @@ func TestValidateRayClusterSpec_Resources(t *testing.T) {
692696
},
693697
WorkerGroupSpecs: []rayv1.WorkerGroupSpec{
694698
{
695-
GroupName: "worker-group",
696-
Template: podTemplateSpec(nil, nil),
699+
GroupName: "worker-group",
700+
Template: podTemplateSpec(nil, nil),
701+
MinReplicas: ptr.To(int32(0)),
702+
MaxReplicas: ptr.To(int32(5)),
697703
},
698704
},
699705
}
@@ -773,8 +779,10 @@ func TestValidateRayClusterSpec_Labels(t *testing.T) {
773779
},
774780
WorkerGroupSpecs: []rayv1.WorkerGroupSpec{
775781
{
776-
GroupName: "worker-group",
777-
Template: podTemplateSpec(nil, nil),
782+
GroupName: "worker-group",
783+
Template: podTemplateSpec(nil, nil),
784+
MinReplicas: ptr.To(int32(0)),
785+
MaxReplicas: ptr.To(int32(5)),
778786
},
779787
},
780788
}
@@ -2473,3 +2481,83 @@ func TestValidateRayClusterUpgradeOptions(t *testing.T) {
24732481
})
24742482
}
24752483
}
2484+
2485+
func TestValidateRayClusterSpec_WorkerGroupReplicaValidation(t *testing.T) {
2486+
createSpec := func() rayv1.RayClusterSpec {
2487+
return rayv1.RayClusterSpec{
2488+
EnableInTreeAutoscaling: ptr.To(false),
2489+
HeadGroupSpec: rayv1.HeadGroupSpec{
2490+
Template: podTemplateSpec(nil, nil),
2491+
},
2492+
}
2493+
}
2494+
2495+
tests := []struct {
2496+
name string
2497+
errorMsg string
2498+
spec rayv1.RayClusterSpec
2499+
expectError bool
2500+
}{
2501+
{
2502+
name: "minReplicas greater than maxReplicas",
2503+
spec: func() rayv1.RayClusterSpec {
2504+
s := createSpec()
2505+
s.WorkerGroupSpecs = []rayv1.WorkerGroupSpec{
2506+
{
2507+
GroupName: "worker-group-3",
2508+
Template: podTemplateSpec(nil, nil),
2509+
MinReplicas: ptr.To(int32(5)),
2510+
MaxReplicas: ptr.To(int32(3)),
2511+
},
2512+
}
2513+
return s
2514+
}(),
2515+
expectError: true,
2516+
errorMsg: "worker group worker-group-3 has minReplicas 5 greater than maxReplicas 3",
2517+
},
2518+
{
2519+
name: "replicas smaller than minReplicas when autoscaling disabled",
2520+
spec: func() rayv1.RayClusterSpec {
2521+
s := createSpec()
2522+
s.WorkerGroupSpecs = []rayv1.WorkerGroupSpec{
2523+
{
2524+
GroupName: "worker-group-3",
2525+
Template: podTemplateSpec(nil, nil),
2526+
Replicas: ptr.To(int32(1)),
2527+
MinReplicas: ptr.To(int32(2)),
2528+
MaxReplicas: ptr.To(int32(5)),
2529+
},
2530+
}
2531+
return s
2532+
}(),
2533+
expectError: false,
2534+
},
2535+
{
2536+
name: "valid when autoscaling enabled",
2537+
spec: func() rayv1.RayClusterSpec {
2538+
s := createSpec()
2539+
s.EnableInTreeAutoscaling = ptr.To(true)
2540+
s.WorkerGroupSpecs = []rayv1.WorkerGroupSpec{
2541+
{
2542+
GroupName: "worker-group-3",
2543+
Template: podTemplateSpec(nil, nil),
2544+
},
2545+
}
2546+
return s
2547+
}(),
2548+
expectError: false,
2549+
},
2550+
}
2551+
2552+
for _, tt := range tests {
2553+
t.Run(tt.name, func(t *testing.T) {
2554+
err := ValidateRayClusterSpec(&tt.spec, nil)
2555+
if tt.expectError {
2556+
require.Error(t, err)
2557+
require.EqualError(t, err, tt.errorMsg)
2558+
} else {
2559+
require.NoError(t, err)
2560+
}
2561+
})
2562+
}
2563+
}

0 commit comments

Comments
 (0)