Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import threading
from asyncio import iscoroutinefunction
from copy import deepcopy
from queue import Queue
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator

Expand Down Expand Up @@ -161,7 +162,7 @@ async def use_streaming():
elif not iscoroutinefunction(program):
program = asyncify(program)

callbacks = settings.callbacks
callbacks = deepcopy(settings.callbacks)
status_streaming_callback = StatusStreamingCallback(status_message_provider)
if not any(isinstance(c, StatusStreamingCallback) for c in callbacks):
callbacks.append(status_streaming_callback)
Expand Down
94 changes: 94 additions & 0 deletions tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,100 @@ def module_start_status_message(self, instance, inputs):
assert status_messages[2].message == "Predict starting!"


@pytest.mark.anyio
async def test_default_then_custom_status_message_provider():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to refactor the testing a bit - instead of default => custom, we can put up two threads that apply different MyStatusMessageProvider, and verify that both threads work as expected. Then at the end we verify that the callback is not being modified.

class MyProgram(dspy.Module):
def __init__(self):
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
self.predict = dspy.Predict("question->answer")

def __call__(self, x: str):
question = self.generate_question(x=x)
return self.predict(question=question)

class MyStatusMessageProvider(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Tool starting!"

def tool_end_status_message(self, outputs):
return "Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Predict starting!"

lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram())
output = program("sky")

status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)

assert len(status_messages) == 2
assert status_messages[0].message == "Calling tool generate_question..."
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."

program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 3
assert status_messages[0].message == "Tool starting!"
assert status_messages[1].message == "Tool finished!"
assert status_messages[2].message == "Predict starting!"


@pytest.mark.anyio
async def test_custom_then_default_status_message_provider():
class MyProgram(dspy.Module):
def __init__(self):
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
self.predict = dspy.Predict("question->answer")

def __call__(self, x: str):
question = self.generate_question(x=x)
return self.predict(question=question)

class MyStatusMessageProvider(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Tool starting!"

def tool_end_status_message(self, outputs):
return "Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Predict starting!"

lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 3
assert status_messages[0].message == "Tool starting!"
assert status_messages[1].message == "Tool finished!"
assert status_messages[2].message == "Predict starting!"

program = dspy.streamify(MyProgram())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 2
assert status_messages[0].message == "Calling tool generate_question..."
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."


@pytest.mark.llm_call
@pytest.mark.anyio
async def test_stream_listener_chat_adapter(lm_for_test):
Expand Down