Skip to content

Commit 3881eb7

Browse files
lint fixes
1 parent 1eb95ce commit 3881eb7

File tree

5 files changed

+111
-165
lines changed

5 files changed

+111
-165
lines changed

dspy/clients/finetune.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import os
12
from abc import abstractmethod
23
from concurrent.futures import Future
34
from enum import Enum
4-
import os
55
from pathlib import Path
6-
from typing import List, Dict, Any, Optional
7-
import ujson
6+
from typing import Any, Dict, List, Optional
87

8+
import ujson
99
from datasets.fingerprint import Hasher
1010

1111

@@ -14,7 +14,7 @@ def get_finetune_directory() -> str:
1414
# TODO: Move to a centralized location with all the other env variables
1515
dspy_cachedir = os.environ.get("DSPY_CACHEDIR")
1616
dspy_cachedir = dspy_cachedir or os.path.join(Path.home(), ".dspy_cache")
17-
finetune_dir = os.path.join(dspy_cachedir, 'finetune')
17+
finetune_dir = os.path.join(dspy_cachedir, "finetune")
1818
finetune_dir = os.path.abspath(finetune_dir)
1919
return finetune_dir
2020

@@ -24,17 +24,19 @@ def get_finetune_directory() -> str:
2424

2525
class TrainingMethod(str, Enum):
2626
"""Enum class for training methods.
27-
27+
2828
When comparing enums, Python checks for object IDs, which means that the
2929
enums can't be compared directly. Subclassing the Enum class along with the
3030
str class allows for direct comparison of the enums.
3131
"""
32+
3233
SFT = "SFT"
3334
Preference = "Preference"
3435

3536

3637
class TrainingStatus(str, Enum):
3738
"""Enum class for remote training status."""
39+
3840
not_started = "not_started"
3941
pending = "pending"
4042
running = "running"
@@ -49,12 +51,13 @@ class TrainingStatus(str, Enum):
4951
TrainingMethod.Preference: ["prompt", "chosen", "rejected"],
5052
}
5153

52-
class FinetuneJob(Future):
5354

54-
def __init__(self,
55+
class FinetuneJob(Future):
56+
def __init__(
57+
self,
5558
model: str,
5659
train_data: List[Dict[str, Any]],
57-
train_kwargs: Optional[Dict[str, Any]]=None,
60+
train_kwargs: Optional[Dict[str, Any]] = None,
5861
train_method: TrainingMethod = TrainingMethod.SFT,
5962
provider: str = "openai",
6063
):
@@ -64,7 +67,7 @@ def __init__(self,
6467
self.train_method = train_method
6568
self.provider = provider
6669
super().__init__()
67-
70+
6871
def get_kwargs(self):
6972
return dict(
7073
model=self.model,
@@ -89,26 +92,24 @@ def status(self):
8992
raise NotImplementedError("Method `status` is not implemented.")
9093

9194

92-
def validate_finetune_data(
93-
data: List[Dict[str, Any]],
94-
train_method: TrainingMethod
95-
) -> Optional[AssertionError]:
95+
def validate_finetune_data(data: List[Dict[str, Any]], train_method: TrainingMethod) -> Optional[AssertionError]:
9696
"""Validate the finetune data based on the training method."""
9797
# Get the required data keys for the training method
9898
required_keys = TRAINING_METHOD_TO_DATA_KEYS[train_method]
9999

100100
# Check if the training data has the required keys
101101
for ind, data_dict in enumerate(data):
102-
err_msg = f"The datapoint at index {ind} is missing the keys required for {train_method} training."
103-
err_msg = f"\n Expected: {required_keys}"
104-
err_msg = f"\n Found: {data_dict.keys()}"
105-
assert all([key in data_dict for key in required_keys]), err_msg
102+
if not all([key in data_dict for key in required_keys]):
103+
raise ValueError(
104+
f"The datapoint at index {ind} is missing the keys required for {train_method} training. Expected: "
105+
f"{required_keys}, Found: {data_dict.keys()}"
106+
)
106107

107108

108109
def save_data(
109-
data: List[Dict[str, Any]],
110-
provider_name: Optional[str]=None,
111-
) -> str:
110+
data: List[Dict[str, Any]],
111+
provider_name: Optional[str] = None,
112+
) -> str:
112113
"""Save the fine-tuning data to a file."""
113114
# Construct the file name based on the data hash
114115
hash = Hasher.hash(data)

dspy/clients/lm_finetune_utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from typing import Any, Dict, List, Optional, Type, Union
22

3-
from dspy.utils.logging import logger
3+
from dspy.clients.anyscale import FinetuneJobAnyScale, finetune_anyscale
44
from dspy.clients.finetune import FinetuneJob, TrainingMethod
55
from dspy.clients.openai import FinetuneJobOpenAI, finetune_openai
6-
from dspy.clients.anyscale import FinetuneJobAnyScale, finetune_anyscale
7-
6+
from dspy.utils.logging import logger
87

98
_PROVIDER_ANYSCALE = "anyscale"
109
_PROVIDER_OPENAI = "openai"
@@ -30,11 +29,7 @@ def get_provider_finetune_function(provider: str) -> callable:
3029

3130
# Note: Type of LM should be LM. We aren't importing it here to avoid
3231
# circular imports.
33-
def execute_finetune_job(
34-
job: FinetuneJob,
35-
lm: Any,
36-
cache_finetune: bool=True
37-
):
32+
def execute_finetune_job(job: FinetuneJob, lm: Any, cache_finetune: bool = True):
3833
"""Execute the finetune job in a blocking manner."""
3934
try:
4035
job_kwargs = job.get_kwargs()
@@ -54,7 +49,7 @@ def cached_finetune(
5449
job,
5550
model: str,
5651
train_data: List[Dict[str, Any]],
57-
train_kwargs: Optional[Dict[str, Any]]=None,
52+
train_kwargs: Optional[Dict[str, Any]] = None,
5853
train_method: TrainingMethod = TrainingMethod.SFT,
5954
provider: str = "openai",
6055
) -> Union[str, Exception]:
@@ -72,7 +67,7 @@ def finetune(
7267
job,
7368
model: str,
7469
train_data: List[Dict[str, Any]],
75-
train_kwargs: Optional[Dict[str, Any]]=None,
70+
train_kwargs: Optional[Dict[str, Any]] = None,
7671
train_method: TrainingMethod = TrainingMethod.SFT,
7772
provider: str = "openai",
7873
) -> Union[str, Exception]:

dspy/clients/openai.py

Lines changed: 32 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,34 @@
1-
from collections import defaultdict
21
import re
32
import time
3+
from collections import defaultdict
44
from typing import Any, Dict, List, Optional, Union
55

66
import openai
77

8-
from dspy.utils.logging import logger
98
from dspy.clients.finetune import (
109
FinetuneJob,
1110
TrainingMethod,
1211
TrainingStatus,
13-
validate_finetune_data,
1412
save_data,
13+
validate_finetune_data,
1514
)
15+
from dspy.utils.logging import logger
1616

1717
# Provider name
1818
PROVIDER_OPENAI = "openai"
1919

20-
# List of model IDs
21-
_MODEL_IDS = [
22-
"gpt-4o",
23-
"gpt-4o-2024-08-06",
24-
"gpt-4o-2024-05-13",
25-
"chatgpt-4o-latest",
26-
"gpt-4o-mini",
27-
"gpt-4o-mini-2024-07-18",
28-
"gpt-4o-realtime-preview",
29-
"gpt-4o-realtime-preview-2024-10-01",
30-
"o1-preview",
31-
"o1-preview-2024-09-12",
32-
"o1-mini",
33-
"o1-mini-2024-09-12",
34-
"gpt-4-turbo",
35-
"gpt-4-turbo-2024-04-09",
36-
"gpt-4-turbo-preview",
37-
"gpt-4-0125-preview",
38-
"gpt-4-1106-preview",
39-
"gpt-4",
40-
"gpt-4-0613",
41-
"gpt-4-0314",
42-
"gpt-3.5-turbo-0125",
43-
"gpt-3.5-turbo",
44-
"gpt-3.5-turbo-1106",
45-
"gpt-3.5-turbo-instruct",
46-
"dall-e-3",
47-
"dall-e-2",
48-
"tts-1",
49-
"tts-1-hd",
50-
"text-embedding-3-large",
51-
"text-embedding-3-small",
52-
"text-embedding-ada-002",
53-
"omni-moderation-latest",
54-
"omni-moderation-2024-09-26",
55-
"text-moderation-latest",
56-
"text-moderation-stable",
57-
"text-moderation-007",
58-
"babbage-002",
59-
"davinci-002"
60-
]
61-
6220

6321
def is_openai_model(model: str) -> bool:
6422
"""Check if the model is an OpenAI model."""
6523
# Filter the provider_prefix, if exists
6624
provider_prefix = f"{PROVIDER_OPENAI}/"
6725
if model.startswith(provider_prefix):
68-
model = model[len(provider_prefix):]
26+
model = model[len(provider_prefix) :]
6927

28+
client = openai.OpenAI()
29+
valid_model_names = [model.id for model in client.models.list().data]
7030
# Check if the model is a base OpenAI model
71-
if model in _MODEL_IDS:
31+
if model in valid_model_names:
7232
return True
7333

7434
# Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI models
@@ -77,15 +37,15 @@ def is_openai_model(model: str) -> bool:
7737
# base model name.
7838
# TODO: This part can be updated to match the actual fine-tuned model names
7939
# by making a call to the OpenAI API to be more exact, but this might
80-
# require an API key with the right permissions.
40+
# require an API key with the right permissions.
8141
match = re.match(r"ft:([^:]+):", model)
82-
if match and match.group(1) in _MODEL_IDS:
42+
if match and match.group(1) in valid_model_names:
8343
return True
8444

8545
return False
8646

87-
class FinetuneJobOpenAI(FinetuneJob):
8847

48+
class FinetuneJobOpenAI(FinetuneJob):
8949
def __init__(self, *args, **kwargs):
9050
self.provider_file_id = None # TODO: Can we get this using the job_id?
9151
self.provider_job_id = None
@@ -118,12 +78,12 @@ def status(self) -> TrainingStatus:
11878

11979

12080
def finetune_openai(
121-
job: FinetuneJobOpenAI,
122-
model: str,
123-
train_data: List[Dict[str, Any]],
124-
train_kwargs: Optional[Dict[str, Any]]=None,
125-
train_method: TrainingMethod = TrainingMethod.SFT,
126-
) -> str:
81+
job: FinetuneJobOpenAI,
82+
model: str,
83+
train_data: List[Dict[str, Any]],
84+
train_kwargs: Optional[Dict[str, Any]] = None,
85+
train_method: TrainingMethod = TrainingMethod.SFT,
86+
) -> str:
12787
train_kwargs = train_kwargs or {}
12888
train_method = TrainingMethod.SFT # Note: This could be an argument; ignoring method
12989

@@ -171,10 +131,12 @@ def finetune_openai(
171131

172132
return model
173133

134+
174135
_SUPPORTED_TRAINING_METHODS = [
175136
TrainingMethod.SFT,
176137
]
177138

139+
178140
def _get_training_status(job_id: str) -> Union[TrainingStatus, Exception]:
179141
# TODO: Should this type be shared across all fine-tune clients?
180142
provider_status_to_training_status = {
@@ -228,10 +190,7 @@ def _is_terminal_training_status(status: TrainingStatus) -> bool:
228190
]
229191

230192

231-
def _validate_data(
232-
data: Dict[str, str],
233-
train_method: TrainingMethod
234-
) -> Optional[Exception]:
193+
def _validate_data(data: Dict[str, str], train_method: TrainingMethod) -> Optional[Exception]:
235194
# Check if this train method is supported
236195
if train_method not in _SUPPORTED_TRAINING_METHODS:
237196
err_msg = f"OpenAI does not support the training method {train_method}."
@@ -241,20 +200,17 @@ def _validate_data(
241200

242201

243202
def _convert_data(
244-
data: List[Dict[str, str]],
245-
system_prompt: Optional[str]=None,
246-
) -> Union[List[Dict[str, Any]], Exception]:
203+
data: List[Dict[str, str]],
204+
system_prompt: Optional[str] = None,
205+
) -> Union[List[Dict[str, Any]], Exception]:
247206
# Item-wise conversion function
248207
def _row_converter(d):
249-
messages = [
250-
{"role": "user", "content": d["prompt"]},
251-
{"role": "assistant", "content": d["completion"]}
252-
]
208+
messages = [{"role": "user", "content": d["prompt"]}, {"role": "assistant", "content": d["completion"]}]
253209
if system_prompt:
254210
messages.insert(0, {"role": "system", "content": system_prompt})
255211
messages_dict = {"messages": messages}
256212
return messages_dict
257-
213+
258214
# Convert the data to the OpenAI format; validate the converted data
259215
converted_data = list(map(_row_converter, data))
260216
openai_data_validation(converted_data)
@@ -270,11 +226,7 @@ def _upload_data(data_path: str) -> str:
270226
return provider_file.id
271227

272228

273-
def _start_remote_training(
274-
train_file_id: str,
275-
model: id,
276-
train_kwargs: Optional[Dict[str, Any]]=None
277-
) -> str:
229+
def _start_remote_training(train_file_id: str, model: id, train_kwargs: Optional[Dict[str, Any]] = None) -> str:
278230
train_kwargs = train_kwargs or {}
279231
provider_job = openai.fine_tuning.jobs.create(
280232
model=model,
@@ -286,7 +238,7 @@ def _start_remote_training(
286238

287239
def _wait_for_job(
288240
job: FinetuneJobOpenAI,
289-
poll_frequency: int=60,
241+
poll_frequency: int = 60,
290242
):
291243
while not _is_terminal_training_status(job.status()):
292244
time.sleep(poll_frequency)
@@ -304,6 +256,7 @@ def _get_trained_model(job):
304256
finetuned_model = provider_job.fine_tuned_model
305257
return finetuned_model
306258

259+
307260
# Adapted from https://cookbook.openai.com/examples/chat_finetuning_data_prep
308261
def openai_data_validation(dataset: List[dict[str, Any]]):
309262
format_errors = defaultdict(int)
@@ -364,7 +317,9 @@ def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]:
364317
n_too_long = sum([length > 16385 for length in convo_lens])
365318

366319
if n_too_long > 0:
367-
logger.info(f"There are {n_too_long} examples that may be over the 16,385 token limit, they will be truncated during fine-tuning.")
320+
logger.info(
321+
f"There are {n_too_long} examples that may be over the 16,385 token limit, they will be truncated during fine-tuning."
322+
)
368323

369324
if n_missing_system > 0:
370325
logger.info(f"There are {n_missing_system} examples that are missing a system message.")
@@ -377,6 +332,7 @@ def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]:
377332

378333
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
379334
import tiktoken
335+
380336
encoding = tiktoken.get_encoding("cl100k_base")
381337

382338
num_tokens = 0
@@ -392,6 +348,7 @@ def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
392348

393349
def num_assistant_tokens_from_messages(messages):
394350
import tiktoken
351+
395352
encoding = tiktoken.get_encoding("cl100k_base")
396353

397354
num_tokens = 0

0 commit comments

Comments
 (0)