Skip to content

Commit 365b67b

Browse files
Add BedrockConverseModel.count_tokens so it works with UsageLimits.count_tokens_before_request (pydantic#3367)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 41336ac commit 365b67b

File tree

9 files changed

+287
-24
lines changed

9 files changed

+287
-24
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@
103103
'bedrock:us.anthropic.claude-opus-4-20250514-v1:0',
104104
'bedrock:anthropic.claude-sonnet-4-20250514-v1:0',
105105
'bedrock:us.anthropic.claude-sonnet-4-20250514-v1:0',
106+
'bedrock:eu.anthropic.claude-sonnet-4-20250514-v1:0',
107+
'bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0',
108+
'bedrock:us.anthropic.claude-sonnet-4-5-20250929-v1:0',
109+
'bedrock:eu.anthropic.claude-sonnet-4-5-20250929-v1:0',
110+
'bedrock:anthropic.claude-haiku-4-5-20251001-v1:0',
111+
'bedrock:us.anthropic.claude-haiku-4-5-20251001-v1:0',
112+
'bedrock:eu.anthropic.claude-haiku-4-5-20251001-v1:0',
106113
'bedrock:cohere.command-text-v14',
107114
'bedrock:cohere.command-r-v1:0',
108115
'bedrock:cohere.command-r-plus-v1:0',

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
DocumentUrl,
2323
FinishReason,
2424
ImageUrl,
25-
ModelHTTPError,
2625
ModelMessage,
2726
ModelProfileSpec,
2827
ModelRequest,
@@ -41,7 +40,7 @@
4140
usage,
4241
)
4342
from pydantic_ai._run_context import RunContext
44-
from pydantic_ai.exceptions import UserError
43+
from pydantic_ai.exceptions import ModelHTTPError, UserError
4544
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
4645
from pydantic_ai.providers import Provider, infer_provider
4746
from pydantic_ai.providers.bedrock import BedrockModelProfile
@@ -61,6 +60,7 @@
6160
ConverseStreamMetadataEventTypeDef,
6261
ConverseStreamOutputTypeDef,
6362
ConverseStreamResponseTypeDef,
63+
CountTokensRequestTypeDef,
6464
DocumentBlockTypeDef,
6565
GuardrailConfigurationTypeDef,
6666
ImageBlockTypeDef,
@@ -77,7 +77,6 @@
7777
VideoBlockTypeDef,
7878
)
7979

80-
8180
LatestBedrockModelNames = Literal[
8281
'amazon.titan-tg1-large',
8382
'amazon.titan-text-lite-v1',
@@ -106,6 +105,13 @@
106105
'us.anthropic.claude-opus-4-20250514-v1:0',
107106
'anthropic.claude-sonnet-4-20250514-v1:0',
108107
'us.anthropic.claude-sonnet-4-20250514-v1:0',
108+
'eu.anthropic.claude-sonnet-4-20250514-v1:0',
109+
'anthropic.claude-sonnet-4-5-20250929-v1:0',
110+
'us.anthropic.claude-sonnet-4-5-20250929-v1:0',
111+
'eu.anthropic.claude-sonnet-4-5-20250929-v1:0',
112+
'anthropic.claude-haiku-4-5-20251001-v1:0',
113+
'us.anthropic.claude-haiku-4-5-20251001-v1:0',
114+
'eu.anthropic.claude-haiku-4-5-20251001-v1:0',
109115
'cohere.command-text-v14',
110116
'cohere.command-r-v1:0',
111117
'cohere.command-r-plus-v1:0',
@@ -136,7 +142,6 @@
136142
See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for a full list.
137143
"""
138144

139-
140145
P = ParamSpec('P')
141146
T = typing.TypeVar('T')
142147

@@ -149,6 +154,13 @@
149154
'tool_use': 'tool_call',
150155
}
151156

157+
_AWS_BEDROCK_INFERENCE_GEO_PREFIXES: tuple[str, ...] = ('us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.')
158+
"""Geo prefixes for Bedrock inference profile IDs (e.g., 'eu.', 'us.').
159+
160+
Used to strip the geo prefix so we can pass a pure foundation model ID/ARN to CountTokens,
161+
which does not accept profile IDs. Extend if new geos appear (e.g., 'global.', 'us-gov.').
162+
"""
163+
152164

153165
class BedrockModelSettings(ModelSettings, total=False):
154166
"""Settings for Bedrock models.
@@ -275,6 +287,34 @@ async def request(
275287
model_response = await self._process_response(response)
276288
return model_response
277289

290+
async def count_tokens(
291+
self,
292+
messages: list[ModelMessage],
293+
model_settings: ModelSettings | None,
294+
model_request_parameters: ModelRequestParameters,
295+
) -> usage.RequestUsage:
296+
"""Count the number of tokens, works with limited models.
297+
298+
Check the actual supported models on <https://docs.aws.amazon.com/bedrock/latest/userguide/count-tokens.html>
299+
"""
300+
model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters)
301+
system_prompt, bedrock_messages = await self._map_messages(messages, model_request_parameters)
302+
params: CountTokensRequestTypeDef = {
303+
'modelId': self._remove_inference_geo_prefix(self.model_name),
304+
'input': {
305+
'converse': {
306+
'messages': bedrock_messages,
307+
'system': system_prompt,
308+
},
309+
},
310+
}
311+
try:
312+
response = await anyio.to_thread.run_sync(functools.partial(self.client.count_tokens, **params))
313+
except ClientError as e:
314+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
315+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
316+
return usage.RequestUsage(input_tokens=response['inputTokens'])
317+
278318
@asynccontextmanager
279319
async def request_stream(
280320
self,
@@ -642,6 +682,14 @@ def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
642682
'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
643683
}
644684

685+
@staticmethod
686+
def _remove_inference_geo_prefix(model_name: BedrockModelName) -> BedrockModelName:
687+
"""Remove inference geographic prefix from model ID if present."""
688+
for prefix in _AWS_BEDROCK_INFERENCE_GEO_PREFIXES:
689+
if model_name.startswith(prefix):
690+
return model_name.removeprefix(prefix)
691+
return model_name
692+
645693

646694
@dataclass
647695
class BedrockStreamedResponse(StreamedResponse):

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ google = ["google-genai>=1.46.0"]
7474
anthropic = ["anthropic>=0.70.0"]
7575
groq = ["groq>=0.25.0"]
7676
mistral = ["mistralai>=1.9.10"]
77-
bedrock = ["boto3>=1.39.0"]
77+
bedrock = ["boto3>=1.40.14"]
7878
huggingface = ["huggingface-hub[inference]>=0.33.5"]
7979
outlines-transformers = ["outlines[transformers]>=1.0.0, <1.3.0; (sys_platform != 'darwin' or platform_machine != 'x86_64')", "transformers>=4.0.0", "pillow", "torch; (sys_platform != 'darwin' or platform_machine != 'x86_64')"]
8080
outlines-llamacpp = ["outlines[llamacpp]>=1.0.0, <1.3.0"]

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ def bedrock_provider():
424424
region_name=os.getenv('AWS_REGION', 'us-east-1'),
425425
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID', 'AKIA6666666666666666'),
426426
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY', '6666666666666666666666666666666666666666'),
427+
aws_session_token=os.getenv('AWS_SESSION_TOKEN', None),
427428
)
428429
yield BedrockProvider(bedrock_client=bedrock_client)
429430
bedrock_client.close()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
interactions:
2+
- request:
3+
body: '{"input": {"converse": {"messages": [{"role": "user", "content": [{"text": "hello"}]}], "system": []}}}'
4+
headers:
5+
amz-sdk-invocation-id:
6+
- !!binary |
7+
ODdjZWFjMTYtN2U4OC00YTMzLTg5Y2QtZDUwNWM4N2YzNmNk
8+
amz-sdk-request:
9+
- !!binary |
10+
YXR0ZW1wdD0x
11+
content-length:
12+
- '103'
13+
content-type:
14+
- !!binary |
15+
YXBwbGljYXRpb24vanNvbg==
16+
method: POST
17+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/does-not-exist-model-v1%3A0/count-tokens
18+
response:
19+
headers:
20+
connection:
21+
- keep-alive
22+
content-length:
23+
- '55'
24+
content-type:
25+
- application/json
26+
parsed_body:
27+
message: The provided model identifier is invalid.
28+
status:
29+
code: 400
30+
message: Bad Request
31+
version: 1
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
interactions:
2+
- request:
3+
body: '{"input": {"converse": {"messages": [{"role": "user", "content": [{"text": "The quick brown fox jumps over the
4+
lazydog."}]}], "system": []}}}'
5+
headers:
6+
amz-sdk-invocation-id:
7+
- !!binary |
8+
ZDYxNmVkOTktYzgwMi00MDE0LTljZGUtYWFjMjk5N2I2MDFj
9+
amz-sdk-request:
10+
- !!binary |
11+
YXR0ZW1wdD0x
12+
content-length:
13+
- '141'
14+
content-type:
15+
- !!binary |
16+
YXBwbGljYXRpb24vanNvbg==
17+
method: POST
18+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-20250514-v1%3A0/count-tokens
19+
response:
20+
headers:
21+
connection:
22+
- keep-alive
23+
content-length:
24+
- '18'
25+
content-type:
26+
- application/json
27+
parsed_body:
28+
inputTokens: 19
29+
status:
30+
code: 200
31+
message: OK
32+
version: 1
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
interactions:
2+
- request:
3+
body: '{"input": {"converse": {"messages": [{"role": "user", "content": [{"text": "The quick brown fox jumps over the
4+
lazydog."}]}], "system": []}}}'
5+
headers:
6+
amz-sdk-invocation-id:
7+
- !!binary |
8+
OWQ3NzFhZmItYTkwYi00N2E4LWFkNjMtZmI5OTJhZDEyN2E4
9+
amz-sdk-request:
10+
- !!binary |
11+
YXR0ZW1wdD0x
12+
content-length:
13+
- '141'
14+
content-type:
15+
- !!binary |
16+
YXBwbGljYXRpb24vanNvbg==
17+
method: POST
18+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-20250514-v1%3A0/count-tokens
19+
response:
20+
headers:
21+
connection:
22+
- keep-alive
23+
content-length:
24+
- '18'
25+
content-type:
26+
- application/json
27+
parsed_body:
28+
inputTokens: 19
29+
status:
30+
code: 200
31+
message: OK
32+
- request:
33+
body: '{"messages": [{"role": "user", "content": [{"text": "The quick brown fox jumps over the lazydog."}]}], "system":
34+
[], "inferenceConfig": {}}'
35+
headers:
36+
amz-sdk-invocation-id:
37+
- !!binary |
38+
MWMwNDdlYWEtOWIxMy00YjAyLWI3ZjMtMjZkNjQ2MDEzOTY2
39+
amz-sdk-request:
40+
- !!binary |
41+
YXR0ZW1wdD0x
42+
content-length:
43+
- '139'
44+
content-type:
45+
- !!binary |
46+
YXBwbGljYXRpb24vanNvbg==
47+
method: POST
48+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-20250514-v1%3A0/converse
49+
response:
50+
headers:
51+
connection:
52+
- keep-alive
53+
content-length:
54+
- '785'
55+
content-type:
56+
- application/json
57+
parsed_body:
58+
metrics:
59+
latencyMs: 2333
60+
output:
61+
message:
62+
content:
63+
- text: "I notice there's a small typo in your message - it should be \"lazy dog\" (two words) rather than \"lazydog.\"\n\nThe corrected version is: \"The quick brown fox jumps over the lazy dog.\"\n\nThis is a famous pangram - a sentence that contains every letter of the English alphabet at least once. It's commonly used for testing typewriters, keyboards, fonts, and other applications where you want to display all the letters.\n\nIs there something specific you'd like to know about this phrase, or were you perhaps testing something?"
64+
role: assistant
65+
stopReason: end_turn
66+
usage:
67+
cacheReadInputTokenCount: 0
68+
cacheReadInputTokens: 0
69+
cacheWriteInputTokenCount: 0
70+
cacheWriteInputTokens: 0
71+
inputTokens: 19
72+
outputTokens: 108
73+
serverToolUsage: {}
74+
totalTokens: 127
75+
status:
76+
code: 200
77+
message: OK
78+
version: 1

tests/models/test_bedrock.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
FunctionToolCallEvent,
1515
FunctionToolResultEvent,
1616
ImageUrl,
17-
ModelHTTPError,
1817
ModelRequest,
1918
ModelResponse,
2019
PartDeltaEvent,
@@ -33,12 +32,12 @@
3332
VideoUrl,
3433
)
3534
from pydantic_ai.agent import Agent
36-
from pydantic_ai.exceptions import ModelRetry
35+
from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UsageLimitExceeded
3736
from pydantic_ai.messages import AgentStreamEvent
3837
from pydantic_ai.models import ModelRequestParameters
3938
from pydantic_ai.run import AgentRunResult, AgentRunResultEvent
4039
from pydantic_ai.tools import ToolDefinition
41-
from pydantic_ai.usage import RequestUsage, RunUsage
40+
from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits
4241

4342
from ..conftest import IsDatetime, IsInstance, IsStr, try_import
4443

@@ -98,6 +97,73 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro
9897
)
9998

10099

100+
async def test_bedrock_model_usage_limit_exceeded(
101+
allow_model_requests: None,
102+
bedrock_provider: BedrockProvider,
103+
):
104+
model = BedrockConverseModel('us.anthropic.claude-sonnet-4-20250514-v1:0', provider=bedrock_provider)
105+
agent = Agent(model=model)
106+
107+
with pytest.raises(
108+
UsageLimitExceeded,
109+
match='The next request would exceed the input_tokens_limit of 18 \\(input_tokens=19\\)',
110+
):
111+
await agent.run(
112+
'The quick brown fox jumps over the lazydog.',
113+
usage_limits=UsageLimits(input_tokens_limit=18, count_tokens_before_request=True),
114+
)
115+
116+
117+
async def test_bedrock_model_usage_limit_not_exceeded(
118+
allow_model_requests: None,
119+
bedrock_provider: BedrockProvider,
120+
):
121+
model = BedrockConverseModel('us.anthropic.claude-sonnet-4-20250514-v1:0', provider=bedrock_provider)
122+
agent = Agent(model=model)
123+
124+
result = await agent.run(
125+
'The quick brown fox jumps over the lazydog.',
126+
usage_limits=UsageLimits(input_tokens_limit=25, count_tokens_before_request=True),
127+
)
128+
129+
assert result.output == snapshot(
130+
'I notice there\'s a small typo in your message - it should be "lazy dog" (two words) rather than '
131+
'"lazydog."\n\nThe corrected version is: "The quick brown fox jumps over the lazy dog."\n\n'
132+
'This is a famous pangram - a sentence that contains every letter of the English alphabet at least once. '
133+
"It's commonly used for testing typewriters, keyboards, fonts, and other applications where you want to "
134+
"display all the letters.\n\nIs there something specific you'd like to know about this phrase, or were you "
135+
'perhaps testing something?'
136+
)
137+
138+
139+
@pytest.mark.vcr()
140+
async def test_bedrock_count_tokens_error(allow_model_requests: None, bedrock_provider: BedrockProvider):
141+
"""Test that errors convert to ModelHTTPError."""
142+
model_id = 'us.does-not-exist-model-v1:0'
143+
model = BedrockConverseModel(model_id, provider=bedrock_provider)
144+
agent = Agent(model)
145+
146+
with pytest.raises(ModelHTTPError) as exc_info:
147+
await agent.run('hello', usage_limits=UsageLimits(input_tokens_limit=20, count_tokens_before_request=True))
148+
149+
assert exc_info.value.status_code == 400
150+
assert exc_info.value.model_name == model_id
151+
assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr]
152+
153+
154+
@pytest.mark.parametrize(
155+
('model_name', 'expected'),
156+
[
157+
('us.anthropic.claude-sonnet-4-20250514-v1:0', 'anthropic.claude-sonnet-4-20250514-v1:0'),
158+
('eu.amazon.nova-micro-v1:0', 'amazon.nova-micro-v1:0'),
159+
('apac.meta.llama3-8b-instruct-v1:0', 'meta.llama3-8b-instruct-v1:0'),
160+
('anthropic.claude-3-7-sonnet-20250219-v1:0', 'anthropic.claude-3-7-sonnet-20250219-v1:0'),
161+
],
162+
)
163+
def test_remove_inference_geo_prefix(model_name: str, expected: str):
164+
assert BedrockConverseModel._remove_inference_geo_prefix(model_name) == expected # pyright: ignore[reportPrivateUsage]
165+
166+
101167
async def test_bedrock_model_structured_output(allow_model_requests: None, bedrock_provider: BedrockProvider):
102168
model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider)
103169
agent = Agent(model=model, system_prompt='You are a helpful chatbot.', retries=5)

0 commit comments

Comments
 (0)