Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (k *KubeScheduler) Name() string {
return schedulerInstanceName
}

func createPodGroup(ctx context.Context, app *rayv1.RayCluster) *v1alpha1.PodGroup {
func createPodGroup(app *rayv1.RayCluster) *v1alpha1.PodGroup {
// TODO(troychiu): Consider the case when autoscaling is enabled.

podGroup := &v1alpha1.PodGroup{
Expand All @@ -62,7 +62,7 @@ func createPodGroup(ctx context.Context, app *rayv1.RayCluster) *v1alpha1.PodGro
},
},
Spec: v1alpha1.PodGroupSpec{
MinMember: utils.CalculateDesiredReplicas(ctx, app) + 1, // +1 for the head pod
MinMember: utils.CalculateDesiredReplicas(app) + 1, // +1 for the head pod
MinResources: utils.CalculateDesiredResources(app),
},
}
Expand All @@ -82,7 +82,7 @@ func (k *KubeScheduler) DoBatchSchedulingOnSubmission(ctx context.Context, objec
if !errors.IsNotFound(err) {
return err
}
podGroup = createPodGroup(ctx, app)
podGroup = createPodGroup(app)
if err := k.cli.Create(ctx, podGroup); err != nil {
if errors.IsAlreadyExists(err) {
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestCreatePodGroup(t *testing.T) {

cluster := createTestRayCluster(1)

podGroup := createPodGroup(context.TODO(), &cluster)
podGroup := createPodGroup(&cluster)

// 256m * 3 (requests, not limits)
a.Equal("768m", podGroup.Spec.MinResources.Cpu().String())
Expand All @@ -102,7 +102,7 @@ func TestCreatePodGroupWithMultipleHosts(t *testing.T) {

cluster := createTestRayCluster(2) // 2 hosts

podGroup := createPodGroup(context.TODO(), &cluster)
podGroup := createPodGroup(&cluster)

// 256m * 5 (requests, not limits)
a.Equal("1280m", podGroup.Spec.MinResources.Cpu().String())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (v *VolcanoBatchScheduler) handleRayCluster(ctx context.Context, raycluster
return nil
}

minMember, totalResource := v.calculatePodGroupParams(ctx, &raycluster.Spec)
minMember, totalResource := v.calculatePodGroupParams(&raycluster.Spec)

return v.syncPodGroup(ctx, raycluster, minMember, totalResource)
}
Expand All @@ -74,7 +74,7 @@ func (v *VolcanoBatchScheduler) handleRayJob(ctx context.Context, rayJob *rayv1.
}

var totalResourceList []corev1.ResourceList
minMember, totalResource := v.calculatePodGroupParams(ctx, rayJob.Spec.RayClusterSpec)
minMember, totalResource := v.calculatePodGroupParams(rayJob.Spec.RayClusterSpec)
totalResourceList = append(totalResourceList, totalResource)

// MinMember intentionally excludes the submitter pod to avoid a startup deadlock
Expand Down Expand Up @@ -186,11 +186,11 @@ func (v *VolcanoBatchScheduler) syncPodGroup(ctx context.Context, owner metav1.O
return nil
}

func (v *VolcanoBatchScheduler) calculatePodGroupParams(ctx context.Context, rayClusterSpec *rayv1.RayClusterSpec) (int32, corev1.ResourceList) {
func (v *VolcanoBatchScheduler) calculatePodGroupParams(rayClusterSpec *rayv1.RayClusterSpec) (int32, corev1.ResourceList) {
rayCluster := &rayv1.RayCluster{Spec: *rayClusterSpec}

if !utils.IsAutoscalingEnabled(rayClusterSpec) {
return utils.CalculateDesiredReplicas(ctx, rayCluster) + 1, utils.CalculateDesiredResources(rayCluster)
return utils.CalculateDesiredReplicas(rayCluster) + 1, utils.CalculateDesiredResources(rayCluster)
}
return utils.CalculateMinReplicas(rayCluster) + 1, utils.CalculateMinResources(rayCluster)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func TestCreatePodGroupForRayCluster(t *testing.T) {

cluster := createTestRayCluster(1)

minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
totalResource := utils.CalculateDesiredResources(&cluster)
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
require.NoError(t, err)
Expand All @@ -185,7 +185,7 @@ func TestCreatePodGroupForRayCluster_NumOfHosts2(t *testing.T) {

cluster := createTestRayCluster(2)

minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
totalResource := utils.CalculateDesiredResources(&cluster)
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
require.NoError(t, err)
Expand Down Expand Up @@ -227,7 +227,7 @@ func TestCreatePodGroup_NetworkTopologyBothLabels(t *testing.T) {
NetworkTopologyHighestTierAllowedLabelKey: "3",
})

minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
totalResource := utils.CalculateDesiredResources(&cluster)
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
require.NoError(t, err)
Expand All @@ -246,7 +246,7 @@ func TestCreatePodGroup_NetworkTopologyOnlyModeLabel(t *testing.T) {
NetworkTopologyModeLabelKey: "hard",
})

minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
totalResource := utils.CalculateDesiredResources(&cluster)
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)
require.NoError(t, err)
Expand All @@ -266,7 +266,7 @@ func TestCreatePodGroup_NetworkTopologyHighestTierAllowedNotInt(t *testing.T) {
NetworkTopologyHighestTierAllowedLabelKey: "not-an-int",
})

minMember := utils.CalculateDesiredReplicas(context.Background(), &cluster) + 1
minMember := utils.CalculateDesiredReplicas(&cluster) + 1
totalResource := utils.CalculateDesiredResources(&cluster)
pg, err := createPodGroup(&cluster, getAppPodGroupName(&cluster), minMember, totalResource)

Expand Down Expand Up @@ -474,7 +474,7 @@ func TestCalculatePodGroupParams(t *testing.T) {
t.Run("Autoscaling disabled", func(_ *testing.T) {
cluster := createTestRayCluster(1)

minMember, totalResource := scheduler.calculatePodGroupParams(context.Background(), &cluster.Spec)
minMember, totalResource := scheduler.calculatePodGroupParams(&cluster.Spec)

// 1 head + 2 workers (desired replicas)
a.Equal(int32(3), minMember)
Expand All @@ -490,7 +490,7 @@ func TestCalculatePodGroupParams(t *testing.T) {
cluster := createTestRayCluster(1)
cluster.Spec.EnableInTreeAutoscaling = ptr.To(true)

minMember, totalResource := scheduler.calculatePodGroupParams(context.Background(), &cluster.Spec)
minMember, totalResource := scheduler.calculatePodGroupParams(&cluster.Spec)

// 1 head + 1 worker (min replicas)
a.Equal(int32(2), minMember)
Expand Down
6 changes: 3 additions & 3 deletions ray-operator/controllers/ray/raycluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
continue
}
// workerReplicas will store the target number of pods for this worker group.
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, worker))
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(worker))
logger.Info("reconcilePods", "desired workerReplicas (always adhering to minReplicas/maxReplica)", numExpectedWorkerPods, "worker group", worker.GroupName, "maxReplicas", worker.MaxReplicas, "minReplicas", worker.MinReplicas, "replicas", worker.Replicas)

workerPods := corev1.PodList{}
Expand Down Expand Up @@ -1049,7 +1049,7 @@ func (r *RayClusterReconciler) reconcileMultiHostWorkerGroup(ctx context.Context
}
}
numRunningReplicas := len(validReplicaGroups)
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, *worker))
numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(*worker))

// Ensure that if numExpectedWorkerPods is not a multiple of NumOfHosts, we log an error.
if numExpectedWorkerPods%int(worker.NumOfHosts) != 0 {
Expand Down Expand Up @@ -1582,7 +1582,7 @@ func (r *RayClusterReconciler) calculateStatus(ctx context.Context, instance *ra

newInstance.Status.ReadyWorkerReplicas = utils.CalculateReadyReplicas(runtimePods)
newInstance.Status.AvailableWorkerReplicas = utils.CalculateAvailableReplicas(runtimePods)
newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(ctx, newInstance)
newInstance.Status.DesiredWorkerReplicas = utils.CalculateDesiredReplicas(newInstance)
newInstance.Status.MinWorkerReplicas = utils.CalculateMinReplicas(newInstance)
newInstance.Status.MaxWorkerReplicas = utils.CalculateMaxReplicas(newInstance)

Expand Down
30 changes: 15 additions & 15 deletions ray-operator/controllers/ray/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,34 +384,31 @@ func GenerateIdentifier(clusterName string, nodeType rayv1.RayNodeType) string {
return fmt.Sprintf("%s-%s", clusterName, nodeType)
}

func GetWorkerGroupDesiredReplicas(ctx context.Context, workerGroupSpec rayv1.WorkerGroupSpec) int32 {
log := ctrl.LoggerFrom(ctx)
func GetWorkerGroupDesiredReplicas(workerGroupSpec rayv1.WorkerGroupSpec) int32 {
// Always adhere to min/max replicas constraints.
var workerReplicas int32
minReplicas := ptr.Deref(workerGroupSpec.MinReplicas, int32(0))
maxReplicas := ptr.Deref(workerGroupSpec.MaxReplicas, int32(math.MaxInt32))
if workerGroupSpec.Suspend != nil && *workerGroupSpec.Suspend {
return 0
}
if *workerGroupSpec.MinReplicas > *workerGroupSpec.MaxReplicas {
log.Info("minReplicas is greater than maxReplicas, using maxReplicas as desired replicas. "+
"Please fix this to avoid any unexpected behaviors.", "minReplicas", *workerGroupSpec.MinReplicas, "maxReplicas", *workerGroupSpec.MaxReplicas)
workerReplicas = *workerGroupSpec.MaxReplicas
} else if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < *workerGroupSpec.MinReplicas {
if workerGroupSpec.Replicas == nil || *workerGroupSpec.Replicas < minReplicas {
// Replicas is impossible to be nil as it has a default value assigned in the CRD.
// Add this check to make testing easier.
workerReplicas = *workerGroupSpec.MinReplicas
} else if *workerGroupSpec.Replicas > *workerGroupSpec.MaxReplicas {
workerReplicas = *workerGroupSpec.MaxReplicas
workerReplicas = minReplicas
} else if *workerGroupSpec.Replicas > maxReplicas {
workerReplicas = maxReplicas
} else {
workerReplicas = *workerGroupSpec.Replicas
}
return workerReplicas * workerGroupSpec.NumOfHosts
}

// CalculateDesiredReplicas calculate desired worker replicas at the cluster level
func CalculateDesiredReplicas(ctx context.Context, cluster *rayv1.RayCluster) int32 {
func CalculateDesiredReplicas(cluster *rayv1.RayCluster) int32 {
count := int32(0)
for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs {
count += GetWorkerGroupDesiredReplicas(ctx, nodeGroup)
count += GetWorkerGroupDesiredReplicas(nodeGroup)
}

return count
Expand All @@ -424,7 +421,8 @@ func CalculateMinReplicas(cluster *rayv1.RayCluster) int32 {
if nodeGroup.Suspend != nil && *nodeGroup.Suspend {
continue
}
count += (*nodeGroup.MinReplicas * nodeGroup.NumOfHosts)
minReplicas := ptr.Deref(nodeGroup.MinReplicas, int32(0))
count += (minReplicas * nodeGroup.NumOfHosts)
}

return count
Expand All @@ -437,7 +435,8 @@ func CalculateMaxReplicas(cluster *rayv1.RayCluster) int32 {
if nodeGroup.Suspend != nil && *nodeGroup.Suspend {
continue
}
count += int64(*nodeGroup.MaxReplicas) * int64(nodeGroup.NumOfHosts)
maxReplicas := ptr.Deref(nodeGroup.MaxReplicas, int32(math.MaxInt32))
count += int64(maxReplicas) * int64(nodeGroup.NumOfHosts)
}

return SafeInt64ToInt32(count)
Expand Down Expand Up @@ -499,7 +498,8 @@ func CalculateMinResources(cluster *rayv1.RayCluster) corev1.ResourceList {
for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs {
podResource := CalculatePodResource(nodeGroup.Template.Spec)
calculateReplicaResource(&podResource, nodeGroup.NumOfHosts)
for i := int32(0); i < *nodeGroup.MinReplicas; i++ {
minReplicas := ptr.Deref(nodeGroup.MinReplicas, int32(0))
for range minReplicas {
minResourcesList = append(minResourcesList, podResource)
}
}
Expand Down
25 changes: 9 additions & 16 deletions ray-operator/controllers/ray/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ func TestGenerateHeadServiceName(t *testing.T) {
}

func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
ctx := context.Background()
// Test 1: `WorkerGroupSpec.Replicas` is nil.
// `Replicas` is impossible to be nil in a real RayCluster CR as it has a default value assigned in the CRD.
numOfHosts := int32(1)
Expand All @@ -565,37 +564,31 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
MinReplicas: &minReplicas,
MaxReplicas: &maxReplicas,
}
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), minReplicas)

// Test 2: `WorkerGroupSpec.Replicas` is not nil and is within the range.
replicas := int32(3)
workerGroupSpec.Replicas = &replicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas)

// Test 3: `WorkerGroupSpec.Replicas` is not nil but is more than maxReplicas.
replicas = int32(6)
workerGroupSpec.Replicas = &replicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), maxReplicas)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), maxReplicas)

// Test 4: `WorkerGroupSpec.Replicas` is not nil but is less than minReplicas.
replicas = int32(0)
workerGroupSpec.Replicas = &replicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), minReplicas)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), minReplicas)

// Test 5: `WorkerGroupSpec.Replicas` is nil and minReplicas is less than maxReplicas.
workerGroupSpec.Replicas = nil
workerGroupSpec.MinReplicas = &maxReplicas
workerGroupSpec.MaxReplicas = &minReplicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), *workerGroupSpec.MaxReplicas)

// Test 6: `WorkerGroupSpec.Suspend` is true.
// Test 5: `WorkerGroupSpec.Suspend` is true.
suspend := true
workerGroupSpec.MinReplicas = &maxReplicas
workerGroupSpec.MaxReplicas = &minReplicas
workerGroupSpec.Suspend = &suspend
assert.Zero(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec))
assert.Zero(t, GetWorkerGroupDesiredReplicas(workerGroupSpec))

// Test 7: `WorkerGroupSpec.NumOfHosts` is 4.
// Test 6: `WorkerGroupSpec.NumOfHosts` is 4.
numOfHosts = int32(4)
replicas = int32(5)
suspend = false
Expand All @@ -604,7 +597,7 @@ func TestGetWorkerGroupDesiredReplicas(t *testing.T) {
workerGroupSpec.Suspend = &suspend
workerGroupSpec.MinReplicas = &minReplicas
workerGroupSpec.MaxReplicas = &maxReplicas
assert.Equal(t, GetWorkerGroupDesiredReplicas(ctx, workerGroupSpec), replicas*numOfHosts)
assert.Equal(t, GetWorkerGroupDesiredReplicas(workerGroupSpec), replicas*numOfHosts)
}

func TestCalculateMinAndMaxReplicas(t *testing.T) {
Expand Down Expand Up @@ -801,7 +794,7 @@ func TestCalculateDesiredReplicas(t *testing.T) {
},
},
}
assert.Equal(t, CalculateDesiredReplicas(context.Background(), &cluster), tc.answer)
assert.Equal(t, CalculateDesiredReplicas(&cluster), tc.answer)
})
}
}
Expand Down
22 changes: 19 additions & 3 deletions ray-operator/controllers/ray/utils/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,30 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s
return err
}

// Check if autoscaling is enabled once to avoid repeated calls
isAutoscalingEnabled := IsAutoscalingEnabled(spec)

for _, workerGroup := range spec.WorkerGroupSpecs {
if len(workerGroup.Template.Spec.Containers) == 0 {
return fmt.Errorf("workerGroupSpec should have at least one container")
}

// When autoscaling is enabled, MinReplicas and MaxReplicas are optional
// as users can manually update them and the autoscaler will handle the adjustment.
if !isAutoscalingEnabled && (workerGroup.MinReplicas == nil || workerGroup.MaxReplicas == nil) {
return fmt.Errorf("worker group %s must set both minReplicas and maxReplicas when autoscaling is disabled", workerGroup.GroupName)
}
if workerGroup.MinReplicas != nil && *workerGroup.MinReplicas < 0 {
return fmt.Errorf("worker group %s has negative minReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas)
}
if workerGroup.MaxReplicas != nil && *workerGroup.MaxReplicas < 0 {
return fmt.Errorf("worker group %s has negative maxReplicas %d", workerGroup.GroupName, *workerGroup.MaxReplicas)
}
if workerGroup.MinReplicas != nil && workerGroup.MaxReplicas != nil {
if *workerGroup.MinReplicas > *workerGroup.MaxReplicas {
return fmt.Errorf("worker group %s has minReplicas %d greater than maxReplicas %d", workerGroup.GroupName, *workerGroup.MinReplicas, *workerGroup.MaxReplicas)
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check if workerGroup.Replicas lies within the workerGroup.MinReplicas and workerGroup.MaxReplicas`?

Copy link
Contributor Author

@kash2104 kash2104 Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@machichima Here in thevalidation.go, we are just checking whether the values for min and max replicas aren't incorrect or impossible before pod creation happens but the actual logic of number of replicas is moved to util.go

So I think that here in validation.go, we won't be needing this check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm: validation.go is strictly for ensuring the min and max replica configurations are valid, while the actual clipping logic resides in util.go. Therefore, we don't need to re-verify if the replicas fall within the [min, max] range here. Please let me know if I've misunderstood anything. Thanks!

if err := validateRayGroupResources(workerGroup.GroupName, workerGroup.RayStartParams, workerGroup.Resources); err != nil {
return err
}
Expand Down Expand Up @@ -175,9 +194,6 @@ func ValidateRayClusterSpec(spec *rayv1.RayClusterSpec, annotations map[string]s
}
}

// Check if autoscaling is enabled once to avoid repeated calls
isAutoscalingEnabled := IsAutoscalingEnabled(spec)

// Validate that RAY_enable_autoscaler_v2 environment variable is not set to "1" or "true" when autoscaler is disabled
if !isAutoscalingEnabled {
if envVar, exists := EnvVarByName(RAY_ENABLE_AUTOSCALER_V2, spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env); exists {
Expand Down
Loading
Loading