Skip to content

Commit aa03b3d

Browse files
authored
feat: Implement typed events internally (#745)
Step 1/N for implementing typed-events; first just preserve the existing behaviors with no changes to the public api. A follow-up change will update how we invoke callbacks and pass invocation_state around, while this one just adds typed classes for events internally. --------- Co-authored-by: Mackenzie Zastrow <[email protected]>
1 parent 0fac648 commit aa03b3d

File tree

7 files changed

+529
-94
lines changed

7 files changed

+529
-94
lines changed

src/strands/agent/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from ..tools.executors._executor import ToolExecutor
5151
from ..tools.registry import ToolRegistry
5252
from ..tools.watcher import ToolWatcher
53+
from ..types._events import InitEventLoopEvent
5354
from ..types.agent import AgentInput
5455
from ..types.content import ContentBlock, Message, Messages
5556
from ..types.exceptions import ContextWindowOverflowException
@@ -604,7 +605,7 @@ async def _run_loop(
604605
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
605606

606607
try:
607-
yield {"callback": {"init_event_loop": True, **invocation_state}}
608+
yield InitEventLoopEvent(invocation_state)
608609

609610
for message in messages:
610611
self._append_message(message)

src/strands/event_loop/event_loop.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@
2525
from ..telemetry.metrics import Trace
2626
from ..telemetry.tracer import get_tracer
2727
from ..tools._validator import validate_and_prepare_tools
28+
from ..types._events import (
29+
EventLoopStopEvent,
30+
EventLoopThrottleEvent,
31+
ForceStopEvent,
32+
ModelMessageEvent,
33+
StartEvent,
34+
StartEventLoopEvent,
35+
ToolResultMessageEvent,
36+
)
2837
from ..types.content import Message
2938
from ..types.exceptions import (
3039
ContextWindowOverflowException,
@@ -91,8 +100,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
91100
cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes)
92101
invocation_state["event_loop_cycle_trace"] = cycle_trace
93102

94-
yield {"callback": {"start": True}}
95-
yield {"callback": {"start_event_loop": True}}
103+
yield StartEvent()
104+
yield StartEventLoopEvent()
96105

97106
# Create tracer span for this event loop cycle
98107
tracer = get_tracer()
@@ -175,7 +184,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
175184

176185
if isinstance(e, ModelThrottledException):
177186
if attempt + 1 == MAX_ATTEMPTS:
178-
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
187+
yield ForceStopEvent(reason=e)
179188
raise e
180189

181190
logger.debug(
@@ -189,7 +198,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
189198
time.sleep(current_delay)
190199
current_delay = min(current_delay * 2, MAX_DELAY)
191200

192-
yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}}
201+
yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state)
193202
else:
194203
raise e
195204

@@ -201,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
201210
# Add the response message to the conversation
202211
agent.messages.append(message)
203212
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
204-
yield {"callback": {"message": message}}
213+
yield ModelMessageEvent(message=message)
205214

206215
# Update metrics
207216
agent.event_loop_metrics.update_usage(usage)
@@ -235,8 +244,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
235244
cycle_start_time=cycle_start_time,
236245
invocation_state=invocation_state,
237246
)
238-
async for event in events:
239-
yield event
247+
async for typed_event in events:
248+
yield typed_event
240249

241250
return
242251

@@ -264,11 +273,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
264273
tracer.end_span_with_error(cycle_span, str(e), e)
265274

266275
# Handle any other exceptions
267-
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
276+
yield ForceStopEvent(reason=e)
268277
logger.exception("cycle failed")
269278
raise EventLoopException(e, invocation_state["request_state"]) from e
270279

271-
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
280+
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
272281

273282

274283
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
@@ -295,7 +304,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
295304
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
296305
cycle_trace.add_child(recursive_trace)
297306

298-
yield {"callback": {"start": True}}
307+
yield StartEvent()
299308

300309
events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
301310
async for event in events:
@@ -339,7 +348,7 @@ async def _handle_tool_execution(
339348
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
340349
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
341350
if not tool_uses:
342-
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
351+
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
343352
return
344353

345354
tool_events = agent.tool_executor._execute(
@@ -358,15 +367,15 @@ async def _handle_tool_execution(
358367

359368
agent.messages.append(tool_result_message)
360369
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
361-
yield {"callback": {"message": tool_result_message}}
370+
yield ToolResultMessageEvent(message=message)
362371

363372
if cycle_span:
364373
tracer = get_tracer()
365374
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)
366375

367376
if invocation_state["request_state"].get("stop_event_loop", False):
368377
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
369-
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
378+
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
370379
return
371380

372381
events = recurse_event_loop(agent=agent, invocation_state=invocation_state)

src/strands/event_loop/streaming.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@
55
from typing import Any, AsyncGenerator, AsyncIterable, Optional
66

77
from ..models.model import Model
8+
from ..types._events import (
9+
ModelStopReason,
10+
ModelStreamChunkEvent,
11+
ModelStreamEvent,
12+
ReasoningSignatureStreamEvent,
13+
ReasoningTextStreamEvent,
14+
TextStreamEvent,
15+
ToolUseStreamEvent,
16+
TypedEvent,
17+
)
818
from ..types.content import ContentBlock, Message, Messages
919
from ..types.streaming import (
1020
ContentBlockDeltaEvent,
@@ -105,7 +115,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
105115

106116
def handle_content_block_delta(
107117
event: ContentBlockDeltaEvent, state: dict[str, Any]
108-
) -> tuple[dict[str, Any], dict[str, Any]]:
118+
) -> tuple[dict[str, Any], ModelStreamEvent]:
109119
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state.
110120
111121
Args:
@@ -117,43 +127,41 @@ def handle_content_block_delta(
117127
"""
118128
delta_content = event["delta"]
119129

120-
callback_event = {}
130+
typed_event: ModelStreamEvent = ModelStreamEvent({})
121131

122132
if "toolUse" in delta_content:
123133
if "input" not in state["current_tool_use"]:
124134
state["current_tool_use"]["input"] = ""
125135

126136
state["current_tool_use"]["input"] += delta_content["toolUse"]["input"]
127-
callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]}
137+
typed_event = ToolUseStreamEvent(delta_content, state["current_tool_use"])
128138

129139
elif "text" in delta_content:
130140
state["text"] += delta_content["text"]
131-
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}
141+
typed_event = TextStreamEvent(text=delta_content["text"], delta=delta_content)
132142

133143
elif "reasoningContent" in delta_content:
134144
if "text" in delta_content["reasoningContent"]:
135145
if "reasoningText" not in state:
136146
state["reasoningText"] = ""
137147

138148
state["reasoningText"] += delta_content["reasoningContent"]["text"]
139-
callback_event["callback"] = {
140-
"reasoningText": delta_content["reasoningContent"]["text"],
141-
"delta": delta_content,
142-
"reasoning": True,
143-
}
149+
typed_event = ReasoningTextStreamEvent(
150+
reasoning_text=delta_content["reasoningContent"]["text"],
151+
delta=delta_content,
152+
)
144153

145154
elif "signature" in delta_content["reasoningContent"]:
146155
if "signature" not in state:
147156
state["signature"] = ""
148157

149158
state["signature"] += delta_content["reasoningContent"]["signature"]
150-
callback_event["callback"] = {
151-
"reasoning_signature": delta_content["reasoningContent"]["signature"],
152-
"delta": delta_content,
153-
"reasoning": True,
154-
}
159+
typed_event = ReasoningSignatureStreamEvent(
160+
reasoning_signature=delta_content["reasoningContent"]["signature"],
161+
delta=delta_content,
162+
)
155163

156-
return state, callback_event
164+
return state, typed_event
157165

158166

159167
def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
@@ -251,7 +259,7 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
251259
return usage, metrics
252260

253261

254-
async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]:
262+
async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]:
255263
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.
256264
257265
Args:
@@ -274,14 +282,14 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
274282
metrics: Metrics = Metrics(latencyMs=0)
275283

276284
async for chunk in chunks:
277-
yield {"callback": {"event": chunk}}
285+
yield ModelStreamChunkEvent(chunk=chunk)
278286
if "messageStart" in chunk:
279287
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
280288
elif "contentBlockStart" in chunk:
281289
state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"])
282290
elif "contentBlockDelta" in chunk:
283-
state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
284-
yield callback_event
291+
state, typed_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
292+
yield typed_event
285293
elif "contentBlockStop" in chunk:
286294
state = handle_content_block_stop(state)
287295
elif "messageStop" in chunk:
@@ -291,15 +299,15 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
291299
elif "redactContent" in chunk:
292300
handle_redact_content(chunk["redactContent"], state)
293301

294-
yield {"stop": (stop_reason, state["message"], usage, metrics)}
302+
yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics)
295303

296304

297305
async def stream_messages(
298306
model: Model,
299307
system_prompt: Optional[str],
300308
messages: Messages,
301309
tool_specs: list[ToolSpec],
302-
) -> AsyncGenerator[dict[str, Any], None]:
310+
) -> AsyncGenerator[TypedEvent, None]:
303311
"""Streams messages to the model and processes the response.
304312
305313
Args:

0 commit comments

Comments
 (0)