1313
1414
1515class BaseCallback :
16- """
17- A base class for defining callback handlers for DSPy components.
16+ """A base class for defining callback handlers for DSPy components.
1817
1918 To use a callback, subclass this class and implement the desired handlers. Each handler
20- will be called at the appropriate time before/after the execution of the corresponding component.
19+ will be called at the appropriate time before/after the execution of the corresponding component. For example, if
20+ you want to print a message before and after an LM is called, implement `the on_llm_start` and `on_lm_end` handler.
21+ Users can set the callback globally using `dspy.settings.configure` or locally by passing it to the component
22+ constructor.
23+
2124
22- For example, if you want to print a message before and after an LM is called, implement
23- the on_llm_start and on_lm_end handler and set the callback to the global settings using `dspy.settings.configure`.
25+ Example 1: Set a global callback using `dspy.settings.configure`.
2426
2527 ```
2628 import dspy
@@ -45,19 +47,18 @@ def on_lm_end(self, call_id, outputs, exception):
4547 # > LM is finished with outputs: {'answer': '42'}
4648 ```
4749
48- Another way to set the callback is to pass it directly to the component constructor.
49- In this case, the callback will only be triggered for that specific instance.
50+ Example 2: Set a local callback by passing it to the component constructor.
5051
5152 ```
52- lm = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()])
53- lm (question="What is the meaning of life?")
53+ lm_1 = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()])
54+ lm_1 (question="What is the meaning of life?")
5455
5556 # > LM is called with inputs: {'question': 'What is the meaning of life?'}
5657 # > LM is finished with outputs: {'answer': '42'}
5758
5859 lm_2 = dspy.LM("gpt-3.5-turbo")
5960 lm_2(question="What is the meaning of life?")
60- # No logging here
61+ # No logging here because only `lm_1` has the callback set.
6162 ```
6263 """
6364
@@ -67,8 +68,7 @@ def on_module_start(
6768 instance : Any ,
6869 inputs : Dict [str , Any ],
6970 ):
70- """
71- A handler triggered when forward() method of a module (subclass of dspy.Module) is called.
71+ """A handler triggered when forward() method of a module (subclass of dspy.Module) is called.
7272
7373 Args:
7474 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -84,8 +84,7 @@ def on_module_end(
8484 outputs : Optional [Any ],
8585 exception : Optional [Exception ] = None ,
8686 ):
87- """
88- A handler triggered after forward() method of a module (subclass of dspy.Module) is executed.
87+ """A handler triggered after forward() method of a module (subclass of dspy.Module) is executed.
8988
9089 Args:
9190 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -101,8 +100,7 @@ def on_lm_start(
101100 instance : Any ,
102101 inputs : Dict [str , Any ],
103102 ):
104- """
105- A handler triggered when __call__ method of dspy.LM instance is called.
103+ """A handler triggered when __call__ method of dspy.LM instance is called.
106104
107105 Args:
108106 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -118,8 +116,7 @@ def on_lm_end(
118116 outputs : Optional [Dict [str , Any ]],
119117 exception : Optional [Exception ] = None ,
120118 ):
121- """
122- A handler triggered after __call__ method of dspy.LM instance is executed.
119+ """A handler triggered after __call__ method of dspy.LM instance is executed.
123120
124121 Args:
125122 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -129,14 +126,13 @@ def on_lm_end(
129126 """
130127 pass
131128
132- def on_format_start (
129+ def on_adapter_format_start (
133130 self ,
134131 call_id : str ,
135132 instance : Any ,
136133 inputs : Dict [str , Any ],
137134 ):
138- """
139- A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called.
135+ """A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called.
140136
141137 Args:
142138 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -146,14 +142,13 @@ def on_format_start(
146142 """
147143 pass
148144
149- def on_format_end (
145+ def on_adapter_format_end (
150146 self ,
151147 call_id : str ,
152148 outputs : Optional [Dict [str , Any ]],
153149 exception : Optional [Exception ] = None ,
154150 ):
155- """
156- A handler triggered after format() method of dspy.LM instance is executed.
151+ """A handler triggered after format() method of an adapter (subclass of dspy.Adapter) is called..
157152
158153 Args:
159154 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -163,14 +158,13 @@ def on_format_end(
163158 """
164159 pass
165160
166- def on_parse_start (
161+ def on_adapter_parse_start (
167162 self ,
168163 call_id : str ,
169164 instance : Any ,
170165 inputs : Dict [str , Any ],
171166 ):
172- """
173- A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called.
167+ """A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called.
174168
175169 Args:
176170 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -180,14 +174,13 @@ def on_parse_start(
180174 """
181175 pass
182176
183- def on_parse_end (
177+ def on_adapter_parse_end (
184178 self ,
185179 call_id : str ,
186180 outputs : Optional [Dict [str , Any ]],
187181 exception : Optional [Exception ] = None ,
188182 ):
189- """
190- A handler triggered after parse() method of dspy.LM instance is executed.
183+ """A handler triggered after parse() method of an adapter (subclass of dspy.Adapter) is called.
191184
192185 Args:
193186 call_id: A unique identifier for the call. Can be used to connect start/end handlers.
@@ -200,23 +193,23 @@ def on_parse_end(
200193
201194def with_callbacks (fn ):
202195 @functools .wraps (fn )
203- def wrapper (self , * args , ** kwargs ):
204- # Combine global and local (per-instance) callbacks
205- callbacks = dspy .settings .get ("callbacks" , []) + getattr (self , "callbacks" , [])
196+ def wrapper (instance , * args , ** kwargs ):
197+ # Combine global and local (per-instance) callbacks.
198+ callbacks = dspy .settings .get ("callbacks" , []) + getattr (instance , "callbacks" , [])
206199
207- # if no callbacks are provided, just call the function
200+ # If no callbacks are provided, just call the function
208201 if not callbacks :
209- return fn (self , * args , ** kwargs )
202+ return fn (instance , * args , ** kwargs )
210203
211- # Generate call ID to connect start/end handlers if needed
204+ # Generate call ID as the unique identifier for the call, this is useful for instrumentation.
212205 call_id = uuid .uuid4 ().hex
213206
214- inputs = inspect .getcallargs (fn , self , * args , ** kwargs )
207+ inputs = inspect .getcallargs (fn , instance , * args , ** kwargs )
215208 inputs .pop ("self" ) # Not logging self as input
216209
217210 for callback in callbacks :
218211 try :
219- _get_on_start_handler (callback , self , fn )(call_id = call_id , instance = self , inputs = inputs )
212+ _get_on_start_handler (callback , instance , fn )(call_id = call_id , instance = instance , inputs = inputs )
220213
221214 except Exception as e :
222215 logger .warning (f"Error when calling callback { callback } : { e } " )
@@ -225,58 +218,59 @@ def wrapper(self, *args, **kwargs):
225218 exception = None
226219 try :
227220 parent_call_id = ACTIVE_CALL_ID .get ()
228- # Active ID must be set right before the function is called,
229- # not before calling the callbacks.
221+ # Active ID must be set right before the function is called, not before calling the callbacks.
230222 ACTIVE_CALL_ID .set (call_id )
231- results = fn (self , * args , ** kwargs )
223+ results = fn (instance , * args , ** kwargs )
232224 return results
233225 except Exception as e :
234226 exception = e
235227 raise exception
236228 finally :
229+ # Execute the end handlers even if the function call raises an exception.
237230 ACTIVE_CALL_ID .set (parent_call_id )
238231 for callback in callbacks :
239232 try :
240- _get_on_end_handler (callback , self , fn )(
233+ _get_on_end_handler (callback , instance , fn )(
241234 call_id = call_id ,
242235 outputs = results ,
243236 exception = exception ,
244237 )
245238 except Exception as e :
246- logger .warning (f"Error when calling callback { callback } : { e } " )
239+ logger .warning (
240+ f"Error when applying callback { callback } 's end handler on function { fn .__name__ } : { e } ."
241+ )
247242
248243 return wrapper
249244
250245
251246def _get_on_start_handler (callback : BaseCallback , instance : Any , fn : Callable ) -> Callable :
252- """
253- Selects the appropriate on_start handler of the callback
254- based on the instance and function name.
255- """
256- if isinstance (instance , (dspy .LM )):
247+ """Selects the appropriate on_start handler of the callback based on the instance and function name."""
248+ if isinstance (instance , dspy .LM ):
257249 return callback .on_lm_start
258- elif isinstance (instance , (dspy .Adapter )):
250+
251+ if isinstance (instance , dspy .Adapter ):
259252 if fn .__name__ == "format" :
260- return callback .on_format_start
253+ return callback .on_adapter_format_start
261254 elif fn .__name__ == "parse" :
262- return callback .on_parse_start
255+ return callback .on_adapter_parse_start
256+ else :
257+ raise ValueError (f"Unsupported adapter method for using callback: { fn .__name__ } ." )
263258
264259 # We treat everything else as a module.
265260 return callback .on_module_start
266261
267262
268263def _get_on_end_handler (callback : BaseCallback , instance : Any , fn : Callable ) -> Callable :
269- """
270- Selects the appropriate on_end handler of the callback
271- based on the instance and function name.
272- """
264+ """Selects the appropriate on_end handler of the callback based on the instance and function name."""
273265 if isinstance (instance , (dspy .LM )):
274266 return callback .on_lm_end
275- elif isinstance (instance , (dspy .Adapter )):
267+
268+ if isinstance (instance , (dspy .Adapter )):
276269 if fn .__name__ == "format" :
277- return callback .on_format_end
270+ return callback .on_adapter_format_end
278271 elif fn .__name__ == "parse" :
279- return callback .on_parse_end
280-
272+ return callback .on_adapter_parse_end
273+ else :
274+ raise ValueError (f"Unsupported adapter method for using callback: { fn .__name__ } ." )
281275 # We treat everything else as a module.
282276 return callback .on_module_end
0 commit comments