Skip to content

Commit cf25842

Browse files
Adding osft e2e test
1 parent 01c2d35 commit cf25842

File tree

4 files changed

+500
-0
lines changed

4 files changed

+500
-0
lines changed

tests/trainer/kubeflow_sdk_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,17 @@ import (
2020
"testing"
2121

2222
. "github.com/opendatahub-io/distributed-workloads/tests/common"
23+
support "github.com/opendatahub-io/distributed-workloads/tests/common/support"
2324
sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests"
2425
)
2526

2627
func TestKubeflowSdkSanity(t *testing.T) {
2728
Tags(t, Sanity)
2829
sdktests.RunFashionMnistCpuDistributedTraining(t)
2930
}
31+
32+
// TestOsftTrainingHubMultiNodeMultiGPU tests OSFT training using TrainingHubTrainer
33+
func TestOsftTrainingHubMultiNodeMultiGPU(t *testing.T) {
34+
Tags(t, KftoCuda, MultiNodeMultiGpu(2, support.NVIDIA, 1)) // TODO: may need to be updated once https://issues.redhat.com/browse/RHOAIENG-30719 and https://issues.redhat.com/browse/RHOAIENG-24552 are resolved
35+
sdktests.RunOsftTrainingHubMultiGpuDistributedTraining(t)
36+
}

tests/trainer/resources/osft.ipynb

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 3,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# Standard library imports\n",
10+
"import logging\n",
11+
"import os\n",
12+
"import sys\n",
13+
"import time\n",
14+
"from io import StringIO"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"\n",
24+
"from kubernetes import client as k8s, config as k8s_config\n",
25+
"# Edit to match your specific settings\n",
26+
"api_server = os.getenv(\"OPENSHIFT_API_URL\")\n",
27+
"token = os.getenv(\"NOTEBOOK_USER_TOKEN\")\n",
28+
"PVC_NAME = os.getenv(\"SHARED_PVC_NAME\", \"shared\")\n",
29+
"\n",
30+
"configuration = k8s.Configuration()\n",
31+
"configuration.host = api_server\n",
32+
"# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA\n",
33+
"configuration.verify_ssl = False\n",
34+
"configuration.api_key = {\"authorization\": f\"Bearer {token}\"}\n",
35+
"api_client = k8s.ApiClient(configuration)\n",
36+
"\n",
37+
"PVC_MOUNT_PATH = \"/opt/app-root/src\""
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"import json\n",
47+
"import random\n",
48+
"\n",
49+
"from datasets import load_dataset\n",
50+
"\n",
51+
"# Load the Table-GPT dataset\n",
52+
"print(\"Loading Table-GPT dataset...\")\n",
53+
"dataset = load_dataset(\"LipengCS/Table-GPT\", \"All\")\n",
54+
"\n",
55+
"# Get the training split and create a random subset of 100 samples\n",
56+
"train_data = dataset[\"train\"]\n",
57+
"print(f\"Original training set size: {len(train_data)}\")\n",
58+
"\n",
59+
"# Create a random subset of 100 samples\n",
60+
"random.seed(42) # For reproducibility\n",
61+
"subset_indices = random.sample(range(len(train_data)), min(100, len(train_data)))\n",
62+
"subset_data = train_data.select(subset_indices)\n",
63+
"\n",
64+
"print(f\"Subset size: {len(subset_data)}\")\n",
65+
"\n",
66+
"# Save the subset to a JSONL file\n",
67+
"# Save the subset to a JSONL file - USE ABSOLUTE PATH\n",
68+
"output_dir = \"table-gpt-data/train\"\n",
69+
"output_file = f\"{output_dir}/train_All_100.jsonl\"\n",
70+
"\n",
71+
"print(f\"Creating directory: {output_dir}\")\n",
72+
"os.makedirs(output_dir, exist_ok=True)\n",
73+
"\n",
74+
"with open(output_file, \"w\") as f:\n",
75+
" for example in subset_data:\n",
76+
" f.write(json.dumps(example) + \"\\n\")\n",
77+
"\n",
78+
"print(f\"Subset saved to {output_file}\")"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": null,
84+
"metadata": {},
85+
"outputs": [
86+
{
87+
"name": "stdout",
88+
"output_type": "stream",
89+
"text": [
90+
"⚙️ Training Hyperparameters\n",
91+
"==================================================\n"
92+
]
93+
}
94+
],
95+
"source": [
96+
"params = {\n",
97+
" ###########################################################################\n",
98+
" # 🤖 Model + Data Paths #\n",
99+
" ###########################################################################\n",
100+
" \"model_path\": \"Qwen/Qwen2.5-1.5B-Instruct\",\n",
101+
" \"data_path\": \"/opt/app-root/src/table-gpt-data/train/train_All_100.jsonl\",\n",
102+
" \"ckpt_output_dir\": \"/opt/app-root/src/checkpoints-logs-dir\",\n",
103+
" \"data_output_path\": \"/opt/app-root/src/osft-json/_data\",\n",
104+
" ############################################################################\n",
105+
" # 🏋️‍♀️ Training Hyperparameters #\n",
106+
" ############################################################################\n",
107+
" # Important for OSFT\n",
108+
" \"unfreeze_rank_ratio\": 0.25,\n",
109+
" # Standard parameters\n",
110+
" \"effective_batch_size\": 128,\n",
111+
" \"learning_rate\": 5.0e-6,\n",
112+
" \"num_epochs\": 1,\n",
113+
" \"lr_scheduler\": \"cosine\",\n",
114+
" \"warmup_steps\": 0,\n",
115+
" \"seed\": 42,\n",
116+
" ###########################################################################\n",
117+
" # 🏎️ Performance Hyperparameters #\n",
118+
" ###########################################################################\n",
119+
" \"use_liger\": True,\n",
120+
" \"max_tokens_per_gpu\": 32000,\n",
121+
" \"max_seq_len\": 2048,\n",
122+
" ############################################################################\n",
123+
" # 💾 Checkpointing Settings #\n",
124+
" ############################################################################\n",
125+
" # Here we only want to save the very last checkpoint\n",
126+
" \"save_final_checkpoint\": True,\n",
127+
" \"checkpoint_at_epoch\": False,\n",
128+
" # \"nproc_per_node\": 2,\n",
129+
" # \"nnodes\": 2,\n",
130+
" # Please note that the distributed training parameters are removed because they are\n",
131+
" # delegated to Kubeflow Trainer\n",
132+
"}\n"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"metadata": {},
139+
"outputs": [],
140+
"source": [
141+
"from kubeflow.trainer import TrainerClient\n",
142+
"from kubeflow.trainer.rhai import TrainingHubAlgorithms\n",
143+
"from kubeflow.trainer.rhai import TrainingHubTrainer\n",
144+
"from kubeflow_trainer_api import models\n",
145+
"from kubeflow.common.types import KubernetesBackendConfig\n",
146+
"\n",
147+
"backend_cfg = KubernetesBackendConfig(\n",
148+
" client_configuration=api_client.configuration, # <— key part\n",
149+
")\n",
150+
"\n",
151+
"client = TrainerClient(backend_cfg)\n",
152+
"print(client)"
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": null,
158+
"metadata": {},
159+
"outputs": [],
160+
"source": [
161+
"th_runtime = None\n",
162+
"for runtime in client.list_runtimes():\n",
163+
" if runtime.name == \"training-hub-2node-1gpu\":\n",
164+
" th_runtime = runtime\n",
165+
" print(\"Found runtime: \" + str(th_runtime))\n",
166+
" break\n",
167+
"\n",
168+
"if th_runtime is None:\n",
169+
" raise RuntimeError(\"Required runtime 'training-hub-2node-1gpu' not found\")"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"\n",
179+
"from kubeflow.trainer.options.kubernetes import (\n",
180+
" PodTemplateOverrides,\n",
181+
" PodTemplateOverride,\n",
182+
" PodSpecOverride,\n",
183+
" ContainerOverride,\n",
184+
")\n",
185+
"\n",
186+
"cache_root = \"/opt/app-root/src/.cache/huggingface\"\n",
187+
"triton_cache = \"/opt/app-root/src/.triton\"\n",
188+
"\n",
189+
"job_name = client.train(\n",
190+
" trainer=TrainingHubTrainer(\n",
191+
" algorithm=TrainingHubAlgorithms.OSFT,\n",
192+
" func_args=params,\n",
193+
" env={ \n",
194+
" \"HF_HOME\": cache_root,\n",
195+
" \"TRITON_CACHE_DIR\": triton_cache,\n",
196+
" \"XDG_CACHE_HOME\": \"/opt/app-root/src/.cache\",\n",
197+
" \"NCCL_DEBUG\": \"INFO\",\n",
198+
" },\n",
199+
" ),\n",
200+
" options=[\n",
201+
" PodTemplateOverrides(\n",
202+
" PodTemplateOverride(\n",
203+
" target_jobs=[\"node\"],\n",
204+
" spec=PodSpecOverride(\n",
205+
" volumes=[\n",
206+
" {\"name\": \"work\", \"persistentVolumeClaim\": {\"claimName\": PVC_NAME}},\n",
207+
" ],\n",
208+
" containers=[\n",
209+
" ContainerOverride(\n",
210+
" name=\"node\", \n",
211+
" volume_mounts=[\n",
212+
" {\"name\": \"work\", \"mountPath\": \"/opt/app-root/src\", \"readOnly\": False},\n",
213+
" ],\n",
214+
" )\n",
215+
" ],\n",
216+
" ),\n",
217+
" )\n",
218+
" )\n",
219+
" ],\n",
220+
" runtime=th_runtime,\n",
221+
")"
222+
]
223+
},
224+
{
225+
"cell_type": "code",
226+
"execution_count": null,
227+
"metadata": {},
228+
"outputs": [],
229+
"source": [
230+
"# Wait for the running status, then completion.\n",
231+
"client.wait_for_job_status(name=job_name, status={\"Running\"}, timeout=300)\n",
232+
"client.wait_for_job_status(name=job_name, status={\"Complete\"}, timeout=600)"
233+
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": null,
238+
"metadata": {},
239+
"outputs": [],
240+
"source": [
241+
"for c in client.get_job(name=job_name).steps:\n",
242+
" print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")"
243+
]
244+
},
245+
{
246+
"cell_type": "code",
247+
"execution_count": null,
248+
"metadata": {},
249+
"outputs": [],
250+
"source": [
251+
"for logline in client.get_job_logs(job_name, follow=False):\n",
252+
" print(logline)"
253+
]
254+
},
255+
{
256+
"cell_type": "code",
257+
"execution_count": null,
258+
"metadata": {},
259+
"outputs": [],
260+
"source": [
261+
"client.delete_job(job_name)"
262+
]
263+
}
264+
],
265+
"metadata": {
266+
"kernelspec": {
267+
"display_name": "Python 3.12",
268+
"language": "python",
269+
"name": "python3"
270+
},
271+
"language_info": {
272+
"codemirror_mode": {
273+
"name": "ipython",
274+
"version": 3
275+
},
276+
"file_extension": ".py",
277+
"mimetype": "text/x-python",
278+
"name": "python",
279+
"nbconvert_exporter": "python",
280+
"pygments_lexer": "ipython3",
281+
"version": "3.12.9"
282+
}
283+
},
284+
"nbformat": 4,
285+
"nbformat_minor": 4
286+
}

0 commit comments

Comments
 (0)