|
6 | 6 | import uuid |
7 | 7 | from datetime import datetime |
8 | 8 | from hashlib import sha256 |
9 | | -from typing import Any, Dict, List, Literal, Optional, cast, TYPE_CHECKING |
| 9 | +from typing import Any, Dict, List, Literal, Optional, cast |
10 | 10 |
|
11 | 11 | import litellm |
12 | 12 | import pydantic |
|
21 | 21 | from dspy.clients.provider import Provider, TrainingJob |
22 | 22 | from dspy.clients.utils_finetune import TrainDataFormat |
23 | 23 | from dspy.utils.callback import BaseCallback, with_callbacks |
24 | | -if TYPE_CHECKING: |
25 | | - from dspy.adapters.base import Adapter |
26 | 24 |
|
27 | 25 | from .base_lm import BaseLM |
28 | 26 |
|
@@ -142,17 +140,20 @@ def __call__(self, prompt=None, messages=None, **kwargs): |
142 | 140 |
|
143 | 141 | # Logging, with removed api key & where `cost` is None on cache hit. |
144 | 142 | kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} |
145 | | - entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response) |
146 | | - entry = dict(**entry, outputs=outputs, usage=dict(response["usage"])) |
147 | | - entry = dict(**entry, cost=response.get("_hidden_params", {}).get("response_cost")) |
148 | | - entry = dict( |
149 | | - **entry, |
150 | | - timestamp=datetime.now().isoformat(), |
151 | | - uuid=str(uuid.uuid4()), |
152 | | - model=self.model, |
153 | | - response_model=response["model"], |
154 | | - model_type=self.model_type, |
155 | | - ) |
| 143 | + entry = { |
| 144 | + "prompt": prompt, |
| 145 | + "messages": messages, |
| 146 | + "kwargs": kwargs, |
| 147 | + "response": response, |
| 148 | + "outputs": outputs, |
| 149 | + "usage": dict(response["usage"]), |
| 150 | + "cost": response.get("_hidden_params", {}).get("response_cost"), |
| 151 | + "timestamp": datetime.now().isoformat(), |
| 152 | + "uuid": str(uuid.uuid4()), |
| 153 | + "model": self.model, |
| 154 | + "response_model": response["model"], |
| 155 | + "model_type": self.model_type, |
| 156 | + } |
156 | 157 | self.history.append(entry) |
157 | 158 | self.update_global_history(entry) |
158 | 159 |
|
@@ -216,38 +217,8 @@ def _run_finetune_job(self, job: TrainingJob): |
216 | 217 | def infer_provider(self) -> Provider: |
217 | 218 | if OpenAIProvider.is_provider_model(self.model): |
218 | 219 | return OpenAIProvider() |
219 | | - # TODO(PR): Keeping this function here will require us to import all |
220 | | - # providers in this file. Is this okay? |
221 | 220 | return Provider() |
222 | 221 |
|
223 | | - def infer_adapter(self) -> "Adapter": |
224 | | - import dspy |
225 | | - |
226 | | - if dspy.settings.adapter: |
227 | | - return dspy.settings.adapter |
228 | | - |
229 | | - model_type_to_adapter = { |
230 | | - "chat": dspy.ChatAdapter(), |
231 | | - } |
232 | | - model_type = self.model_type |
233 | | - return model_type_to_adapter[model_type] |
234 | | - |
235 | | - def copy(self, **kwargs): |
236 | | - """Returns a copy of the language model with possibly updated parameters.""" |
237 | | - |
238 | | - import copy |
239 | | - |
240 | | - new_instance = copy.deepcopy(self) |
241 | | - new_instance.history = [] |
242 | | - |
243 | | - for key, value in kwargs.items(): |
244 | | - if hasattr(self, key): |
245 | | - setattr(new_instance, key, value) |
246 | | - if (key in self.kwargs) or (not hasattr(self, key)): |
247 | | - new_instance.kwargs[key] = value |
248 | | - |
249 | | - return new_instance |
250 | | - |
251 | 222 | def dump_state(self): |
252 | 223 | state_keys = ["model", "model_type", "cache", "cache_in_memory", "num_retries", "finetuning_model", "launch_kwargs", "train_kwargs"] |
253 | 224 | return { key: getattr(self, key) for key in state_keys } | self.kwargs |
|
0 commit comments