Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.interrupt import InterruptResponseContent
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
ConversationManager,
SlidingWindowConversationManager,
)
from .interrupt import InterruptState
from .state import AgentState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -337,6 +339,8 @@ def __init__(

self.hooks = HookRegistry()

self.interrupt_state = InterruptState()

# Initialize session management functionality
self._session_manager = session_manager
if self._session_manager:
Expand Down Expand Up @@ -671,6 +675,17 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
yield event

def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
if self.interrupt_state.activated:
responses = [
cast(InterruptResponseContent, content)["interruptResponse"]
for content in cast(list[ContentBlock], prompt)
]

for response in responses:
self.interrupt_state[response["interruptId"]].response = response["response"]

return [{"role": "strands", "content": cast(list[ContentBlock], prompt)}]

messages: Messages | None = None
if prompt is not None:
if isinstance(prompt, str):
Expand Down
3 changes: 3 additions & 0 deletions src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from typing import Any

from ..hooks import Interrupt
from ..telemetry.metrics import EventLoopMetrics
from ..types.content import Message
from ..types.streaming import StopReason
Expand All @@ -20,12 +21,14 @@ class AgentResult:
message: The last message generated by the agent.
metrics: Performance metrics collected during processing.
state: Additional state information from the event loop.
interrupts: List of interrupts if raised by user.
"""

stop_reason: StopReason
message: Message
metrics: EventLoopMetrics
state: Any
interrupts: list[Interrupt] | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider Sequence[Interrupt]? Might be more flexible in the future if we have some HumanInterrupt or EnhancedInterrupt


def __str__(self) -> str:
"""Get the agent's last message as a string.
Expand Down
62 changes: 62 additions & 0 deletions src/strands/agent/interrupt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""<TODO>."""

from typing import Any

from ..hooks import Interrupt


class InterruptState:
"""<TODO>."""

def __init__(
self,
activated: bool = False,
context: dict[str, Any] | None = None,
interrupts: dict[str, Interrupt] | None = None,
) -> None:
"""<TODO>."""
self.activated = activated
self.context = context or {}
self.interrupts = interrupts or {}

def __contains__(self, interrupt_id: str) -> bool:
"""<TODO>."""
return interrupt_id in self.interrupts

def __getitem__(self, interrupt_id: str) -> Interrupt:
"""<TODO>."""
return self.interrupts[interrupt_id]

def __setitem__(self, interrupt_id: str, interrupt: Interrupt) -> None:
"""<TODO>."""
self.interrupts[interrupt_id] = interrupt

def set(self, context: dict[str, Any] | None = None) -> None:
"""<TODO>."""
self.activated = True
self.context = context or {}

def clear(self) -> None:
"""<TODO>."""
self.activated = False
self.context = {}
self.interrupts = {}

def to_dict(self) -> dict[str, Any]:
"""<TODO>."""
return {
"activated": self.activated,
"context": self.context,
"interrupts": {interrupt_id: interrupt.to_dict() for interrupt_id, interrupt in self.interrupts.items()},
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "InterruptState":
"""<TODO>."""
return cls(
activated=data["activated"],
context=data["context"],
interrupts={
interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items()
},
)
100 changes: 88 additions & 12 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@

from opentelemetry import trace as trace_api

from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, Interrupt, MessageAddedEvent
from ..telemetry.metrics import Trace
from ..telemetry.tracer import Tracer, get_tracer
from ..tools._validator import validate_and_prepare_tools
from ..types._events import (
EventLoopStopEvent,
EventLoopThrottleEvent,
ForceStopEvent,
InterruptMessageEvent,
ModelMessageEvent,
ModelStopReason,
StartEvent,
StartEventLoopEvent,
ToolInterruptEvent,
ToolResultMessageEvent,
TypedEvent,
)
Expand Down Expand Up @@ -106,13 +108,18 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
)
invocation_state["event_loop_cycle_span"] = cycle_span

model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
if agent.interrupt_state.activated:
stop_reason: StopReason = "tool_use"
message = agent.interrupt_state.context["tool_use_message"]

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
else:
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We skip over model invocation if resuming from an interrupt and head straight to tool execution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add that as a comment in the code


try:
if stop_reason == "max_tokens":
Expand Down Expand Up @@ -142,6 +149,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
cycle_span=cycle_span,
cycle_start_time=cycle_start_time,
invocation_state=invocation_state,
tracer=tracer,
)
async for tool_event in tool_events:
yield tool_event
Expand Down Expand Up @@ -345,6 +353,7 @@ async def _handle_tool_execution(
cycle_span: Any,
cycle_start_time: float,
invocation_state: dict[str, Any],
tracer: Tracer,
) -> AsyncGenerator[TypedEvent, None]:
"""Handles the execution of tools requested by the model during an event loop cycle.

Expand All @@ -356,6 +365,7 @@ async def _handle_tool_execution(
cycle_span: Span object for tracing the cycle (type may vary).
cycle_start_time: Start time of the current cycle.
invocation_state: Additional keyword arguments, including request state.
tracer: Tracer instance for span management.

Yields:
Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple
Expand All @@ -371,30 +381,59 @@ async def _handle_tool_execution(

validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]

if agent.interrupt_state.activated:
tool_results.extend(agent.interrupt_state.context["tool_results"])

# Filter to only the interrupted tools when resuming from interrupt (tool uses without results)
tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results}
tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids]

if not tool_uses:
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
return

interrupts = []
tool_events = agent.tool_executor._execute(
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state
)
async for tool_event in tool_events:
if isinstance(tool_event, ToolInterruptEvent):
interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"])

yield tool_event

# Store parent cycle ID for the next cycle
invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"]

if interrupts:
events = _handle_interrupts(
message,
agent,
interrupts,
tool_results,
cycle_trace,
cycle_span,
cycle_start_time,
invocation_state,
tracer,
)
async for event in events:
yield event

return

agent.interrupt_state.clear()

tool_result_message: Message = {
"role": "user",
"content": [{"toolResult": result} for result in tool_results],
"content": [{"toolResult": tool_result} for tool_result in tool_results],
}

agent.messages.append(tool_result_message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
yield ToolResultMessageEvent(message=tool_result_message)
agent.messages.append(tool_result_message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))

if cycle_span:
tracer = get_tracer()
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)

if invocation_state["request_state"].get("stop_event_loop", False):
Expand All @@ -405,3 +444,40 @@ async def _handle_tool_execution(
events = recurse_event_loop(agent=agent, invocation_state=invocation_state)
async for event in events:
yield event


async def _handle_interrupts(
tool_use_message: Message,
agent: "Agent",
interrupts: list[Interrupt],
tool_results: list[ToolResult],
cycle_trace: Trace,
cycle_span: Any,
cycle_start_time: float,
invocation_state: dict[str, Any],
tracer: Tracer,
) -> AsyncGenerator[TypedEvent, None]:
"""<TODO>."""
for interrupt in interrupts:
agent.interrupt_state[interrupt.id_] = interrupt

interrupt_message: Message = {
"role": "strands",
"content": [interrupt.to_reason_content() for interrupt in interrupts],
}
yield InterruptMessageEvent(message=interrupt_message)

agent.messages.append(interrupt_message)
agent.interrupt_state.set(context={"tool_use_message": tool_use_message, "tool_results": tool_results})
agent.hooks.invoke_callbacks(MessageAddedEvent(agent, message=interrupt_message))

agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
yield EventLoopStopEvent(
"interrupt",
tool_use_message,
agent.event_loop_metrics,
invocation_state["request_state"],
interrupts,
)
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=tool_use_message)
16 changes: 16 additions & 0 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@
logger = logging.getLogger(__name__)


def remove_strands_messages(messages: Messages) -> Messages:
"""Remove messages from the conversation that are not from the user or the assistant.

TODO.

Args:
messages: Conversation messages to update.

Returns:
Updated messages.
"""
return [message for message in messages if message["role"] != "strands"]


def remove_blank_messages_content_text(messages: Messages) -> Messages:
"""Remove or replace blank text in message content.

Expand Down Expand Up @@ -345,7 +359,9 @@ async def stream_messages(
"""
logger.debug("model=<%s> | streaming messages", model)

messages = remove_strands_messages(messages)
messages = remove_blank_messages_content_text(messages)

chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)

async for event in process_stream(chunks):
Expand Down
4 changes: 4 additions & 0 deletions src/strands/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
BeforeToolCallEvent,
MessageAddedEvent,
)
from .interrupt import Interrupt, InterruptException, InterruptHookEvent
from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry

__all__ = [
Expand All @@ -56,4 +57,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
"HookRegistry",
"HookEvent",
"BaseHookEvent",
"Interrupt",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be experimental until we finalize the hook?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface for this is an interrupt instance variable on the BeforeToolCallEvent. To make experimental, we could have customers register their BeforeToolCallEvent hook with an interrupt enable flag to turn the feature on? What do you think about that?

"InterruptException",
"InterruptHookEvent",
]
20 changes: 19 additions & 1 deletion src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
This module defines the events that are emitted as Agents run through the lifecycle of a request.
"""

import json
import uuid
from dataclasses import dataclass
from typing import Any, Optional

from typing_extensions import override

from ..types.content import Message
from ..types.streaming import StopReason
from ..types.tools import AgentTool, ToolResult, ToolUse
from .interrupt import InterruptHookEvent
from .registry import HookEvent


Expand Down Expand Up @@ -84,7 +89,7 @@ class MessageAddedEvent(HookEvent):


@dataclass
class BeforeToolCallEvent(HookEvent):
class BeforeToolCallEvent(HookEvent, InterruptHookEvent):
"""Event triggered before a tool is invoked.

This event is fired just before the agent executes a tool, allowing hook
Expand All @@ -110,6 +115,19 @@ class BeforeToolCallEvent(HookEvent):
def _can_write(self, name: str) -> bool:
return name in ["cancel_tool", "selected_tool", "tool_use"]

@override
def interrupt_id(self, name: str, reason: Any) -> str:
"""Unique id for the interrupt.

Args:
name: User defined name for the interrupt.
reason: User provided reason for the interrupt.

Returns:
Interrupt id.
"""
return f"v1:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, f'{name}:{json.dumps(reason)}')}"


@dataclass
class AfterToolCallEvent(HookEvent):
Expand Down
Loading
Loading