@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
1414limitations under the License.
1515*/
1616
17- package fms
17+ package kfto
1818
1919import (
2020 "testing"
@@ -29,8 +29,6 @@ import (
2929 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3030 kueuev1beta1 "sigs.k8s.io/kueue/apis/kueue/v1beta1"
3131 kueueacv1beta1 "sigs.k8s.io/kueue/client-go/applyconfiguration/kueue/v1beta1"
32-
33- "github.com/opendatahub-io/distributed-workloads/tests/kfto"
3432)
3533
3634var (
@@ -47,10 +45,16 @@ func TestSetupPytorchjob(t *testing.T) {
4745 createOrGetUpgradeTestNamespace (test , namespaceName )
4846
4947 // Create a ConfigMap with training dataset and configuration
48+ mnist := readFile (test , "resources/mnist.py" )
49+ download_mnist_dataset := readFile (test , "resources/download_mnist_datasets.py" )
50+ requirementsFileName := readFile (test , "resources/requirements.txt" )
51+
5052 configData := map [string ][]byte {
51- "config.json" : ReadFile (test , "resources/config.json" ),
52- "twitter_complaints_small.json" : ReadFile (test , "resources/twitter_complaints_small.json" ),
53+ "mnist.py" : mnist ,
54+ "download_mnist_datasets.py" : download_mnist_dataset ,
55+ "requirements.txt" : requirementsFileName ,
5356 }
57+
5458 config := CreateConfigMap (test , namespaceName , configData )
5559
5660 // Create Kueue resources
@@ -70,7 +74,7 @@ func TestSetupPytorchjob(t *testing.T) {
7074 WithName (kueuev1beta1 .ResourceFlavorReference (resourceFlavorName )).
7175 WithResources (
7276 kueueacv1beta1 .ResourceQuota ().WithName (corev1 .ResourceCPU ).WithNominalQuota (resource .MustParse ("8" )),
73- kueueacv1beta1 .ResourceQuota ().WithName (corev1 .ResourceMemory ).WithNominalQuota (resource .MustParse ("12Gi " )),
77+ kueueacv1beta1 .ResourceQuota ().WithName (corev1 .ResourceMemory ).WithNominalQuota (resource .MustParse ("18Gi " )),
7478 ),
7579 ),
7680 ).
@@ -133,6 +137,10 @@ func createUpgradePyTorchJob(test Test, namespace, localQueueName string, config
133137 }
134138
135139 tuningJob := & kftov1.PyTorchJob {
140+ TypeMeta : metav1.TypeMeta {
141+ APIVersion : corev1 .SchemeGroupVersion .String (),
142+ Kind : "PyTorchJob" ,
143+ },
136144 ObjectMeta : metav1.ObjectMeta {
137145 Name : pyTorchJobName ,
138146 Labels : map [string ]string {
@@ -141,85 +149,76 @@ func createUpgradePyTorchJob(test Test, namespace, localQueueName string, config
141149 },
142150 Spec : kftov1.PyTorchJobSpec {
143151 PyTorchReplicaSpecs : map [kftov1.ReplicaType ]* kftov1.ReplicaSpec {
144- "Master" : {
152+ kftov1 . PyTorchJobReplicaTypeMaster : {
145153 Replicas : Ptr (int32 (1 )),
146- RestartPolicy : "OnFailure" ,
154+ RestartPolicy : kftov1 . RestartPolicyOnFailure ,
147155 Template : corev1.PodTemplateSpec {
156+ ObjectMeta : metav1.ObjectMeta {
157+ Labels : map [string ]string {
158+ "app" : "kfto-mnist" ,
159+ "role" : "master" ,
160+ },
161+ },
148162 Spec : corev1.PodSpec {
149- InitContainers : []corev1.Container {
150- {
151- Name : "copy-model" ,
152- Image : kfto .GetBloomModelImage (),
153- ImagePullPolicy : corev1 .PullIfNotPresent ,
154- VolumeMounts : []corev1.VolumeMount {
163+ Affinity : & corev1.Affinity {
164+ PodAntiAffinity : & corev1.PodAntiAffinity {
165+ RequiredDuringSchedulingIgnoredDuringExecution : []corev1.PodAffinityTerm {
155166 {
156- Name : "tmp-volume" ,
157- MountPath : "/tmp" ,
167+ LabelSelector : & metav1.LabelSelector {
168+ MatchLabels : map [string ]string {
169+ "app" : "kfto-mnist" ,
170+ },
171+ },
172+ TopologyKey : "kubernetes.io/hostname" ,
158173 },
159174 },
160- Command : []string {"/bin/sh" , "-c" },
161- Args : []string {"mkdir /tmp/model; cp -r /models/bloom-560m /tmp/model" },
162175 },
163176 },
164177 Containers : []corev1.Container {
165178 {
166179 Name : "pytorch" ,
167- Image : GetFmsHfTuningImage ( test ),
180+ Image : GetCudaTrainingImage ( ),
168181 ImagePullPolicy : corev1 .PullIfNotPresent ,
169- Env : []corev1. EnvVar {
170- {
171- Name : "SFT_TRAINER_CONFIG_JSON_PATH" ,
172- Value : "/etc/config/config.json" ,
173- },
174- {
175- Name : "HF_HOME" ,
176- Value : "/tmp/huggingface" ,
177- } ,
182+ Command : []string {
183+ "/bin/bash" , "-c" ,
184+ ( `mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
185+ pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib --verbose && \
186+ echo "Downloading MNIST dataset..." && \
187+ python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
188+ echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
189+ echo -e "\n\n Starting training..." && \
190+ torchrun --nproc_per_node 2 /mnt/files/mnist.py --dataset_path "/tmp/datasets/mnist" --epochs 7 --save_every 2 --batch_size 128 --lr 0.001 --snapshot_path "mnist_snapshot.pt" --backend "gloo"` ) ,
178191 },
179192 VolumeMounts : []corev1.VolumeMount {
180193 {
181- Name : " config-volume" ,
182- MountPath : "/etc/config " ,
194+ Name : config . Name ,
195+ MountPath : "/mnt/files " ,
183196 },
184197 {
185198 Name : "tmp-volume" ,
186199 MountPath : "/tmp" ,
187200 },
188- {
189- Name : "output-volume" ,
190- MountPath : "/mnt/output" ,
191- },
192201 },
193202 Resources : corev1.ResourceRequirements {
194203 Requests : corev1.ResourceList {
195204 corev1 .ResourceCPU : resource .MustParse ("2" ),
196- corev1 .ResourceMemory : resource .MustParse ("7Gi " ),
205+ corev1 .ResourceMemory : resource .MustParse ("6Gi " ),
197206 },
198207 Limits : corev1.ResourceList {
199208 corev1 .ResourceCPU : resource .MustParse ("2" ),
200- corev1 .ResourceMemory : resource .MustParse ("7Gi " ),
209+ corev1 .ResourceMemory : resource .MustParse ("6Gi " ),
201210 },
202211 },
203212 },
204213 },
205214 Volumes : []corev1.Volume {
206215 {
207- Name : " config-volume" ,
216+ Name : config . Name ,
208217 VolumeSource : corev1.VolumeSource {
209218 ConfigMap : & corev1.ConfigMapVolumeSource {
210219 LocalObjectReference : corev1.LocalObjectReference {
211220 Name : config .Name ,
212221 },
213- Items : []corev1.KeyToPath {
214- {
215- Key : "config.json" ,
216- Path : "config.json" ,
217- },
218- {
219- Key : "twitter_complaints_small.json" ,
220- Path : "twitter_complaints_small.json" ,
221- },
222- },
223222 },
224223 },
225224 },
@@ -229,13 +228,92 @@ func createUpgradePyTorchJob(test Test, namespace, localQueueName string, config
229228 EmptyDir : & corev1.EmptyDirVolumeSource {},
230229 },
231230 },
231+ },
232+ RestartPolicy : corev1 .RestartPolicyOnFailure ,
233+ },
234+ },
235+ },
236+ kftov1 .PyTorchJobReplicaTypeWorker : {
237+ Replicas : Ptr (int32 (2 )),
238+ RestartPolicy : kftov1 .RestartPolicyOnFailure ,
239+ Template : corev1.PodTemplateSpec {
240+ ObjectMeta : metav1.ObjectMeta {
241+ Labels : map [string ]string {
242+ "app" : "kfto-mnist" ,
243+ "role" : "worker" ,
244+ },
245+ },
246+ Spec : corev1.PodSpec {
247+ Affinity : & corev1.Affinity {
248+ PodAntiAffinity : & corev1.PodAntiAffinity {
249+ RequiredDuringSchedulingIgnoredDuringExecution : []corev1.PodAffinityTerm {
250+ {
251+ LabelSelector : & metav1.LabelSelector {
252+ MatchLabels : map [string ]string {
253+ "app" : "kfto-mnist" ,
254+ },
255+ },
256+ TopologyKey : "kubernetes.io/hostname" ,
257+ },
258+ },
259+ },
260+ },
261+ Containers : []corev1.Container {
232262 {
233- Name : "output-volume" ,
263+ Name : "pytorch" ,
264+ Image : GetCudaTrainingImage (),
265+ ImagePullPolicy : corev1 .PullIfNotPresent ,
266+ Command : []string {
267+ "/bin/bash" , "-c" ,
268+ (`mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
269+ pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib --verbose && \
270+ echo "Downloading MNIST dataset..." && \
271+ python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
272+ echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
273+ echo -e "\n\n Starting training..." && \
274+ torchrun --nproc_per_node 2 /mnt/files/mnist.py --dataset_path "/tmp/datasets/mnist" --epochs 7 --save_every 2 --batch_size 128 --lr 0.001 --snapshot_path "mnist_snapshot.pt" --backend "gloo"` ),
275+ },
276+ VolumeMounts : []corev1.VolumeMount {
277+ {
278+ Name : config .Name ,
279+ MountPath : "/mnt/files" ,
280+ },
281+ {
282+ Name : "tmp-volume" ,
283+ MountPath : "/tmp" ,
284+ },
285+ },
286+ Resources : corev1.ResourceRequirements {
287+ Requests : corev1.ResourceList {
288+ corev1 .ResourceCPU : resource .MustParse ("2" ),
289+ corev1 .ResourceMemory : resource .MustParse ("6Gi" ),
290+ },
291+ Limits : corev1.ResourceList {
292+ corev1 .ResourceCPU : resource .MustParse ("2" ),
293+ corev1 .ResourceMemory : resource .MustParse ("6Gi" ),
294+ },
295+ },
296+ },
297+ },
298+ Volumes : []corev1.Volume {
299+ {
300+ Name : config .Name ,
301+ VolumeSource : corev1.VolumeSource {
302+ ConfigMap : & corev1.ConfigMapVolumeSource {
303+ LocalObjectReference : corev1.LocalObjectReference {
304+ Name : config .Name ,
305+ },
306+ },
307+ },
308+ },
309+ {
310+ Name : "tmp-volume" ,
234311 VolumeSource : corev1.VolumeSource {
235312 EmptyDir : & corev1.EmptyDirVolumeSource {},
236313 },
237314 },
238315 },
316+ RestartPolicy : corev1 .RestartPolicyOnFailure ,
239317 },
240318 },
241319 },
0 commit comments