generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 86
Expand file tree
/
Copy pathmodel_utils.py
More file actions
90 lines (68 loc) · 2.63 KB
/
model_utils.py
File metadata and controls
90 lines (68 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Utilities for loading model providers in strands."""
import importlib
import json
import os
import pathlib
from typing import Any
from botocore.config import Config
from strands.types.models import Model
# Default model configuration
DEFAULT_MODEL_CONFIG = {
"model_id": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"max_tokens": int(os.getenv("STRANDS_MAX_TOKENS", "64000")),
"boto_client_config": Config(
read_timeout=900,
connect_timeout=900,
retries=dict(max_attempts=3, mode="adaptive"),
),
"additional_request_fields": {
"thinking": {
"type": "enabled",
"budget_tokens": int(os.getenv("STRANDS_BUDGET_TOKENS", "2048")),
}
},
}
def load_path(name: str) -> pathlib.Path:
"""Locate the model provider module file path.
First search "$CWD/.models". If the module file is not found, fall back to the built-in models directory.
Args:
name: Name of the model provider (e.g., bedrock).
Returns:
The file path to the model provider module.
Raises:
ImportError: If the model provider module cannot be found.
"""
path = pathlib.Path.cwd() / ".models" / f"{name}.py"
if not path.exists():
path = pathlib.Path(__file__).parent / ".." / "models" / f"{name}.py"
if not path.exists():
raise ImportError(f"model_provider=<{name}> | does not exist")
return path
def load_config(config: str) -> dict[str, Any]:
"""Load model configuration from a JSON string or file.
Args:
config: A JSON string or path to a JSON file containing model configuration.
If empty string or '{}', the default config is used.
Returns:
The parsed configuration.
"""
if not config or config == "{}":
return DEFAULT_MODEL_CONFIG
if config.endswith(".json"):
with open(config) as fp:
return json.load(fp)
return json.loads(config)
def load_model(path: pathlib.Path, config: dict[str, Any]) -> Model:
"""Dynamically load and instantiate a model provider from a Python module.
Imports the module at the specified path and calls its 'instance' function
with the provided configuration to create a model instance.
Args:
path: Path to the Python module containing the model provider implementation.
config: Configuration to pass to the model provider's instance function.
Returns:
An instantiated model provider.
"""
spec = importlib.util.spec_from_file_location(path.stem, str(path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.instance(**config)