Skip to content

Commit 6ecc92d

Browse files
simplify dspy.LM (#7940)
1 parent ef32f66 commit 6ecc92d

File tree

3 files changed

+37
-54
lines changed

3 files changed

+37
-54
lines changed

dspy/clients/base_lm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@ def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, c
1616
def __call__(self, prompt=None, messages=None, **kwargs):
1717
pass
1818

19+
def copy(self, **kwargs):
20+
"""Returns a copy of the language model with possibly updated parameters."""
21+
22+
import copy
23+
24+
new_instance = copy.deepcopy(self)
25+
new_instance.history = []
26+
27+
for key, value in kwargs.items():
28+
if hasattr(self, key):
29+
setattr(new_instance, key, value)
30+
if (key in self.kwargs) or (not hasattr(self, key)):
31+
new_instance.kwargs[key] = value
32+
33+
return new_instance
34+
1935
def inspect_history(self, n: int = 1):
2036
_inspect_history(self.history, n)
2137

dspy/clients/lm.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import uuid
77
from datetime import datetime
88
from hashlib import sha256
9-
from typing import Any, Dict, List, Literal, Optional, cast, TYPE_CHECKING
9+
from typing import Any, Dict, List, Literal, Optional, cast
1010

1111
import litellm
1212
import pydantic
@@ -21,8 +21,6 @@
2121
from dspy.clients.provider import Provider, TrainingJob
2222
from dspy.clients.utils_finetune import TrainDataFormat
2323
from dspy.utils.callback import BaseCallback, with_callbacks
24-
if TYPE_CHECKING:
25-
from dspy.adapters.base import Adapter
2624

2725
from .base_lm import BaseLM
2826

@@ -142,17 +140,20 @@ def __call__(self, prompt=None, messages=None, **kwargs):
142140

143141
# Logging, with removed api key & where `cost` is None on cache hit.
144142
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
145-
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
146-
entry = dict(**entry, outputs=outputs, usage=dict(response["usage"]))
147-
entry = dict(**entry, cost=response.get("_hidden_params", {}).get("response_cost"))
148-
entry = dict(
149-
**entry,
150-
timestamp=datetime.now().isoformat(),
151-
uuid=str(uuid.uuid4()),
152-
model=self.model,
153-
response_model=response["model"],
154-
model_type=self.model_type,
155-
)
143+
entry = {
144+
"prompt": prompt,
145+
"messages": messages,
146+
"kwargs": kwargs,
147+
"response": response,
148+
"outputs": outputs,
149+
"usage": dict(response["usage"]),
150+
"cost": response.get("_hidden_params", {}).get("response_cost"),
151+
"timestamp": datetime.now().isoformat(),
152+
"uuid": str(uuid.uuid4()),
153+
"model": self.model,
154+
"response_model": response["model"],
155+
"model_type": self.model_type,
156+
}
156157
self.history.append(entry)
157158
self.update_global_history(entry)
158159

@@ -216,38 +217,8 @@ def _run_finetune_job(self, job: TrainingJob):
216217
def infer_provider(self) -> Provider:
217218
if OpenAIProvider.is_provider_model(self.model):
218219
return OpenAIProvider()
219-
# TODO(PR): Keeping this function here will require us to import all
220-
# providers in this file. Is this okay?
221220
return Provider()
222221

223-
def infer_adapter(self) -> "Adapter":
224-
import dspy
225-
226-
if dspy.settings.adapter:
227-
return dspy.settings.adapter
228-
229-
model_type_to_adapter = {
230-
"chat": dspy.ChatAdapter(),
231-
}
232-
model_type = self.model_type
233-
return model_type_to_adapter[model_type]
234-
235-
def copy(self, **kwargs):
236-
"""Returns a copy of the language model with possibly updated parameters."""
237-
238-
import copy
239-
240-
new_instance = copy.deepcopy(self)
241-
new_instance.history = []
242-
243-
for key, value in kwargs.items():
244-
if hasattr(self, key):
245-
setattr(new_instance, key, value)
246-
if (key in self.kwargs) or (not hasattr(self, key)):
247-
new_instance.kwargs[key] = value
248-
249-
return new_instance
250-
251222
def dump_state(self):
252223
state_keys = ["model", "model_type", "cache", "cache_in_memory", "num_retries", "finetuning_model", "launch_kwargs", "train_kwargs"]
253224
return { key: getattr(self, key) for key in state_keys } | self.kwargs

dspy/teleprompt/bootstrap_finetune.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
import dspy
66
from dspy.adapters.base import Adapter
7+
from dspy.adapters.chat_adapter import ChatAdapter
78
from dspy.clients.lm import LM
89
from dspy.clients.utils_finetune import infer_data_format
10+
from dspy.dsp.utils.settings import settings
911
from dspy.evaluate.evaluate import Evaluate
1012
from dspy.predict.predict import Predict
1113
from dspy.primitives.example import Example
@@ -160,7 +162,7 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
160162
logger.info(f"After filtering with the metric, {len(trace_data)} examples remain")
161163

162164
data = []
163-
adapter = self.adapter[lm] or lm.infer_adapter()
165+
adapter = self.adapter[lm] or settings.adapter or ChatAdapter()
164166
data_format = infer_data_format(adapter)
165167
for item in trace_data:
166168
for pred_ind, _ in enumerate(item["trace"]):
@@ -181,18 +183,12 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
181183
def build_call_data_from_trace(
182184
trace: List[Dict],
183185
pred_ind: int,
184-
adapter: Optional[Adapter] = None,
186+
adapter: Adapter,
185187
exclude_demos: bool = False,
186188
) -> Dict[str, List[Dict[str, Any]]]:
187189
# Find data that's relevant to the predictor
188190
pred, inputs, outputs = trace[pred_ind] # assuming that the order is kept
189191

190-
if not adapter:
191-
# TODO(feature): A trace is collected using a particular adapter. It
192-
# would be nice to get this adapter information from the trace (e.g.
193-
# pred.lm.adapter) as opposed to using the inference method below.
194-
adapter = pred.lm.infer_adapter()
195-
196192
demos = [] if exclude_demos else pred.demos
197193
call_data = adapter.format_finetune_data(
198194
signature=pred.signature,
@@ -209,8 +205,8 @@ def bootstrap_trace_data(
209205
metric: Optional[Callable] = None,
210206
num_threads=6,
211207
) -> List[Dict[str, Any]]:
212-
# Return a list of dicts with the following keys:
213-
# example_ind, example, prediction, trace, and score (if metric != None)
208+
# Return a list of dicts with the following keys: example_ind, example, prediction, trace, and score
209+
# (if metric != None)
214210
evaluator = Evaluate(
215211
devset=dataset,
216212
num_threads=num_threads,

0 commit comments

Comments
 (0)