-
Notifications
You must be signed in to change notification settings - Fork 425
hooks - before tool call event - interrupt #987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
11114e2
4dd5489
7799409
e6f0529
a163acf
4149a8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from dataclasses import dataclass | ||
from typing import Any | ||
|
||
from ..hooks import Interrupt | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..types.content import Message | ||
from ..types.streaming import StopReason | ||
|
@@ -20,12 +21,14 @@ class AgentResult: | |
message: The last message generated by the agent. | ||
metrics: Performance metrics collected during processing. | ||
state: Additional state information from the event loop. | ||
interrupts: List of interrupts if raised by user. | ||
""" | ||
|
||
stop_reason: StopReason | ||
message: Message | ||
metrics: EventLoopMetrics | ||
state: Any | ||
interrupts: list[Interrupt] | None = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: consider Sequence[Interrupt]? Might be more flexible in the future if we have some HumanInterrupt or EnhancedInterrupt |
||
|
||
def __str__(self) -> str: | ||
"""Get the agent's last message as a string. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""<TODO>.""" | ||
|
||
from typing import Any | ||
|
||
from ..hooks import Interrupt | ||
|
||
|
||
class InterruptState: | ||
"""<TODO>.""" | ||
|
||
def __init__( | ||
self, | ||
activated: bool = False, | ||
context: dict[str, Any] | None = None, | ||
interrupts: dict[str, Interrupt] | None = None, | ||
) -> None: | ||
"""<TODO>.""" | ||
self.activated = activated | ||
self.context = context or {} | ||
self.interrupts = interrupts or {} | ||
|
||
def __contains__(self, interrupt_id: str) -> bool: | ||
"""<TODO>.""" | ||
return interrupt_id in self.interrupts | ||
|
||
def __getitem__(self, interrupt_id: str) -> Interrupt: | ||
"""<TODO>.""" | ||
return self.interrupts[interrupt_id] | ||
|
||
def __setitem__(self, interrupt_id: str, interrupt: Interrupt) -> None: | ||
"""<TODO>.""" | ||
self.interrupts[interrupt_id] = interrupt | ||
|
||
def set(self, context: dict[str, Any] | None = None) -> None: | ||
"""<TODO>.""" | ||
self.activated = True | ||
self.context = context or {} | ||
|
||
def clear(self) -> None: | ||
"""<TODO>.""" | ||
self.activated = False | ||
self.context = {} | ||
self.interrupts = {} | ||
|
||
def to_dict(self) -> dict[str, Any]: | ||
"""<TODO>.""" | ||
return { | ||
"activated": self.activated, | ||
"context": self.context, | ||
"interrupts": {interrupt_id: interrupt.to_dict() for interrupt_id, interrupt in self.interrupts.items()}, | ||
} | ||
|
||
@classmethod | ||
def from_dict(cls, data: dict[str, Any]) -> "InterruptState": | ||
"""<TODO>.""" | ||
return cls( | ||
activated=data["activated"], | ||
context=data["context"], | ||
interrupts={ | ||
interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() | ||
}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,18 +15,20 @@ | |
|
||
from opentelemetry import trace as trace_api | ||
|
||
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent | ||
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, Interrupt, MessageAddedEvent | ||
from ..telemetry.metrics import Trace | ||
from ..telemetry.tracer import Tracer, get_tracer | ||
from ..tools._validator import validate_and_prepare_tools | ||
from ..types._events import ( | ||
EventLoopStopEvent, | ||
EventLoopThrottleEvent, | ||
ForceStopEvent, | ||
InterruptMessageEvent, | ||
ModelMessageEvent, | ||
ModelStopReason, | ||
StartEvent, | ||
StartEventLoopEvent, | ||
ToolInterruptEvent, | ||
ToolResultMessageEvent, | ||
TypedEvent, | ||
) | ||
|
@@ -106,13 +108,18 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> | |
) | ||
invocation_state["event_loop_cycle_span"] = cycle_span | ||
|
||
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) | ||
async for model_event in model_events: | ||
if not isinstance(model_event, ModelStopReason): | ||
yield model_event | ||
if agent.interrupt_state.activated: | ||
stop_reason: StopReason = "tool_use" | ||
message = agent.interrupt_state.context["tool_use_message"] | ||
|
||
stop_reason, message, *_ = model_event["stop"] | ||
yield ModelMessageEvent(message=message) | ||
else: | ||
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) | ||
async for model_event in model_events: | ||
if not isinstance(model_event, ModelStopReason): | ||
yield model_event | ||
|
||
stop_reason, message, *_ = model_event["stop"] | ||
yield ModelMessageEvent(message=message) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We skip over model invocation if resuming from an interrupt and head straight to tool execution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add that as a comment in the code |
||
|
||
try: | ||
if stop_reason == "max_tokens": | ||
|
@@ -142,6 +149,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> | |
cycle_span=cycle_span, | ||
cycle_start_time=cycle_start_time, | ||
invocation_state=invocation_state, | ||
tracer=tracer, | ||
) | ||
async for tool_event in tool_events: | ||
yield tool_event | ||
|
@@ -345,6 +353,7 @@ async def _handle_tool_execution( | |
cycle_span: Any, | ||
cycle_start_time: float, | ||
invocation_state: dict[str, Any], | ||
tracer: Tracer, | ||
) -> AsyncGenerator[TypedEvent, None]: | ||
"""Handles the execution of tools requested by the model during an event loop cycle. | ||
|
||
|
@@ -356,6 +365,7 @@ async def _handle_tool_execution( | |
cycle_span: Span object for tracing the cycle (type may vary). | ||
cycle_start_time: Start time of the current cycle. | ||
invocation_state: Additional keyword arguments, including request state. | ||
tracer: Tracer instance for span management. | ||
|
||
Yields: | ||
Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple | ||
|
@@ -371,30 +381,59 @@ async def _handle_tool_execution( | |
|
||
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) | ||
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] | ||
|
||
if agent.interrupt_state.activated: | ||
tool_results.extend(agent.interrupt_state.context["tool_results"]) | ||
|
||
# Filter to only the interrupted tools when resuming from interrupt (tool uses without results) | ||
tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results} | ||
tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] | ||
|
||
if not tool_uses: | ||
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) | ||
return | ||
|
||
interrupts = [] | ||
tool_events = agent.tool_executor._execute( | ||
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state | ||
) | ||
async for tool_event in tool_events: | ||
if isinstance(tool_event, ToolInterruptEvent): | ||
interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"]) | ||
|
||
yield tool_event | ||
|
||
# Store parent cycle ID for the next cycle | ||
invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] | ||
|
||
if interrupts: | ||
events = _handle_interrupts( | ||
message, | ||
agent, | ||
interrupts, | ||
tool_results, | ||
cycle_trace, | ||
cycle_span, | ||
cycle_start_time, | ||
invocation_state, | ||
tracer, | ||
) | ||
async for event in events: | ||
yield event | ||
|
||
return | ||
|
||
agent.interrupt_state.clear() | ||
|
||
tool_result_message: Message = { | ||
"role": "user", | ||
"content": [{"toolResult": result} for result in tool_results], | ||
"content": [{"toolResult": tool_result} for tool_result in tool_results], | ||
} | ||
|
||
agent.messages.append(tool_result_message) | ||
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) | ||
yield ToolResultMessageEvent(message=tool_result_message) | ||
agent.messages.append(tool_result_message) | ||
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) | ||
|
||
if cycle_span: | ||
tracer = get_tracer() | ||
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) | ||
|
||
if invocation_state["request_state"].get("stop_event_loop", False): | ||
|
@@ -405,3 +444,40 @@ async def _handle_tool_execution( | |
events = recurse_event_loop(agent=agent, invocation_state=invocation_state) | ||
async for event in events: | ||
yield event | ||
|
||
|
||
async def _handle_interrupts( | ||
tool_use_message: Message, | ||
agent: "Agent", | ||
interrupts: list[Interrupt], | ||
tool_results: list[ToolResult], | ||
cycle_trace: Trace, | ||
cycle_span: Any, | ||
cycle_start_time: float, | ||
invocation_state: dict[str, Any], | ||
tracer: Tracer, | ||
) -> AsyncGenerator[TypedEvent, None]: | ||
"""<TODO>.""" | ||
for interrupt in interrupts: | ||
agent.interrupt_state[interrupt.id_] = interrupt | ||
|
||
interrupt_message: Message = { | ||
"role": "strands", | ||
"content": [interrupt.to_reason_content() for interrupt in interrupts], | ||
} | ||
yield InterruptMessageEvent(message=interrupt_message) | ||
|
||
agent.messages.append(interrupt_message) | ||
agent.interrupt_state.set(context={"tool_use_message": tool_use_message, "tool_results": tool_results}) | ||
agent.hooks.invoke_callbacks(MessageAddedEvent(agent, message=interrupt_message)) | ||
|
||
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) | ||
yield EventLoopStopEvent( | ||
"interrupt", | ||
tool_use_message, | ||
agent.event_loop_metrics, | ||
invocation_state["request_state"], | ||
interrupts, | ||
) | ||
if cycle_span: | ||
tracer.end_event_loop_cycle_span(span=cycle_span, message=tool_use_message) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: | |
BeforeToolCallEvent, | ||
MessageAddedEvent, | ||
) | ||
from .interrupt import Interrupt, InterruptException, InterruptHookEvent | ||
from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry | ||
|
||
__all__ = [ | ||
|
@@ -56,4 +57,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: | |
"HookRegistry", | ||
"HookEvent", | ||
"BaseHookEvent", | ||
"Interrupt", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these be experimental until we finalize the hook? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The interface for this is an interrupt instance variable on the BeforeToolCallEvent. To make experimental, we could have customers register their BeforeToolCallEvent hook with an interrupt enable flag to turn the feature on? What do you think about that? |
||
"InterruptException", | ||
"InterruptHookEvent", | ||
] |
Uh oh!
There was an error while loading. Please reload this page.