Skip to content

Commit d926bc0

Browse files
committed
CARRY: Use temp folder to store progress file
1 parent d07ec3b commit d926bc0

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

pkg/controller.v1/pytorch/envvar.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
121121
Value: strconv.Itoa(int(totalReplicas)),
122122
})
123123
}
124+
125+
// Set the training progress file path.
126+
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
127+
Name: EnvTrainingProgressFilePath,
128+
Value: GetProgressFilePath(pytorchjob),
129+
})
124130
}
125131

126132
return nil

pkg/controller.v1/pytorch/pytorchjob_controller.go

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,7 @@ func (r *PyTorchJobReconciler) Reconcile(ctx context.Context, req ctrl.Request)
191191
return ctrl.Result{}, err
192192
}
193193

194-
jobIsRunning := false
195-
for _, condition := range pytorchjob.Status.Conditions {
196-
if condition.Type == kubeflowv1.JobRunning && condition.Status == corev1.ConditionTrue {
197-
jobIsRunning = true
198-
break
199-
}
200-
}
201-
202-
if jobIsRunning {
194+
if commonutil.IsRunning(pytorchjob.Status) {
203195
if content, err := r.readCompletionPercentageFromPod(pytorchjob); err == nil {
204196
if percentage, parseErr := r.parseCompletionPercentage(content); parseErr == nil {
205197
// Assuming your PyTorchJobStatus has a CompletionPercentage field.
@@ -219,12 +211,12 @@ func (r *PyTorchJobReconciler) Reconcile(ctx context.Context, req ctrl.Request)
219211
} else {
220212
logrus.Debugf("Failed to read completion percentage from rank-0 pod for PyTorchJob %s: %v", pytorchjob.Name, err)
221213
}
222-
214+
223215
// Return a short requeue interval for running jobs
224216
// TODO instead of hard coding the requeue interval we could make this configurable
225217
return ctrl.Result{RequeueAfter: 30 * time.Second}, nil
226218
}
227-
219+
228220
t, err := util.DurationUntilExpireTime(&pytorchjob.Spec.RunPolicy, pytorchjob.Status)
229221
if err != nil {
230222
logrus.Warnf("Reconcile PyTorchJob error %v", err)
@@ -527,7 +519,9 @@ func (r *PyTorchJobReconciler) execInPod(pod *corev1.Pod, containerName string,
527519
}
528520

529521
var stdout, stderr bytes.Buffer
530-
err = executor.StreamWithContext(context.Background(), remotecommand.StreamOptions{
522+
ctx, cancel := context.WithTimeoutCause(context.Background(), 10*time.Second, fmt.Errorf("pod execution timed out"))
523+
defer cancel()
524+
err = executor.StreamWithContext(ctx, remotecommand.StreamOptions{
531525
Stdout: &stdout,
532526
Stderr: &stderr,
533527
})
@@ -576,16 +570,13 @@ func (r *PyTorchJobReconciler) readCompletionPercentageFromPod(pytorchjob *kubef
576570
return "", fmt.Errorf("rank-0 pod %s is not in running state: %s", rankZeroPod.Name, rankZeroPod.Status.Phase)
577571
}
578572

579-
// Get the container name (use default PyTorch container name)
580-
containerName := kubeflowv1.PyTorchJobDefaultContainerName
581-
if len(rankZeroPod.Spec.Containers) > 0 {
582-
containerName = rankZeroPod.Spec.Containers[0].Name
583-
}
573+
// Get the container name
574+
containerName := rankZeroPod.Spec.Containers[0].Name
584575

585576
// Read the progress.json file from /mnt/checkpoints - /var/run is not accessible by non-root user
586577
// TODO we could have the user add the file path in an annotation instead of hardcoding it here
587578
// later we could update the CRD spec to allow for checkpoint config
588-
progressFilePath := "/mnt/checkpoints/progress.json"
579+
progressFilePath := GetProgressFilePath(pytorchjob)
589580
catCommand := []string{"cat", progressFilePath}
590581
content, err := r.execInPod(rankZeroPod, containerName, catCommand)
591582
if err != nil {
@@ -600,7 +591,7 @@ func (r *PyTorchJobReconciler) parseCompletionPercentage(content string) (string
600591
var progress ProgressData
601592

602593
if err := json.Unmarshal([]byte(content), &progress); err != nil {
603-
return "", fmt.Errorf("failed to parse JSON: %v", err)
594+
return "", fmt.Errorf("failed to parse JSON from content '%s': %v", content, err)
604595
}
605596

606597
// Extract current and total steps
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License
14+
15+
package pytorch
16+
17+
import (
18+
"fmt"
19+
20+
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
21+
)
22+
23+
const (
24+
// EnvTrainingProgressFilePath is the environment variable name for the training progress file path.
25+
EnvTrainingProgressFilePath = "TRAINING_PROGRESS_FILE_PATH"
26+
)
27+
28+
func GetProgressFilePath(job *kubeflowv1.PyTorchJob) string {
29+
return fmt.Sprintf("/tmp/training_data/%s/progress.json", job.Name)
30+
}

0 commit comments

Comments
 (0)