From 901ddcf43c43172237e36e8cdc9a69b50a223c04 Mon Sep 17 00:00:00 2001 From: Emile Riberdy Date: Mon, 10 Nov 2025 11:20:48 -0500 Subject: [PATCH 1/3] Tests for streamify status message settings leak --- tests/streaming/test_streaming.py | 94 +++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index 89a84250ff..92834d0ad2 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -134,6 +134,100 @@ def module_start_status_message(self, instance, inputs): assert status_messages[2].message == "Predict starting!" +@pytest.mark.anyio +async def test_default_then_custom_status_message_provider(): + class MyProgram(dspy.Module): + def __init__(self): + self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question") + self.predict = dspy.Predict("question->answer") + + def __call__(self, x: str): + question = self.generate_question(x=x) + return self.predict(question=question) + + class MyStatusMessageProvider(StatusMessageProvider): + def tool_start_status_message(self, instance, inputs): + return "Tool starting!" + + def tool_end_status_message(self, outputs): + return "Tool finished!" + + def module_start_status_message(self, instance, inputs): + if isinstance(instance, dspy.Predict): + return "Predict starting!" + + lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}]) + with dspy.context(lm=lm): + program = dspy.streamify(MyProgram()) + output = program("sky") + + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + + assert len(status_messages) == 2 + assert status_messages[0].message == "Calling tool generate_question..." + assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..." + + program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider()) + output = program("sky") + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + assert len(status_messages) == 3 + assert status_messages[0].message == "Tool starting!" + assert status_messages[1].message == "Tool finished!" + assert status_messages[2].message == "Predict starting!" + + +@pytest.mark.anyio +async def test_custom_then_default_status_message_provider(): + class MyProgram(dspy.Module): + def __init__(self): + self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question") + self.predict = dspy.Predict("question->answer") + + def __call__(self, x: str): + question = self.generate_question(x=x) + return self.predict(question=question) + + class MyStatusMessageProvider(StatusMessageProvider): + def tool_start_status_message(self, instance, inputs): + return "Tool starting!" + + def tool_end_status_message(self, outputs): + return "Tool finished!" + + def module_start_status_message(self, instance, inputs): + if isinstance(instance, dspy.Predict): + return "Predict starting!" + + lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}]) + with dspy.context(lm=lm): + program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider()) + output = program("sky") + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + assert len(status_messages) == 3 + assert status_messages[0].message == "Tool starting!" + assert status_messages[1].message == "Tool finished!" + assert status_messages[2].message == "Predict starting!" + + program = dspy.streamify(MyProgram()) + output = program("sky") + status_messages = [] + async for value in output: + if isinstance(value, StatusMessage): + status_messages.append(value) + assert len(status_messages) == 2 + assert status_messages[0].message == "Calling tool generate_question..." + assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..." + + @pytest.mark.llm_call @pytest.mark.anyio async def test_stream_listener_chat_adapter(lm_for_test): From 2497b19746429770239f2100d50f5876360dc27a Mon Sep 17 00:00:00 2001 From: Gabriel Lesperance Date: Tue, 18 Nov 2025 13:57:35 -0500 Subject: [PATCH 2/3] fix: streamify should deepcopy settings before appending a callback --- dspy/streaming/streamify.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dspy/streaming/streamify.py b/dspy/streaming/streamify.py index d8ec56eb2f..8565f98a72 100644 --- a/dspy/streaming/streamify.py +++ b/dspy/streaming/streamify.py @@ -3,6 +3,7 @@ import logging import threading from asyncio import iscoroutinefunction +from copy import deepcopy from queue import Queue from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator @@ -161,7 +162,7 @@ async def use_streaming(): elif not iscoroutinefunction(program): program = asyncify(program) - callbacks = settings.callbacks + callbacks = deepcopy(settings.callbacks) status_streaming_callback = StatusStreamingCallback(status_message_provider) if not any(isinstance(c, StatusStreamingCallback) for c in callbacks): callbacks.append(status_streaming_callback) From 04f7088d19c02de5e1d5b1411d8a3864c3df24fa Mon Sep 17 00:00:00 2001 From: Gabriel Lesperance Date: Fri, 21 Nov 2025 11:45:05 -0500 Subject: [PATCH 3/3] fix: shallow copy instead of deepcopy --- dspy/streaming/streamify.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dspy/streaming/streamify.py b/dspy/streaming/streamify.py index 8565f98a72..6ae9cc7e56 100644 --- a/dspy/streaming/streamify.py +++ b/dspy/streaming/streamify.py @@ -3,7 +3,6 @@ import logging import threading from asyncio import iscoroutinefunction -from copy import deepcopy from queue import Queue from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator @@ -162,7 +161,7 @@ async def use_streaming(): elif not iscoroutinefunction(program): program = asyncify(program) - callbacks = deepcopy(settings.callbacks) + callbacks = list(settings.callbacks) status_streaming_callback = StatusStreamingCallback(status_message_provider) if not any(isinstance(c, StatusStreamingCallback) for c in callbacks): callbacks.append(status_streaming_callback)