Skip to content

Commit 9a952fe

Browse files
authored
Merge pull request #1594 from stanfordnlp/dev_finetune
Refactor finetuning implementation to be 2.5 compatible
2 parents 23f9e9e + 9f040a8 commit 9a952fe

File tree

10 files changed

+1146
-22
lines changed

10 files changed

+1146
-22
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def parse(self, signature, completion, _parse_values=True):
6767

6868
return fields
6969

70+
def format_turn(self, signature, values, role, incomplete=False):
71+
return format_turn(signature, values, role, incomplete)
72+
7073

7174
def format_blob(blob):
7275
if "\n" not in blob and "«" not in blob and "»" not in blob:

dspy/clients/anyscale.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
from typing import Any, Dict, List, Optional
2+
import json
3+
import yaml
4+
import os
5+
6+
from dspy.utils.logging import logger
7+
from dspy.clients.finetune import (
8+
FinetuneJob,
9+
TrainingMethod,
10+
save_data,
11+
)
12+
from dspy.clients.openai import openai_data_validation
13+
14+
try:
15+
# AnyScale fine-tuning requires the following additional imports
16+
import anyscale
17+
from anyscale.job import JobConfig
18+
except ImportError:
19+
anyscale = None
20+
21+
22+
# List of training methods supported by AnyScale
23+
TRAINING_METHODS_ANYSCALE = [
24+
TrainingMethod.SFT,
25+
]
26+
27+
PROVIDER_ANYSCALE = "anyscale"
28+
29+
30+
def is_anyscale_model(model: str) -> bool:
31+
"""Check if the model is an AnyScale model."""
32+
# TODO: This needs to be implemented to support fine-tuning
33+
logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py")
34+
return False
35+
36+
37+
class FinetuneJobAnyScale(FinetuneJob):
38+
39+
def __init__(self, *args, **kwargs):
40+
self.job_id = None
41+
self.model_names = None
42+
super().__init__(*args, **kwargs)
43+
44+
45+
def finetune_anyscale(
46+
job: FinetuneJobAnyScale,
47+
model: str,
48+
train_data: List[Dict[str, Any]],
49+
train_kwargs: Optional[Dict[str, Any]]=None,
50+
train_method: TrainingMethod = TrainingMethod.SFT,
51+
) -> str:
52+
"""Start the finetune job."""
53+
train_kwargs = train_kwargs or {}
54+
assert "model" not in train_kwargs, "Model should not be in the train_kwargs"
55+
train_kwargs_copy = train_kwargs.copy()
56+
train_kwargs_copy["model"] = model
57+
58+
logger.info("[Finetune] Starting training process...")
59+
if train_method not in TRAINING_METHODS_ANYSCALE:
60+
raise NotImplementedError(f"AnyScale can only support {TRAINING_METHODS_ANYSCALE} for the time being")
61+
62+
logger.info("[Finetune] Validating the dataset format...")
63+
if not verify_dataset(train_data):
64+
# TODO: Does AnyScale support text completion models?
65+
err = "[Finetune] Error: Unable to verify that the dataset is in the correct format."
66+
logger.error(err)
67+
raise RuntimeError(err)
68+
69+
logger.info("[Finetune] Converting data to JSONL format...")
70+
train_data_path = save_data(train_data, provider_name=PROVIDER_ANYSCALE)
71+
logger.info("[Finetune] Submitting data to remote storage...")
72+
remote_train_path, _ = submit_data(train_path=train_data_path)
73+
logger.info(f"[Finetune] Data submitted. Remote train path: {remote_train_path}")
74+
75+
logger.info("[Finetune] Generating configuration files...")
76+
_, compute_config = generate_config_files(train_path=remote_train_path, **train_kwargs_copy)
77+
78+
logger.info("[Finetune] Starting remote training...")
79+
job_id = start_remote_training(compute_config=compute_config, **train_kwargs_copy)
80+
job.job_id = job_id
81+
logger.info(f"[Finetune] Remote training started. Job ID: {job_id}")
82+
83+
logger.info("[Finetune] Waiting for training to complete...")
84+
wait_for_training(job.job_id)
85+
logger.info("[Finetune] Training completed.")
86+
87+
logger.info("[Finetune] Retrieving model information...")
88+
model_info = get_model_info(job.job_id)
89+
logger.info(f"[Finetune] Model info retrieved: {model_info}")
90+
91+
storage_uri = model_info["storage_uri"]
92+
logger.info(f"[Finetune] Copying LoRA weights from {storage_uri}...")
93+
model_names, lora_dynamic_path = copy_lora_weights(storage_uri, model_info, job.job_id)
94+
logger.info(f"[Finetune] LoRA weights copied. Model names: {model_names}")
95+
96+
97+
logger.info("[Finetune] Setting result in future object...")
98+
model_step_pairs = sorted([(model_name, int(model_name.split("-")[-1])) for model_name in model_names], key=lambda x: x[1])
99+
last_model_checkpoint = model_step_pairs[-1][0]
100+
logger.info("[Finetune] Training process completed successfully.")
101+
102+
logger.info("[Finetune] Updating model config with the proper dynamic path")
103+
serve_config_path = train_kwargs.pop("serve_config_path", "serve_1B.yaml")
104+
update_model_config(lora_dynamic_path, serve_config_path, job_id)
105+
job.model_names = model_names
106+
107+
return last_model_checkpoint
108+
109+
def wait_for_training(job_id):
110+
"""Wait for the training to complete."""
111+
anyscale.job.wait(id=job_id, timeout_s=18000)
112+
113+
114+
def update_model_config(lora_dynamic_path: str, serve_config_path: str, job_id: str):
115+
"""Update the model config storage location with the job_id."""
116+
with open(serve_config_path, "r") as f:
117+
serve_config = yaml.safe_load(f)
118+
119+
model_config_location = serve_config["applications"][0]["args"]["llm_configs"][0]
120+
121+
with open(model_config_location, "r") as f:
122+
model_config = yaml.safe_load(f)
123+
124+
dynamic_path_until_job_id = lora_dynamic_path.split(job_id)[0] + job_id
125+
model_config["lora_config"]["dynamic_lora_loading_path"] = dynamic_path_until_job_id
126+
127+
with open(model_config_location, "w") as f:
128+
yaml.safe_dump(model_config, f)
129+
130+
131+
def verify_dataset(dataset: List[dict[str, Any]]) -> bool:
132+
"""Verify the training arguments before starting training."""
133+
dataset_validation = openai_data_validation(dataset)
134+
135+
if dataset_validation:
136+
logger.error(f"Dataset validation failed: {dataset_validation}")
137+
return False
138+
139+
return True
140+
141+
142+
def submit_data(train_path: str):
143+
"""Upload the data to the Workspace cloud storage."""
144+
storage = os.environ['ANYSCALE_ARTIFACT_STORAGE']
145+
146+
datasets = {"train": train_path}
147+
148+
fine_tuning_file_ids = {}
149+
for name, path in datasets.items():
150+
num_items = len(read_jsonl(path))
151+
logger.info(f"Number of items in {name} data: {num_items}")
152+
153+
remote_path = os.path.join(storage, path.split("/")[-1])
154+
logger.info(f"Uploading {name} data to S3 at {remote_path}")
155+
if remote_path[:2] == "s3":
156+
os.system(f"aws s3 cp {path} {remote_path}")
157+
elif remote_path[:2] == "gs":
158+
os.system(f"gcloud storage cp {path} {remote_path}")
159+
else:
160+
os.system(f"cp {path} {remote_path}")
161+
logger.info(f"Copied {path} to {remote_path}")
162+
fine_tuning_file_ids[name] = remote_path
163+
164+
return fine_tuning_file_ids["train"], fine_tuning_file_ids.get("val", None)
165+
166+
167+
def generate_config_files(train_path: str, **kwargs):
168+
base_model_yaml_path = kwargs.get("train_config_yaml", None)
169+
assert kwargs["model"] is not None, "Model is required to generate the config files"
170+
171+
use_lora = kwargs.get("use_lora", False)
172+
example_dir = ""
173+
lora_path = "configs/training/lora" if use_lora else "configs/training/full_param"
174+
175+
176+
if not base_model_yaml_path:
177+
def get_yaml_config(model_name):
178+
if "llama" in model_name.lower():
179+
if "70b" in model_name:
180+
return "llama-3-70b.yaml"
181+
elif "13b" in model_name:
182+
return "llama-3-70b.yaml"
183+
else:
184+
return "llama-3-8b.yaml"
185+
elif "mistral" in model_name.lower():
186+
if "mixtral" in model_name.lower():
187+
return "mixtral-8x7b.yaml"
188+
else:
189+
return "mistral-7b.yaml"
190+
else:
191+
raise RuntimeError("No default yaml found for the model")
192+
193+
default_model_yaml_path = get_yaml_config(kwargs["model"])
194+
base_model_yaml_path = os.path.join(example_dir, lora_path, default_model_yaml_path)
195+
logger.info(f"Using default yaml template for model: {base_model_yaml_path}")
196+
197+
model_config_data = yaml.safe_load(open(base_model_yaml_path, "r"))
198+
model_config_data.update(kwargs.get("hyperparameters", {}))
199+
200+
model_config_data["model_id"] = kwargs["model"]
201+
202+
custom_modifications = {
203+
"model_id": kwargs["model"],
204+
"train_path": train_path,
205+
"logger": {
206+
"provider": "wandb",
207+
},
208+
"num_checkpoints_to_keep": 10
209+
}
210+
if kwargs.get("output_dir", None):
211+
custom_modifications["output_dir"] = kwargs["output_dir"]
212+
213+
model_config_data.update(custom_modifications)
214+
model_config_data = {k: v for k, v in model_config_data.items() if v is not None}
215+
216+
def freeze(d):
217+
if isinstance(d, dict):
218+
return tuple(sorted((key, freeze(value)) for key, value in d.items()))
219+
elif isinstance(d, list):
220+
return tuple(freeze(value) for value in sorted(d))
221+
elif isinstance(d, set):
222+
return tuple(freeze(value) for value in sorted(d))
223+
return d
224+
225+
def hash_dict(d):
226+
return hash(freeze(d))
227+
dict_sorted_hash = hash_dict(model_config_data)
228+
if dict_sorted_hash < 0:
229+
dict_sorted_hash = -dict_sorted_hash
230+
filename = f"model_config_dspy_{dict_sorted_hash}.yaml"
231+
logger.info(f"Model config data: {model_config_data}")
232+
yaml.safe_dump(model_config_data, open(filename, "w"))
233+
234+
ft_path = os.path.join("utils", "ft.py")
235+
236+
compute_config_dict = {
237+
"name": "dspy-llmforge-fine-tuning-job",
238+
"entrypoint": f"llmforge anyscale finetune {filename}",
239+
"working_dir": ".",
240+
"image_uri": "localhost:5555/anyscale/llm-forge:0.5.6",
241+
"requirements": [
242+
"wandb",
243+
],
244+
"env_vars": {
245+
"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", ""),
246+
"HF_TOKEN": os.environ.get("HF_TOKEN", ""),
247+
"HF_HOME": os.environ.get("HF_HOME", ""),
248+
}
249+
}
250+
compute_config_kwargs = kwargs.get("compute_config", {})
251+
compute_config_dict.update(compute_config_kwargs)
252+
compute_config = JobConfig(**compute_config_dict)
253+
254+
job_runner_config_path = kwargs.get("compute_yaml_path", "job_runner_config.yaml")
255+
256+
return job_runner_config_path, compute_config
257+
258+
259+
def start_remote_training(compute_config, **kwargs) -> str:
260+
job_id: str = anyscale.job.submit(compute_config)
261+
return job_id
262+
263+
264+
def wait_for_training(job_id):
265+
logger.info("Waiting for training to complete")
266+
anyscale.job.wait(id=job_id, timeout_s=18000)
267+
268+
269+
def get_model_info(job_id):
270+
return anyscale.llm.model.get(job_id=job_id).to_dict()
271+
272+
273+
def copy_lora_weights(storage_uri, model_info, job_id):
274+
try:
275+
from google.cloud import storage
276+
277+
storage_client = storage.Client()
278+
279+
bucket_name = storage_uri.split('/')[2]
280+
source_folder = '/'.join(storage_uri.split('/')[3:-1])
281+
logger.info(f"Source folder: {source_folder}")
282+
283+
bucket = storage_client.bucket(bucket_name)
284+
285+
blobs = bucket.list_blobs(prefix=source_folder)
286+
287+
subfolders = set()
288+
for blob in blobs:
289+
if '/' in blob.name[len(source_folder):]:
290+
subfolder = blob.name.split('/')[:-1]
291+
subfolders.add('/'.join(subfolder))
292+
293+
base_model_id = model_info["base_model_id"]
294+
lora_dynamic_path = f"dspy/lora_weights/{job_id}/{base_model_id}"
295+
296+
model_names = []
297+
for subfolder in subfolders:
298+
subfolder_name = subfolder.split('/')[-1]
299+
destination_folder = f"{lora_dynamic_path}:{subfolder_name}"
300+
if subfolder_name.startswith("epoch"):
301+
model_names.append("/".join(destination_folder.split("/")[-2:]))
302+
else:
303+
continue
304+
305+
subfolder_blobs = bucket.list_blobs(prefix=subfolder)
306+
307+
for blob in subfolder_blobs:
308+
source_blob = bucket.blob(blob.name)
309+
destination_blob_name = f"{destination_folder}/{blob.name.split('/')[-1]}"
310+
bucket.copy_blob(source_blob, bucket, destination_blob_name)
311+
logger.info(f"Copied {source_blob.name} to {destination_blob_name}")
312+
313+
logger.info(f"All subfolders copied to: gs://{bucket_name}/{lora_dynamic_path}")
314+
completed_path = f"gs://{bucket_name}/{lora_dynamic_path}"
315+
return model_names, completed_path
316+
317+
except Exception as e:
318+
logger.error(f"An error occurred: {str(e)}")
319+
raise e
320+
321+
322+
def read_jsonl(filename):
323+
with open(filename, "r") as f:
324+
return [json.loads(line) for line in f]

0 commit comments

Comments
 (0)