diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index e28e1c5b8..f3758c8d2 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -42,5 +42,4 @@ def __str__(self) -> str: for item in content_array: if isinstance(item, dict) and "text" in item: result += item.get("text", "") + "\n" - return result diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 1f8c260a4..df8761b80 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -5,6 +5,7 @@ from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model +from ..types.citations import CitationsContentBlock from ..types.content import ContentBlock, Message, Messages from ..types.streaming import ( ContentBlockDeltaEvent, @@ -130,6 +131,13 @@ def handle_content_block_delta( state["text"] += delta_content["text"] callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} + elif "citation" in delta_content: + if "citationsContent" not in state: + state["citationsContent"] = [] + + state["citationsContent"].append(delta_content["citation"]) + callback_event["callback"] = {"citation_metadata": delta_content["citation"], "delta": delta_content} + elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: if "reasoningText" not in state: @@ -170,6 +178,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: current_tool_use = state["current_tool_use"] text = state["text"] reasoning_text = state["reasoningText"] + citations_content = state["citationsContent"] if current_tool_use: if "input" not in current_tool_use: @@ -194,6 +203,10 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: elif text: content.append({"text": text}) state["text"] = "" + if citations_content: + citations_block: CitationsContentBlock = {"citations": citations_content} + content.append({"citationsContent": citations_block}) + state["citationsContent"] = [] elif reasoning_text: content_block: ContentBlock = { @@ -267,6 +280,8 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d "text": "", "current_tool_use": {}, "reasoningText": "", + "signature": "", + "citationsContent": [], } state["content"] = state["message"]["content"] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ace35640a..2e5b48ca5 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -7,7 +7,7 @@ import json import logging import os -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -18,7 +18,10 @@ from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.exceptions import ( + ContextWindowOverflowException, + ModelThrottledException, +) from ..types.streaming import StreamEvent from ..types.tools import ToolResult, ToolSpec from .model import Model @@ -510,7 +513,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera yield {"messageStart": {"role": response["output"]["message"]["role"]}} # Process content blocks - for content in response["output"]["message"]["content"]: + for content in cast(list[ContentBlock], response["output"]["message"]["content"]): # Yield contentBlockStart event if needed if "toolUse" in content: yield { @@ -553,6 +556,25 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera } } } + elif "citationsContent" in content: + # For non-streaming citations, emit text and metadata deltas in sequence + # to match streaming behavior where they flow naturally + if "content" in content["citationsContent"]: + text_content = "".join([content["text"] for content in content["citationsContent"]["content"]]) + yield { + "contentBlockDelta": {"delta": {"text": text_content}}, + } + + for citation in content["citationsContent"]["citations"]: + # Then emit citation metadata (for structure) + from ..types.streaming import CitationsDelta + + citation_metadata: CitationsDelta = { + "title": citation["title"], + "location": citation["location"], + "sourceContent": citation["sourceContent"], + } + yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}} # Yield contentBlockStop event yield {"contentBlockStop": {}} diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py new file mode 100644 index 000000000..b0e28f655 --- /dev/null +++ b/src/strands/types/citations.py @@ -0,0 +1,152 @@ +"""Citation type definitions for the SDK. + +These types are modeled after the Bedrock API. +""" + +from typing import List, Union + +from typing_extensions import TypedDict + + +class CitationsConfig(TypedDict): + """Configuration for enabling citations on documents. + + Attributes: + enabled: Whether citations are enabled for this document. + """ + + enabled: bool + + +class DocumentCharLocation(TypedDict, total=False): + """Specifies a character-level location within a document. + + Provides precise positioning information for cited content using + start and end character indices. + + Attributes: + documentIndex: The index of the document within the array of documents + provided in the request. Minimum value of 0. + start: The starting character position of the cited content within + the document. Minimum value of 0. + end: The ending character position of the cited content within + the document. Minimum value of 0. + """ + + documentIndex: int + start: int + end: int + + +class DocumentChunkLocation(TypedDict, total=False): + """Specifies a chunk-level location within a document. + + Provides positioning information for cited content using logical + document segments or chunks. + + Attributes: + documentIndex: The index of the document within the array of documents + provided in the request. Minimum value of 0. + start: The starting chunk identifier or index of the cited content + within the document. Minimum value of 0. + end: The ending chunk identifier or index of the cited content + within the document. Minimum value of 0. + """ + + documentIndex: int + start: int + end: int + + +class DocumentPageLocation(TypedDict, total=False): + """Specifies a page-level location within a document. + + Provides positioning information for cited content using page numbers. + + Attributes: + documentIndex: The index of the document within the array of documents + provided in the request. Minimum value of 0. + start: The starting page number of the cited content within + the document. Minimum value of 0. + end: The ending page number of the cited content within + the document. Minimum value of 0. + """ + + documentIndex: int + start: int + end: int + + +# Union type for citation locations +CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] + + +class CitationSourceContent(TypedDict, total=False): + """Contains the actual text content from a source document. + + Contains the actual text content from a source document that is being + cited or referenced in the model's response. + + Note: + This is a UNION type, so only one of the members can be specified. + + Attributes: + text: The text content from the source document that is being cited. + """ + + text: str + + +class CitationGeneratedContent(TypedDict, total=False): + """Contains the generated text content that corresponds to a citation. + + Contains the generated text content that corresponds to or is supported + by a citation from a source document. + + Note: + This is a UNION type, so only one of the members can be specified. + + Attributes: + text: The text content that was generated by the model and is + supported by the associated citation. + """ + + text: str + + +class Citation(TypedDict, total=False): + """Contains information about a citation that references a source document. + + Citations provide traceability between the model's generated response + and the source documents that informed that response. + + Attributes: + location: The precise location within the source document where the + cited content can be found, including character positions, page + numbers, or chunk identifiers. + sourceContent: The specific content from the source document that was + referenced or cited in the generated response. + title: The title or identifier of the source document being cited. + """ + + location: CitationLocation + sourceContent: List[CitationSourceContent] + title: str + + +class CitationsContentBlock(TypedDict, total=False): + """A content block containing generated text and associated citations. + + This block type is returned when document citations are enabled, providing + traceability between the generated content and the source documents that + informed the response. + + Attributes: + citations: An array of citations that reference the source documents + used to generate the associated content. + content: The generated content that is supported by the associated + citations. + """ + + citations: List[Citation] + content: List[CitationGeneratedContent] diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 790e9094c..c3eddca4d 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -10,6 +10,7 @@ from typing_extensions import TypedDict +from .citations import CitationsContentBlock from .media import DocumentContent, ImageContent, VideoContent from .tools import ToolResult, ToolUse @@ -83,6 +84,7 @@ class ContentBlock(TypedDict, total=False): toolResult: The result for a tool request that a model makes. toolUse: Information about a tool use request from a model. video: Video to include in the message. + citationsContent: Contains the citations for a document. """ cachePoint: CachePoint @@ -94,6 +96,7 @@ class ContentBlock(TypedDict, total=False): toolResult: ToolResult toolUse: ToolUse video: VideoContent + citationsContent: CitationsContentBlock class SystemContentBlock(TypedDict, total=False): diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 29b89e5c6..69cd60cf3 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,10 +5,12 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal +from typing import Literal, Optional from typing_extensions import TypedDict +from .citations import CitationsConfig + DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] """Supported document formats.""" @@ -23,7 +25,7 @@ class DocumentSource(TypedDict): bytes: bytes -class DocumentContent(TypedDict): +class DocumentContent(TypedDict, total=False): """A document to include in a message. Attributes: @@ -35,6 +37,8 @@ class DocumentContent(TypedDict): format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] name: str source: DocumentSource + citations: Optional[CitationsConfig] + context: Optional[str] ImageFormat = Literal["png", "jpeg", "gif", "webp"] diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index 9c99b2108..dcfd541a8 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -9,6 +9,7 @@ from typing_extensions import TypedDict +from .citations import CitationLocation from .content import ContentBlockStart, Role from .event_loop import Metrics, StopReason, Usage from .guardrails import Trace @@ -57,6 +58,41 @@ class ContentBlockDeltaToolUse(TypedDict): input: str +class CitationSourceContentDelta(TypedDict, total=False): + """Contains incremental updates to source content text during streaming. + + Allows clients to build up the cited content progressively during + streaming responses. + + Attributes: + text: An incremental update to the text content from the source + document that is being cited. + """ + + text: str + + +class CitationsDelta(TypedDict, total=False): + """Contains incremental updates to citation information during streaming. + + This allows clients to build up citation data progressively as the + response is generated. + + Attributes: + location: Specifies the precise location within a source document + where cited content can be found. This can include character-level + positions, page numbers, or document chunks depending on the + document type and indexing method. + sourceContent: The specific content from the source document that was + referenced or cited in the generated response. + title: The title or identifier of the source document being cited. + """ + + location: CitationLocation + sourceContent: list[CitationSourceContentDelta] + title: str + + class ReasoningContentBlockDelta(TypedDict, total=False): """Delta for reasoning content block in a streaming response. @@ -83,6 +119,7 @@ class ContentBlockDelta(TypedDict, total=False): reasoningContent: ReasoningContentBlockDelta text: str toolUse: ContentBlockDeltaToolUse + citation: CitationsDelta class ContentBlockDeltaEvent(TypedDict, total=False): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 7760c498a..2c3fb54a8 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -163,12 +163,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {"toolUseId": "123", "name": "test", "input": '{"key": "value"}'}, "text": "", "reasoningText": "", + "citationsContent": [], }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), # Tool Use - Missing input @@ -178,12 +180,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {"toolUseId": "123", "name": "test"}, "text": "", "reasoningText": "", + "citationsContent": [], }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), # Text @@ -193,12 +197,31 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "test", "reasoningText": "", + "citationsContent": [], }, { "content": [{"text": "test"}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], + }, + ), + # Citations + ( + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + }, + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], }, ), # Reasoning @@ -209,6 +232,7 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "text": "", "reasoningText": "test", "signature": "123", + "citationsContent": [], }, { "content": [{"reasoningContent": {"reasoningText": {"text": "test", "signature": "123"}}}], @@ -216,6 +240,7 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "text": "", "reasoningText": "", "signature": "123", + "citationsContent": [], }, ), # Reasoning without signature @@ -225,12 +250,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "", "reasoningText": "test", + "citationsContent": [], }, { "content": [{"reasoningContent": {"reasoningText": {"text": "test"}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), # Empty @@ -240,12 +267,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), ], diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 61c2bf9a1..26453e1f7 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -22,6 +22,13 @@ def yellow_img(pytestconfig): return fp.read() +@pytest.fixture +def letter_pdf(pytestconfig): + path = pytestconfig.rootdir / "tests_integ/letter.pdf" + with open(path, "rb") as fp: + return fp.read() + + ## Async diff --git a/tests_integ/letter.pdf b/tests_integ/letter.pdf new file mode 100644 index 000000000..d8c59f749 Binary files /dev/null and b/tests_integ/letter.pdf differ diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index bd40938c9..00107411a 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -4,6 +4,7 @@ import strands from strands import Agent from strands.models import BedrockModel +from strands.types.content import ContentBlock @pytest.fixture @@ -27,12 +28,20 @@ def non_streaming_model(): @pytest.fixture def streaming_agent(streaming_model, system_prompt): - return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + return Agent( + model=streaming_model, + system_prompt=system_prompt, + load_tools_from_directory=False, + ) @pytest.fixture def non_streaming_agent(non_streaming_model, system_prompt): - return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + return Agent( + model=non_streaming_model, + system_prompt=system_prompt, + load_tools_from_directory=False, + ) @pytest.fixture @@ -184,6 +193,42 @@ def test_invoke_multi_modal_input(streaming_agent, yellow_img): assert "yellow" in text +def test_document_citations(non_streaming_agent, letter_pdf): + content: list[ContentBlock] = [ + { + "document": { + "name": "letter to shareholders", + "source": {"bytes": letter_pdf}, + "citations": {"enabled": True}, + "context": "This is a letter to shareholders", + "format": "pdf", + }, + }, + {"text": "What does the document say about artificial intelligence? Use citations to back up your answer."}, + ] + non_streaming_agent(content) + + assert any("citationsContent" in content for content in non_streaming_agent.messages[-1]["content"]) + + +def test_document_citations_streaming(streaming_agent, letter_pdf): + content: list[ContentBlock] = [ + { + "document": { + "name": "letter to shareholders", + "source": {"bytes": letter_pdf}, + "citations": {"enabled": True}, + "context": "This is a letter to shareholders", + "format": "pdf", + }, + }, + {"text": "What does the document say about artificial intelligence? Use citations to back up your answer."}, + ] + streaming_agent(content) + + assert any("citationsContent" in content for content in streaming_agent.messages[-1]["content"]) + + def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color): content = [ {"text": "Is this image red, blue, or yellow?"}, diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index bf5668349..66c5fe9ad 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -2,8 +2,8 @@ import pytest -from src.strands.agent import AgentResult from strands import Agent, tool +from strands.agent import AgentResult from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException