Skip to content

Commit 1e539b0

Browse files
dilarasoyluokhat
andauthored
Dev finetune update (#1698)
* Adapter updates * Client updates * Add provider * Add cache dir utils * Add BootstrapFinetune * Add BetterTogether draft * Prepare PR * Add AnyScale changes -- ruff * Teporarily remove BetterTogether * Remove BetterTogether import * Add comment * Replace OpenAI client call with library call * Add OpenAI models list to check valid models * Temporarily switch to print * Prepare BootstrapFinetune for BetterTogether * Add BetterTogether * Add dev notebook * Revamp ChainOfThoughtWithHint and adjust max auto valset of MIPROv2 * fix * ruff fixes * disable cot_hint tests * unsafe ruff fixes --------- Co-authored-by: Omar Khattab <[email protected]>
1 parent 1df86fa commit 1e539b0

28 files changed

+2656
-1057
lines changed

dsp/utils/settings_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main():
7373
futures = {executor.submit(thread_wrapper, sample_program, parent_tid, arg) for arg in range(3)}
7474

7575
for future in as_completed(futures):
76-
res = future.result()
76+
future.result()
7777

7878
print(f"Main thread {parent_tid} config after threads: {dsp_settings._get_current_config()}")
7979

dspy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
BootstrapFewShot = dspy.teleprompt.BootstrapFewShot
6868
BootstrapFewShotWithRandomSearch = dspy.teleprompt.BootstrapFewShotWithRandomSearch
6969
BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch
70+
BootstrapFinetune = dspy.teleprompt.BootstrapFinetune
71+
BetterTogether = dspy.teleprompt.BetterTogether
7072
COPRO = dspy.teleprompt.COPRO
7173
MIPROv2 = dspy.teleprompt.MIPROv2
7274
Ensemble = dspy.teleprompt.Ensemble

dspy/adapters/base.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,8 @@
1-
import abc
2-
from dspy.utils.callback import with_callbacks
3-
4-
class Adapter:
5-
@abc.abstractmethod
6-
def format(self, signature, demos, inputs):
7-
"""
8-
Format the input data for the LLM.
9-
"""
1+
from abc import ABC, abstractmethod
102

11-
@abc.abstractmethod
12-
def parse(self, signature, completion):
13-
"""
14-
Parse the output data from the LLM.
15-
"""
3+
from dspy.utils.callback import with_callbacks
164

5+
class Adapter(ABC):
176
def __init__(self, callbacks=None):
187
self.callbacks = callbacks or []
198

@@ -31,7 +20,6 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
3120
outputs = lm(**inputs_, **lm_kwargs)
3221
values = []
3322

34-
3523
try:
3624
for output in outputs:
3725
value = self.parse(signature, output, _parse_values=_parse_values)
@@ -45,3 +33,13 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
4533
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values)
4634
raise e
4735

36+
@abstractmethod
37+
def format(self, signature, demos, inputs):
38+
raise NotImplementedError
39+
40+
@abstractmethod
41+
def parse(self, signature, completion, _parse_values):
42+
raise NotImplementedError
43+
44+
def format_finetune_data(self, signature, demos, inputs, outputs):
45+
raise NotImplementedError

dspy/adapters/chat_adapter.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ class FieldInfoWithName(NamedTuple):
3434
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())
3535

3636
class ChatAdapter(Adapter):
37-
"""
38-
ChatAdapter is used to format and parse data for chat-based LLMs.
39-
"""
40-
4137
def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
4238
messages: list[dict[str, Any]] = []
4339

@@ -90,6 +86,19 @@ def parse(self, signature, completion, _parse_values=True):
9086

9187
return fields
9288

89+
# TODO(PR): Looks ok?
90+
def format_finetune_data(self, signature, demos, inputs, outputs):
91+
# Get system + user messages
92+
messages = self.format(signature, demos, inputs)
93+
94+
# Add the assistant message
95+
role = "assistant"
96+
incomplete = False
97+
assistant_message = format_turn(signature, outputs, role, incomplete)
98+
messages.append(assistant_message)
99+
100+
# Wrap the messages in a dictionary with a "messages" key
101+
return dict(messages=messages)
93102
def format_turn(self, signature, values, role, incomplete=False):
94103
return format_turn(signature, values, role, incomplete)
95104

dspy/clients/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .lm import LM
2+
from .provider import Provider, TrainingJob
23
from .base_lm import BaseLM, inspect_history
34
from .embedding import Embedding
45
import litellm

dspy/clients/anyscale.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dspy.clients.finetune import (
99
FinetuneJob,
10-
TrainingMethod,
10+
# TrainingMethod,
1111
save_data,
1212
)
1313
from dspy.clients.openai import openai_data_validation
@@ -32,7 +32,7 @@
3232
def is_anyscale_model(model: str) -> bool:
3333
"""Check if the model is an AnyScale model."""
3434
# TODO: This needs to be implemented to support fine-tuning
35-
logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py")
35+
print("Is AnyScale model is not implemented, returning False as a default to not break lm.py")
3636
return False
3737

3838

@@ -103,9 +103,9 @@ def finetune_anyscale(
103103

104104
def wait_for_training(job_id):
105105
"""Wait for the training to complete."""
106-
logger.info("[Finetune] Waiting for training to complete...")
106+
print("[Finetune] Waiting for training to complete...")
107107
anyscale.job.wait(id=job_id)
108-
logger.info("[Finetune] Training completed.")
108+
print("[Finetune] Training completed.")
109109

110110

111111
def update_serve_model_config(lora_dynamic_path: str, serve_config_path: str):
@@ -126,7 +126,7 @@ def update_serve_model_config(lora_dynamic_path: str, serve_config_path: str):
126126

127127
def verify_dataset(dataset: List[dict[str, Any]]) -> bool:
128128
"""Verify the training arguments before starting training."""
129-
logger.info("[Finetune] Verifying dataset...")
129+
print("[Finetune] Verifying dataset...")
130130
dataset_validation = openai_data_validation(dataset)
131131

132132
if dataset_validation:
@@ -138,11 +138,11 @@ def verify_dataset(dataset: List[dict[str, Any]]) -> bool:
138138

139139
def submit_data(train_path: str, job_config: Dict[str, Any]):
140140
"""Upload the data to cloud storage."""
141-
logger.info("[Finetune] Submitting data to remote storage...")
141+
print("[Finetune] Submitting data to remote storage...")
142142
dataset_suffix = os.path.basename(train_path).split(".")[0]
143143
dataset_name = f"dataset-{job_config.get('name', dataset_suffix)}"
144144
train_path_remote = anyscale.llm.dataset.upload(train_path, name=dataset_name, cloud=job_config.get("cloud", None)).storage_uri
145-
logger.info(f"[Finetune] Data submitted. Remote train path: {train_path_remote}")
145+
print(f"[Finetune] Data submitted. Remote train path: {train_path_remote}")
146146

147147
return train_path_remote
148148

@@ -158,7 +158,7 @@ def generate_config_files(train_path: str, llmforge_config_path: str, job_config
158158
llmforge_config["train_path"] = train_path
159159
llmforge_config = {k: v for k, v in llmforge_config.items() if v is not None}
160160

161-
logger.info(f"Model config data: {llmforge_config}")
161+
print(f"Model config data: {llmforge_config}")
162162
yaml.safe_dump(llmforge_config, open(llmforge_config_path, "w"))
163163

164164
if not job_config_dict.get("env_vars", None):
@@ -176,21 +176,21 @@ def generate_config_files(train_path: str, llmforge_config_path: str, job_config
176176

177177

178178
def start_remote_training(job_config) -> str:
179-
logger.info("[Finetune] Starting remote training...")
179+
print("[Finetune] Starting remote training...")
180180
job_id: str = anyscale.job.submit(job_config)
181-
logger.info(f"[Finetune] Remote training started. Job ID: {job_id}")
181+
print(f"[Finetune] Remote training started. Job ID: {job_id}")
182182
return job_id
183183

184184

185185
def wait_for_training(job_id):
186-
logger.info("Waiting for training to complete")
186+
print("Waiting for training to complete")
187187
anyscale.job.wait(id=job_id, timeout_s=18000)
188188

189189

190190
def get_model_info(job_id):
191-
logger.info("[Finetune] Retrieving model information from Anyscale Models SDK...")
191+
print("[Finetune] Retrieving model information from Anyscale Models SDK...")
192192
info = anyscale.llm.model.get(job_id=job_id).to_dict()
193-
logger.info(f"[Finetune] Model info retrieved: {info}")
193+
print(f"[Finetune] Model info retrieved: {info}")
194194
return info
195195

196196
def read_jsonl(filename):

dspy/clients/finetune.py

Lines changed: 0 additions & 132 deletions
This file was deleted.

0 commit comments

Comments
 (0)