diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..7237b515a 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -139,9 +139,10 @@ async def stream( logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) tool_calls: dict[int, list[Any]] = {} + started_reasoning = False + started_text = False async for event in response: # Defensive: skip events with empty or missing choices @@ -149,12 +150,11 @@ async def stream( continue choice = event.choices[0] - if choice.delta.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - ) - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + if not started_reasoning: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"}) + started_reasoning = True + yield self.format_chunk( { "chunk_type": "content_delta", @@ -163,14 +163,27 @@ async def stream( } ) + if choice.delta.content: + if started_reasoning: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) + started_reasoning = False + + if not started_text: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + started_text = True + + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + ) + for tool_call in choice.delta.tool_calls or []: tool_calls.setdefault(tool_call.index, []).append(tool_call) if choice.finish_reason: + if started_text: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) break - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - for tool_deltas in tool_calls.values(): yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bc81fc819..d3f7a6194 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -182,6 +182,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm thinking"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"text": "I'll calculate"}}}, {"contentBlockDelta": {"delta": {"text": "that for you"}}}, {"contentBlockStop": {}}, @@ -251,8 +253,6 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene tru_events = await alist(response) exp_events = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, ] diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index c1f442b2a..1b69e83f0 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -67,10 +67,26 @@ def __init__(self): "api_key": os.getenv("ANTHROPIC_API_KEY"), }, model_id="claude-3-7-sonnet-20250219", - max_tokens=512, + max_tokens=2048, + params={ + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + }, + }, + ), +) +bedrock = ProviderInfo( + id="bedrock", + factory=lambda: BedrockModel( + additional_request_fields={ + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + }, + }, ), ) -bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel()) cohere = ProviderInfo( id="cohere", environment_variable="COHERE_API_KEY", @@ -84,7 +100,16 @@ def __init__(self): ), ) litellm = ProviderInfo( - id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") + id="litellm", + factory=lambda: LiteLLMModel( + model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + params={ + "thinking": { + "budget_tokens": 1024, + "type": "enabled", + }, + }, + ), ) llama = ProviderInfo( id="llama", @@ -133,7 +158,12 @@ def __init__(self): factory=lambda: GeminiModel( api_key=os.getenv("GOOGLE_API_KEY"), model_id="gemini-2.5-flash", - params={"temperature": 0.7}, + params={ + "temperature": 0.7, + "thinking_config": { + "include_thoughts": True, + }, + }, ), ) diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index eaef1eb88..439412adf 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -5,7 +5,7 @@ from strands import Agent from strands.models import Model -from tests_integ.models.providers import ProviderInfo, all_providers, cohere, llama, mistral +from tests_integ.models.providers import ProviderInfo, all_providers, cohere, gemini, llama, mistral, openai, writer def get_models(): @@ -60,3 +60,13 @@ class Weather(BaseModel): assert len(result.time) > 0 assert len(result.weather) > 0 + + +def test_stream_reasoning(skip_for, model): + skip_for([cohere, gemini, llama, mistral, openai, writer], "reasoning is not supported") + + agent = Agent(model) + result = agent("Please reason about the equation 2+2.") + + assert "reasoningContent" in result.message["content"][0] + assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"]