Skip to content

Commit a0dfb96

Browse files
ronakrmclaude
andcommitted
Add Anthropic prompt caching support with CachePoint
This implementation adds prompt caching support for Anthropic models, allowing users to cache parts of prompts (system prompts, long context, tools) to reduce costs by ~90% for cached tokens. Key changes: - Add CachePoint class to mark cache boundaries in prompts - Implement cache control in AnthropicModel using BetaCacheControlEphemeralParam - Add cache metrics mapping (cache_creation_input_tokens → cache_write_tokens) - Add comprehensive tests for CachePoint functionality - Add working example demonstrating prompt caching usage - Add CachePoint filtering in OpenAI models for compatibility The implementation is Anthropic-only (removed Bedrock complexity from original PR pydantic#2560) for a cleaner, more maintainable solution. Related to pydantic#2560 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 41336ac commit a0dfb96

File tree

6 files changed

+310
-3
lines changed

6 files changed

+310
-3
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#!/usr/bin/env python3
2+
"""Example demonstrating Anthropic prompt caching.
3+
4+
This example shows how to use CachePoint to reduce costs by caching:
5+
- Long system prompts
6+
- Large context (like documentation)
7+
- Tool definitions
8+
9+
Run with: uv run -m pydantic_ai_examples.anthropic_prompt_caching
10+
"""
11+
12+
from pydantic_ai import Agent, CachePoint
13+
14+
# Sample long context to demonstrate caching
15+
# Need at least 1024 tokens - repeating 10x to be safe
16+
LONG_CONTEXT = (
17+
"""
18+
# Product Documentation
19+
20+
## Overview
21+
Our API provides comprehensive data access with the following features:
22+
23+
### Authentication
24+
All requests require a Bearer token in the Authorization header.
25+
Rate limits: 1000 requests/hour for standard tier.
26+
27+
### Endpoints
28+
29+
#### GET /api/users
30+
Returns a list of users with pagination support.
31+
Parameters:
32+
- page: Page number (default: 1)
33+
- limit: Items per page (default: 20, max: 100)
34+
- filter: Optional filter expression
35+
36+
#### GET /api/products
37+
Returns product catalog with detailed specifications.
38+
Parameters:
39+
- category: Filter by category
40+
- in_stock: Boolean, filter available items
41+
- sort: Sort order (price_asc, price_desc, name)
42+
43+
#### POST /api/orders
44+
Create a new order. Requires authentication.
45+
Request body:
46+
- user_id: Integer, required
47+
- items: Array of {product_id, quantity}
48+
- shipping_address: Object with address details
49+
50+
#### Error Handling
51+
Standard HTTP status codes are used:
52+
- 200: Success
53+
- 400: Bad request
54+
- 401: Unauthorized
55+
- 404: Not found
56+
- 500: Server error
57+
58+
## Best Practices
59+
1. Always handle rate limiting with exponential backoff
60+
2. Cache responses where appropriate
61+
3. Use pagination for large datasets
62+
4. Validate input before submission
63+
5. Monitor API usage through dashboard
64+
65+
## Code Examples
66+
See detailed examples in our GitHub repository.
67+
"""
68+
* 10
69+
) # Repeat 10x to ensure we exceed Anthropic's minimum cache size (1024 tokens)
70+
71+
72+
async def main() -> None:
73+
"""Demonstrate prompt caching with Anthropic."""
74+
print('=== Anthropic Prompt Caching Demo ===\n')
75+
76+
agent = Agent(
77+
'anthropic:claude-sonnet-4-5',
78+
system_prompt='You are a helpful API documentation assistant.',
79+
)
80+
81+
# First request with cache point - this will write to cache
82+
print('First request (will cache context)...')
83+
result1 = await agent.run(
84+
[
85+
LONG_CONTEXT,
86+
CachePoint(), # Everything before this will be cached
87+
'What authentication method does the API use?',
88+
]
89+
)
90+
91+
print(f'Response: {result1.output}\n')
92+
usage1 = result1.usage()
93+
print(f'Usage: {usage1}')
94+
if usage1.cache_write_tokens:
95+
print(
96+
f' Cache write tokens: {usage1.cache_write_tokens} (tokens written to cache)'
97+
)
98+
print()
99+
100+
# Second request with same cached context - should use cache
101+
print('Second request (should read from cache)...')
102+
result2 = await agent.run(
103+
[
104+
LONG_CONTEXT,
105+
CachePoint(), # Same content, should hit cache
106+
'What are the available API endpoints?',
107+
]
108+
)
109+
110+
print(f'Response: {result2.output}\n')
111+
usage2 = result2.usage()
112+
print(f'Usage: {usage2}')
113+
if usage2.cache_read_tokens:
114+
print(
115+
f' Cache read tokens: {usage2.cache_read_tokens} (tokens read from cache)'
116+
)
117+
print(
118+
f' Cache savings: ~{usage2.cache_read_tokens * 0.9:.0f} token-equivalents (90% discount)'
119+
)
120+
print()
121+
122+
# Third request with different question, same cache
123+
print('Third request (should also read from cache)...')
124+
result3 = await agent.run(
125+
[
126+
LONG_CONTEXT,
127+
CachePoint(),
128+
'How should I handle rate limiting?',
129+
]
130+
)
131+
132+
print(f'Response: {result3.output}\n')
133+
usage3 = result3.usage()
134+
print(f'Usage: {usage3}')
135+
if usage3.cache_read_tokens:
136+
print(f' Cache read tokens: {usage3.cache_read_tokens}')
137+
print()
138+
139+
print('=== Summary ===')
140+
total_usage = usage1 + usage2 + usage3
141+
print(f'Total input tokens: {total_usage.input_tokens}')
142+
print(f'Total cache write: {total_usage.cache_write_tokens}')
143+
print(f'Total cache read: {total_usage.cache_read_tokens}')
144+
if total_usage.cache_read_tokens:
145+
savings = total_usage.cache_read_tokens * 0.9
146+
print(f'Estimated savings: ~{savings:.0f} token-equivalents')
147+
148+
149+
if __name__ == '__main__':
150+
import asyncio
151+
152+
asyncio.run(main())

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
BinaryImage,
4343
BuiltinToolCallPart,
4444
BuiltinToolReturnPart,
45+
CachePoint,
4546
DocumentFormat,
4647
DocumentMediaType,
4748
DocumentUrl,
@@ -141,6 +142,7 @@
141142
'BinaryContent',
142143
'BuiltinToolCallPart',
143144
'BuiltinToolReturnPart',
145+
'CachePoint',
144146
'DocumentFormat',
145147
'DocumentMediaType',
146148
'DocumentUrl',

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,20 @@ def __init__(
612612
raise ValueError('`BinaryImage` must be have a media type that starts with "image/"') # pragma: no cover
613613

614614

615+
@dataclass
616+
class CachePoint:
617+
"""A cache point marker for prompt caching.
618+
619+
Can be inserted into UserPromptPart.content to mark cache boundaries.
620+
Models that don't support caching will filter these out.
621+
"""
622+
623+
kind: Literal['cache-point'] = 'cache-point'
624+
"""Type identifier, this is available on all parts as a discriminator."""
625+
626+
615627
MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
616-
UserContent: TypeAlias = str | MultiModalContent
628+
UserContent: TypeAlias = str | MultiModalContent | CachePoint
617629

618630

619631
@dataclass(repr=False)

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BinaryContent,
2020
BuiltinToolCallPart,
2121
BuiltinToolReturnPart,
22+
CachePoint,
2223
DocumentUrl,
2324
FilePart,
2425
FinishReason,
@@ -58,6 +59,7 @@
5859
from anthropic.types.beta import (
5960
BetaBase64PDFBlockParam,
6061
BetaBase64PDFSourceParam,
62+
BetaCacheControlEphemeralParam,
6163
BetaCitationsDelta,
6264
BetaCodeExecutionTool20250522Param,
6365
BetaCodeExecutionToolResultBlock,
@@ -477,7 +479,10 @@ async def _map_message( # noqa: C901
477479
system_prompt_parts.append(request_part.content)
478480
elif isinstance(request_part, UserPromptPart):
479481
async for content in self._map_user_prompt(request_part):
480-
user_content_params.append(content)
482+
if isinstance(content, CachePoint):
483+
self._add_cache_control_to_last_param(user_content_params)
484+
else:
485+
user_content_params.append(content)
481486
elif isinstance(request_part, ToolReturnPart):
482487
tool_result_block_param = BetaToolResultBlockParam(
483488
tool_use_id=_guard_tool_call_id(t=request_part),
@@ -639,10 +644,26 @@ async def _map_message( # noqa: C901
639644
system_prompt = '\n\n'.join(system_prompt_parts)
640645
return system_prompt, anthropic_messages
641646

647+
@staticmethod
648+
def _add_cache_control_to_last_param(params: list[BetaContentBlockParam]) -> None:
649+
"""Add cache control to the last content block param."""
650+
if not params:
651+
raise UserError(
652+
'CachePoint cannot be the first content in a user message - there must be previous content to attach the CachePoint to.'
653+
)
654+
655+
# Only certain types support cache_control
656+
cacheable_types = {'text', 'tool_use', 'server_tool_use', 'image', 'tool_result'}
657+
if params[-1]['type'] not in cacheable_types:
658+
raise UserError(f'Cache control not supported for param type: {params[-1]["type"]}')
659+
660+
# Add cache_control to the last param
661+
params[-1]['cache_control'] = BetaCacheControlEphemeralParam(type='ephemeral')
662+
642663
@staticmethod
643664
async def _map_user_prompt(
644665
part: UserPromptPart,
645-
) -> AsyncGenerator[BetaContentBlockParam]:
666+
) -> AsyncGenerator[BetaContentBlockParam | CachePoint]:
646667
if isinstance(part.content, str):
647668
if part.content: # Only yield non-empty text
648669
yield BetaTextBlockParam(text=part.content, type='text')
@@ -651,6 +672,8 @@ async def _map_user_prompt(
651672
if isinstance(item, str):
652673
if item: # Only yield non-empty text
653674
yield BetaTextBlockParam(text=item, type='text')
675+
elif isinstance(item, CachePoint):
676+
yield item
654677
elif isinstance(item, BinaryContent):
655678
if item.is_image:
656679
yield BetaImageBlockParam(
@@ -717,6 +740,8 @@ def _map_usage(
717740
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
718741
}
719742

743+
# Note: genai-prices already extracts cache_creation_input_tokens and cache_read_input_tokens
744+
# from the Anthropic response and maps them to cache_write_tokens and cache_read_tokens
720745
return usage.RequestUsage.extract(
721746
dict(model=model, usage=details),
722747
provider=provider,

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
BinaryImage,
2727
BuiltinToolCallPart,
2828
BuiltinToolReturnPart,
29+
CachePoint,
2930
DocumentUrl,
3031
FilePart,
3132
FinishReason,
@@ -860,6 +861,9 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse
860861
)
861862
elif isinstance(item, VideoUrl): # pragma: no cover
862863
raise NotImplementedError('VideoUrl is not supported for OpenAI')
864+
elif isinstance(item, CachePoint):
865+
# OpenAI doesn't support prompt caching via CachePoint, so we filter it out
866+
pass
863867
else:
864868
assert_never(item)
865869
return chat.ChatCompletionUserMessageParam(role='user', content=content)
@@ -1673,6 +1677,9 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
16731677
)
16741678
elif isinstance(item, VideoUrl): # pragma: no cover
16751679
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
1680+
elif isinstance(item, CachePoint):
1681+
# OpenAI doesn't support prompt caching via CachePoint, so we filter it out
1682+
pass
16761683
else:
16771684
assert_never(item)
16781685
return responses.EasyInputMessageParam(role='user', content=content)

0 commit comments

Comments
 (0)