diff --git a/dspy/streaming/streamify.py b/dspy/streaming/streamify.py index d8ec56eb2f..6ae9cc7e56 100644 --- a/dspy/streaming/streamify.py +++ b/dspy/streaming/streamify.py @@ -161,7 +161,7 @@ async def use_streaming(): elif not iscoroutinefunction(program): program = asyncify(program) - callbacks = settings.callbacks + callbacks = list(settings.callbacks) status_streaming_callback = StatusStreamingCallback(status_message_provider) if not any(isinstance(c, StatusStreamingCallback) for c in callbacks): callbacks.append(status_streaming_callback) diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index 89a84250ff..92834d0ad2 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -134,6 +134,100 @@ def module_start_status_message(self, instance, inputs): assert status_messages[2].message == "Predict starting!" +@pytest.mark.anyio +async def test_default_then_custom_status_message_provider(): + class MyProgram(dspy.Module): + def __init__(self): + self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question") + self.predict = dspy.Predict("question->answer") + + def __call__(self, x: str): + question = self.generate_question(x=x) + return self.predict(question=question) + + class MyStatusMessageProvider(StatusMessageProvider): + def tool_start_status_message(self, instance, inputs): + return "Tool starting!" + + def tool_end_status_message(self, outputs): + return "Tool finished!" + + def module_start_status_message(self, instance, inputs): + if isinstance(instance, dspy.Predict): + return "Predict starting!" + + lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}]) + with dspy.context(lm=lm): + program = dspy.streamify(MyProgram()) + output = program("sky") + + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + + assert len(status_messages) == 2 + assert status_messages[0].message == "Calling tool generate_question..." + assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..." + + program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider()) + output = program("sky") + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + assert len(status_messages) == 3 + assert status_messages[0].message == "Tool starting!" + assert status_messages[1].message == "Tool finished!" + assert status_messages[2].message == "Predict starting!" + + +@pytest.mark.anyio +async def test_custom_then_default_status_message_provider(): + class MyProgram(dspy.Module): + def __init__(self): + self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question") + self.predict = dspy.Predict("question->answer") + + def __call__(self, x: str): + question = self.generate_question(x=x) + return self.predict(question=question) + + class MyStatusMessageProvider(StatusMessageProvider): + def tool_start_status_message(self, instance, inputs): + return "Tool starting!" + + def tool_end_status_message(self, outputs): + return "Tool finished!" + + def module_start_status_message(self, instance, inputs): + if isinstance(instance, dspy.Predict): + return "Predict starting!" + + lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}]) + with dspy.context(lm=lm): + program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider()) + output = program("sky") + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + assert len(status_messages) == 3 + assert status_messages[0].message == "Tool starting!" + assert status_messages[1].message == "Tool finished!" + assert status_messages[2].message == "Predict starting!" + + program = dspy.streamify(MyProgram()) + output = program("sky") + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + assert len(status_messages) == 2 + assert status_messages[0].message == "Calling tool generate_question..." + assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..." + + @pytest.mark.llm_call @pytest.mark.anyio async def test_stream_listener_chat_adapter(lm_for_test):