Skip to content

feat: claude citation support with BedrockModel #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,4 @@ def __str__(self) -> str:
for item in content_array:
if isinstance(item, dict) and "text" in item:
result += item.get("text", "") + "\n"

return result
15 changes: 15 additions & 0 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, AsyncGenerator, AsyncIterable, Optional

from ..models.model import Model
from ..types.citations import CitationsContentBlock
from ..types.content import ContentBlock, Message, Messages
from ..types.streaming import (
ContentBlockDeltaEvent,
Expand Down Expand Up @@ -130,6 +131,13 @@ def handle_content_block_delta(
state["text"] += delta_content["text"]
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}

elif "citation" in delta_content:
if "citationsContent" not in state:
state["citationsContent"] = []

state["citationsContent"].append(delta_content["citation"])
callback_event["callback"] = {"citation_metadata": delta_content["citation"], "delta": delta_content}

elif "reasoningContent" in delta_content:
if "text" in delta_content["reasoningContent"]:
if "reasoningText" not in state:
Expand Down Expand Up @@ -170,6 +178,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
current_tool_use = state["current_tool_use"]
text = state["text"]
reasoning_text = state["reasoningText"]
citations_content = state["citationsContent"]

if current_tool_use:
if "input" not in current_tool_use:
Expand All @@ -194,6 +203,10 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
elif text:
content.append({"text": text})
state["text"] = ""
if citations_content:
citations_block: CitationsContentBlock = {"citations": citations_content}
content.append({"citationsContent": citations_block})
state["citationsContent"] = []

elif reasoning_text:
content_block: ContentBlock = {
Expand Down Expand Up @@ -267,6 +280,8 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
"text": "",
"current_tool_use": {},
"reasoningText": "",
"signature": "",
"citationsContent": [],
}
state["content"] = state["message"]["content"]

Expand Down
28 changes: 25 additions & 3 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import logging
import os
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast

import boto3
from botocore.config import Config as BotocoreConfig
Expand All @@ -18,7 +18,10 @@
from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.exceptions import (
ContextWindowOverflowException,
ModelThrottledException,
)
from ..types.streaming import StreamEvent
from ..types.tools import ToolResult, ToolSpec
from .model import Model
Expand Down Expand Up @@ -510,7 +513,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
yield {"messageStart": {"role": response["output"]["message"]["role"]}}

# Process content blocks
for content in response["output"]["message"]["content"]:
for content in cast(list[ContentBlock], response["output"]["message"]["content"]):
# Yield contentBlockStart event if needed
if "toolUse" in content:
yield {
Expand Down Expand Up @@ -553,6 +556,25 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
}
}
}
elif "citationsContent" in content:
# For non-streaming citations, emit text and metadata deltas in sequence
# to match streaming behavior where they flow naturally
if "content" in content["citationsContent"]:
text_content = "".join([content["text"] for content in content["citationsContent"]["content"]])
yield {
"contentBlockDelta": {"delta": {"text": text_content}},
}

for citation in content["citationsContent"]["citations"]:
# Then emit citation metadata (for structure)
from ..types.streaming import CitationsDelta
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets be consistent and just import this at the top


citation_metadata: CitationsDelta = {
"title": citation["title"],
"location": citation["location"],
"sourceContent": citation["sourceContent"],
}
yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}}

# Yield contentBlockStop event
yield {"contentBlockStop": {}}
Expand Down
152 changes: 152 additions & 0 deletions src/strands/types/citations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Citation type definitions for the SDK.

These types are modeled after the Bedrock API.
"""

from typing import List, Union

from typing_extensions import TypedDict


class CitationsConfig(TypedDict):
"""Configuration for enabling citations on documents.

Attributes:
enabled: Whether citations are enabled for this document.
"""

enabled: bool


class DocumentCharLocation(TypedDict, total=False):
"""Specifies a character-level location within a document.

Provides precise positioning information for cited content using
start and end character indices.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting character position of the cited content within
the document. Minimum value of 0.
end: The ending character position of the cited content within
the document. Minimum value of 0.
"""

documentIndex: int
start: int
end: int


class DocumentChunkLocation(TypedDict, total=False):
"""Specifies a chunk-level location within a document.

Provides positioning information for cited content using logical
document segments or chunks.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting chunk identifier or index of the cited content
within the document. Minimum value of 0.
end: The ending chunk identifier or index of the cited content
within the document. Minimum value of 0.
"""

documentIndex: int
start: int
end: int


class DocumentPageLocation(TypedDict, total=False):
"""Specifies a page-level location within a document.

Provides positioning information for cited content using page numbers.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting page number of the cited content within
the document. Minimum value of 0.
end: The ending page number of the cited content within
the document. Minimum value of 0.
"""

documentIndex: int
start: int
end: int


# Union type for citation locations
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]


class CitationSourceContent(TypedDict, total=False):
"""Contains the actual text content from a source document.

Contains the actual text content from a source document that is being
cited or referenced in the model's response.

Note:
This is a UNION type, so only one of the members can be specified.

Attributes:
text: The text content from the source document that is being cited.
"""

text: str


class CitationGeneratedContent(TypedDict, total=False):
"""Contains the generated text content that corresponds to a citation.

Contains the generated text content that corresponds to or is supported
by a citation from a source document.

Note:
This is a UNION type, so only one of the members can be specified.

Attributes:
text: The text content that was generated by the model and is
supported by the associated citation.
"""

text: str


class Citation(TypedDict, total=False):
"""Contains information about a citation that references a source document.

Citations provide traceability between the model's generated response
and the source documents that informed that response.

Attributes:
location: The precise location within the source document where the
cited content can be found, including character positions, page
numbers, or chunk identifiers.
sourceContent: The specific content from the source document that was
referenced or cited in the generated response.
title: The title or identifier of the source document being cited.
"""

location: CitationLocation
sourceContent: List[CitationSourceContent]
title: str


class CitationsContentBlock(TypedDict, total=False):
"""A content block containing generated text and associated citations.

This block type is returned when document citations are enabled, providing
traceability between the generated content and the source documents that
informed the response.

Attributes:
citations: An array of citations that reference the source documents
used to generate the associated content.
content: The generated content that is supported by the associated
citations.
"""

citations: List[Citation]
content: List[CitationGeneratedContent]
3 changes: 3 additions & 0 deletions src/strands/types/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing_extensions import TypedDict

from .citations import CitationsContentBlock
from .media import DocumentContent, ImageContent, VideoContent
from .tools import ToolResult, ToolUse

Expand Down Expand Up @@ -83,6 +84,7 @@ class ContentBlock(TypedDict, total=False):
toolResult: The result for a tool request that a model makes.
toolUse: Information about a tool use request from a model.
video: Video to include in the message.
citationsContent: Contains the citations for a document.
"""

cachePoint: CachePoint
Expand All @@ -94,6 +96,7 @@ class ContentBlock(TypedDict, total=False):
toolResult: ToolResult
toolUse: ToolUse
video: VideoContent
citationsContent: CitationsContentBlock


class SystemContentBlock(TypedDict, total=False):
Expand Down
8 changes: 6 additions & 2 deletions src/strands/types/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
- Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
"""

from typing import Literal
from typing import Literal, Optional

from typing_extensions import TypedDict

from .citations import CitationsConfig

DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
"""Supported document formats."""

Expand All @@ -23,7 +25,7 @@ class DocumentSource(TypedDict):
bytes: bytes


class DocumentContent(TypedDict):
class DocumentContent(TypedDict, total=False):
"""A document to include in a message.

Attributes:
Expand All @@ -35,6 +37,8 @@ class DocumentContent(TypedDict):
format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
name: str
source: DocumentSource
citations: Optional[CitationsConfig]
context: Optional[str]


ImageFormat = Literal["png", "jpeg", "gif", "webp"]
Expand Down
37 changes: 37 additions & 0 deletions src/strands/types/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from typing_extensions import TypedDict

from .citations import CitationLocation
from .content import ContentBlockStart, Role
from .event_loop import Metrics, StopReason, Usage
from .guardrails import Trace
Expand Down Expand Up @@ -57,6 +58,41 @@ class ContentBlockDeltaToolUse(TypedDict):
input: str


class CitationSourceContentDelta(TypedDict, total=False):
"""Contains incremental updates to source content text during streaming.

Allows clients to build up the cited content progressively during
streaming responses.

Attributes:
text: An incremental update to the text content from the source
document that is being cited.
"""

text: str


class CitationsDelta(TypedDict, total=False):
"""Contains incremental updates to citation information during streaming.

This allows clients to build up citation data progressively as the
response is generated.

Attributes:
location: Specifies the precise location within a source document
where cited content can be found. This can include character-level
positions, page numbers, or document chunks depending on the
document type and indexing method.
sourceContent: The specific content from the source document that was
referenced or cited in the generated response.
title: The title or identifier of the source document being cited.
"""

location: CitationLocation
sourceContent: list[CitationSourceContentDelta]
title: str


class ReasoningContentBlockDelta(TypedDict, total=False):
"""Delta for reasoning content block in a streaming response.

Expand All @@ -83,6 +119,7 @@ class ContentBlockDelta(TypedDict, total=False):
reasoningContent: ReasoningContentBlockDelta
text: str
toolUse: ContentBlockDeltaToolUse
citation: CitationsDelta


class ContentBlockDeltaEvent(TypedDict, total=False):
Expand Down
Loading
Loading