@@ -136,6 +136,12 @@ func createUpgradePyTorchJob(test Test, namespace, localQueueName string, config
136136 test .T ().Fatalf ("Error retrieving PyTorchJob with name `%s`: %v" , pyTorchJobName , err )
137137 }
138138
139+ storage_bucket_endpoint , storage_bucket_endpoint_exists := GetStorageBucketDefaultEndpoint ()
140+ storage_bucket_access_key_id , storage_bucket_access_key_id_exists := GetStorageBucketAccessKeyId ()
141+ storage_bucket_secret_key , storage_bucket_secret_key_exists := GetStorageBucketSecretKey ()
142+ storage_bucket_name , storage_bucket_name_exists := GetStorageBucketName ()
143+ storage_bucket_mnist_dir , storage_bucket_mnist_dir_exists := GetStorageBucketMnistDir ()
144+
139145 tuningJob := & kftov1.PyTorchJob {
140146 TypeMeta : metav1.TypeMeta {
141147 APIVersion : corev1 .SchemeGroupVersion .String (),
@@ -321,6 +327,62 @@ func createUpgradePyTorchJob(test Test, namespace, localQueueName string, config
321327 },
322328 }
323329
330+ // Add PIP Index to download python packages, use provided custom PYPI mirror index url in case of disconnected environemnt
331+ tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env = []corev1.EnvVar {
332+ {
333+ Name : "PIP_INDEX_URL" ,
334+ Value : GetPipIndexURL (),
335+ },
336+ {
337+ Name : "PIP_TRUSTED_HOST" ,
338+ Value : GetPipTrustedHost (),
339+ },
340+ }
341+ tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeWorker ].Template .Spec .Containers [0 ].Env = []corev1.EnvVar {
342+ {
343+ Name : "PIP_INDEX_URL" ,
344+ Value : GetPipIndexURL (),
345+ },
346+ {
347+ Name : "PIP_TRUSTED_HOST" ,
348+ Value : GetPipTrustedHost (),
349+ },
350+ }
351+
352+ // Use storage bucket to download the MNIST datasets if required environment variables are provided, else use default MNIST mirror references as the fallback
353+ if storage_bucket_endpoint_exists && storage_bucket_access_key_id_exists && storage_bucket_secret_key_exists && storage_bucket_name_exists && storage_bucket_mnist_dir_exists {
354+ storage_bucket_env_vars := []corev1.EnvVar {
355+ {
356+ Name : "AWS_DEFAULT_ENDPOINT" ,
357+ Value : storage_bucket_endpoint ,
358+ },
359+ {
360+ Name : "AWS_ACCESS_KEY_ID" ,
361+ Value : storage_bucket_access_key_id ,
362+ },
363+ {
364+ Name : "AWS_SECRET_ACCESS_KEY" ,
365+ Value : storage_bucket_secret_key ,
366+ },
367+ {
368+ Name : "AWS_STORAGE_BUCKET" ,
369+ Value : storage_bucket_name ,
370+ },
371+ {
372+ Name : "AWS_STORAGE_BUCKET_MNIST_DIR" ,
373+ Value : storage_bucket_mnist_dir ,
374+ },
375+ }
376+
377+ // Append the list of environment variables for the worker container
378+ for _ , envVar := range storage_bucket_env_vars {
379+ tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env = upsert (tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env , envVar , withEnvVarName (envVar .Name ))
380+ tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeWorker ].Template .Spec .Containers [0 ].Env = upsert (tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeWorker ].Template .Spec .Containers [0 ].Env , envVar , withEnvVarName (envVar .Name ))
381+ }
382+ } else {
383+ test .T ().Logf ("Skipped usage of S3 storage bucket, because required environment variables aren't provided!\n Required environment variables : AWS_DEFAULT_ENDPOINT, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_STORAGE_BUCKET, AWS_STORAGE_BUCKET_MNIST_DIR" )
384+ }
385+
324386 tuningJob , err = test .Client ().Kubeflow ().KubeflowV1 ().PyTorchJobs (namespace ).Create (test .Ctx (), tuningJob , metav1.CreateOptions {})
325387 test .Expect (err ).NotTo (HaveOccurred ())
326388 test .T ().Logf ("Created PytorchJob %s/%s successfully" , tuningJob .Namespace , tuningJob .Name )
0 commit comments