Skip to content

Commit c85464c

Browse files
fix(event_loop): raise dedicated exception when encountering max toke… (#576)
* fix(event_loop): raise dedicated exception when encountering max tokens stop reason * fix: update integ tests * fix: rename exception message, add to exception, move earlier in cycle * Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg <[email protected]> * Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg <[email protected]> * linting --------- Co-authored-by: Nick Clegg <[email protected]>
1 parent 8b1de4d commit c85464c

File tree

4 files changed

+116
-3
lines changed

4 files changed

+116
-3
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from ..telemetry.tracer import get_tracer
2929
from ..tools.executor import run_tools, validate_and_prepare_tools
3030
from ..types.content import Message
31-
from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
31+
from ..types.exceptions import (
32+
ContextWindowOverflowException,
33+
EventLoopException,
34+
MaxTokensReachedException,
35+
ModelThrottledException,
36+
)
3237
from ..types.streaming import Metrics, StopReason
3338
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
3439
from .streaming import stream_messages
@@ -187,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
187192
raise e
188193

189194
try:
195+
if stop_reason == "max_tokens":
196+
"""
197+
Handle max_tokens limit reached by the model.
198+
199+
When the model reaches its maximum token limit, this represents a potentially unrecoverable
200+
state where the model's response was truncated. By default, Strands fails hard with an
201+
MaxTokensReachedException to maintain consistency with other failure types.
202+
"""
203+
raise MaxTokensReachedException(
204+
message=(
205+
"Agent has reached an unrecoverable state due to max_tokens limit. "
206+
"For more information see: "
207+
"https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
208+
),
209+
incomplete_message=message,
210+
)
190211
# Add message in trace and mark the end of the stream messages trace
191212
stream_trace.add_message(message)
192213
stream_trace.end()
@@ -231,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
231252
# Don't yield or log the exception - we already did it when we
232253
# raised the exception and we don't need that duplication.
233254
raise
234-
except ContextWindowOverflowException as e:
255+
except (ContextWindowOverflowException, MaxTokensReachedException) as e:
256+
# Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
235257
if cycle_span:
236258
tracer.end_span_with_error(cycle_span, str(e), e)
237259
raise e

src/strands/types/exceptions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import Any
44

5+
from strands.types.content import Message
6+
57

68
class EventLoopException(Exception):
79
"""Exception raised by the event loop."""
@@ -18,6 +20,25 @@ def __init__(self, original_exception: Exception, request_state: Any = None) ->
1820
super().__init__(str(original_exception))
1921

2022

23+
class MaxTokensReachedException(Exception):
24+
"""Exception raised when the model reaches its maximum token generation limit.
25+
26+
This exception is raised when the model stops generating tokens because it has reached the maximum number of
27+
tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for
28+
the complexity of the response, or when the model naturally reaches its configured output limit during generation.
29+
"""
30+
31+
def __init__(self, message: str, incomplete_message: Message):
32+
"""Initialize the exception with an error message and the incomplete message object.
33+
34+
Args:
35+
message: The error message describing the token limit issue
36+
incomplete_message: The valid Message object with incomplete content due to token limits
37+
"""
38+
self.incomplete_message = incomplete_message
39+
super().__init__(message)
40+
41+
2142
class ContextWindowOverflowException(Exception):
2243
"""Exception raised when the context window is exceeded.
2344

tests/strands/event_loop/test_event_loop.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
)
2020
from strands.telemetry.metrics import EventLoopMetrics
2121
from strands.tools.registry import ToolRegistry
22-
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
22+
from strands.types.exceptions import (
23+
ContextWindowOverflowException,
24+
EventLoopException,
25+
MaxTokensReachedException,
26+
ModelThrottledException,
27+
)
2328
from tests.fixtures.mock_hook_provider import MockHookProvider
2429

2530

@@ -556,6 +561,51 @@ async def test_event_loop_tracing_with_model_error(
556561
mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect)
557562

558563

564+
@pytest.mark.asyncio
565+
async def test_event_loop_cycle_max_tokens_exception(
566+
agent,
567+
model,
568+
agenerator,
569+
alist,
570+
):
571+
"""Test that max_tokens stop reason raises MaxTokensReachedException."""
572+
573+
# Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495
574+
model.stream.return_value = agenerator(
575+
[
576+
{
577+
"contentBlockStart": {
578+
"start": {
579+
"toolUse": {},
580+
},
581+
},
582+
},
583+
{"contentBlockStop": {}},
584+
{"messageStop": {"stopReason": "max_tokens"}},
585+
]
586+
)
587+
588+
# Call event_loop_cycle, expecting it to raise MaxTokensReachedException
589+
with pytest.raises(MaxTokensReachedException) as exc_info:
590+
stream = strands.event_loop.event_loop.event_loop_cycle(
591+
agent=agent,
592+
invocation_state={},
593+
)
594+
await alist(stream)
595+
596+
# Verify the exception message contains the expected content
597+
expected_message = (
598+
"Agent has reached an unrecoverable state due to max_tokens limit. "
599+
"For more information see: "
600+
"https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
601+
)
602+
assert str(exc_info.value) == expected_message
603+
604+
# Verify that the message has not been appended to the messages array
605+
assert len(agent.messages) == 1
606+
assert exc_info.value.incomplete_message not in agent.messages
607+
608+
559609
@patch("strands.event_loop.event_loop.get_tracer")
560610
@pytest.mark.asyncio
561611
async def test_event_loop_tracing_with_tool_execution(
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
from strands import Agent, tool
4+
from strands.models.bedrock import BedrockModel
5+
from strands.types.exceptions import MaxTokensReachedException
6+
7+
8+
@tool
9+
def story_tool(story: str) -> str:
10+
return story
11+
12+
13+
def test_context_window_overflow():
14+
model = BedrockModel(max_tokens=100)
15+
agent = Agent(model=model, tools=[story_tool])
16+
17+
with pytest.raises(MaxTokensReachedException):
18+
agent("Tell me a story!")
19+
20+
assert len(agent.messages) == 1

0 commit comments

Comments
 (0)