88import ujson
99
1010from dspy .clients .provider import Provider , TrainingJob
11- from dspy .clients .utils_finetune import DataFormat , get_finetune_directory
11+ from dspy .clients .utils_finetune import TrainDataFormat , get_finetune_directory
1212
1313if TYPE_CHECKING :
1414 from databricks .sdk import WorkspaceClient
@@ -50,7 +50,7 @@ def is_provider_model(model: str) -> bool:
5050 @staticmethod
5151 def deploy_finetuned_model (
5252 model : str ,
53- data_format : Optional [DataFormat ] = None ,
53+ data_format : Optional [TrainDataFormat ] = None ,
5454 databricks_host : Optional [str ] = None ,
5555 databricks_token : Optional [str ] = None ,
5656 deploy_timeout : int = 900 ,
@@ -148,11 +148,11 @@ def deploy_finetuned_model(
148148 num_retries = deploy_timeout // 60
149149 for _ in range (num_retries ):
150150 try :
151- if data_format == DataFormat . chat :
151+ if data_format == TrainDataFormat . CHAT :
152152 client .chat .completions .create (
153153 messages = [{"role" : "user" , "content" : "hi" }], model = model_name , max_tokens = 1
154154 )
155- elif data_format == DataFormat . completion :
155+ elif data_format == TrainDataFormat . COMPLETION :
156156 client .completions .create (prompt = "hi" , model = model_name , max_tokens = 1 )
157157 logger .info (f"Databricks model serving endpoint { model_name } is ready!" )
158158 return
@@ -169,17 +169,17 @@ def finetune(
169169 job : TrainingJobDatabricks ,
170170 model : str ,
171171 train_data : List [Dict [str , Any ]],
172+ train_data_format : Optional [Union [TrainDataFormat , str ]] = "chat" ,
172173 train_kwargs : Optional [Dict [str , Any ]] = None ,
173- data_format : Optional [Union [DataFormat , str ]] = "chat" ,
174174 ) -> str :
175175 if isinstance (data_format , str ):
176176 if data_format == "chat" :
177- data_format = DataFormat . chat
177+ data_format = TrainDataFormat . CHAT
178178 elif data_format == "completion" :
179- data_format = DataFormat . completion
179+ data_format = TrainDataFormat . COMPLETION
180180 else :
181181 raise ValueError (
182- f"String `data_format ` must be one of 'chat' or 'completion', but received: { data_format } ."
182+ f"String `train_data_format ` must be one of 'chat' or 'completion', but received: { data_format } ."
183183 )
184184
185185 if "train_data_path" not in train_kwargs :
@@ -243,7 +243,7 @@ def finetune(
243243 return f"databricks/{ job .endpoint_name } "
244244
245245 @staticmethod
246- def upload_data (train_data : List [Dict [str , Any ]], databricks_unity_catalog_path : str , data_format : DataFormat ):
246+ def upload_data (train_data : List [Dict [str , Any ]], databricks_unity_catalog_path : str , data_format : TrainDataFormat ):
247247 logger .info ("Uploading finetuning data to Databricks Unity Catalog..." )
248248 file_path = _save_data_to_local_file (train_data , data_format )
249249
@@ -303,7 +303,7 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric
303303 logger .info (f"Successfully created directory { databricks_unity_catalog_path } in Databricks Unity Catalog!" )
304304
305305
306- def _save_data_to_local_file (train_data : List [Dict [str , Any ]], data_format : DataFormat ):
306+ def _save_data_to_local_file (train_data : List [Dict [str , Any ]], data_format : TrainDataFormat ):
307307 import uuid
308308
309309 file_name = f"finetuning_{ uuid .uuid4 ()} .jsonl"
@@ -313,9 +313,9 @@ def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: Data
313313 file_path = os .path .abspath (file_path )
314314 with open (file_path , "w" ) as f :
315315 for item in train_data :
316- if data_format == DataFormat . chat :
316+ if data_format == TrainDataFormat . CHAT :
317317 _validate_chat_data (item )
318- elif data_format == DataFormat . completion :
318+ elif data_format == TrainDataFormat . COMPLETION :
319319 _validate_completion_data (item )
320320
321321 f .write (ujson .dumps (item ) + "\n " )
0 commit comments