Skip to content

Commit 16ba98a

Browse files
Some fixes to callback (#1696)
1 parent c9a8cd4 commit 16ba98a

File tree

4 files changed

+82
-87
lines changed

4 files changed

+82
-87
lines changed

dspy/clients/lm.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
from concurrent.futures import ThreadPoolExecutor
2-
from datetime import datetime
31
import functools
42
import os
3+
import uuid
4+
from concurrent.futures import ThreadPoolExecutor
5+
from datetime import datetime
56
from pathlib import Path
67
from typing import Any, Dict, List, Optional
8+
9+
import litellm
710
import ujson
8-
import uuid
11+
from litellm.caching import Cache
912

10-
from dspy.utils.logging import logger
1113
from dspy.clients.finetune import FinetuneJob, TrainingMethod
1214
from dspy.clients.lm_finetune_utils import (
13-
get_provider_finetune_job_class,
1415
execute_finetune_job,
16+
get_provider_finetune_job_class,
1517
)
16-
1718
from dspy.utils.callback import with_callbacks
18-
import litellm
19-
from litellm.caching import Cache
20-
19+
from dspy.utils.logging import logger
2120

2221
DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
2322
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")

dspy/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .callback import *
2-
from .dummies import *
3-
from .logging import *
1+
from dspy.utils.callback import BaseCallback, with_callbacks
2+
from dspy.utils.dummies import *
3+
from dspy.utils.logging import *

dspy/utils/callback.py

Lines changed: 53 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313

1414

1515
class BaseCallback:
16-
"""
17-
A base class for defining callback handlers for DSPy components.
16+
"""A base class for defining callback handlers for DSPy components.
1817
1918
To use a callback, subclass this class and implement the desired handlers. Each handler
20-
will be called at the appropriate time before/after the execution of the corresponding component.
19+
will be called at the appropriate time before/after the execution of the corresponding component. For example, if
20+
you want to print a message before and after an LM is called, implement `the on_llm_start` and `on_lm_end` handler.
21+
Users can set the callback globally using `dspy.settings.configure` or locally by passing it to the component
22+
constructor.
23+
2124
22-
For example, if you want to print a message before and after an LM is called, implement
23-
the on_llm_start and on_lm_end handler and set the callback to the global settings using `dspy.settings.configure`.
25+
Example 1: Set a global callback using `dspy.settings.configure`.
2426
2527
```
2628
import dspy
@@ -45,19 +47,18 @@ def on_lm_end(self, call_id, outputs, exception):
4547
# > LM is finished with outputs: {'answer': '42'}
4648
```
4749
48-
Another way to set the callback is to pass it directly to the component constructor.
49-
In this case, the callback will only be triggered for that specific instance.
50+
Example 2: Set a local callback by passing it to the component constructor.
5051
5152
```
52-
lm = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()])
53-
lm(question="What is the meaning of life?")
53+
lm_1 = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()])
54+
lm_1(question="What is the meaning of life?")
5455
5556
# > LM is called with inputs: {'question': 'What is the meaning of life?'}
5657
# > LM is finished with outputs: {'answer': '42'}
5758
5859
lm_2 = dspy.LM("gpt-3.5-turbo")
5960
lm_2(question="What is the meaning of life?")
60-
# No logging here
61+
# No logging here because only `lm_1` has the callback set.
6162
```
6263
"""
6364

@@ -67,8 +68,7 @@ def on_module_start(
6768
instance: Any,
6869
inputs: Dict[str, Any],
6970
):
70-
"""
71-
A handler triggered when forward() method of a module (subclass of dspy.Module) is called.
71+
"""A handler triggered when forward() method of a module (subclass of dspy.Module) is called.
7272
7373
Args:
7474
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -84,8 +84,7 @@ def on_module_end(
8484
outputs: Optional[Any],
8585
exception: Optional[Exception] = None,
8686
):
87-
"""
88-
A handler triggered after forward() method of a module (subclass of dspy.Module) is executed.
87+
"""A handler triggered after forward() method of a module (subclass of dspy.Module) is executed.
8988
9089
Args:
9190
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -101,8 +100,7 @@ def on_lm_start(
101100
instance: Any,
102101
inputs: Dict[str, Any],
103102
):
104-
"""
105-
A handler triggered when __call__ method of dspy.LM instance is called.
103+
"""A handler triggered when __call__ method of dspy.LM instance is called.
106104
107105
Args:
108106
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -118,8 +116,7 @@ def on_lm_end(
118116
outputs: Optional[Dict[str, Any]],
119117
exception: Optional[Exception] = None,
120118
):
121-
"""
122-
A handler triggered after __call__ method of dspy.LM instance is executed.
119+
"""A handler triggered after __call__ method of dspy.LM instance is executed.
123120
124121
Args:
125122
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -129,14 +126,13 @@ def on_lm_end(
129126
"""
130127
pass
131128

132-
def on_format_start(
129+
def on_adapter_format_start(
133130
self,
134131
call_id: str,
135132
instance: Any,
136133
inputs: Dict[str, Any],
137134
):
138-
"""
139-
A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called.
135+
"""A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called.
140136
141137
Args:
142138
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -146,14 +142,13 @@ def on_format_start(
146142
"""
147143
pass
148144

149-
def on_format_end(
145+
def on_adapter_format_end(
150146
self,
151147
call_id: str,
152148
outputs: Optional[Dict[str, Any]],
153149
exception: Optional[Exception] = None,
154150
):
155-
"""
156-
A handler triggered after format() method of dspy.LM instance is executed.
151+
"""A handler triggered after format() method of an adapter (subclass of dspy.Adapter) is called..
157152
158153
Args:
159154
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -163,14 +158,13 @@ def on_format_end(
163158
"""
164159
pass
165160

166-
def on_parse_start(
161+
def on_adapter_parse_start(
167162
self,
168163
call_id: str,
169164
instance: Any,
170165
inputs: Dict[str, Any],
171166
):
172-
"""
173-
A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called.
167+
"""A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called.
174168
175169
Args:
176170
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -180,14 +174,13 @@ def on_parse_start(
180174
"""
181175
pass
182176

183-
def on_parse_end(
177+
def on_adapter_parse_end(
184178
self,
185179
call_id: str,
186180
outputs: Optional[Dict[str, Any]],
187181
exception: Optional[Exception] = None,
188182
):
189-
"""
190-
A handler triggered after parse() method of dspy.LM instance is executed.
183+
"""A handler triggered after parse() method of an adapter (subclass of dspy.Adapter) is called.
191184
192185
Args:
193186
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -200,23 +193,23 @@ def on_parse_end(
200193

201194
def with_callbacks(fn):
202195
@functools.wraps(fn)
203-
def wrapper(self, *args, **kwargs):
204-
# Combine global and local (per-instance) callbacks
205-
callbacks = dspy.settings.get("callbacks", []) + getattr(self, "callbacks", [])
196+
def wrapper(instance, *args, **kwargs):
197+
# Combine global and local (per-instance) callbacks.
198+
callbacks = dspy.settings.get("callbacks", []) + getattr(instance, "callbacks", [])
206199

207-
# if no callbacks are provided, just call the function
200+
# If no callbacks are provided, just call the function
208201
if not callbacks:
209-
return fn(self, *args, **kwargs)
202+
return fn(instance, *args, **kwargs)
210203

211-
# Generate call ID to connect start/end handlers if needed
204+
# Generate call ID as the unique identifier for the call, this is useful for instrumentation.
212205
call_id = uuid.uuid4().hex
213206

214-
inputs = inspect.getcallargs(fn, self, *args, **kwargs)
207+
inputs = inspect.getcallargs(fn, instance, *args, **kwargs)
215208
inputs.pop("self") # Not logging self as input
216209

217210
for callback in callbacks:
218211
try:
219-
_get_on_start_handler(callback, self, fn)(call_id=call_id, instance=self, inputs=inputs)
212+
_get_on_start_handler(callback, instance, fn)(call_id=call_id, instance=instance, inputs=inputs)
220213

221214
except Exception as e:
222215
logger.warning(f"Error when calling callback {callback}: {e}")
@@ -225,58 +218,59 @@ def wrapper(self, *args, **kwargs):
225218
exception = None
226219
try:
227220
parent_call_id = ACTIVE_CALL_ID.get()
228-
# Active ID must be set right before the function is called,
229-
# not before calling the callbacks.
221+
# Active ID must be set right before the function is called, not before calling the callbacks.
230222
ACTIVE_CALL_ID.set(call_id)
231-
results = fn(self, *args, **kwargs)
223+
results = fn(instance, *args, **kwargs)
232224
return results
233225
except Exception as e:
234226
exception = e
235227
raise exception
236228
finally:
229+
# Execute the end handlers even if the function call raises an exception.
237230
ACTIVE_CALL_ID.set(parent_call_id)
238231
for callback in callbacks:
239232
try:
240-
_get_on_end_handler(callback, self, fn)(
233+
_get_on_end_handler(callback, instance, fn)(
241234
call_id=call_id,
242235
outputs=results,
243236
exception=exception,
244237
)
245238
except Exception as e:
246-
logger.warning(f"Error when calling callback {callback}: {e}")
239+
logger.warning(
240+
f"Error when applying callback {callback}'s end handler on function {fn.__name__}: {e}."
241+
)
247242

248243
return wrapper
249244

250245

251246
def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable:
252-
"""
253-
Selects the appropriate on_start handler of the callback
254-
based on the instance and function name.
255-
"""
256-
if isinstance(instance, (dspy.LM)):
247+
"""Selects the appropriate on_start handler of the callback based on the instance and function name."""
248+
if isinstance(instance, dspy.LM):
257249
return callback.on_lm_start
258-
elif isinstance(instance, (dspy.Adapter)):
250+
251+
if isinstance(instance, dspy.Adapter):
259252
if fn.__name__ == "format":
260-
return callback.on_format_start
253+
return callback.on_adapter_format_start
261254
elif fn.__name__ == "parse":
262-
return callback.on_parse_start
255+
return callback.on_adapter_parse_start
256+
else:
257+
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")
263258

264259
# We treat everything else as a module.
265260
return callback.on_module_start
266261

267262

268263
def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable:
269-
"""
270-
Selects the appropriate on_end handler of the callback
271-
based on the instance and function name.
272-
"""
264+
"""Selects the appropriate on_end handler of the callback based on the instance and function name."""
273265
if isinstance(instance, (dspy.LM)):
274266
return callback.on_lm_end
275-
elif isinstance(instance, (dspy.Adapter)):
267+
268+
if isinstance(instance, (dspy.Adapter)):
276269
if fn.__name__ == "format":
277-
return callback.on_format_end
270+
return callback.on_adapter_format_end
278271
elif fn.__name__ == "parse":
279-
return callback.on_parse_end
280-
272+
return callback.on_adapter_parse_end
273+
else:
274+
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")
281275
# We treat everything else as a module.
282276
return callback.on_module_end

tests/callback/test_callback.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def reset_settings():
1818

1919

2020
class MyCallback(BaseCallback):
21+
"""A simple callback that records the calls."""
22+
2123
def __init__(self):
2224
self.calls = []
2325

@@ -33,17 +35,17 @@ def on_lm_start(self, call_id, instance, inputs):
3335
def on_lm_end(self, call_id, outputs, exception):
3436
self.calls.append({"handler": "on_lm_end", "outputs": outputs, "exception": exception})
3537

36-
def on_format_start(self, call_id, instance, inputs):
37-
self.calls.append({"handler": "on_format_start", "instance": instance, "inputs": inputs})
38+
def on_adapter_format_start(self, call_id, instance, inputs):
39+
self.calls.append({"handler": "on_adapter_format_start", "instance": instance, "inputs": inputs})
3840

39-
def on_format_end(self, call_id, outputs, exception):
40-
self.calls.append({"handler": "on_format_end", "outputs": outputs, "exception": exception})
41+
def on_adapter_format_end(self, call_id, outputs, exception):
42+
self.calls.append({"handler": "on_adapter_format_end", "outputs": outputs, "exception": exception})
4143

42-
def on_parse_start(self, call_id, instance, inputs):
43-
self.calls.append({"handler": "on_parse_start", "instance": instance, "inputs": inputs})
44+
def on_adapter_parse_start(self, call_id, instance, inputs):
45+
self.calls.append({"handler": "on_adapter_parse_start", "instance": instance, "inputs": inputs})
4446

45-
def on_parse_end(self, call_id, outputs, exception):
46-
self.calls.append({"handler": "on_parse_end", "outputs": outputs, "exception": exception})
47+
def on_adapter_parse_end(self, call_id, outputs, exception):
48+
self.calls.append({"handler": "on_adapter_parse_end", "outputs": outputs, "exception": exception})
4749

4850

4951
@pytest.mark.parametrize(
@@ -163,17 +165,17 @@ def test_callback_complex_module():
163165
assert [call["handler"] for call in callback.calls] == [
164166
"on_module_start",
165167
"on_module_start",
166-
"on_format_start",
167-
"on_format_end",
168+
"on_adapter_format_start",
169+
"on_adapter_format_end",
168170
"on_lm_start",
169171
"on_lm_end",
170172
# Parsing will run per output (n=3)
171-
"on_parse_start",
172-
"on_parse_end",
173-
"on_parse_start",
174-
"on_parse_end",
175-
"on_parse_start",
176-
"on_parse_end",
173+
"on_adapter_parse_start",
174+
"on_adapter_parse_end",
175+
"on_adapter_parse_start",
176+
"on_adapter_parse_end",
177+
"on_adapter_parse_start",
178+
"on_adapter_parse_end",
177179
"on_module_end",
178180
"on_module_end",
179181
]

0 commit comments

Comments
 (0)