Skip to content

Commit b109e72

Browse files
dilarasoyluokhat
andauthored
Finetune fix (#7629)
* Update local fine-tuning * Make train_data_format a required argument for finetune * Update finetune optimizers * Update finetune classification demo * Update launch check * Update lm_local.py * Revert notebook. Revert bootstrap_trace_data. * Add back launch_lms etc --------- Co-authored-by: Omar Khattab <[email protected]>
1 parent 10fe155 commit b109e72

File tree

8 files changed

+272
-166
lines changed

8 files changed

+272
-166
lines changed

dspy/clients/databricks.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import ujson
99

1010
from dspy.clients.provider import Provider, TrainingJob
11-
from dspy.clients.utils_finetune import DataFormat, get_finetune_directory
11+
from dspy.clients.utils_finetune import TrainDataFormat, get_finetune_directory
1212

1313
if TYPE_CHECKING:
1414
from databricks.sdk import WorkspaceClient
@@ -50,7 +50,7 @@ def is_provider_model(model: str) -> bool:
5050
@staticmethod
5151
def deploy_finetuned_model(
5252
model: str,
53-
data_format: Optional[DataFormat] = None,
53+
data_format: Optional[TrainDataFormat] = None,
5454
databricks_host: Optional[str] = None,
5555
databricks_token: Optional[str] = None,
5656
deploy_timeout: int = 900,
@@ -148,11 +148,11 @@ def deploy_finetuned_model(
148148
num_retries = deploy_timeout // 60
149149
for _ in range(num_retries):
150150
try:
151-
if data_format == DataFormat.chat:
151+
if data_format == TrainDataFormat.CHAT:
152152
client.chat.completions.create(
153153
messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1
154154
)
155-
elif data_format == DataFormat.completion:
155+
elif data_format == TrainDataFormat.COMPLETION:
156156
client.completions.create(prompt="hi", model=model_name, max_tokens=1)
157157
logger.info(f"Databricks model serving endpoint {model_name} is ready!")
158158
return
@@ -169,17 +169,17 @@ def finetune(
169169
job: TrainingJobDatabricks,
170170
model: str,
171171
train_data: List[Dict[str, Any]],
172+
train_data_format: Optional[Union[TrainDataFormat, str]] = "chat",
172173
train_kwargs: Optional[Dict[str, Any]] = None,
173-
data_format: Optional[Union[DataFormat, str]] = "chat",
174174
) -> str:
175175
if isinstance(data_format, str):
176176
if data_format == "chat":
177-
data_format = DataFormat.chat
177+
data_format = TrainDataFormat.CHAT
178178
elif data_format == "completion":
179-
data_format = DataFormat.completion
179+
data_format = TrainDataFormat.COMPLETION
180180
else:
181181
raise ValueError(
182-
f"String `data_format` must be one of 'chat' or 'completion', but received: {data_format}."
182+
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {data_format}."
183183
)
184184

185185
if "train_data_path" not in train_kwargs:
@@ -243,7 +243,7 @@ def finetune(
243243
return f"databricks/{job.endpoint_name}"
244244

245245
@staticmethod
246-
def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: DataFormat):
246+
def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: TrainDataFormat):
247247
logger.info("Uploading finetuning data to Databricks Unity Catalog...")
248248
file_path = _save_data_to_local_file(train_data, data_format)
249249

@@ -303,7 +303,7 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric
303303
logger.info(f"Successfully created directory {databricks_unity_catalog_path} in Databricks Unity Catalog!")
304304

305305

306-
def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: DataFormat):
306+
def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: TrainDataFormat):
307307
import uuid
308308

309309
file_name = f"finetuning_{uuid.uuid4()}.jsonl"
@@ -313,9 +313,9 @@ def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: Data
313313
file_path = os.path.abspath(file_path)
314314
with open(file_path, "w") as f:
315315
for item in train_data:
316-
if data_format == DataFormat.chat:
316+
if data_format == TrainDataFormat.CHAT:
317317
_validate_chat_data(item)
318-
elif data_format == DataFormat.completion:
318+
elif data_format == TrainDataFormat.COMPLETION:
319319
_validate_completion_data(item)
320320

321321
f.write(ujson.dumps(item) + "\n")

dspy/clients/lm.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dspy.adapters.base import Adapter
2121
from dspy.clients.openai import OpenAIProvider
2222
from dspy.clients.provider import Provider, TrainingJob
23-
from dspy.clients.utils_finetune import DataFormat, infer_data_format, validate_data_format
23+
from dspy.clients.utils_finetune import TrainDataFormat
2424
from dspy.utils.callback import BaseCallback, with_callbacks
2525

2626
from .base_lm import BaseLM
@@ -46,6 +46,7 @@ def __init__(
4646
provider=None,
4747
finetuning_model: Optional[str] = None,
4848
launch_kwargs: Optional[dict[str, Any]] = None,
49+
train_kwargs: Optional[dict[str, Any]] = None,
4950
**kwargs,
5051
):
5152
"""
@@ -79,7 +80,8 @@ def __init__(
7980
self.callbacks = callbacks or []
8081
self.num_retries = num_retries
8182
self.finetuning_model = finetuning_model
82-
self.launch_kwargs = launch_kwargs
83+
self.launch_kwargs = launch_kwargs or {}
84+
self.train_kwargs = train_kwargs or {}
8385

8486
# Handle model-specific configuration for different model families
8587
model_family = model.split("/")[-1].lower() if "/" in model else model.lower()
@@ -156,18 +158,16 @@ def __call__(self, prompt=None, messages=None, **kwargs):
156158
return outputs
157159

158160
def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
159-
launch_kwargs = launch_kwargs or self.launch_kwargs
160161
self.provider.launch(self, launch_kwargs)
161162

162163
def kill(self, launch_kwargs: Optional[Dict[str, Any]] = None):
163-
launch_kwargs = launch_kwargs or self.launch_kwargs
164164
self.provider.kill(self, launch_kwargs)
165165

166166
def finetune(
167167
self,
168168
train_data: List[Dict[str, Any]],
169+
train_data_format: Optional[TrainDataFormat],
169170
train_kwargs: Optional[Dict[str, Any]] = None,
170-
data_format: Optional[DataFormat] = None,
171171
) -> TrainingJob:
172172
from dspy import settings as settings
173173

@@ -178,27 +178,18 @@ def finetune(
178178
err = f"Provider {self.provider} does not support fine-tuning."
179179
assert self.provider.finetunable, err
180180

181-
# Perform data validation before starting the thread to fail early
182-
train_kwargs = train_kwargs or {}
183-
if not data_format:
184-
adapter = self.infer_adapter()
185-
data_format = infer_data_format(adapter)
186-
validate_data_format(data=train_data, data_format=data_format)
187-
188-
# TODO(PR): We can quickly add caching, but doing so requires
189-
# adding functions that just call other functions as we had in the last
190-
# iteration, unless people have other ideas.
191181
def thread_function_wrapper():
192182
return self._run_finetune_job(job)
193183

194184
thread = threading.Thread(target=thread_function_wrapper)
195-
model_to_finetune = self.finetuning_model or self.model
185+
train_kwargs = train_kwargs or self.train_kwargs
186+
model_to_finetune = self.finetuning_model or self.model
196187
job = self.provider.TrainingJob(
197188
thread=thread,
198189
model=model_to_finetune,
199190
train_data=train_data,
191+
train_data_format=train_data_format,
200192
train_kwargs=train_kwargs,
201-
data_format=data_format,
202193
)
203194
thread.start()
204195

@@ -212,8 +203,8 @@ def _run_finetune_job(self, job: TrainingJob):
212203
job=job,
213204
model=job.model,
214205
train_data=job.train_data,
206+
train_data_format=job.train_data_format,
215207
train_kwargs=job.train_kwargs,
216-
data_format=job.data_format,
217208
)
218209
lm = self.copy(model=model)
219210
job.set_result(lm)

0 commit comments

Comments
 (0)