Skip to content

Commit 3cc71b1

Browse files
Fix kfto-sdk-minist-test notebook to update local-queue usage flow
1 parent 5733b42 commit 3cc71b1

File tree

2 files changed

+22
-66
lines changed

2 files changed

+22
-66
lines changed

tests/kfto/kfto_mnist_sdk_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ func runMnistSDK(t *testing.T, trainingImage string) {
9191

9292
clusterQueue := CreateKueueClusterQueue(test, cqSpec)
9393
defer test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
94-
CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
94+
95+
localQueue := CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
9596

9697
jupyterNotebook := string(readFile(test, "resources/mnist_kfto.ipynb"))
9798
requirements := readFile(test, "resources/requirements.txt")
@@ -107,8 +108,8 @@ func runMnistSDK(t *testing.T, trainingImage string) {
107108
"-c",
108109
fmt.Sprintf("pip install papermill && papermill /opt/app-root/notebooks/%s"+
109110
" /opt/app-root/src/mnist-kfto-out.ipynb -p namespace %s -p openshift_api_url %s"+
110-
" -p token %s -p num_gpus %d -p training_image %s --log-output && sleep infinity",
111-
jupyterNotebookConfigMapFileName, namespace.Name, GetOpenShiftApiUrl(test), userToken, 0, trainingImage)}
111+
" -p token %s -p num_gpus %d -p training_image %s -p localQueue %s --log-output && sleep infinity",
112+
jupyterNotebookConfigMapFileName, namespace.Name, GetOpenShiftApiUrl(test), userToken, 0, trainingImage, localQueue.Name)}
112113

113114
// Create PVC for Notebook
114115
notebookPVC := CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", AccessModes(corev1.ReadWriteOnce))

tests/kfto/resources/mnist_kfto.ipynb

Lines changed: 18 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 42,
15+
"execution_count": null,
1616
"id": "b55bc3ea-4ce3-49bf-bb1f-e209de8ca47a",
1717
"metadata": {
1818
"tags": []
@@ -21,15 +21,12 @@
2121
"source": [
2222
"import sys\n",
2323
"sys.path.append(\"../notebooks\") # needed to make kfto_sdk_mnist\n",
24-
"from kfto_sdk_mnist import train_func\n",
25-
"from kubeflow.training import TrainingClient\n",
26-
"from kubernetes import client as c\n",
27-
"import time"
24+
"from kfto_sdk_mnist import train_func "
2825
]
2926
},
3027
{
3128
"cell_type": "code",
32-
"execution_count": 41,
29+
"execution_count": null,
3330
"id": "72dd1751",
3431
"metadata": {
3532
"tags": [
@@ -42,54 +39,9 @@
4239
"num_gpus = \"${num_gpus}\"\n",
4340
"openshift_api_url = \"${api_url}\"\n",
4441
"namespace = \"${namespace}\"\n",
45-
"token = \"${password}\"\n",
46-
"training_image= \"${training_image}\""
47-
]
48-
},
49-
{
50-
"cell_type": "code",
51-
"execution_count": null,
52-
"id": "aadaa4a7",
53-
"metadata": {},
54-
"outputs": [],
55-
"source": [
56-
"def GetDefaultLocalQueue(namespace: str):\n",
57-
" \"\"\"\n",
58-
" Fetches the LocalQueue in the given namespace whose annotation\n",
59-
" \"kueue.x-k8s.io/default-queue\" == \"true\". Returns the dict\n",
60-
" for the queue, or None.\n",
61-
" \"\"\"\n",
62-
" group = \"kueue.x-k8s.io\"\n",
63-
" version = \"v1beta1\"\n",
64-
" plural = \"localqueues\"\n",
65-
"\n",
66-
" conf = c.Configuration()\n",
67-
" conf.host = openshift_api_url\n",
68-
" conf.verify_ssl = False\n",
69-
" conf.api_key = {\"authorization\": f\"Bearer {token}\"}\n",
70-
"\n",
71-
" api_client = c.ApiClient(configuration=conf)\n",
72-
" api = c.CustomObjectsApi(api_client)\n",
73-
"\n",
74-
" resp = api.list_namespaced_custom_object(\n",
75-
" group=group, version=version, namespace=namespace, plural=plural\n",
76-
" )\n",
77-
"\n",
78-
" default_q = None\n",
79-
" for item in resp.get(\"items\", []):\n",
80-
" ann = item.get(\"metadata\", {}).get(\"annotations\") or {}\n",
81-
" if ann.get(\"kueue.x-k8s.io/default-queue\") == \"true\":\n",
82-
" if default_q is not None:\n",
83-
" raise RuntimeError(\n",
84-
" f\"multiple LocalQueues annotated as default in {namespace}: \"\n",
85-
" f\"{default_q['metadata']['name']} and {item['metadata']['name']}\"\n",
86-
" )\n",
87-
" default_q = item['metadata']['name']\n",
88-
"\n",
89-
" if default_q is None:\n",
90-
" raise RuntimeError(f\"no LocalQueue annotated as default in namespace {namespace}\")\n",
91-
"\n",
92-
" return default_q"
42+
"token = \"${token}\"\n",
43+
"training_image= \"${training_image}\"\n",
44+
"localQueue = \"${localQueue}\""
9345
]
9446
},
9547
{
@@ -99,13 +51,15 @@
9951
"metadata": {},
10052
"outputs": [],
10153
"source": [
102-
"api_key = {\"authorization\": f\"Bearer {token}\"}\n",
103-
"# config = c.Configuration(host=openshift_api_url, api_key=token)\n",
104-
"# config.verify_ssl = False\n",
105-
"tc = TrainingClient()\n",
54+
"from kubernetes import client\n",
55+
"from kubeflow.training import TrainingClient\n",
10656
"\n",
107-
"# get default local queue\n",
108-
"default_local_queue=GetDefaultLocalQueue(namespace)"
57+
"configuration = client.Configuration()\n",
58+
"configuration.host = openshift_api_url\n",
59+
"configuration.api_key = {\"authorization\": f\"Bearer {token}\"}\n",
60+
"configuration.verify_ssl = False\n",
61+
"api_client = client.ApiClient(configuration)\n",
62+
"client = TrainingClient(client_configuration=api_client.configuration)"
10963
]
11064
},
11165
{
@@ -116,7 +70,7 @@
11670
"outputs": [],
11771
"source": [
11872
"import os\n",
119-
"tc.create_job(\n",
73+
"client.create_job(\n",
12074
" name=\"pytorch-ddp\",\n",
12175
" namespace=namespace,\n",
12276
" train_func=train_func,\n",
@@ -132,7 +86,7 @@
13286
" \"PIP_TRUSTED_HOST\": os.environ.get(\"PIP_TRUSTED_HOST\")\n",
13387
" },\n",
13488
" labels={\n",
135-
" \"kueue.x-k8s.io/queue-name\": default_local_queue,\n",
89+
" \"kueue.x-k8s.io/queue-name\": localQueue,\n",
13690
" }\n",
13791
")"
13892
]
@@ -144,7 +98,8 @@
14498
"metadata": {},
14599
"outputs": [],
146100
"source": [
147-
"while not tc.is_job_succeeded(name=\"pytorch-ddp\", namespace=namespace): \n",
101+
"import time\n",
102+
"while not client.is_job_succeeded(name=\"pytorch-ddp\", namespace=namespace): \n",
148103
" time.sleep(1)\n",
149104
"print(\"PytorchJob Succeeded!\")"
150105
]

0 commit comments

Comments
 (0)