Skip to content

Commit c72caac

Browse files
authored
Fix launcher job scheduling directives when unsuspending (kubeflow#772)
Signed-off-by: GonzaloSaez <11050889+GonzaloSaez@users.noreply.github.com>
1 parent 9504d8c commit c72caac

File tree

3 files changed

+284
-6
lines changed

3 files changed

+284
-6
lines changed

pkg/controller/mpi_job_controller.go

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"encoding/pem"
2525
"errors"
2626
"fmt"
27+
"maps"
2728
"reflect"
2829
"sort"
2930
"strconv"
@@ -687,9 +688,29 @@ func (c *MPIJobController) syncHandler(key string) error {
687688
}
688689

689690
if launcher != nil {
690-
if isMPIJobSuspended(mpiJob) != isJobSuspended(launcher) {
691-
// align the suspension state of launcher with the MPIJob
692-
launcher.Spec.Suspend = ptr.To(isMPIJobSuspended(mpiJob))
691+
if !isMPIJobSuspended(mpiJob) && isJobSuspended(launcher) {
692+
// We are unsuspending, hence we need to sync the pod template with the current MPIJob spec.
693+
// This is important for interop with Kueue as it may have injected schedulingGates.
694+
// Kubernetes validates that a Job template is immutable once StartTime is set,
695+
// so we must clear it first via a status sub-resource update (consistent with JobSet).
696+
if launcher.Status.StartTime != nil {
697+
launcher.Status.StartTime = nil
698+
var err error
699+
if launcher, err = c.kubeClient.BatchV1().Jobs(namespace).UpdateStatus(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
700+
return err
701+
}
702+
}
703+
704+
// Sync mutable scheduling directives (KEP-2926) and unsuspend.
705+
desiredPodTemplate := c.newLauncherPodTemplate(mpiJob)
706+
syncLauncherSchedulingDirectives(launcher, &desiredPodTemplate)
707+
launcher.Spec.Suspend = ptr.To(false)
708+
if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
709+
return err
710+
}
711+
} else if isMPIJobSuspended(mpiJob) && !isJobSuspended(launcher) {
712+
// align the suspension state of launcher with the MPIJob.
713+
launcher.Spec.Suspend = ptr.To(true)
693714
if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
694715
return err
695716
}
@@ -1623,6 +1644,24 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev
16231644
}
16241645
}
16251646

1647+
func mergeMaps[K comparable, V any](a, b map[K]V) map[K]V {
1648+
merged := make(map[K]V, max(len(a), len(b)))
1649+
maps.Copy(merged, a)
1650+
maps.Copy(merged, b)
1651+
return merged
1652+
}
1653+
1654+
// syncLauncherSchedulingDirectives updates the mutable scheduling directives (as per KEP-2926) on
1655+
// the launcher Job's pod template to match the desired template.
1656+
func syncLauncherSchedulingDirectives(launcher *batchv1.Job, desired *corev1.PodTemplateSpec) {
1657+
launcher.Spec.Template.Labels = mergeMaps(launcher.Spec.Template.Labels, desired.Labels)
1658+
launcher.Spec.Template.Annotations = mergeMaps(launcher.Spec.Template.Annotations, desired.Annotations)
1659+
1660+
launcher.Spec.Template.Spec.NodeSelector = desired.Spec.NodeSelector
1661+
launcher.Spec.Template.Spec.Tolerations = desired.Spec.Tolerations
1662+
launcher.Spec.Template.Spec.SchedulingGates = desired.Spec.SchedulingGates
1663+
}
1664+
16261665
func (c *MPIJobController) jobPods(j *batchv1.Job) ([]*corev1.Pod, error) {
16271666
selector, err := metav1.LabelSelectorAsSelector(j.Spec.Selector)
16281667
if err != nil {

pkg/controller/mpi_job_controller_test.go

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,14 +1024,16 @@ func TestResumeMPIJob(t *testing.T) {
10241024
// resume the MPIJob
10251025
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
10261026

1027-
// expect creation of the pods
1027+
// expect creation of the worker pods
10281028
for i := 0; i < int(replicas); i++ {
10291029
worker := fmjc.newWorker(mpiJob, i)
10301030
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
10311031
}
10321032

1033-
// expect the launcher update to resume it
1033+
// expect the launcher update to sync scheduling directives and resume it
10341034
launcherCopy := launcher.DeepCopy()
1035+
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
1036+
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
10351037
launcherCopy.Spec.Suspend = ptr.To(false)
10361038
f.expectUpdateJobAction(launcherCopy)
10371039

@@ -1044,6 +1046,183 @@ func TestResumeMPIJob(t *testing.T) {
10441046
f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
10451047
}
10461048

1049+
func TestResumeMPIJobWithExistingLauncher(t *testing.T) {
1050+
// Tests the running→suspended→resumed path where a launcher already exists
1051+
// (from before suspension) with startTime == nil. The launcher should be
1052+
// updated in place with synced scheduling directives (KEP-2926).
1053+
fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second))
1054+
f := newFixture(t, "")
1055+
1056+
var replicas int32 = 8
1057+
startTime := metav1.Now()
1058+
mpiJob := newMPIJob("test", &replicas, &startTime, nil)
1059+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(true)
1060+
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name)
1061+
updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg)
1062+
updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended")
1063+
msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name)
1064+
updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg)
1065+
mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
1066+
kubeflow.MPIReplicaTypeLauncher: {},
1067+
kubeflow.MPIReplicaTypeWorker: {},
1068+
}
1069+
f.setUpMPIJob(mpiJob)
1070+
1071+
scheme.Scheme.Default(mpiJob)
1072+
f.expectCreateServiceAction(newJobService(mpiJob))
1073+
cfgMap := newConfigMap(mpiJob, replicas, "")
1074+
updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "")
1075+
f.setUpConfigMap(cfgMap)
1076+
secret, err := newSSHAuthSecret(mpiJob)
1077+
if err != nil {
1078+
t.Fatalf("Failed creating secret")
1079+
}
1080+
f.setUpSecret(secret)
1081+
1082+
// set up an existing suspended launcher (startTime == nil, never started)
1083+
fmjc := f.newFakeMPIJobController()
1084+
launcher := fmjc.newLauncherJob(mpiJob)
1085+
launcher.Spec.Suspend = ptr.To(true)
1086+
// Simulate Kueue injecting scheduling directives into the MPIJob template
1087+
// after the launcher was already created (so the launcher has stale templates).
1088+
launcherSpec := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template
1089+
launcherSpec.Spec.NodeSelector = map[string]string{
1090+
"foo": "bar",
1091+
}
1092+
launcherSpec.Spec.Tolerations = []corev1.Toleration{
1093+
{Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule},
1094+
}
1095+
launcherSpec.Spec.SchedulingGates = []corev1.PodSchedulingGate{
1096+
{Name: "kueue.x-k8s.io/topology"},
1097+
}
1098+
if launcherSpec.Annotations == nil {
1099+
launcherSpec.Annotations = make(map[string]string)
1100+
}
1101+
launcherSpec.Annotations["kueue.x-k8s.io/workload"] = "my-workload"
1102+
f.setUpLauncher(launcher)
1103+
1104+
fakeClock.Sleep(time.Second)
1105+
1106+
// resume the MPIJob
1107+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
1108+
1109+
// expect creation of the worker pods
1110+
for i := 0; i < int(replicas); i++ {
1111+
worker := fmjc.newWorker(mpiJob, i)
1112+
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
1113+
}
1114+
1115+
// expect the launcher to be updated (scheduling directives synced + unsuspended)
1116+
launcherCopy := launcher.DeepCopy()
1117+
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
1118+
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
1119+
launcherCopy.Spec.Suspend = ptr.To(false)
1120+
1121+
// Verify the synced launcher has the Kueue-injected scheduling directives.
1122+
tmpl := &launcherCopy.Spec.Template
1123+
if tmpl.Spec.NodeSelector["foo"] != "bar" {
1124+
t.Errorf("expected nodeSelector to be synced, got %v", tmpl.Spec.NodeSelector)
1125+
}
1126+
if len(tmpl.Spec.Tolerations) != 1 || tmpl.Spec.Tolerations[0].Key != "gpu" {
1127+
t.Errorf("expected tolerations to be synced, got %v", tmpl.Spec.Tolerations)
1128+
}
1129+
if len(tmpl.Spec.SchedulingGates) != 1 || tmpl.Spec.SchedulingGates[0].Name != "kueue.x-k8s.io/topology" {
1130+
t.Errorf("expected schedulingGates to be synced, got %v", tmpl.Spec.SchedulingGates)
1131+
}
1132+
if tmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" {
1133+
t.Errorf("expected annotations to be synced, got %v", tmpl.Annotations)
1134+
}
1135+
1136+
f.expectUpdateJobAction(launcherCopy)
1137+
1138+
// expect status update
1139+
mpiJobCopy := mpiJob.DeepCopy()
1140+
mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()}
1141+
updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed")
1142+
f.expectUpdateMPIJobStatusAction(mpiJobCopy)
1143+
1144+
f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
1145+
}
1146+
1147+
func TestResumeMPIJobClearsStartTime(t *testing.T) {
1148+
// Tests the re-admission case where the launcher has startTime != nil.
1149+
// The controller should clear StartTime via a status sub-resource update
1150+
// (consistent with JobSet), then sync scheduling directives and unsuspend.
1151+
fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second))
1152+
f := newFixture(t, "")
1153+
1154+
var replicas int32 = 8
1155+
startTime := metav1.Now()
1156+
mpiJob := newMPIJob("test", &replicas, &startTime, nil)
1157+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(true)
1158+
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name)
1159+
updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg)
1160+
updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended")
1161+
msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name)
1162+
updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg)
1163+
mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
1164+
kubeflow.MPIReplicaTypeLauncher: {},
1165+
kubeflow.MPIReplicaTypeWorker: {},
1166+
}
1167+
f.setUpMPIJob(mpiJob)
1168+
1169+
scheme.Scheme.Default(mpiJob)
1170+
f.expectCreateServiceAction(newJobService(mpiJob))
1171+
cfgMap := newConfigMap(mpiJob, replicas, "")
1172+
updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "")
1173+
f.setUpConfigMap(cfgMap)
1174+
secret, err := newSSHAuthSecret(mpiJob)
1175+
if err != nil {
1176+
t.Fatalf("Failed creating secret")
1177+
}
1178+
f.setUpSecret(secret)
1179+
1180+
// set up an existing suspended launcher that was previously started (startTime != nil)
1181+
fmjc := f.newFakeMPIJobController()
1182+
launcher := fmjc.newLauncherJob(mpiJob)
1183+
launcher.Spec.Suspend = ptr.To(true)
1184+
launcherStartTime := metav1.Now()
1185+
launcher.Status.StartTime = &launcherStartTime
1186+
f.setUpLauncher(launcher)
1187+
1188+
fakeClock.Sleep(time.Second)
1189+
1190+
// resume the MPIJob
1191+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
1192+
1193+
// expect creation of worker pods
1194+
for i := 0; i < int(replicas); i++ {
1195+
worker := fmjc.newWorker(mpiJob, i)
1196+
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
1197+
}
1198+
1199+
// expect a status sub-resource update to clear launcher's StartTime
1200+
launcherStatusCleared := launcher.DeepCopy()
1201+
launcherStatusCleared.Status.StartTime = nil
1202+
f.kubeActions = append(f.kubeActions, core.NewUpdateSubresourceAction(
1203+
schema.GroupVersionResource{Resource: "jobs", Group: "batch", Version: "v1"},
1204+
"status",
1205+
mpiJob.Namespace,
1206+
launcherStatusCleared,
1207+
))
1208+
1209+
// expect the launcher to be updated (scheduling directives synced + unsuspended)
1210+
launcherCopy := launcher.DeepCopy()
1211+
launcherCopy.Status.StartTime = nil
1212+
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
1213+
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
1214+
launcherCopy.Spec.Suspend = ptr.To(false)
1215+
f.expectUpdateJobAction(launcherCopy)
1216+
1217+
// expect MPIJob status update
1218+
mpiJobCopy := mpiJob.DeepCopy()
1219+
mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()}
1220+
updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed")
1221+
f.expectUpdateMPIJobStatusAction(mpiJobCopy)
1222+
1223+
f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
1224+
}
1225+
10471226
func TestWorkerNotControlledByUs(t *testing.T) {
10481227
f := newFixture(t, "")
10491228
startTime := metav1.Now()

test/integration/mpi_job_controller_test.go

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,14 +385,37 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {
385385
t.Errorf("MPIJob missing Suspended condition")
386386
}
387387
if !isJobSuspended(launcherJob) {
388-
t.Errorf("LauncherJob is suspended")
388+
t.Errorf("LauncherJob is not suspended")
389389
}
390390
if mpiJob.Status.StartTime != nil {
391391
t.Errorf("MPIJob has unexpected start time: %v", mpiJob.Status.StartTime)
392392
}
393393

394394
s.events.verify(t)
395395

396+
// Simulate Kueue injecting scheduling directives into the MPIJob template
397+
// while suspended. When resumed, these must propagate to the launcher Job.
398+
launcherTemplate := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template
399+
launcherTemplate.Labels = map[string]string{
400+
"foo": "bar",
401+
}
402+
launcherTemplate.Annotations = map[string]string{
403+
"kueue.x-k8s.io/workload": "my-workload",
404+
}
405+
launcherTemplate.Spec.NodeSelector = map[string]string{
406+
"example.com/accelerator": "example-model",
407+
}
408+
launcherTemplate.Spec.Tolerations = []corev1.Toleration{
409+
{Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule},
410+
}
411+
launcherTemplate.Spec.SchedulingGates = []corev1.PodSchedulingGate{
412+
{Name: "kueue.x-k8s.io/topology"},
413+
}
414+
mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(s.namespace).Update(ctx, mpiJob, metav1.UpdateOptions{})
415+
if err != nil {
416+
t.Fatalf("Failed to update the MPIJob: %v", err)
417+
}
418+
396419
// 2. Resume the MPIJob
397420
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
398421
mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(mpiJob.Namespace).Update(ctx, mpiJob, metav1.UpdateOptions{})
@@ -422,6 +445,24 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {
422445

423446
s.events.verify(t)
424447

448+
// Verify all scheduling directives were propagated to the launcher Job's pod template.
449+
launcherTmpl := &launcherJob.Spec.Template
450+
if launcherTmpl.Labels["foo"] != "bar" {
451+
t.Errorf("expected label 'foo=bar' on launcher Job template, got labels: %v", launcherTmpl.Labels)
452+
}
453+
if launcherTmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" {
454+
t.Errorf("expected annotation 'kueue.x-k8s.io/workload' on launcher Job template, got annotations: %v", launcherTmpl.Annotations)
455+
}
456+
if launcherTmpl.Spec.NodeSelector["example.com/accelerator"] != "example-model" {
457+
t.Errorf("expected nodeSelector 'example.com/accelerator=example-model' on launcher Job template, got: %v", launcherTmpl.Spec.NodeSelector)
458+
}
459+
if len(launcherTmpl.Spec.Tolerations) == 0 || launcherTmpl.Spec.Tolerations[len(launcherTmpl.Spec.Tolerations)-1].Key != "gpu" {
460+
t.Errorf("expected toleration with key 'gpu' on launcher Job template, got: %v", launcherTmpl.Spec.Tolerations)
461+
}
462+
if len(launcherTmpl.Spec.SchedulingGates) == 0 || launcherTmpl.Spec.SchedulingGates[len(launcherTmpl.Spec.SchedulingGates)-1].Name != "kueue.x-k8s.io/topology" {
463+
t.Errorf("expected schedulingGate 'kueue.x-k8s.io/topology' on launcher Job template, got: %v", launcherTmpl.Spec.SchedulingGates)
464+
}
465+
425466
// 3. Set the pods to be running
426467
err = updatePodsToPhase(ctx, s.kClient, workerPods, corev1.PodRunning)
427468
if err != nil {
@@ -473,6 +514,25 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {
473514
if !mpiJobHasConditionWithStatus(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse) {
474515
t.Errorf("MPIJob has unexpected Running condition")
475516
}
517+
518+
// Update the MPIJob launcher template again and resume, verifying the
519+
// launcher Job gets the updated scheduling directives on second resume.
520+
mpiJobLauncherTemplate := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template
521+
mpiJobLauncherTemplate.Labels["foo"] = "baz"
522+
mpiJobLauncherTemplate.Spec.NodeSelector["example.com/accelerator"] = "example-model-v2"
523+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
524+
mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(s.namespace).Update(ctx, mpiJob, metav1.UpdateOptions{})
525+
if err != nil {
526+
t.Fatalf("Failed to update the MPIJob: %v", err)
527+
}
528+
529+
_, launcherJob = validateMPIJobDependencies(ctx, t, s.kClient, mpiJob, 2, nil)
530+
if launcherJob.Spec.Template.Labels["foo"] != "baz" {
531+
t.Errorf("expected label 'foo=baz' on launcher Job template, got labels: %v", launcherJob.Spec.Template.Labels)
532+
}
533+
if launcherJob.Spec.Template.Spec.NodeSelector["example.com/accelerator"] != "example-model-v2" {
534+
t.Errorf("expected nodeSelector 'example.com/accelerator=example-model-v2' on launcher Job template, got: %v", launcherJob.Spec.Template.Spec.NodeSelector)
535+
}
476536
}
477537

478538
func TestMPIJobFailure(t *testing.T) {

0 commit comments

Comments
 (0)