diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 332ab2ae3..8b2c9ddf4 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -11,7 +11,7 @@ from ..telemetry import metrics_constants as constants from ..types.content import Message -from ..types.streaming import Metrics, Usage +from ..types.event_loop import Metrics, Usage from ..types.tools import ToolUse logger = logging.getLogger(__name__) @@ -264,6 +264,21 @@ def update_usage(self, usage: Usage) -> None: self.accumulated_usage["outputTokens"] += usage["outputTokens"] self.accumulated_usage["totalTokens"] += usage["totalTokens"] + # Handle optional cached token metrics + if "cacheReadInputTokens" in usage: + cache_read_tokens = usage["cacheReadInputTokens"] + self._metrics_client.event_loop_cache_read_input_tokens.record(cache_read_tokens) + self.accumulated_usage["cacheReadInputTokens"] = ( + self.accumulated_usage.get("cacheReadInputTokens", 0) + cache_read_tokens + ) + + if "cacheWriteInputTokens" in usage: + cache_write_tokens = usage["cacheWriteInputTokens"] + self._metrics_client.event_loop_cache_write_input_tokens.record(cache_write_tokens) + self.accumulated_usage["cacheWriteInputTokens"] = ( + self.accumulated_usage.get("cacheWriteInputTokens", 0) + cache_write_tokens + ) + def update_metrics(self, metrics: Metrics) -> None: """Update the accumulated performance metrics with new metrics data. @@ -325,11 +340,21 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name f"├─ Cycles: total={summary['total_cycles']}, avg_time={summary['average_cycle_time']:.3f}s, " f"total_time={summary['total_duration']:.3f}s" ) - yield ( - f"├─ Tokens: in={summary['accumulated_usage']['inputTokens']}, " - f"out={summary['accumulated_usage']['outputTokens']}, " - f"total={summary['accumulated_usage']['totalTokens']}" - ) + + # Build token display with optional cached tokens + token_parts = [ + f"in={summary['accumulated_usage']['inputTokens']}", + f"out={summary['accumulated_usage']['outputTokens']}", + f"total={summary['accumulated_usage']['totalTokens']}", + ] + + # Add cached token info if present + if summary["accumulated_usage"].get("cacheReadInputTokens"): + token_parts.append(f"cache_read={summary['accumulated_usage']['cacheReadInputTokens']}") + if summary["accumulated_usage"].get("cacheWriteInputTokens"): + token_parts.append(f"cache_write={summary['accumulated_usage']['cacheWriteInputTokens']}") + + yield f"├─ Tokens: {', '.join(token_parts)}" yield f"├─ Bedrock Latency: {summary['accumulated_metrics']['latencyMs']}ms" yield "├─ Tool Usage:" @@ -421,6 +446,8 @@ class MetricsClient: event_loop_latency: Histogram event_loop_input_tokens: Histogram event_loop_output_tokens: Histogram + event_loop_cache_read_input_tokens: Histogram + event_loop_cache_write_input_tokens: Histogram tool_call_count: Counter tool_success_count: Counter @@ -474,3 +501,9 @@ def create_instruments(self) -> None: self.event_loop_output_tokens = self.meter.create_histogram( name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token" ) + self.event_loop_cache_read_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS, unit="token" + ) + self.event_loop_cache_write_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS, unit="token" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py index b622eebff..f8fac34da 100644 --- a/src/strands/telemetry/metrics_constants.py +++ b/src/strands/telemetry/metrics_constants.py @@ -13,3 +13,5 @@ STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" +STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS = "strands.event_loop.cache_read.input.tokens" +STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS = "strands.event_loop.cache_write.input.tokens" diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 7be33b6fd..2c240972b 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -2,21 +2,25 @@ from typing import Literal -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict -class Usage(TypedDict): +class Usage(TypedDict, total=False): """Token usage information for model interactions. Attributes: - inputTokens: Number of tokens sent in the request to the model.. + inputTokens: Number of tokens sent in the request to the model. outputTokens: Number of tokens that the model generated for the request. totalTokens: Total number of tokens (input + output). + cacheReadInputTokens: Number of tokens read from cache (optional). + cacheWriteInputTokens: Number of tokens written to cache (optional). """ - inputTokens: int - outputTokens: int - totalTokens: int + inputTokens: Required[int] + outputTokens: Required[int] + totalTokens: Required[int] + cacheReadInputTokens: int + cacheWriteInputTokens: int class Metrics(TypedDict): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 921fd91de..18da3e3ed 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -260,6 +260,18 @@ def test_extract_usage_metrics(): assert tru_usage == exp_usage and tru_metrics == exp_metrics +def test_extract_usage_metrics_with_cache_tokens(): + event = { + "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, "cacheReadInputTokens": 0}, + "metrics": {"latencyMs": 0}, + } + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage, exp_metrics = event["usage"], event["metrics"] + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + @pytest.mark.parametrize( ("response", "exp_events"), [ diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 215e1efde..89a612e72 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -90,6 +90,7 @@ def usage(request): "inputTokens": 1, "outputTokens": 2, "totalTokens": 3, + "cacheWriteInputTokens": 10, } if hasattr(request, "param"): params.update(request.param) @@ -315,17 +316,14 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met event_loop_metrics.update_usage(usage) tru_usage = event_loop_metrics.accumulated_usage - exp_usage = Usage( - inputTokens=3, - outputTokens=6, - totalTokens=9, - ) + exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheWriteInputTokens=30) assert tru_usage == exp_usage mock_get_meter_provider.return_value.get_meter.assert_called() metrics_client = event_loop_metrics._metrics_client metrics_client.event_loop_input_tokens.record.assert_called() metrics_client.event_loop_output_tokens.record.assert_called() + metrics_client.event_loop_cache_write_input_tokens.record.assert_called() def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider):