|
29 | 29 | logger = logging.getLogger(__name__)
|
30 | 30 |
|
31 | 31 | DEFAULT_BEDROCK_REGION = "us-west-2"
|
32 |
| -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" |
| 32 | +DEFAULT_BEDROCK_MODEL_ID = "anthropic.claude-sonnet-4-20250514-v1:0" |
33 | 33 |
|
34 | 34 | BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
|
35 | 35 | "Input is too long for requested model",
|
@@ -137,7 +137,11 @@ def __init__(
|
137 | 137 | # get default model id based on resolved region
|
138 | 138 | resolved_model_id = self._get_default_model_for_region(resolved_region)
|
139 | 139 | if resolved_model_id == "":
|
140 |
| - raise ValueError("default model {} is not available in {} region. Specify another model".format(DEFAULT_BEDROCK_MODEL_ID, resolved_region)) |
| 140 | + raise ValueError( |
| 141 | + "default model {} is not available in {} region. Specify another model".format( |
| 142 | + DEFAULT_BEDROCK_MODEL_ID, resolved_region |
| 143 | + ) |
| 144 | + ) |
141 | 145 |
|
142 | 146 | self.config = BedrockModel.BedrockConfig(model_id=resolved_model_id)
|
143 | 147 | self.update_config(**model_config)
|
@@ -357,15 +361,18 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
|
357 | 361 | return events
|
358 | 362 |
|
359 | 363 | def _get_default_model_for_region(self, region: str) -> str:
|
360 |
| - client = boto3.client("bedrock", region_name=region) |
361 |
| - response = client.list_inference_profiles() |
362 |
| - inferenceProfileSummary = response["inferenceProfileSummaries"] |
363 |
| - |
364 |
| - for profile in inferenceProfileSummary: |
365 |
| - if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]: |
366 |
| - return profile["inferenceProfileId"] |
367 |
| - |
368 |
| - return "" |
| 364 | + try: |
| 365 | + client = boto3.client("bedrock", region_name=region) |
| 366 | + response = client.list_inference_profiles() |
| 367 | + inference_profile_summary = response["inferenceProfileSummaries"] |
| 368 | + |
| 369 | + for profile in inference_profile_summary: |
| 370 | + if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]: |
| 371 | + return str(profile["inferenceProfileId"]) |
| 372 | + |
| 373 | + return "" |
| 374 | + except ClientError as e: |
| 375 | + raise e |
369 | 376 |
|
370 | 377 | @override
|
371 | 378 | async def stream(
|
|
0 commit comments