Skip to content

Commit 995f993

Browse files
refactor shared logic into common module
Signed-off-by: Kevin <[email protected]>
1 parent c16fd7f commit 995f993

19 files changed

+158
-515
lines changed

tests/odh/environment.go renamed to tests/common/environment.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
package odh
17+
package common
1818

1919
import (
2020
"os"

tests/odh/notebook.go renamed to tests/common/notebook.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
package odh
17+
package common
1818

1919
import (
2020
"bytes"
21+
"embed"
2122

2223
gomega "github.com/onsi/gomega"
2324
. "github.com/project-codeflare/codeflare-common/support"
@@ -34,6 +35,16 @@ const (
3435
NOTEBOOK_CONTAINER_NAME = "jupyter-nb-kube-3aadmin"
3536
)
3637

38+
//go:embed resources/*
39+
var files embed.FS
40+
41+
func readFile(t Test, fileName string) []byte {
42+
t.T().Helper()
43+
file, err := files.ReadFile(fileName)
44+
t.Expect(err).NotTo(gomega.HaveOccurred())
45+
return file
46+
}
47+
3748
var notebookResource = schema.GroupVersionResource{Group: "kubeflow.org", Version: "v1", Resource: "notebooks"}
3849

3950
type NotebookProps struct {
@@ -42,7 +53,7 @@ type NotebookProps struct {
4253
KubernetesUserBearerToken string
4354
Namespace string
4455
OpenDataHubNamespace string
45-
RayImage string
56+
Command []string
4657
NotebookImage string
4758
NotebookConfigMapName string
4859
NotebookConfigMapFileName string
@@ -57,7 +68,7 @@ type NotebookProps struct {
5768
S3DefaultRegion string
5869
}
5970

60-
func createNotebook(test Test, namespace *corev1.Namespace, notebookUserToken, rayImage string, jupyterNotebookConfigMapName, jupyterNotebookConfigMapFileName string, numGpus int) {
71+
func CreateNotebook(test Test, namespace *corev1.Namespace, notebookUserToken string, command []string, jupyterNotebookConfigMapName, jupyterNotebookConfigMapFileName string, numGpus int) {
6172
// Create PVC for Notebook
6273
notebookPVC := CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", corev1.ReadWriteOnce)
6374
s3BucketName, s3BucketNameExists := GetStorageBucketName()
@@ -81,7 +92,7 @@ func createNotebook(test Test, namespace *corev1.Namespace, notebookUserToken, r
8192
KubernetesUserBearerToken: notebookUserToken,
8293
Namespace: namespace.Name,
8394
OpenDataHubNamespace: GetOpenDataHubNamespace(test),
84-
RayImage: rayImage,
95+
Command: command,
8596
NotebookImage: GetNotebookImage(test),
8697
NotebookConfigMapName: jupyterNotebookConfigMapName,
8798
NotebookConfigMapFileName: jupyterNotebookConfigMapFileName,
@@ -108,12 +119,12 @@ func createNotebook(test Test, namespace *corev1.Namespace, notebookUserToken, r
108119
test.Expect(err).NotTo(gomega.HaveOccurred())
109120
}
110121

111-
func deleteNotebook(test Test, namespace *corev1.Namespace) {
122+
func DeleteNotebook(test Test, namespace *corev1.Namespace) {
112123
err := test.Client().Dynamic().Resource(notebookResource).Namespace(namespace.Name).Delete(test.Ctx(), "jupyter-nb-kube-3aadmin", metav1.DeleteOptions{})
113124
test.Expect(err).NotTo(gomega.HaveOccurred())
114125
}
115126

116-
func listNotebooks(test Test, namespace *corev1.Namespace) []*unstructured.Unstructured {
127+
func ListNotebooks(test Test, namespace *corev1.Namespace) []*unstructured.Unstructured {
117128
ntbs, err := test.Client().Dynamic().Resource(notebookResource).Namespace(namespace.Name).List(test.Ctx(), metav1.ListOptions{})
118129
test.Expect(err).NotTo(gomega.HaveOccurred())
119130

tests/common/template.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
Copyright 2024.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package common
18+
19+
import (
20+
"bytes"
21+
"text/template"
22+
23+
"github.com/onsi/gomega"
24+
"github.com/project-codeflare/codeflare-common/support"
25+
)
26+
27+
func ParseAWSArgs(t support.Test, inputTemplate []byte) []byte {
28+
storage_bucket_endpoint, storage_bucket_endpoint_exists := support.GetStorageBucketDefaultEndpoint()
29+
storage_bucket_access_key_id, storage_bucket_access_key_id_exists := support.GetStorageBucketAccessKeyId()
30+
storage_bucket_secret_key, storage_bucket_secret_key_exists := support.GetStorageBucketSecretKey()
31+
storage_bucket_name, storage_bucket_name_exists := support.GetStorageBucketName()
32+
storage_bucket_mnist_dir, storage_bucket_mnist_dir_exists := support.GetStorageBucketMnistDir()
33+
34+
props := struct {
35+
StorageBucketDefaultEndpoint string
36+
StorageBucketDefaultEndpointExists bool
37+
StorageBucketAccessKeyId string
38+
StorageBucketAccessKeyIdExists bool
39+
StorageBucketSecretKey string
40+
StorageBucketSecretKeyExists bool
41+
StorageBucketName string
42+
StorageBucketNameExists bool
43+
StorageBucketMnistDir string
44+
StorageBucketMnistDirExists bool
45+
}{
46+
StorageBucketDefaultEndpoint: storage_bucket_endpoint,
47+
StorageBucketDefaultEndpointExists: storage_bucket_endpoint_exists,
48+
StorageBucketAccessKeyId: storage_bucket_access_key_id,
49+
StorageBucketAccessKeyIdExists: storage_bucket_access_key_id_exists,
50+
StorageBucketSecretKey: storage_bucket_secret_key,
51+
StorageBucketSecretKeyExists: storage_bucket_secret_key_exists,
52+
StorageBucketName: storage_bucket_name,
53+
StorageBucketNameExists: storage_bucket_name_exists,
54+
StorageBucketMnistDir: storage_bucket_mnist_dir,
55+
StorageBucketMnistDirExists: storage_bucket_mnist_dir_exists,
56+
}
57+
58+
return ParseTemplate(t, inputTemplate, props)
59+
}
60+
61+
func ParseTemplate(t support.Test, inputTemplate []byte, props interface{}) []byte {
62+
t.T().Helper()
63+
64+
// Parse input template
65+
parsedTemplate, err := template.New("template").Parse(string(inputTemplate))
66+
t.Expect(err).NotTo(gomega.HaveOccurred())
67+
68+
// Filter template and store results to the buffer
69+
buffer := new(bytes.Buffer)
70+
err = parsedTemplate.Execute(buffer, props)
71+
t.Expect(err).NotTo(gomega.HaveOccurred())
72+
err = parsedTemplate.Execute(buffer, props) // NOTE: not sure if template package handles recursive case
73+
t.Expect(err).NotTo(gomega.HaveOccurred())
74+
75+
return buffer.Bytes()
76+
}

tests/kfto/kfto_mnist_sdk_test.go

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"testing"
2222

2323
. "github.com/onsi/gomega"
24+
. "github.com/opendatahub-io/distributed-workloads/tests/common"
2425
. "github.com/project-codeflare/codeflare-common/support"
2526

2627
v1 "k8s.io/api/core/v1"
@@ -33,7 +34,7 @@ func TestMnistSDK(t *testing.T) {
3334
userName := GetNotebookUserName(test)
3435
userToken := GetNotebookUserToken(test)
3536
jupyterNotebookConfigMapFileName := "mnist_kfto.ipynb"
36-
mnist := readMnistScriptTemplate(test, "resources/kfto_sdk_mnist.py")
37+
mnist := ParseAWSArgs(test, readFile(test, "resources/kfto_sdk_mnist.py"))
3738

3839
// Create role binding with Namespace specific admin cluster role
3940
CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin")
@@ -45,8 +46,8 @@ func TestMnistSDK(t *testing.T) {
4546
"${namespace}": namespace.Name,
4647
}
4748

48-
jupyterNotebook := string(ReadFile(test, "resources/mnist_kfto.ipynb"))
49-
requirements := ReadFile(test, "resources/requirements.txt")
49+
jupyterNotebook := string(readFile(test, "resources/mnist_kfto.ipynb"))
50+
requirements := readFile(test, "resources/requirements.txt")
5051
for oldValue, newValue := range requiredChangesInNotebook {
5152
jupyterNotebook = strings.Replace(string(jupyterNotebook), oldValue, newValue, -1)
5253
}
@@ -57,13 +58,21 @@ func TestMnistSDK(t *testing.T) {
5758
"requirements.txt": requirements,
5859
})
5960

61+
notebookCommand := []string{
62+
"bin/sh",
63+
"-c",
64+
"pip install papermill && papermill /opt/app-root/notebooks/{{.NotebookConfigMapFileName}}" +
65+
" /opt/app-root/src/mcad-out.ipynb -p namespace {{.Namespace}} -p openshift_api_url {{.OpenShiftApiUrl}}" +
66+
" -p kubernetes_user_bearer_token {{.KubernetesUserBearerToken}}" +
67+
" -p num_gpus {{ .NumGpus }} --log-output && sleep infinity",
68+
}
6069
// Create Notebook CR
61-
createNotebook(test, namespace, userToken, config.Name, jupyterNotebookConfigMapFileName, 0)
70+
CreateNotebook(test, namespace, userToken, notebookCommand, config.Name, jupyterNotebookConfigMapFileName, 0)
6271

6372
// Gracefully cleanup Notebook
6473
defer func() {
65-
deleteNotebook(test, namespace)
66-
test.Eventually(listNotebooks(test, namespace), TestTimeoutGpuProvisioning).Should(HaveLen(0))
74+
DeleteNotebook(test, namespace)
75+
test.Eventually(ListNotebooks(test, namespace), TestTimeoutGpuProvisioning).Should(HaveLen(0))
6776
}()
6877

6978
// Make sure pytorch job is created
@@ -77,40 +86,3 @@ func TestMnistSDK(t *testing.T) {
7786
// TODO: write torch job logs?
7887
// time.Sleep(60 * time.Second)
7988
}
80-
81-
func readMnistScriptTemplate(test Test, filePath string) []byte {
82-
// Read the mnist.py from resources and perform replacements for custom values using go template
83-
storage_bucket_endpoint, storage_bucket_endpoint_exists := GetStorageBucketDefaultEndpoint()
84-
storage_bucket_access_key_id, storage_bucket_access_key_id_exists := GetStorageBucketAccessKeyId()
85-
storage_bucket_secret_key, storage_bucket_secret_key_exists := GetStorageBucketSecretKey()
86-
storage_bucket_name, storage_bucket_name_exists := GetStorageBucketName()
87-
storage_bucket_mnist_dir, storage_bucket_mnist_dir_exists := GetStorageBucketMnistDir()
88-
89-
props := struct {
90-
StorageBucketDefaultEndpoint string
91-
StorageBucketDefaultEndpointExists bool
92-
StorageBucketAccessKeyId string
93-
StorageBucketAccessKeyIdExists bool
94-
StorageBucketSecretKey string
95-
StorageBucketSecretKeyExists bool
96-
StorageBucketName string
97-
StorageBucketNameExists bool
98-
StorageBucketMnistDir string
99-
StorageBucketMnistDirExists bool
100-
}{
101-
StorageBucketDefaultEndpoint: storage_bucket_endpoint,
102-
StorageBucketDefaultEndpointExists: storage_bucket_endpoint_exists,
103-
StorageBucketAccessKeyId: storage_bucket_access_key_id,
104-
StorageBucketAccessKeyIdExists: storage_bucket_access_key_id_exists,
105-
StorageBucketSecretKey: storage_bucket_secret_key,
106-
StorageBucketSecretKeyExists: storage_bucket_secret_key_exists,
107-
StorageBucketName: storage_bucket_name,
108-
StorageBucketNameExists: storage_bucket_name_exists,
109-
StorageBucketMnistDir: storage_bucket_mnist_dir,
110-
StorageBucketMnistDirExists: storage_bucket_mnist_dir_exists,
111-
}
112-
template, err := files.ReadFile(filePath)
113-
test.Expect(err).NotTo(HaveOccurred())
114-
115-
return ParseTemplate(test, template, props)
116-
}

tests/kfto/kfto_mnist_training_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ func runKFTOPyTorchMnistJob(t *testing.T, accelerator Accelerator, image string,
5959
// Create a namespace
6060
namespace := test.NewTestNamespace()
6161

62-
mnist := ReadFile(test, "resources/mnist.py")
63-
download_mnist_dataset := ReadFile(test, "resources/download_mnist_datasets.py")
64-
requirementsFileName := ReadFile(test, requirementsFile)
62+
mnist := readFile(test, "resources/mnist.py")
63+
download_mnist_dataset := readFile(test, "resources/download_mnist_datasets.py")
64+
requirementsFileName := readFile(test, requirementsFile)
6565

6666
if accelerator.isGpu() {
6767
mnist = bytes.Replace(mnist, []byte("accelerator=\"has to be specified\""), []byte("accelerator=\"gpu\""), 1)

tests/kfto/kfto_training_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func runKFTOPyTorchJob(t *testing.T, image string, gpu Accelerator, numGpus, num
7070

7171
// Create a ConfigMap with training script
7272
configData := map[string][]byte{
73-
"hf_llm_training.py": ReadFile(test, "resources/hf_llm_training.py"),
73+
"hf_llm_training.py": readFile(test, "resources/hf_llm_training.py"),
7474
}
7575
config := CreateConfigMap(test, namespace, configData)
7676

0 commit comments

Comments
 (0)