1- from collections import defaultdict
21import re
32import time
3+ from collections import defaultdict
44from typing import Any , Dict , List , Optional , Union
55
66import openai
77
8- from dspy .utils .logging import logger
98from dspy .clients .finetune import (
109 FinetuneJob ,
1110 TrainingMethod ,
1211 TrainingStatus ,
13- validate_finetune_data ,
1412 save_data ,
13+ validate_finetune_data ,
1514)
15+ from dspy .utils .logging import logger
1616
1717# Provider name
1818PROVIDER_OPENAI = "openai"
1919
20- # List of model IDs
21- _MODEL_IDS = [
22- "gpt-4o" ,
23- "gpt-4o-2024-08-06" ,
24- "gpt-4o-2024-05-13" ,
25- "chatgpt-4o-latest" ,
26- "gpt-4o-mini" ,
27- "gpt-4o-mini-2024-07-18" ,
28- "gpt-4o-realtime-preview" ,
29- "gpt-4o-realtime-preview-2024-10-01" ,
30- "o1-preview" ,
31- "o1-preview-2024-09-12" ,
32- "o1-mini" ,
33- "o1-mini-2024-09-12" ,
34- "gpt-4-turbo" ,
35- "gpt-4-turbo-2024-04-09" ,
36- "gpt-4-turbo-preview" ,
37- "gpt-4-0125-preview" ,
38- "gpt-4-1106-preview" ,
39- "gpt-4" ,
40- "gpt-4-0613" ,
41- "gpt-4-0314" ,
42- "gpt-3.5-turbo-0125" ,
43- "gpt-3.5-turbo" ,
44- "gpt-3.5-turbo-1106" ,
45- "gpt-3.5-turbo-instruct" ,
46- "dall-e-3" ,
47- "dall-e-2" ,
48- "tts-1" ,
49- "tts-1-hd" ,
50- "text-embedding-3-large" ,
51- "text-embedding-3-small" ,
52- "text-embedding-ada-002" ,
53- "omni-moderation-latest" ,
54- "omni-moderation-2024-09-26" ,
55- "text-moderation-latest" ,
56- "text-moderation-stable" ,
57- "text-moderation-007" ,
58- "babbage-002" ,
59- "davinci-002"
60- ]
61-
6220
6321def is_openai_model (model : str ) -> bool :
6422 """Check if the model is an OpenAI model."""
6523 # Filter the provider_prefix, if exists
6624 provider_prefix = f"{ PROVIDER_OPENAI } /"
6725 if model .startswith (provider_prefix ):
68- model = model [len (provider_prefix ):]
26+ model = model [len (provider_prefix ) :]
6927
28+ client = openai .OpenAI ()
29+ valid_model_names = [model .id for model in client .models .list ().data ]
7030 # Check if the model is a base OpenAI model
71- if model in _MODEL_IDS :
31+ if model in valid_model_names :
7232 return True
7333
7434 # Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI models
@@ -77,15 +37,15 @@ def is_openai_model(model: str) -> bool:
7737 # base model name.
7838 # TODO: This part can be updated to match the actual fine-tuned model names
7939 # by making a call to the OpenAI API to be more exact, but this might
80- # require an API key with the right permissions.
40+ # require an API key with the right permissions.
8141 match = re .match (r"ft:([^:]+):" , model )
82- if match and match .group (1 ) in _MODEL_IDS :
42+ if match and match .group (1 ) in valid_model_names :
8343 return True
8444
8545 return False
8646
87- class FinetuneJobOpenAI (FinetuneJob ):
8847
48+ class FinetuneJobOpenAI (FinetuneJob ):
8949 def __init__ (self , * args , ** kwargs ):
9050 self .provider_file_id = None # TODO: Can we get this using the job_id?
9151 self .provider_job_id = None
@@ -118,12 +78,12 @@ def status(self) -> TrainingStatus:
11878
11979
12080def finetune_openai (
121- job : FinetuneJobOpenAI ,
122- model : str ,
123- train_data : List [Dict [str , Any ]],
124- train_kwargs : Optional [Dict [str , Any ]]= None ,
125- train_method : TrainingMethod = TrainingMethod .SFT ,
126- ) -> str :
81+ job : FinetuneJobOpenAI ,
82+ model : str ,
83+ train_data : List [Dict [str , Any ]],
84+ train_kwargs : Optional [Dict [str , Any ]] = None ,
85+ train_method : TrainingMethod = TrainingMethod .SFT ,
86+ ) -> str :
12787 train_kwargs = train_kwargs or {}
12888 train_method = TrainingMethod .SFT # Note: This could be an argument; ignoring method
12989
@@ -171,10 +131,12 @@ def finetune_openai(
171131
172132 return model
173133
134+
174135_SUPPORTED_TRAINING_METHODS = [
175136 TrainingMethod .SFT ,
176137]
177138
139+
178140def _get_training_status (job_id : str ) -> Union [TrainingStatus , Exception ]:
179141 # TODO: Should this type be shared across all fine-tune clients?
180142 provider_status_to_training_status = {
@@ -228,10 +190,7 @@ def _is_terminal_training_status(status: TrainingStatus) -> bool:
228190 ]
229191
230192
231- def _validate_data (
232- data : Dict [str , str ],
233- train_method : TrainingMethod
234- ) -> Optional [Exception ]:
193+ def _validate_data (data : Dict [str , str ], train_method : TrainingMethod ) -> Optional [Exception ]:
235194 # Check if this train method is supported
236195 if train_method not in _SUPPORTED_TRAINING_METHODS :
237196 err_msg = f"OpenAI does not support the training method { train_method } ."
@@ -241,20 +200,17 @@ def _validate_data(
241200
242201
243202def _convert_data (
244- data : List [Dict [str , str ]],
245- system_prompt : Optional [str ]= None ,
246- ) -> Union [List [Dict [str , Any ]], Exception ]:
203+ data : List [Dict [str , str ]],
204+ system_prompt : Optional [str ] = None ,
205+ ) -> Union [List [Dict [str , Any ]], Exception ]:
247206 # Item-wise conversion function
248207 def _row_converter (d ):
249- messages = [
250- {"role" : "user" , "content" : d ["prompt" ]},
251- {"role" : "assistant" , "content" : d ["completion" ]}
252- ]
208+ messages = [{"role" : "user" , "content" : d ["prompt" ]}, {"role" : "assistant" , "content" : d ["completion" ]}]
253209 if system_prompt :
254210 messages .insert (0 , {"role" : "system" , "content" : system_prompt })
255211 messages_dict = {"messages" : messages }
256212 return messages_dict
257-
213+
258214 # Convert the data to the OpenAI format; validate the converted data
259215 converted_data = list (map (_row_converter , data ))
260216 openai_data_validation (converted_data )
@@ -270,11 +226,7 @@ def _upload_data(data_path: str) -> str:
270226 return provider_file .id
271227
272228
273- def _start_remote_training (
274- train_file_id : str ,
275- model : id ,
276- train_kwargs : Optional [Dict [str , Any ]]= None
277- ) -> str :
229+ def _start_remote_training (train_file_id : str , model : id , train_kwargs : Optional [Dict [str , Any ]] = None ) -> str :
278230 train_kwargs = train_kwargs or {}
279231 provider_job = openai .fine_tuning .jobs .create (
280232 model = model ,
@@ -286,7 +238,7 @@ def _start_remote_training(
286238
287239def _wait_for_job (
288240 job : FinetuneJobOpenAI ,
289- poll_frequency : int = 60 ,
241+ poll_frequency : int = 60 ,
290242):
291243 while not _is_terminal_training_status (job .status ()):
292244 time .sleep (poll_frequency )
@@ -304,6 +256,7 @@ def _get_trained_model(job):
304256 finetuned_model = provider_job .fine_tuned_model
305257 return finetuned_model
306258
259+
307260# Adapted from https://cookbook.openai.com/examples/chat_finetuning_data_prep
308261def openai_data_validation (dataset : List [dict [str , Any ]]):
309262 format_errors = defaultdict (int )
@@ -364,7 +317,9 @@ def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]:
364317 n_too_long = sum ([length > 16385 for length in convo_lens ])
365318
366319 if n_too_long > 0 :
367- logger .info (f"There are { n_too_long } examples that may be over the 16,385 token limit, they will be truncated during fine-tuning." )
320+ logger .info (
321+ f"There are { n_too_long } examples that may be over the 16,385 token limit, they will be truncated during fine-tuning."
322+ )
368323
369324 if n_missing_system > 0 :
370325 logger .info (f"There are { n_missing_system } examples that are missing a system message." )
@@ -377,6 +332,7 @@ def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]:
377332
378333def num_tokens_from_messages (messages , tokens_per_message = 3 , tokens_per_name = 1 ):
379334 import tiktoken
335+
380336 encoding = tiktoken .get_encoding ("cl100k_base" )
381337
382338 num_tokens = 0
@@ -392,6 +348,7 @@ def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
392348
393349def num_assistant_tokens_from_messages (messages ):
394350 import tiktoken
351+
395352 encoding = tiktoken .get_encoding ("cl100k_base" )
396353
397354 num_tokens = 0
0 commit comments