|
28 | 28 |
|
29 | 29 | logger = logging.getLogger(__name__)
|
30 | 30 |
|
31 |
| -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" |
32 | 31 | DEFAULT_BEDROCK_REGION = "us-west-2"
|
| 32 | +DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" |
33 | 33 |
|
34 | 34 | BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
|
35 | 35 | "Input is too long for requested model",
|
@@ -133,8 +133,13 @@ def __init__(
|
133 | 133 |
|
134 | 134 | session = boto_session or boto3.Session()
|
135 | 135 | resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
|
136 |
| - self.config = BedrockModel.BedrockConfig(model_id=self._get_default_model_for_region(resolved_region)) |
137 | 136 |
|
| 137 | + # get default model id based on resolved region |
| 138 | + resolved_model_id = self._get_default_model_for_region(resolved_region) |
| 139 | + if resolved_model_id == "": |
| 140 | + raise ValueError("default model {} is not available in {} region. Specify another model".format(DEFAULT_BEDROCK_MODEL_ID, resolved_region)) |
| 141 | + |
| 142 | + self.config = BedrockModel.BedrockConfig(model_id=resolved_model_id) |
138 | 143 | self.update_config(**model_config)
|
139 | 144 |
|
140 | 145 | logger.debug("config=<%s> | initializing", self.config)
|
@@ -352,18 +357,15 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
|
352 | 357 | return events
|
353 | 358 |
|
354 | 359 | def _get_default_model_for_region(self, region: str) -> str:
|
355 |
| - priorities = [ |
356 |
| - "sonnet-4", |
357 |
| - "3-7-sonnet", # Claude 3.7 sonnet as a fallback |
358 |
| - ] |
359 | 360 | client = boto3.client("bedrock", region_name=region)
|
360 | 361 | response = client.list_inference_profiles()
|
361 | 362 | inferenceProfileSummary = response["inferenceProfileSummaries"]
|
362 |
| - for priority in priorities: |
363 |
| - for profile in inferenceProfileSummary: |
364 |
| - if priority in profile["inferenceProfileId"]: |
365 |
| - return profile["inferenceProfileId"] |
366 |
| - return None |
| 363 | + |
| 364 | + for profile in inferenceProfileSummary: |
| 365 | + if DEFAULT_BEDROCK_MODEL_ID in profile["inferenceProfileId"]: |
| 366 | + return profile["inferenceProfileId"] |
| 367 | + |
| 368 | + return "" |
367 | 369 |
|
368 | 370 | @override
|
369 | 371 | async def stream(
|
|
0 commit comments