|
12 | 12 | },
|
13 | 13 | {
|
14 | 14 | "cell_type": "code",
|
15 |
| - "execution_count": 42, |
| 15 | + "execution_count": null, |
16 | 16 | "id": "b55bc3ea-4ce3-49bf-bb1f-e209de8ca47a",
|
17 | 17 | "metadata": {
|
18 | 18 | "tags": []
|
|
21 | 21 | "source": [
|
22 | 22 | "import sys\n",
|
23 | 23 | "sys.path.append(\"../notebooks\") # needed to make kfto_sdk_mnist\n",
|
24 |
| - "from kfto_sdk_mnist import train_func\n", |
25 |
| - "from kubeflow.training import TrainingClient\n", |
26 |
| - "from kubernetes import client as c\n", |
27 |
| - "import time" |
| 24 | + "from kfto_sdk_mnist import train_func " |
28 | 25 | ]
|
29 | 26 | },
|
30 | 27 | {
|
31 | 28 | "cell_type": "code",
|
32 |
| - "execution_count": 41, |
| 29 | + "execution_count": null, |
33 | 30 | "id": "72dd1751",
|
34 | 31 | "metadata": {
|
35 | 32 | "tags": [
|
|
42 | 39 | "num_gpus = \"${num_gpus}\"\n",
|
43 | 40 | "openshift_api_url = \"${api_url}\"\n",
|
44 | 41 | "namespace = \"${namespace}\"\n",
|
45 |
| - "token = \"${password}\"\n", |
46 |
| - "training_image= \"${training_image}\"" |
47 |
| - ] |
48 |
| - }, |
49 |
| - { |
50 |
| - "cell_type": "code", |
51 |
| - "execution_count": null, |
52 |
| - "id": "aadaa4a7", |
53 |
| - "metadata": {}, |
54 |
| - "outputs": [], |
55 |
| - "source": [ |
56 |
| - "def GetDefaultLocalQueue(namespace: str):\n", |
57 |
| - " \"\"\"\n", |
58 |
| - " Fetches the LocalQueue in the given namespace whose annotation\n", |
59 |
| - " \"kueue.x-k8s.io/default-queue\" == \"true\". Returns the dict\n", |
60 |
| - " for the queue, or None.\n", |
61 |
| - " \"\"\"\n", |
62 |
| - " group = \"kueue.x-k8s.io\"\n", |
63 |
| - " version = \"v1beta1\"\n", |
64 |
| - " plural = \"localqueues\"\n", |
65 |
| - "\n", |
66 |
| - " conf = c.Configuration()\n", |
67 |
| - " conf.host = openshift_api_url\n", |
68 |
| - " conf.verify_ssl = False\n", |
69 |
| - " conf.api_key = {\"authorization\": f\"Bearer {token}\"}\n", |
70 |
| - "\n", |
71 |
| - " api_client = c.ApiClient(configuration=conf)\n", |
72 |
| - " api = c.CustomObjectsApi(api_client)\n", |
73 |
| - "\n", |
74 |
| - " resp = api.list_namespaced_custom_object(\n", |
75 |
| - " group=group, version=version, namespace=namespace, plural=plural\n", |
76 |
| - " )\n", |
77 |
| - "\n", |
78 |
| - " default_q = None\n", |
79 |
| - " for item in resp.get(\"items\", []):\n", |
80 |
| - " ann = item.get(\"metadata\", {}).get(\"annotations\") or {}\n", |
81 |
| - " if ann.get(\"kueue.x-k8s.io/default-queue\") == \"true\":\n", |
82 |
| - " if default_q is not None:\n", |
83 |
| - " raise RuntimeError(\n", |
84 |
| - " f\"multiple LocalQueues annotated as default in {namespace}: \"\n", |
85 |
| - " f\"{default_q['metadata']['name']} and {item['metadata']['name']}\"\n", |
86 |
| - " )\n", |
87 |
| - " default_q = item['metadata']['name']\n", |
88 |
| - "\n", |
89 |
| - " if default_q is None:\n", |
90 |
| - " raise RuntimeError(f\"no LocalQueue annotated as default in namespace {namespace}\")\n", |
91 |
| - "\n", |
92 |
| - " return default_q" |
| 42 | + "token = \"${token}\"\n", |
| 43 | + "training_image= \"${training_image}\"\n", |
| 44 | + "localQueue = \"${localQueue}\"" |
93 | 45 | ]
|
94 | 46 | },
|
95 | 47 | {
|
|
99 | 51 | "metadata": {},
|
100 | 52 | "outputs": [],
|
101 | 53 | "source": [
|
102 |
| - "api_key = {\"authorization\": f\"Bearer {token}\"}\n", |
103 |
| - "# config = c.Configuration(host=openshift_api_url, api_key=token)\n", |
104 |
| - "# config.verify_ssl = False\n", |
105 |
| - "tc = TrainingClient()\n", |
| 54 | + "from kubernetes import client\n", |
| 55 | + "from kubeflow.training import TrainingClient\n", |
106 | 56 | "\n",
|
107 |
| - "# get default local queue\n", |
108 |
| - "default_local_queue=GetDefaultLocalQueue(namespace)" |
| 57 | + "configuration = client.Configuration()\n", |
| 58 | + "configuration.host = openshift_api_url\n", |
| 59 | + "configuration.api_key = {\"authorization\": f\"Bearer {token}\"}\n", |
| 60 | + "configuration.verify_ssl = False\n", |
| 61 | + "api_client = client.ApiClient(configuration)\n", |
| 62 | + "client = TrainingClient(client_configuration=api_client.configuration)" |
109 | 63 | ]
|
110 | 64 | },
|
111 | 65 | {
|
|
116 | 70 | "outputs": [],
|
117 | 71 | "source": [
|
118 | 72 | "import os\n",
|
119 |
| - "tc.create_job(\n", |
| 73 | + "client.create_job(\n", |
120 | 74 | " name=\"pytorch-ddp\",\n",
|
121 | 75 | " namespace=namespace,\n",
|
122 | 76 | " train_func=train_func,\n",
|
|
132 | 86 | " \"PIP_TRUSTED_HOST\": os.environ.get(\"PIP_TRUSTED_HOST\")\n",
|
133 | 87 | " },\n",
|
134 | 88 | " labels={\n",
|
135 |
| - " \"kueue.x-k8s.io/queue-name\": default_local_queue,\n", |
| 89 | + " \"kueue.x-k8s.io/queue-name\": localQueue,\n", |
136 | 90 | " }\n",
|
137 | 91 | ")"
|
138 | 92 | ]
|
|
144 | 98 | "metadata": {},
|
145 | 99 | "outputs": [],
|
146 | 100 | "source": [
|
147 |
| - "while not tc.is_job_succeeded(name=\"pytorch-ddp\", namespace=namespace): \n", |
| 101 | + "import time\n", |
| 102 | + "while not client.is_job_succeeded(name=\"pytorch-ddp\", namespace=namespace): \n", |
148 | 103 | " time.sleep(1)\n",
|
149 | 104 | "print(\"PytorchJob Succeeded!\")"
|
150 | 105 | ]
|
|
0 commit comments