-
Notifications
You must be signed in to change notification settings - Fork 425
hooks - before tool call event - interrupt #987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
11114e2
4dd5489
7799409
e6f0529
a163acf
4149a8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ | |
BeforeInvocationEvent, | ||
HookProvider, | ||
HookRegistry, | ||
Interrupt, | ||
MessageAddedEvent, | ||
) | ||
from ..models.bedrock import BedrockModel | ||
|
@@ -54,6 +55,7 @@ | |
from ..types.agent import AgentInput | ||
from ..types.content import ContentBlock, Message, Messages | ||
from ..types.exceptions import ContextWindowOverflowException | ||
from ..types.interrupt import InterruptContent | ||
from ..types.tools import ToolResult, ToolUse | ||
from ..types.traces import AttributeValue | ||
from .agent_result import AgentResult | ||
|
@@ -349,6 +351,9 @@ def __init__( | |
self.hooks.add_hook(hook) | ||
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) | ||
|
||
# Map of active interrupt instances | ||
self._interrupts: dict[tuple[str, str], Interrupt] = {} | ||
|
||
|
||
@property | ||
def tool(self) -> ToolCaller: | ||
"""Call tool as a function. | ||
|
@@ -567,6 +572,8 @@ async def stream_async( | |
yield event["data"] | ||
``` | ||
""" | ||
self._resume_interrupt(prompt) | ||
|
||
callback_handler = kwargs.get("callback_handler", self.callback_handler) | ||
|
||
# Process input and get message to add (if any) | ||
|
@@ -596,6 +603,56 @@ async def stream_async( | |
self._end_agent_trace_span(error=e) | ||
raise | ||
|
||
def _resume_interrupt(self, prompt: AgentInput) -> None: | ||
"""Configure the agent interrupt state if resuming from an interrupt event. | ||
|
||
Args: | ||
messages: Agent's message history. | ||
prompt: User responses if resuming from interrupt. | ||
|
||
Raises: | ||
TypeError: If interrupts are detected but user did not provide responses. | ||
ValueError: If any interrupts are missing corresponding responses. | ||
""" | ||
# Currently, users can only interrupt tool calls. Thus, to determine if the agent was interrupted in the | ||
# previous invocation, we look for tool results with interrupt context in the messages array. | ||
|
||
|
||
message = self.messages[-1] if self.messages else None | ||
if not message or message["role"] != "user": | ||
return | ||
|
||
tool_results = [ | ||
content["toolResult"] | ||
for content in message["content"] | ||
if "toolResult" in content and content["toolResult"]["status"] == "error" | ||
|
||
] | ||
reasons = [ | ||
tool_result["content"][0]["json"]["interrupt"] | ||
for tool_result in tool_results | ||
if "json" in tool_result["content"][0] and "interrupt" in tool_result["content"][0]["json"] | ||
] | ||
if not reasons: | ||
return | ||
|
||
|
||
if not isinstance(prompt, list): | ||
raise TypeError( | ||
f"prompt_type=<{type(prompt)}> | must resume from interrupt with list of interruptResponse's" | ||
) | ||
|
||
responses = [ | ||
cast(InterruptContent, content)["interruptResponse"] | ||
for content in prompt | ||
if isinstance(content, dict) and "interruptResponse" in content | ||
] | ||
|
||
reasons_map = {(reason["name"], reason["event_name"]): reason for reason in reasons} | ||
responses_map = {(response["name"], response["event_name"]): response for response in responses} | ||
missing_keys = reasons_map.keys() - responses_map.keys() | ||
if missing_keys: | ||
|
||
raise ValueError(f"interrupts=<{list(missing_keys)}> | missing responses for interrupts") | ||
|
||
self._interrupts = {key: Interrupt(**{**reasons_map[key], **responses_map[key]}) for key in responses_map} | ||
|
||
|
||
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: | ||
"""Execute the agent's event loop with the given message and parameters. | ||
|
||
|
@@ -671,6 +728,10 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A | |
yield event | ||
|
||
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: | ||
if self._interrupts: | ||
# Do not add interrupt responses to the messages as these are not to be processed by the model | ||
return [] | ||
|
||
messages: Messages | None = None | ||
if prompt is not None: | ||
if isinstance(prompt, str): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from dataclasses import dataclass | ||
from typing import Any | ||
|
||
from ..hooks import Interrupt | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..types.content import Message | ||
from ..types.streaming import StopReason | ||
|
@@ -20,12 +21,14 @@ class AgentResult: | |
message: The last message generated by the agent. | ||
metrics: Performance metrics collected during processing. | ||
state: Additional state information from the event loop. | ||
interrupts: List of interrupts if raised by user. | ||
""" | ||
|
||
stop_reason: StopReason | ||
message: Message | ||
metrics: EventLoopMetrics | ||
state: Any | ||
interrupts: list[Interrupt] | None = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: consider Sequence[Interrupt]? Might be more flexible in the future if we have some HumanInterrupt or EnhancedInterrupt |
||
|
||
def __str__(self) -> str: | ||
"""Get the agent's last message as a string. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
ModelStopReason, | ||
StartEvent, | ||
StartEventLoopEvent, | ||
ToolInterruptEvent, | ||
ToolResultMessageEvent, | ||
TypedEvent, | ||
) | ||
|
@@ -106,13 +107,18 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> | |
) | ||
invocation_state["event_loop_cycle_span"] = cycle_span | ||
|
||
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) | ||
async for model_event in model_events: | ||
if not isinstance(model_event, ModelStopReason): | ||
yield model_event | ||
if agent._interrupts: | ||
stop_reason: StopReason = "tool_use" | ||
message = agent.messages[-2] | ||
|
||
|
||
stop_reason, message, *_ = model_event["stop"] | ||
yield ModelMessageEvent(message=message) | ||
else: | ||
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) | ||
async for model_event in model_events: | ||
if not isinstance(model_event, ModelStopReason): | ||
yield model_event | ||
|
||
stop_reason, message, *_ = model_event["stop"] | ||
yield ModelMessageEvent(message=message) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We skip over model invocation if resuming from an interrupt and head straight to tool execution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add that as a comment in the code |
||
|
||
try: | ||
if stop_reason == "max_tokens": | ||
|
@@ -371,37 +377,84 @@ async def _handle_tool_execution( | |
|
||
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) | ||
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] | ||
|
||
# Filter to only the interrupted tools when resuming from interrupt | ||
if agent._interrupts: | ||
tool_names = {name for name, _ in agent._interrupts.keys()} | ||
tool_uses = [tool_use for tool_use in tool_uses if tool_use["name"] in tool_names] | ||
|
||
if not tool_uses: | ||
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) | ||
return | ||
|
||
tool_interrupts = [] | ||
tool_events = agent.tool_executor._execute( | ||
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state | ||
) | ||
async for tool_event in tool_events: | ||
if isinstance(tool_event, ToolInterruptEvent): | ||
tool_interrupts.append(tool_event["tool_interrupt_event"]["interrupt"]) | ||
|
||
yield tool_event | ||
|
||
# Store parent cycle ID for the next cycle | ||
invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] | ||
|
||
tool_result_message: Message = { | ||
"role": "user", | ||
"content": [{"toolResult": result} for result in tool_results], | ||
} | ||
tool_result_message = _convert_tool_results_to_message(agent, tool_results) | ||
|
||
|
||
agent.messages.append(tool_result_message) | ||
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) | ||
yield ToolResultMessageEvent(message=tool_result_message) | ||
|
||
if cycle_span: | ||
tracer = get_tracer() | ||
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) | ||
|
||
if tool_interrupts: | ||
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) | ||
yield EventLoopStopEvent( | ||
"interrupt", message, agent.event_loop_metrics, invocation_state["request_state"], tool_interrupts | ||
) | ||
return | ||
|
||
if invocation_state["request_state"].get("stop_event_loop", False): | ||
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) | ||
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) | ||
return | ||
|
||
agent._interrupts = {} | ||
|
||
events = recurse_event_loop(agent=agent, invocation_state=invocation_state) | ||
async for event in events: | ||
yield event | ||
|
||
|
||
def _convert_tool_results_to_message(agent: "Agent", results: list[ToolResult]) -> Message: | ||
"""Convert tool results to a message. | ||
|
||
For normal execution, we create a new user message with the tool results and append it to the agent's message | ||
history. When resuming from an interrupt, we instead extend the existing results message in history with the | ||
resumed results. | ||
|
||
Args: | ||
agent: The agent instance containing interrupt state and message history. | ||
results: List of tool results to convert or extend into a message. | ||
|
||
Returns: | ||
Tool results message. | ||
""" | ||
if not agent._interrupts: | ||
message: Message = { | ||
"role": "user", | ||
"content": [{"toolResult": result} for result in results], | ||
} | ||
agent.messages.append(message) | ||
|
||
return message | ||
|
||
message = agent.messages[-1] | ||
|
||
|
||
results_map = {result["toolUseId"]: result for result in results} | ||
for content in message["content"]: | ||
tool_use_id = content["toolResult"]["toolUseId"] | ||
content["toolResult"] = results_map.get(tool_use_id, content["toolResult"]) | ||
|
||
return message |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: | |
BeforeToolCallEvent, | ||
MessageAddedEvent, | ||
) | ||
from .interrupt import Interrupt, InterruptEvent, InterruptException | ||
from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry | ||
|
||
__all__ = [ | ||
|
@@ -56,4 +57,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: | |
"HookRegistry", | ||
"HookEvent", | ||
"BaseHookEvent", | ||
"Interrupt", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these be experimental until we finalize the hook? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The interface for this is an interrupt instance variable on the BeforeToolCallEvent. To make experimental, we could have customers register their BeforeToolCallEvent hook with an interrupt enable flag to turn the feature on? What do you think about that? |
||
"InterruptEvent", | ||
"InterruptException", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from ..types.content import Message | ||
from ..types.streaming import StopReason | ||
from ..types.tools import AgentTool, ToolResult, ToolUse | ||
from .interrupt import InterruptEvent | ||
from .registry import HookEvent | ||
|
||
|
||
|
@@ -84,7 +85,7 @@ class MessageAddedEvent(HookEvent): | |
|
||
|
||
@dataclass | ||
class BeforeToolCallEvent(HookEvent): | ||
class BeforeToolCallEvent(HookEvent, InterruptEvent): | ||
"""Event triggered before a tool is invoked. | ||
|
||
This event is fired just before the agent executes a tool, allowing hook | ||
|
@@ -108,7 +109,7 @@ class BeforeToolCallEvent(HookEvent): | |
cancel_tool: bool | str = False | ||
|
||
def _can_write(self, name: str) -> bool: | ||
return name in ["cancel_tool", "selected_tool", "tool_use"] | ||
return name in ["interrupt", "cancel_tool", "selected_tool", "tool_use"] | ||
|
||
|
||
|
||
@dataclass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
"""Human-in-the-loop interrupt system for agent workflows.""" | ||
|
||
from dataclasses import asdict, dataclass | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from ..types.tools import ToolResultContent | ||
|
||
if TYPE_CHECKING: | ||
from ..agent import Agent | ||
|
||
|
||
@dataclass | ||
class Interrupt: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be in types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally prefer it here to keep things centralized. I get the "dumping grounds" vibe when thinking about moving the definition into types. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this interrupt going to be used outside of hooks? If so, I don't think it belongs under the hooks module |
||
"""Represents an interrupt that can pause agent execution for human-in-the-loop workflows. | ||
Attributes: | ||
name: Unique identifier for the interrupt. | ||
event_name: Name of the hook event under which the interrupt was triggered. | ||
reasons: User provided reasons for raising the interrupt. | ||
response: Human response provided when resuming the agent after an interrupt. | ||
activated: Whether the interrupt is currently active. | ||
""" | ||
|
||
name: str | ||
event_name: str | ||
|
||
reasons: list[Any] | ||
|
||
response: Any = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to indicate that N things can trigger an interrupt, but they must all use the same response to resolve the interrupt - am I understanding this correctly? If so, that seems to indicate that all interrupts for a specific name/event/name need to coordinate their Responses to be compatiable? I'm thinking specifically one interrupt asking:
These are two interrupts which take different resolutions I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we want one reason per interrupt, we would need to add on another identifier. There are a few options:
I think the index approach is probably the most robust. But these issues are the reason why I took the approach of setting up a |
||
activated: bool = False | ||
|
||
def __call__(self, reason: Any) -> Any: | ||
"""Trigger the interrupt with a reason. | ||
Args: | ||
reason: User provided reason for the interrupt. | ||
Returns: | ||
The response from a human user when resuming from an interrupt state. | ||
Raises: | ||
InterruptException: If human input is required. | ||
""" | ||
if self.response: | ||
self.activated = False | ||
return self.response | ||
|
||
self.reasons.append(reason) | ||
self.activated = True | ||
raise InterruptException(self) | ||
|
||
def to_tool_result_content(self) -> list[ToolResultContent]: | ||
|
||
"""Convert the interrupt to tool result content if there are reasons. | ||
Returns: | ||
Tool result content. | ||
""" | ||
if self.reasons: | ||
return [ | ||
{"json": {"interrupt": {"name": self.name, "event_name": self.event_name, "reasons": self.reasons}}}, | ||
] | ||
|
||
return [] | ||
|
||
@classmethod | ||
def from_agent(cls, name: str, event_name: str, agent: "Agent") -> "Interrupt": | ||
"""Initialize an interrupt from agent state. | ||
Creates an interrupt instance from stored agent state, which will be | ||
populated with the human response when resuming. | ||
Args: | ||
name: Unique identifier for the interrupt. | ||
event_name: Name of the hook event under which the interrupt was triggered. | ||
agent: The agent instance containing interrupt state. | ||
Returns: | ||
An Interrupt instance initialized from agent state. | ||
""" | ||
interrupt = agent._interrupts.get((name, event_name)) | ||
params = asdict(interrupt) if interrupt else {"name": name, "event_name": event_name, "reasons": []} | ||
|
||
return cls(**params) | ||
|
||
|
||
class InterruptException(Exception): | ||
|
||
"""Exception raised when human input is required.""" | ||
|
||
def __init__(self, interrupt: Interrupt) -> None: | ||
"""Initialize the exception with an interrupt instance. | ||
Args: | ||
interrupt: The interrupt that triggered this exception. | ||
""" | ||
self.interrupt = interrupt | ||
|
||
|
||
@dataclass | ||
class InterruptEvent: | ||
|
||
"""Interface that adds interrupt support to hook events. | ||
Attributes: | ||
interrupt: The interrupt instance associated with this event. | ||
""" | ||
|
||
interrupt: Interrupt |
Uh oh!
There was an error while loading. Please reload this page.