|
| 1 | +from typing import Any, Dict, List, Optional |
| 2 | +import json |
| 3 | +import yaml |
| 4 | +import os |
| 5 | + |
| 6 | +from dspy.utils.logging import logger |
| 7 | +from dspy.clients.finetune import ( |
| 8 | + FinetuneJob, |
| 9 | + TrainingMethod, |
| 10 | + save_data, |
| 11 | +) |
| 12 | +from dspy.clients.openai import openai_data_validation |
| 13 | + |
| 14 | +try: |
| 15 | + # AnyScale fine-tuning requires the following additional imports |
| 16 | + import anyscale |
| 17 | + from anyscale.job import JobConfig |
| 18 | +except ImportError: |
| 19 | + anyscale = None |
| 20 | + |
| 21 | + |
| 22 | +# List of training methods supported by AnyScale |
| 23 | +TRAINING_METHODS_ANYSCALE = [ |
| 24 | + TrainingMethod.SFT, |
| 25 | +] |
| 26 | + |
| 27 | +PROVIDER_ANYSCALE = "anyscale" |
| 28 | + |
| 29 | + |
| 30 | +def is_anyscale_model(model: str) -> bool: |
| 31 | + """Check if the model is an AnyScale model.""" |
| 32 | + # TODO: This needs to be implemented to support fine-tuning |
| 33 | + logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py") |
| 34 | + return False |
| 35 | + |
| 36 | + |
| 37 | +class FinetuneJobAnyScale(FinetuneJob): |
| 38 | + |
| 39 | + def __init__(self, *args, **kwargs): |
| 40 | + self.job_id = None |
| 41 | + self.model_names = None |
| 42 | + super().__init__(*args, **kwargs) |
| 43 | + |
| 44 | + |
| 45 | +def finetune_anyscale( |
| 46 | + job: FinetuneJobAnyScale, |
| 47 | + model: str, |
| 48 | + train_data: List[Dict[str, Any]], |
| 49 | + train_kwargs: Optional[Dict[str, Any]]=None, |
| 50 | + train_method: TrainingMethod = TrainingMethod.SFT, |
| 51 | + ) -> str: |
| 52 | + """Start the finetune job.""" |
| 53 | + train_kwargs = train_kwargs or {} |
| 54 | + assert "model" not in train_kwargs, "Model should not be in the train_kwargs" |
| 55 | + train_kwargs_copy = train_kwargs.copy() |
| 56 | + train_kwargs_copy["model"] = model |
| 57 | + |
| 58 | + logger.info("[Finetune] Starting training process...") |
| 59 | + if train_method not in TRAINING_METHODS_ANYSCALE: |
| 60 | + raise NotImplementedError(f"AnyScale can only support {TRAINING_METHODS_ANYSCALE} for the time being") |
| 61 | + |
| 62 | + logger.info("[Finetune] Validating the dataset format...") |
| 63 | + if not verify_dataset(train_data): |
| 64 | + # TODO: Does AnyScale support text completion models? |
| 65 | + err = "[Finetune] Error: Unable to verify that the dataset is in the correct format." |
| 66 | + logger.error(err) |
| 67 | + raise RuntimeError(err) |
| 68 | + |
| 69 | + logger.info("[Finetune] Converting data to JSONL format...") |
| 70 | + train_data_path = save_data(train_data, provider_name=PROVIDER_ANYSCALE) |
| 71 | + logger.info("[Finetune] Submitting data to remote storage...") |
| 72 | + remote_train_path, _ = submit_data(train_path=train_data_path) |
| 73 | + logger.info(f"[Finetune] Data submitted. Remote train path: {remote_train_path}") |
| 74 | + |
| 75 | + logger.info("[Finetune] Generating configuration files...") |
| 76 | + _, compute_config = generate_config_files(train_path=remote_train_path, **train_kwargs_copy) |
| 77 | + |
| 78 | + logger.info("[Finetune] Starting remote training...") |
| 79 | + job_id = start_remote_training(compute_config=compute_config, **train_kwargs_copy) |
| 80 | + job.job_id = job_id |
| 81 | + logger.info(f"[Finetune] Remote training started. Job ID: {job_id}") |
| 82 | + |
| 83 | + logger.info("[Finetune] Waiting for training to complete...") |
| 84 | + wait_for_training(job.job_id) |
| 85 | + logger.info("[Finetune] Training completed.") |
| 86 | + |
| 87 | + logger.info("[Finetune] Retrieving model information...") |
| 88 | + model_info = get_model_info(job.job_id) |
| 89 | + logger.info(f"[Finetune] Model info retrieved: {model_info}") |
| 90 | + |
| 91 | + storage_uri = model_info["storage_uri"] |
| 92 | + logger.info(f"[Finetune] Copying LoRA weights from {storage_uri}...") |
| 93 | + model_names, lora_dynamic_path = copy_lora_weights(storage_uri, model_info, job.job_id) |
| 94 | + logger.info(f"[Finetune] LoRA weights copied. Model names: {model_names}") |
| 95 | + |
| 96 | + |
| 97 | + logger.info("[Finetune] Setting result in future object...") |
| 98 | + model_step_pairs = sorted([(model_name, int(model_name.split("-")[-1])) for model_name in model_names], key=lambda x: x[1]) |
| 99 | + last_model_checkpoint = model_step_pairs[-1][0] |
| 100 | + logger.info("[Finetune] Training process completed successfully.") |
| 101 | + |
| 102 | + logger.info("[Finetune] Updating model config with the proper dynamic path") |
| 103 | + serve_config_path = train_kwargs.pop("serve_config_path", "serve_1B.yaml") |
| 104 | + update_model_config(lora_dynamic_path, serve_config_path, job_id) |
| 105 | + job.model_names = model_names |
| 106 | + |
| 107 | + return last_model_checkpoint |
| 108 | + |
| 109 | +def wait_for_training(job_id): |
| 110 | + """Wait for the training to complete.""" |
| 111 | + anyscale.job.wait(id=job_id, timeout_s=18000) |
| 112 | + |
| 113 | + |
| 114 | +def update_model_config(lora_dynamic_path: str, serve_config_path: str, job_id: str): |
| 115 | + """Update the model config storage location with the job_id.""" |
| 116 | + with open(serve_config_path, "r") as f: |
| 117 | + serve_config = yaml.safe_load(f) |
| 118 | + |
| 119 | + model_config_location = serve_config["applications"][0]["args"]["llm_configs"][0] |
| 120 | + |
| 121 | + with open(model_config_location, "r") as f: |
| 122 | + model_config = yaml.safe_load(f) |
| 123 | + |
| 124 | + dynamic_path_until_job_id = lora_dynamic_path.split(job_id)[0] + job_id |
| 125 | + model_config["lora_config"]["dynamic_lora_loading_path"] = dynamic_path_until_job_id |
| 126 | + |
| 127 | + with open(model_config_location, "w") as f: |
| 128 | + yaml.safe_dump(model_config, f) |
| 129 | + |
| 130 | + |
| 131 | +def verify_dataset(dataset: List[dict[str, Any]]) -> bool: |
| 132 | + """Verify the training arguments before starting training.""" |
| 133 | + dataset_validation = openai_data_validation(dataset) |
| 134 | + |
| 135 | + if dataset_validation: |
| 136 | + logger.error(f"Dataset validation failed: {dataset_validation}") |
| 137 | + return False |
| 138 | + |
| 139 | + return True |
| 140 | + |
| 141 | + |
| 142 | +def submit_data(train_path: str): |
| 143 | + """Upload the data to the Workspace cloud storage.""" |
| 144 | + storage = os.environ['ANYSCALE_ARTIFACT_STORAGE'] |
| 145 | + |
| 146 | + datasets = {"train": train_path} |
| 147 | + |
| 148 | + fine_tuning_file_ids = {} |
| 149 | + for name, path in datasets.items(): |
| 150 | + num_items = len(read_jsonl(path)) |
| 151 | + logger.info(f"Number of items in {name} data: {num_items}") |
| 152 | + |
| 153 | + remote_path = os.path.join(storage, path.split("/")[-1]) |
| 154 | + logger.info(f"Uploading {name} data to S3 at {remote_path}") |
| 155 | + if remote_path[:2] == "s3": |
| 156 | + os.system(f"aws s3 cp {path} {remote_path}") |
| 157 | + elif remote_path[:2] == "gs": |
| 158 | + os.system(f"gcloud storage cp {path} {remote_path}") |
| 159 | + else: |
| 160 | + os.system(f"cp {path} {remote_path}") |
| 161 | + logger.info(f"Copied {path} to {remote_path}") |
| 162 | + fine_tuning_file_ids[name] = remote_path |
| 163 | + |
| 164 | + return fine_tuning_file_ids["train"], fine_tuning_file_ids.get("val", None) |
| 165 | + |
| 166 | + |
| 167 | +def generate_config_files(train_path: str, **kwargs): |
| 168 | + base_model_yaml_path = kwargs.get("train_config_yaml", None) |
| 169 | + assert kwargs["model"] is not None, "Model is required to generate the config files" |
| 170 | + |
| 171 | + use_lora = kwargs.get("use_lora", False) |
| 172 | + example_dir = "" |
| 173 | + lora_path = "configs/training/lora" if use_lora else "configs/training/full_param" |
| 174 | + |
| 175 | + |
| 176 | + if not base_model_yaml_path: |
| 177 | + def get_yaml_config(model_name): |
| 178 | + if "llama" in model_name.lower(): |
| 179 | + if "70b" in model_name: |
| 180 | + return "llama-3-70b.yaml" |
| 181 | + elif "13b" in model_name: |
| 182 | + return "llama-3-70b.yaml" |
| 183 | + else: |
| 184 | + return "llama-3-8b.yaml" |
| 185 | + elif "mistral" in model_name.lower(): |
| 186 | + if "mixtral" in model_name.lower(): |
| 187 | + return "mixtral-8x7b.yaml" |
| 188 | + else: |
| 189 | + return "mistral-7b.yaml" |
| 190 | + else: |
| 191 | + raise RuntimeError("No default yaml found for the model") |
| 192 | + |
| 193 | + default_model_yaml_path = get_yaml_config(kwargs["model"]) |
| 194 | + base_model_yaml_path = os.path.join(example_dir, lora_path, default_model_yaml_path) |
| 195 | + logger.info(f"Using default yaml template for model: {base_model_yaml_path}") |
| 196 | + |
| 197 | + model_config_data = yaml.safe_load(open(base_model_yaml_path, "r")) |
| 198 | + model_config_data.update(kwargs.get("hyperparameters", {})) |
| 199 | + |
| 200 | + model_config_data["model_id"] = kwargs["model"] |
| 201 | + |
| 202 | + custom_modifications = { |
| 203 | + "model_id": kwargs["model"], |
| 204 | + "train_path": train_path, |
| 205 | + "logger": { |
| 206 | + "provider": "wandb", |
| 207 | + }, |
| 208 | + "num_checkpoints_to_keep": 10 |
| 209 | + } |
| 210 | + if kwargs.get("output_dir", None): |
| 211 | + custom_modifications["output_dir"] = kwargs["output_dir"] |
| 212 | + |
| 213 | + model_config_data.update(custom_modifications) |
| 214 | + model_config_data = {k: v for k, v in model_config_data.items() if v is not None} |
| 215 | + |
| 216 | + def freeze(d): |
| 217 | + if isinstance(d, dict): |
| 218 | + return tuple(sorted((key, freeze(value)) for key, value in d.items())) |
| 219 | + elif isinstance(d, list): |
| 220 | + return tuple(freeze(value) for value in sorted(d)) |
| 221 | + elif isinstance(d, set): |
| 222 | + return tuple(freeze(value) for value in sorted(d)) |
| 223 | + return d |
| 224 | + |
| 225 | + def hash_dict(d): |
| 226 | + return hash(freeze(d)) |
| 227 | + dict_sorted_hash = hash_dict(model_config_data) |
| 228 | + if dict_sorted_hash < 0: |
| 229 | + dict_sorted_hash = -dict_sorted_hash |
| 230 | + filename = f"model_config_dspy_{dict_sorted_hash}.yaml" |
| 231 | + logger.info(f"Model config data: {model_config_data}") |
| 232 | + yaml.safe_dump(model_config_data, open(filename, "w")) |
| 233 | + |
| 234 | + ft_path = os.path.join("utils", "ft.py") |
| 235 | + |
| 236 | + compute_config_dict = { |
| 237 | + "name": "dspy-llmforge-fine-tuning-job", |
| 238 | + "entrypoint": f"llmforge anyscale finetune {filename}", |
| 239 | + "working_dir": ".", |
| 240 | + "image_uri": "localhost:5555/anyscale/llm-forge:0.5.6", |
| 241 | + "requirements": [ |
| 242 | + "wandb", |
| 243 | + ], |
| 244 | + "env_vars": { |
| 245 | + "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", ""), |
| 246 | + "HF_TOKEN": os.environ.get("HF_TOKEN", ""), |
| 247 | + "HF_HOME": os.environ.get("HF_HOME", ""), |
| 248 | + } |
| 249 | + } |
| 250 | + compute_config_kwargs = kwargs.get("compute_config", {}) |
| 251 | + compute_config_dict.update(compute_config_kwargs) |
| 252 | + compute_config = JobConfig(**compute_config_dict) |
| 253 | + |
| 254 | + job_runner_config_path = kwargs.get("compute_yaml_path", "job_runner_config.yaml") |
| 255 | + |
| 256 | + return job_runner_config_path, compute_config |
| 257 | + |
| 258 | + |
| 259 | +def start_remote_training(compute_config, **kwargs) -> str: |
| 260 | + job_id: str = anyscale.job.submit(compute_config) |
| 261 | + return job_id |
| 262 | + |
| 263 | + |
| 264 | +def wait_for_training(job_id): |
| 265 | + logger.info("Waiting for training to complete") |
| 266 | + anyscale.job.wait(id=job_id, timeout_s=18000) |
| 267 | + |
| 268 | + |
| 269 | +def get_model_info(job_id): |
| 270 | + return anyscale.llm.model.get(job_id=job_id).to_dict() |
| 271 | + |
| 272 | + |
| 273 | +def copy_lora_weights(storage_uri, model_info, job_id): |
| 274 | + try: |
| 275 | + from google.cloud import storage |
| 276 | + |
| 277 | + storage_client = storage.Client() |
| 278 | + |
| 279 | + bucket_name = storage_uri.split('/')[2] |
| 280 | + source_folder = '/'.join(storage_uri.split('/')[3:-1]) |
| 281 | + logger.info(f"Source folder: {source_folder}") |
| 282 | + |
| 283 | + bucket = storage_client.bucket(bucket_name) |
| 284 | + |
| 285 | + blobs = bucket.list_blobs(prefix=source_folder) |
| 286 | + |
| 287 | + subfolders = set() |
| 288 | + for blob in blobs: |
| 289 | + if '/' in blob.name[len(source_folder):]: |
| 290 | + subfolder = blob.name.split('/')[:-1] |
| 291 | + subfolders.add('/'.join(subfolder)) |
| 292 | + |
| 293 | + base_model_id = model_info["base_model_id"] |
| 294 | + lora_dynamic_path = f"dspy/lora_weights/{job_id}/{base_model_id}" |
| 295 | + |
| 296 | + model_names = [] |
| 297 | + for subfolder in subfolders: |
| 298 | + subfolder_name = subfolder.split('/')[-1] |
| 299 | + destination_folder = f"{lora_dynamic_path}:{subfolder_name}" |
| 300 | + if subfolder_name.startswith("epoch"): |
| 301 | + model_names.append("/".join(destination_folder.split("/")[-2:])) |
| 302 | + else: |
| 303 | + continue |
| 304 | + |
| 305 | + subfolder_blobs = bucket.list_blobs(prefix=subfolder) |
| 306 | + |
| 307 | + for blob in subfolder_blobs: |
| 308 | + source_blob = bucket.blob(blob.name) |
| 309 | + destination_blob_name = f"{destination_folder}/{blob.name.split('/')[-1]}" |
| 310 | + bucket.copy_blob(source_blob, bucket, destination_blob_name) |
| 311 | + logger.info(f"Copied {source_blob.name} to {destination_blob_name}") |
| 312 | + |
| 313 | + logger.info(f"All subfolders copied to: gs://{bucket_name}/{lora_dynamic_path}") |
| 314 | + completed_path = f"gs://{bucket_name}/{lora_dynamic_path}" |
| 315 | + return model_names, completed_path |
| 316 | + |
| 317 | + except Exception as e: |
| 318 | + logger.error(f"An error occurred: {str(e)}") |
| 319 | + raise e |
| 320 | + |
| 321 | + |
| 322 | +def read_jsonl(filename): |
| 323 | + with open(filename, "r") as f: |
| 324 | + return [json.loads(line) for line in f] |
0 commit comments