Skip to content

Commit 2383b11

Browse files
committed
feat: 添加agent adk server中间件
(cherry picked from commit 76d6280602efd51249150aa5f45c440571bb9a0d)
1 parent 77b0633 commit 2383b11

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed

agentkit/apps/agent_server_app/agent_server_app.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from fastapi import Request
2020
from fastapi import HTTPException
2121
from fastapi.responses import StreamingResponse
22+
from opentelemetry import trace
2223
from google.adk.agents.base_agent import BaseAgent
2324
from google.adk.artifacts.in_memory_artifact_service import (
2425
InMemoryArtifactService,
@@ -41,6 +42,8 @@
4142
from veadk.memory.short_term_memory import ShortTermMemory
4243

4344
from agentkit.apps.base_app import BaseAgentkitApp
45+
from agentkit.apps.agent_server_app.telemetry import telemetry
46+
from agentkit.apps.agent_server_app.middleware import AgentkitTelemetryHTTPMiddleware
4447

4548

4649
class AgentKitAgentLoader(BaseAgentLoader):
@@ -87,7 +90,13 @@ def __init__(
8790

8891
self.app = self.server.get_fast_api_app()
8992

93+
# Attach ASGI middleware for unified telemetry across all routes
94+
self.app.add_middleware(AgentkitTelemetryHTTPMiddleware)
95+
9096
async def _invoke_compat(request: Request):
97+
# Use current request span from middleware for telemetry
98+
span = trace.get_current_span()
99+
91100
# Extract headers (fallback keys supported)
92101
headers = request.headers
93102
user_id = (
@@ -126,6 +135,14 @@ async def _invoke_compat(request: Request):
126135
text = ""
127136
content = types.UserContent(parts=[types.Part(text=text or "")])
128137

138+
# trace request attributes on current span
139+
telemetry.trace_agent_server(
140+
func_name="_invoke_compat",
141+
span=span,
142+
headers=dict(headers),
143+
text=text or "",
144+
)
145+
129146
# Ensure session exists
130147
session = await self.server.session_service.get_session(
131148
app_name=app_name, user_id=user_id, session_id=session_id
@@ -154,8 +171,11 @@ async def event_generator():
154171
)
155172
+ "\n\n"
156173
)
174+
# finish span on successful end of stream handled by middleware
175+
pass
157176
except Exception as e:
158177
yield f'data: {{"error": "{str(e)}"}}\n\n'
178+
telemetry.trace_agent_server_finish(func_result="", exception=e)
159179

160180
return StreamingResponse(
161181
event_generator(),
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Callable
16+
17+
from opentelemetry import trace
18+
from opentelemetry import context as context_api
19+
20+
from agentkit.apps.agent_server_app.telemetry import telemetry
21+
22+
23+
trace_paths = ["/run_sse", "/run", "/invoke"]
24+
25+
26+
class AgentkitTelemetryHTTPMiddleware:
27+
def __init__(self, app: Callable):
28+
self.app = app
29+
30+
async def __call__(self, scope, receive, send):
31+
print(f"test: {scope}")
32+
if scope["type"] != "http":
33+
return await self.app(scope, receive, send)
34+
35+
method = scope.get("method", "")
36+
path = scope.get("path", "")
37+
headers_list = scope.get("headers", [])
38+
headers = {k.decode("latin-1"): v.decode("latin-1") for k, v in headers_list}
39+
40+
span = telemetry.tracer.start_span(name="agent_server_request")
41+
ctx = trace.set_span_in_context(span)
42+
context_api.attach(ctx)
43+
44+
user_id = headers.get("user_id") or headers.get("x-user-id") or ""
45+
session_id = headers.get("session_id") or headers.get("x-session-id") or ""
46+
headers_like = {"user_id": user_id, "session_id": session_id}
47+
48+
telemetry.trace_agent_server(
49+
func_name=f"{method} {path}",
50+
span=span,
51+
headers=headers_like,
52+
text="", # do not consume body in middleware
53+
)
54+
55+
async def send_wrapper(message):
56+
try:
57+
if message.get("type") == "http.response.body":
58+
more_body = message.get("more_body", False)
59+
if not more_body:
60+
telemetry.trace_agent_server_finish(
61+
func_result="", exception=None
62+
)
63+
elif message.get("type") == "http.response.start":
64+
# could record status code if needed
65+
pass
66+
finally:
67+
await send(message)
68+
69+
try:
70+
await self.app(scope, receive, send_wrapper)
71+
except Exception as e:
72+
telemetry.trace_agent_server_finish(func_result="", exception=e)
73+
raise
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import time
17+
from typing import Optional
18+
19+
from opentelemetry import trace
20+
from opentelemetry.trace import get_tracer
21+
from opentelemetry.metrics import get_meter
22+
from opentelemetry.trace.span import Span
23+
24+
from agentkit.apps.utils import safe_serialize_to_json_string
25+
26+
_GEN_AI_CLIENT_OPERATION_DURATION_BUCKETS = [
27+
0.01,
28+
0.02,
29+
0.04,
30+
0.08,
31+
0.16,
32+
0.32,
33+
0.64,
34+
1.28,
35+
2.56,
36+
5.12,
37+
10.24,
38+
20.48,
39+
40.96,
40+
81.92,
41+
163.84,
42+
]
43+
44+
logger = logging.getLogger("agentkit." + __name__)
45+
46+
47+
class Telemetry:
48+
def __init__(self):
49+
self.tracer = get_tracer("agentkit.agent_server_app")
50+
self.meter = get_meter("agentkit.agent_server_app")
51+
self.latency_histogram = self.meter.create_histogram(
52+
name="agentkit_runtime_operation_latency",
53+
description="operation latency",
54+
unit="s",
55+
explicit_bucket_boundaries_advisory=_GEN_AI_CLIENT_OPERATION_DURATION_BUCKETS,
56+
)
57+
58+
def trace_agent_server(
59+
self,
60+
func_name: str,
61+
span: Span,
62+
headers: dict,
63+
text: str,
64+
) -> None:
65+
span.set_attribute(key="gen_ai.system", value="agentkit")
66+
span.set_attribute(key="gen_ai.func_name", value=func_name)
67+
68+
span.set_attribute(
69+
key="gen_ai.request.headers",
70+
value=safe_serialize_to_json_string(headers),
71+
)
72+
73+
session_id = headers.get("session_id") or headers.get("x-session-id") or ""
74+
if session_id:
75+
span.set_attribute(key="gen_ai.session.id", value=session_id)
76+
user_id = headers.get("user_id") or headers.get("x-user-id") or ""
77+
if user_id:
78+
span.set_attribute(key="gen_ai.user.id", value=user_id)
79+
80+
span.set_attribute(
81+
key="gen_ai.input", value=safe_serialize_to_json_string(text)
82+
)
83+
84+
span.set_attribute(key="gen_ai.span.kind", value="workflow")
85+
span.set_attribute(key="gen_ai.operation.name", value="invoke_agent")
86+
span.set_attribute(key="gen_ai.operation.type", value="agent_server")
87+
88+
def trace_agent_server_finish(
89+
self,
90+
func_result: str,
91+
exception: Optional[Exception],
92+
) -> None:
93+
span = trace.get_current_span()
94+
if span and span.is_recording():
95+
span.set_attribute(key="gen_ai.output", value=func_result)
96+
attributes = {
97+
"gen_ai_operation_name": "invoke_agent",
98+
"gen_ai_operation_type": "agent_server",
99+
}
100+
if exception:
101+
self.handle_exception(span, exception)
102+
attributes["error_type"] = exception.__class__.__name__
103+
104+
if hasattr(span, "start_time") and self.latency_histogram:
105+
duration = (time.time_ns() - span.start_time) / 1e9 # type: ignore
106+
self.latency_histogram.record(duration, attributes)
107+
span.end()
108+
109+
@staticmethod
110+
def handle_exception(span: trace.Span, exception: Exception) -> None:
111+
status = trace.Status(
112+
status_code=trace.StatusCode.ERROR,
113+
description=f"{type(exception).__name__}: {exception}",
114+
)
115+
span.set_status(status)
116+
span.record_exception(exception)
117+
118+
119+
telemetry = Telemetry()

0 commit comments

Comments
 (0)