Skip to content

Commit 121528e

Browse files
committed
CARRY: Move annotation env vars to be added last
1 parent 468c0e5 commit 121528e

File tree

2 files changed

+128
-7
lines changed

2 files changed

+128
-7
lines changed

pkg/controller.v1/pytorch/envvar.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,6 @@ func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
5454
podTemplateSpec.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
5555
}
5656

57-
// Inject checkpointing environment variables from annotations.
58-
checkpointEnvVars := extractCheckpointEnvVars(pytorchjob)
59-
if len(checkpointEnvVars) > 0 {
60-
podTemplateSpec.Spec.Containers[i].Env = append(
61-
podTemplateSpec.Spec.Containers[i].Env, checkpointEnvVars...)
62-
}
63-
6457
// Set PYTHONUNBUFFERED to true, to disable output buffering.
6558
// Ref https://stackoverflow.com/questions/59812009/what-is-the-use-of-pythonunbuffered-in-docker-file.
6659
podTemplateSpec.Spec.Containers[i].Env = append(
@@ -131,6 +124,21 @@ func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
131124
Value: strconv.Itoa(int(totalReplicas)),
132125
})
133126
}
127+
128+
// Inject checkpointing environment variables from annotations.
129+
checkpointEnvVars := extractCheckpointEnvVars(pytorchjob)
130+
for _, checkpointEnvVar := range checkpointEnvVars {
131+
exist := false
132+
for _, envVar := range podTemplateSpec.Spec.Containers[i].Env {
133+
if envVar.Name == checkpointEnvVar.Name {
134+
exist = true
135+
break
136+
}
137+
}
138+
if !exist {
139+
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, checkpointEnvVar)
140+
}
141+
}
134142
}
135143

136144
return nil
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright 2023 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package pytorch
16+
17+
import (
18+
"testing"
19+
20+
"github.com/onsi/gomega"
21+
corev1 "k8s.io/api/core/v1"
22+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
23+
24+
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
25+
)
26+
27+
func TestSetPodEnv(t *testing.T) {
28+
g := gomega.NewWithT(t)
29+
30+
// Prepare a base PyTorchJob.
31+
pytorchjob := &kubeflowv1.PyTorchJob{
32+
ObjectMeta: metav1.ObjectMeta{
33+
Name: "test-job",
34+
Annotations: map[string]string{
35+
"checkpoint.config.kubeflow.org/existing-var": "new-value",
36+
"checkpoint.config.kubeflow.org/new-var": "new-value",
37+
},
38+
},
39+
Spec: kubeflowv1.PyTorchJobSpec{
40+
PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
41+
kubeflowv1.PyTorchJobReplicaTypeMaster: {
42+
Replicas: func(i int32) *int32 { return &i }(1),
43+
Template: corev1.PodTemplateSpec{
44+
Spec: corev1.PodSpec{
45+
Containers: []corev1.Container{{
46+
Name: "pytorch",
47+
Ports: []corev1.ContainerPort{{
48+
Name: "pytorchjob-port",
49+
ContainerPort: 23456,
50+
}},
51+
}},
52+
},
53+
},
54+
},
55+
kubeflowv1.PyTorchJobReplicaTypeWorker: {
56+
Replicas: func(i int32) *int32 { return &i }(1),
57+
Template: corev1.PodTemplateSpec{
58+
Spec: corev1.PodSpec{
59+
Containers: []corev1.Container{{
60+
Name: "pytorch",
61+
Ports: []corev1.ContainerPort{{
62+
Name: "pytorchjob-port",
63+
ContainerPort: 23456,
64+
}},
65+
}},
66+
},
67+
},
68+
},
69+
},
70+
},
71+
}
72+
73+
// Case 1: An environment variable from a checkpoint annotation already exists.
74+
podTemplateSpecWithExistingEnv := &corev1.PodTemplateSpec{
75+
Spec: corev1.PodSpec{
76+
Containers: []corev1.Container{
77+
{
78+
Env: []corev1.EnvVar{
79+
{Name: "EXISTING_VAR", Value: "original-value"},
80+
},
81+
},
82+
},
83+
},
84+
}
85+
86+
err := setPodEnv(pytorchjob, podTemplateSpecWithExistingEnv, "master", "0")
87+
g.Expect(err).NotTo(gomega.HaveOccurred())
88+
89+
// Verify that the existing variable was not overwritten and the new one was added.
90+
g.Expect(podTemplateSpecWithExistingEnv.Spec.Containers[0].Env).To(gomega.ContainElement(corev1.EnvVar{Name: "EXISTING_VAR", Value: "original-value"}))
91+
g.Expect(podTemplateSpecWithExistingEnv.Spec.Containers[0].Env).To(gomega.ContainElement(corev1.EnvVar{Name: "NEW_VAR", Value: "new-value"}))
92+
93+
// Case 2: No conflicting environment variables.
94+
podTemplateSpecNew := &corev1.PodTemplateSpec{
95+
Spec: corev1.PodSpec{
96+
Containers: []corev1.Container{{
97+
Name: "pytorch",
98+
Ports: []corev1.ContainerPort{{
99+
Name: "pytorchjob-port",
100+
ContainerPort: 23456,
101+
}},
102+
}},
103+
},
104+
}
105+
err = setPodEnv(pytorchjob, podTemplateSpecNew, "master", "0")
106+
g.Expect(err).NotTo(gomega.HaveOccurred())
107+
108+
// Verify that the new variable was added.
109+
g.Expect(podTemplateSpecNew.Spec.Containers[0].Env).To(gomega.ContainElement(corev1.EnvVar{Name: "NEW_VAR", Value: "new-value"}))
110+
111+
// Case 3: Check for default env vars like PYTHONUNBUFFERED.
112+
g.Expect(podTemplateSpecNew.Spec.Containers[0].Env).To(gomega.ContainElement(corev1.EnvVar{Name: "PYTHONUNBUFFERED", Value: "1"}))
113+
}

0 commit comments

Comments
 (0)