Skip to content

Commit 0b8898e

Browse files
authored
Merge branch 'strands-agents:main' into feature/vincilb/config-loader
2 parents 03221ce + cfcf93d commit 0b8898e

File tree

7 files changed

+103
-75
lines changed

7 files changed

+103
-75
lines changed

pyproject.toml

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ docs = [
7070
]
7171
litellm = [
7272
"litellm>=1.73.1,<2.0.0",
73+
# https://github.com/BerriAI/litellm/issues/13711
74+
"openai<1.100.0",
7375
]
7476
llamaapi = [
7577
"llama-api-client>=0.1.0,<1.0.0",
@@ -93,7 +95,9 @@ writer = [
9395
sagemaker = [
9496
"boto3>=1.26.0,<2.0.0",
9597
"botocore>=1.29.0,<2.0.0",
96-
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0"
98+
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
99+
# uses OpenAI as part of the implementation
100+
"openai>=1.68.0,<2.0.0",
97101
]
98102

99103
a2a = [
@@ -105,50 +109,7 @@ a2a = [
105109
"starlette>=0.46.2,<1.0.0",
106110
]
107111
all = [
108-
# anthropic
109-
"anthropic>=0.21.0,<1.0.0",
110-
111-
# dev
112-
"commitizen>=4.4.0,<5.0.0",
113-
"hatch>=1.0.0,<2.0.0",
114-
"moto>=5.1.0,<6.0.0",
115-
"mypy>=1.15.0,<2.0.0",
116-
"pre-commit>=3.2.0,<4.2.0",
117-
"pytest>=8.0.0,<9.0.0",
118-
"pytest-asyncio>=0.26.0,<0.27.0",
119-
"pytest-cov>=4.1.0,<5.0.0",
120-
"pytest-xdist>=3.0.0,<4.0.0",
121-
"ruff>=0.4.4,<0.5.0",
122-
123-
# docs
124-
"sphinx>=5.0.0,<6.0.0",
125-
"sphinx-rtd-theme>=1.0.0,<2.0.0",
126-
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
127-
128-
# litellm
129-
"litellm>=1.72.6,<1.73.0",
130-
131-
# llama
132-
"llama-api-client>=0.1.0,<1.0.0",
133-
134-
# mistral
135-
"mistralai>=1.8.2",
136-
137-
# ollama
138-
"ollama>=0.4.8,<1.0.0",
139-
140-
# openai
141-
"openai>=1.68.0,<2.0.0",
142-
143-
# otel
144-
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
145-
146-
# a2a
147-
"a2a-sdk[sql]>=0.3.0,<0.4.0",
148-
"uvicorn>=0.34.2,<1.0.0",
149-
"httpx>=0.28.1,<1.0.0",
150-
"fastapi>=0.115.12,<1.0.0",
151-
"starlette>=0.46.2,<1.0.0",
112+
"strands-agents[a2a,anthropic,dev,docs,litellm,llamaapi,mistral,ollama,openai,otel]",
152113
]
153114

154115
[tool.hatch.version]
@@ -160,7 +121,7 @@ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mis
160121
dependencies = [
161122
"mypy>=1.15.0,<2.0.0",
162123
"ruff>=0.11.6,<0.12.0",
163-
"strands-agents @ {root:uri}"
124+
"strands-agents @ {root:uri}",
164125
]
165126

166127
[tool.hatch.envs.hatch-static-analysis.scripts]

src/strands/event_loop/streaming.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ def remove_blank_messages_content_text(messages: Messages) -> Messages:
4040
# only modify assistant messages
4141
if "role" in message and message["role"] != "assistant":
4242
continue
43-
4443
if "content" in message:
4544
content = message["content"]
4645
has_tool_use = any("toolUse" in item for item in content)
46+
if len(content) == 0:
47+
content.append({"text": "[blank text]"})
48+
continue
4749

4850
if has_tool_use:
4951
# Remove blank 'text' items for assistant messages
@@ -194,16 +196,18 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
194196
state["text"] = ""
195197

196198
elif reasoning_text:
197-
content.append(
198-
{
199-
"reasoningContent": {
200-
"reasoningText": {
201-
"text": state["reasoningText"],
202-
"signature": state["signature"],
203-
}
199+
content_block: ContentBlock = {
200+
"reasoningContent": {
201+
"reasoningText": {
202+
"text": state["reasoningText"],
204203
}
205204
}
206-
)
205+
}
206+
207+
if "signature" in state:
208+
content_block["reasoningContent"]["reasoningText"]["signature"] = state["signature"]
209+
210+
content.append(content_block)
207211
state["reasoningText"] = ""
208212

209213
return state
@@ -263,7 +267,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
263267
"text": "",
264268
"current_tool_use": {},
265269
"reasoningText": "",
266-
"signature": "",
267270
}
268271
state["content"] = state["message"]["content"]
269272

@@ -272,7 +275,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
272275

273276
async for chunk in chunks:
274277
yield {"callback": {"event": chunk}}
275-
276278
if "messageStart" in chunk:
277279
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
278280
elif "contentBlockStart" in chunk:
@@ -312,7 +314,6 @@ async def stream_messages(
312314
logger.debug("model=<%s> | streaming messages", model)
313315

314316
messages = remove_blank_messages_content_text(messages)
315-
316317
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)
317318

318319
async for event in process_stream(chunks):

src/strands/telemetry/metrics.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ..telemetry import metrics_constants as constants
1313
from ..types.content import Message
14-
from ..types.streaming import Metrics, Usage
14+
from ..types.event_loop import Metrics, Usage
1515
from ..types.tools import ToolUse
1616

1717
logger = logging.getLogger(__name__)
@@ -264,6 +264,21 @@ def update_usage(self, usage: Usage) -> None:
264264
self.accumulated_usage["outputTokens"] += usage["outputTokens"]
265265
self.accumulated_usage["totalTokens"] += usage["totalTokens"]
266266

267+
# Handle optional cached token metrics
268+
if "cacheReadInputTokens" in usage:
269+
cache_read_tokens = usage["cacheReadInputTokens"]
270+
self._metrics_client.event_loop_cache_read_input_tokens.record(cache_read_tokens)
271+
self.accumulated_usage["cacheReadInputTokens"] = (
272+
self.accumulated_usage.get("cacheReadInputTokens", 0) + cache_read_tokens
273+
)
274+
275+
if "cacheWriteInputTokens" in usage:
276+
cache_write_tokens = usage["cacheWriteInputTokens"]
277+
self._metrics_client.event_loop_cache_write_input_tokens.record(cache_write_tokens)
278+
self.accumulated_usage["cacheWriteInputTokens"] = (
279+
self.accumulated_usage.get("cacheWriteInputTokens", 0) + cache_write_tokens
280+
)
281+
267282
def update_metrics(self, metrics: Metrics) -> None:
268283
"""Update the accumulated performance metrics with new metrics data.
269284
@@ -325,11 +340,21 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name
325340
f"├─ Cycles: total={summary['total_cycles']}, avg_time={summary['average_cycle_time']:.3f}s, "
326341
f"total_time={summary['total_duration']:.3f}s"
327342
)
328-
yield (
329-
f"├─ Tokens: in={summary['accumulated_usage']['inputTokens']}, "
330-
f"out={summary['accumulated_usage']['outputTokens']}, "
331-
f"total={summary['accumulated_usage']['totalTokens']}"
332-
)
343+
344+
# Build token display with optional cached tokens
345+
token_parts = [
346+
f"in={summary['accumulated_usage']['inputTokens']}",
347+
f"out={summary['accumulated_usage']['outputTokens']}",
348+
f"total={summary['accumulated_usage']['totalTokens']}",
349+
]
350+
351+
# Add cached token info if present
352+
if summary["accumulated_usage"].get("cacheReadInputTokens"):
353+
token_parts.append(f"cache_read_input_tokens={summary['accumulated_usage']['cacheReadInputTokens']}")
354+
if summary["accumulated_usage"].get("cacheWriteInputTokens"):
355+
token_parts.append(f"cache_write_input_tokens={summary['accumulated_usage']['cacheWriteInputTokens']}")
356+
357+
yield f"├─ Tokens: {', '.join(token_parts)}"
333358
yield f"├─ Bedrock Latency: {summary['accumulated_metrics']['latencyMs']}ms"
334359

335360
yield "├─ Tool Usage:"
@@ -421,6 +446,8 @@ class MetricsClient:
421446
event_loop_latency: Histogram
422447
event_loop_input_tokens: Histogram
423448
event_loop_output_tokens: Histogram
449+
event_loop_cache_read_input_tokens: Histogram
450+
event_loop_cache_write_input_tokens: Histogram
424451

425452
tool_call_count: Counter
426453
tool_success_count: Counter
@@ -474,3 +501,9 @@ def create_instruments(self) -> None:
474501
self.event_loop_output_tokens = self.meter.create_histogram(
475502
name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token"
476503
)
504+
self.event_loop_cache_read_input_tokens = self.meter.create_histogram(
505+
name=constants.STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS, unit="token"
506+
)
507+
self.event_loop_cache_write_input_tokens = self.meter.create_histogram(
508+
name=constants.STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS, unit="token"
509+
)

src/strands/telemetry/metrics_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration"
1414
STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens"
1515
STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens"
16+
STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS = "strands.event_loop.cache_read.input.tokens"
17+
STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS = "strands.event_loop.cache_write.input.tokens"

src/strands/types/event_loop.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22

33
from typing import Literal
44

5-
from typing_extensions import TypedDict
5+
from typing_extensions import Required, TypedDict
66

77

8-
class Usage(TypedDict):
8+
class Usage(TypedDict, total=False):
99
"""Token usage information for model interactions.
1010
1111
Attributes:
12-
inputTokens: Number of tokens sent in the request to the model..
12+
inputTokens: Number of tokens sent in the request to the model.
1313
outputTokens: Number of tokens that the model generated for the request.
1414
totalTokens: Total number of tokens (input + output).
15+
cacheReadInputTokens: Number of tokens read from cache (optional).
16+
cacheWriteInputTokens: Number of tokens written to cache (optional).
1517
"""
1618

17-
inputTokens: int
18-
outputTokens: int
19-
totalTokens: int
19+
inputTokens: Required[int]
20+
outputTokens: Required[int]
21+
totalTokens: Required[int]
22+
cacheReadInputTokens: int
23+
cacheWriteInputTokens: int
2024

2125

2226
class Metrics(TypedDict):

tests/strands/event_loop/test_streaming.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ def moto_autouse(moto_env, moto_mock_aws):
2626
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {}}]},
2727
{"role": "assistant", "content": [{"text": ""}, {"toolUse": {}}]},
2828
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},
29+
{"role": "assistant", "content": []},
2930
{"role": "assistant"},
3031
{"role": "user", "content": [{"text": " \n"}]},
3132
],
3233
[
3334
{"role": "assistant", "content": [{"text": "a"}, {"toolUse": {}}]},
3435
{"role": "assistant", "content": [{"toolUse": {}}]},
3536
{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]},
37+
{"role": "assistant", "content": [{"text": "[blank text]"}]},
3638
{"role": "assistant"},
3739
{"role": "user", "content": [{"text": " \n"}]},
3840
],
@@ -216,6 +218,21 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
216218
"signature": "123",
217219
},
218220
),
221+
# Reasoning without signature
222+
(
223+
{
224+
"content": [],
225+
"current_tool_use": {},
226+
"text": "",
227+
"reasoningText": "test",
228+
},
229+
{
230+
"content": [{"reasoningContent": {"reasoningText": {"text": "test"}}}],
231+
"current_tool_use": {},
232+
"text": "",
233+
"reasoningText": "",
234+
},
235+
),
219236
# Empty
220237
(
221238
{
@@ -260,6 +277,18 @@ def test_extract_usage_metrics():
260277
assert tru_usage == exp_usage and tru_metrics == exp_metrics
261278

262279

280+
def test_extract_usage_metrics_with_cache_tokens():
281+
event = {
282+
"usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, "cacheReadInputTokens": 0},
283+
"metrics": {"latencyMs": 0},
284+
}
285+
286+
tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event)
287+
exp_usage, exp_metrics = event["usage"], event["metrics"]
288+
289+
assert tru_usage == exp_usage and tru_metrics == exp_metrics
290+
291+
263292
@pytest.mark.parametrize(
264293
("response", "exp_events"),
265294
[

tests/strands/telemetry/test_metrics.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def usage(request):
9090
"inputTokens": 1,
9191
"outputTokens": 2,
9292
"totalTokens": 3,
93+
"cacheWriteInputTokens": 2,
9394
}
9495
if hasattr(request, "param"):
9596
params.update(request.param)
@@ -315,17 +316,14 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met
315316
event_loop_metrics.update_usage(usage)
316317

317318
tru_usage = event_loop_metrics.accumulated_usage
318-
exp_usage = Usage(
319-
inputTokens=3,
320-
outputTokens=6,
321-
totalTokens=9,
322-
)
319+
exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheWriteInputTokens=6)
323320

324321
assert tru_usage == exp_usage
325322
mock_get_meter_provider.return_value.get_meter.assert_called()
326323
metrics_client = event_loop_metrics._metrics_client
327324
metrics_client.event_loop_input_tokens.record.assert_called()
328325
metrics_client.event_loop_output_tokens.record.assert_called()
326+
metrics_client.event_loop_cache_write_input_tokens.record.assert_called()
329327

330328

331329
def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider):

0 commit comments

Comments
 (0)