|
22 | 22 | DocumentUrl, |
23 | 23 | FinishReason, |
24 | 24 | ImageUrl, |
25 | | - ModelHTTPError, |
26 | 25 | ModelMessage, |
27 | 26 | ModelProfileSpec, |
28 | 27 | ModelRequest, |
|
41 | 40 | usage, |
42 | 41 | ) |
43 | 42 | from pydantic_ai._run_context import RunContext |
44 | | -from pydantic_ai.exceptions import UserError |
| 43 | +from pydantic_ai.exceptions import ModelHTTPError, UserError |
45 | 44 | from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item |
46 | 45 | from pydantic_ai.providers import Provider, infer_provider |
47 | 46 | from pydantic_ai.providers.bedrock import BedrockModelProfile |
|
61 | 60 | ConverseStreamMetadataEventTypeDef, |
62 | 61 | ConverseStreamOutputTypeDef, |
63 | 62 | ConverseStreamResponseTypeDef, |
| 63 | + CountTokensRequestTypeDef, |
64 | 64 | DocumentBlockTypeDef, |
65 | 65 | GuardrailConfigurationTypeDef, |
66 | 66 | ImageBlockTypeDef, |
|
77 | 77 | VideoBlockTypeDef, |
78 | 78 | ) |
79 | 79 |
|
80 | | - |
81 | 80 | LatestBedrockModelNames = Literal[ |
82 | 81 | 'amazon.titan-tg1-large', |
83 | 82 | 'amazon.titan-text-lite-v1', |
|
106 | 105 | 'us.anthropic.claude-opus-4-20250514-v1:0', |
107 | 106 | 'anthropic.claude-sonnet-4-20250514-v1:0', |
108 | 107 | 'us.anthropic.claude-sonnet-4-20250514-v1:0', |
| 108 | + 'eu.anthropic.claude-sonnet-4-20250514-v1:0', |
| 109 | + 'anthropic.claude-sonnet-4-5-20250929-v1:0', |
| 110 | + 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', |
| 111 | + 'eu.anthropic.claude-sonnet-4-5-20250929-v1:0', |
| 112 | + 'anthropic.claude-haiku-4-5-20251001-v1:0', |
| 113 | + 'us.anthropic.claude-haiku-4-5-20251001-v1:0', |
| 114 | + 'eu.anthropic.claude-haiku-4-5-20251001-v1:0', |
109 | 115 | 'cohere.command-text-v14', |
110 | 116 | 'cohere.command-r-v1:0', |
111 | 117 | 'cohere.command-r-plus-v1:0', |
|
136 | 142 | See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for a full list. |
137 | 143 | """ |
138 | 144 |
|
139 | | - |
140 | 145 | P = ParamSpec('P') |
141 | 146 | T = typing.TypeVar('T') |
142 | 147 |
|
|
149 | 154 | 'tool_use': 'tool_call', |
150 | 155 | } |
151 | 156 |
|
| 157 | +_AWS_BEDROCK_INFERENCE_GEO_PREFIXES: tuple[str, ...] = ('us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.') |
| 158 | +"""Geo prefixes for Bedrock inference profile IDs (e.g., 'eu.', 'us.'). |
| 159 | +
|
| 160 | +Used to strip the geo prefix so we can pass a pure foundation model ID/ARN to CountTokens, |
| 161 | +which does not accept profile IDs. Extend if new geos appear (e.g., 'global.', 'us-gov.'). |
| 162 | +""" |
| 163 | + |
152 | 164 |
|
153 | 165 | class BedrockModelSettings(ModelSettings, total=False): |
154 | 166 | """Settings for Bedrock models. |
@@ -275,6 +287,34 @@ async def request( |
275 | 287 | model_response = await self._process_response(response) |
276 | 288 | return model_response |
277 | 289 |
|
| 290 | + async def count_tokens( |
| 291 | + self, |
| 292 | + messages: list[ModelMessage], |
| 293 | + model_settings: ModelSettings | None, |
| 294 | + model_request_parameters: ModelRequestParameters, |
| 295 | + ) -> usage.RequestUsage: |
| 296 | + """Count the number of tokens, works with limited models. |
| 297 | +
|
| 298 | + Check the actual supported models on <https://docs.aws.amazon.com/bedrock/latest/userguide/count-tokens.html> |
| 299 | + """ |
| 300 | + model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters) |
| 301 | + system_prompt, bedrock_messages = await self._map_messages(messages, model_request_parameters) |
| 302 | + params: CountTokensRequestTypeDef = { |
| 303 | + 'modelId': self._remove_inference_geo_prefix(self.model_name), |
| 304 | + 'input': { |
| 305 | + 'converse': { |
| 306 | + 'messages': bedrock_messages, |
| 307 | + 'system': system_prompt, |
| 308 | + }, |
| 309 | + }, |
| 310 | + } |
| 311 | + try: |
| 312 | + response = await anyio.to_thread.run_sync(functools.partial(self.client.count_tokens, **params)) |
| 313 | + except ClientError as e: |
| 314 | + status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500) |
| 315 | + raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e |
| 316 | + return usage.RequestUsage(input_tokens=response['inputTokens']) |
| 317 | + |
278 | 318 | @asynccontextmanager |
279 | 319 | async def request_stream( |
280 | 320 | self, |
@@ -642,6 +682,14 @@ def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef: |
642 | 682 | 'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()} |
643 | 683 | } |
644 | 684 |
|
| 685 | + @staticmethod |
| 686 | + def _remove_inference_geo_prefix(model_name: BedrockModelName) -> BedrockModelName: |
| 687 | + """Remove inference geographic prefix from model ID if present.""" |
| 688 | + for prefix in _AWS_BEDROCK_INFERENCE_GEO_PREFIXES: |
| 689 | + if model_name.startswith(prefix): |
| 690 | + return model_name.removeprefix(prefix) |
| 691 | + return model_name |
| 692 | + |
645 | 693 |
|
646 | 694 | @dataclass |
647 | 695 | class BedrockStreamedResponse(StreamedResponse): |
|
0 commit comments