Skip to content

Commit 0535c6a

Browse files
authored
fix(mcp): add mcp.server parent span wrapper for FastMCP tool calls (#3382)
1 parent 8287b30 commit 0535c6a

File tree

8 files changed

+1078
-133
lines changed

8 files changed

+1078
-133
lines changed

packages/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/fastmcp_instrumentation.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -57,54 +57,68 @@ async def traced_method(wrapped, instance, args, kwargs):
5757
tool_arguments = args[1] if len(args) > 1 else {}
5858

5959
entity_name = tool_key if tool_key else "unknown_tool"
60-
span_name = f"{entity_name}.tool"
6160

62-
with self._tracer.start_as_current_span(span_name) as span:
63-
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, TraceloopSpanKindValues.TOOL.value)
64-
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, entity_name)
61+
# Create parent server.mcp span
62+
with self._tracer.start_as_current_span("mcp.server") as mcp_span:
63+
mcp_span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, "server")
64+
mcp_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, "mcp.server")
6565

66-
if self._should_send_prompts():
67-
try:
68-
input_data = {
69-
"tool_name": entity_name,
70-
"arguments": tool_arguments
71-
}
72-
json_input = json.dumps(input_data, cls=self._get_json_encoder())
73-
truncated_input = self._truncate_json_if_needed(json_input)
74-
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_INPUT, truncated_input)
75-
except (TypeError, ValueError):
76-
pass # Skip input logging if serialization fails
77-
78-
try:
79-
result = await wrapped(*args, **kwargs)
80-
81-
# Add output in traceloop format
82-
if self._should_send_prompts() and result:
66+
# Create nested tool span
67+
span_name = f"{entity_name}.tool"
68+
with self._tracer.start_as_current_span(span_name) as tool_span:
69+
tool_span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, TraceloopSpanKindValues.TOOL.value)
70+
tool_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, entity_name)
71+
72+
if self._should_send_prompts():
8373
try:
84-
# Convert FastMCP Content objects to serializable format
85-
output_data = []
86-
for item in result:
87-
if hasattr(item, 'text'):
88-
output_data.append({"type": "text", "content": item.text})
89-
elif hasattr(item, '__dict__'):
90-
output_data.append(item.__dict__)
91-
else:
92-
output_data.append(str(item))
93-
94-
json_output = json.dumps(output_data, cls=self._get_json_encoder())
95-
truncated_output = self._truncate_json_if_needed(json_output)
96-
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_OUTPUT, truncated_output)
74+
input_data = {
75+
"tool_name": entity_name,
76+
"arguments": tool_arguments
77+
}
78+
json_input = json.dumps(input_data, cls=self._get_json_encoder())
79+
truncated_input = self._truncate_json_if_needed(json_input)
80+
tool_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_INPUT, truncated_input)
9781
except (TypeError, ValueError):
98-
pass # Skip output logging if serialization fails
82+
pass # Skip input logging if serialization fails
9983

100-
span.set_status(Status(StatusCode.OK))
101-
return result
102-
103-
except Exception as e:
104-
span.set_attribute(ERROR_TYPE, type(e).__name__)
105-
span.record_exception(e)
106-
span.set_status(Status(StatusCode.ERROR, str(e)))
107-
raise
84+
try:
85+
result = await wrapped(*args, **kwargs)
86+
87+
# Add output in traceloop format to tool span
88+
if self._should_send_prompts() and result:
89+
try:
90+
# Convert FastMCP Content objects to serializable format
91+
output_data = []
92+
for item in result:
93+
if hasattr(item, 'text'):
94+
output_data.append({"type": "text", "content": item.text})
95+
elif hasattr(item, '__dict__'):
96+
output_data.append(item.__dict__)
97+
else:
98+
output_data.append(str(item))
99+
100+
json_output = json.dumps(output_data, cls=self._get_json_encoder())
101+
truncated_output = self._truncate_json_if_needed(json_output)
102+
tool_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_OUTPUT, truncated_output)
103+
104+
# Also add response to MCP span
105+
mcp_span.set_attribute(SpanAttributes.MCP_RESPONSE_VALUE, truncated_output)
106+
except (TypeError, ValueError):
107+
pass # Skip output logging if serialization fails
108+
109+
tool_span.set_status(Status(StatusCode.OK))
110+
mcp_span.set_status(Status(StatusCode.OK))
111+
return result
112+
113+
except Exception as e:
114+
tool_span.set_attribute(ERROR_TYPE, type(e).__name__)
115+
tool_span.record_exception(e)
116+
tool_span.set_status(Status(StatusCode.ERROR, str(e)))
117+
118+
mcp_span.set_attribute(ERROR_TYPE, type(e).__name__)
119+
mcp_span.record_exception(e)
120+
mcp_span.set_status(Status(StatusCode.ERROR, str(e)))
121+
raise
108122

109123
return traced_method
110124

packages/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
1818

1919
from opentelemetry.instrumentation.mcp.version import __version__
20-
from opentelemetry.instrumentation.mcp.utils import dont_throw
20+
from opentelemetry.instrumentation.mcp.utils import dont_throw, Config
2121
from opentelemetry.instrumentation.mcp.fastmcp_instrumentation import FastMCPInstrumentor
2222

2323
_instruments = ("mcp >= 1.6.0",)
2424

2525

2626
class McpInstrumentor(BaseInstrumentor):
27-
def __init__(self):
27+
def __init__(self, exception_logger=None):
2828
super().__init__()
29+
Config.exception_logger = exception_logger
2930
self._fastmcp_instrumentor = FastMCPInstrumentor()
3031

3132
def instrumentation_dependencies(self) -> Collection[str]:
@@ -38,6 +39,20 @@ def _instrument(self, **kwargs):
3839
# Instrument FastMCP
3940
self._fastmcp_instrumentor.instrument(tracer)
4041

42+
# Instrument FastMCP Client to create a session-level span
43+
register_post_import_hook(
44+
lambda _: wrap_function_wrapper(
45+
"fastmcp.client", "Client.__aenter__", self._fastmcp_client_enter_wrapper(tracer)
46+
),
47+
"fastmcp.client",
48+
)
49+
register_post_import_hook(
50+
lambda _: wrap_function_wrapper(
51+
"fastmcp.client", "Client.__aexit__", self._fastmcp_client_exit_wrapper(tracer)
52+
),
53+
"fastmcp.client",
54+
)
55+
4156
register_post_import_hook(
4257
lambda _: wrap_function_wrapper(
4358
"mcp.client.sse", "sse_client", self._transport_wrapper(tracer)
@@ -181,6 +196,52 @@ async def traced_method(wrapped, instance, args, kwargs):
181196

182197
return traced_method
183198

199+
def _fastmcp_client_enter_wrapper(self, tracer):
200+
"""Wrapper for FastMCP Client.__aenter__ to start a session trace"""
201+
@dont_throw
202+
async def traced_method(wrapped, instance, args, kwargs):
203+
# Start a root span for the MCP client session and make it current
204+
span_context_manager = tracer.start_as_current_span("mcp.client.session")
205+
span = span_context_manager.__enter__()
206+
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, "session")
207+
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, "mcp.client.session")
208+
209+
# Store the span context manager on the instance to properly exit it later
210+
setattr(instance, '_tracing_session_context_manager', span_context_manager)
211+
212+
try:
213+
# Call the original method
214+
result = await wrapped(*args, **kwargs)
215+
return result
216+
except Exception as e:
217+
span.set_attribute(ERROR_TYPE, type(e).__name__)
218+
span.record_exception(e)
219+
span.set_status(Status(StatusCode.ERROR, str(e)))
220+
raise
221+
return traced_method
222+
223+
def _fastmcp_client_exit_wrapper(self, tracer):
224+
"""Wrapper for FastMCP Client.__aexit__ to end the session trace"""
225+
@dont_throw
226+
async def traced_method(wrapped, instance, args, kwargs):
227+
try:
228+
# Call the original method first
229+
result = await wrapped(*args, **kwargs)
230+
231+
# End the session span context manager
232+
context_manager = getattr(instance, '_tracing_session_context_manager', None)
233+
if context_manager:
234+
context_manager.__exit__(None, None, None)
235+
236+
return result
237+
except Exception as e:
238+
# End the session span context manager with exception info
239+
context_manager = getattr(instance, '_tracing_session_context_manager', None)
240+
if context_manager:
241+
context_manager.__exit__(type(e), e, e.__traceback__)
242+
raise
243+
return traced_method
244+
184245
async def _handle_tool_call(self, tracer, method, params, args, kwargs, wrapped):
185246
"""Handle tools/call with tool semantics"""
186247
# Extract the actual tool name

packages/opentelemetry-instrumentation-mcp/poetry.lock

Lines changed: 36 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/opentelemetry-instrumentation-mcp/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pytest = "^8.2.2"
3939
pytest-sugar = "1.0.0"
4040
pytest-recording = "^0.13.1"
4141
opentelemetry-sdk = "^1.27.0"
42+
pytest-asyncio = "^1.2.0"
4243

4344
[build-system]
4445
requires = ["poetry-core"]
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
async def test_fastmcp_server_mcp_parent_span(span_exporter, tracer_provider) -> None:
2+
"""Test that FastMCP tool calls have mcp.server as parent span."""
3+
from fastmcp import FastMCP, Client
4+
5+
# Create a simple FastMCP server
6+
server = FastMCP("test-server")
7+
8+
@server.tool()
9+
async def test_tool(x: int) -> int:
10+
"""A simple test tool."""
11+
return x * 2
12+
13+
# Use in-memory client to connect to the server
14+
async with Client(server) as client:
15+
# Test tool calling
16+
result = await client.call_tool("test_tool", {"x": 5})
17+
assert len(result) == 1
18+
assert result[0].text == "10"
19+
20+
# Get the finished spans
21+
spans = span_exporter.get_finished_spans()
22+
23+
# Debug: Print span details with parent info
24+
print(f"\nTotal spans: {len(spans)}")
25+
for i, span in enumerate(spans):
26+
parent_id = span.parent.span_id if span.parent else "None"
27+
print(f"Span {i}: name='{span.name}', span_id={span.get_span_context().span_id}, "
28+
f"parent_id={parent_id}, trace_id={span.get_span_context().trace_id}")
29+
30+
# Look specifically for mcp.server and tool spans
31+
server_mcp_spans = [span for span in spans if span.name == 'mcp.server']
32+
tool_spans = [span for span in spans if span.name.endswith('.tool')]
33+
34+
print(f"\nMCP Server spans: {len(server_mcp_spans)}")
35+
print(f"Tool spans: {len(tool_spans)}")
36+
37+
# Check if we have the expected spans
38+
assert len(server_mcp_spans) >= 1, f"Expected at least 1 mcp.server span, found {len(server_mcp_spans)}"
39+
assert len(tool_spans) >= 1, f"Expected at least 1 tool span, found {len(tool_spans)}"
40+
41+
# Find server-side spans (should be in same trace)
42+
server_side_spans = []
43+
for server_span in server_mcp_spans:
44+
for tool_span in tool_spans:
45+
if (server_span.get_span_context().trace_id == tool_span.get_span_context().trace_id and
46+
tool_span.parent and
47+
tool_span.parent.span_id == server_span.get_span_context().span_id):
48+
server_side_spans.append((server_span, tool_span))
49+
break
50+
51+
print(f"\nFound {len(server_side_spans)} server-side span pairs")
52+
53+
# Verify we found at least one proper parent-child relationship
54+
assert len(server_side_spans) >= 1, "Expected at least one mcp.server span to be parent of a tool span"
55+
56+
# Check the specific parent-child relationship
57+
server_span, tool_span = server_side_spans[0]
58+
assert tool_span.parent.span_id == server_span.get_span_context().span_id, \
59+
"Tool span should be child of mcp.server span"
60+
assert server_span.get_span_context().trace_id == tool_span.get_span_context().trace_id, \
61+
"Parent and child should be in same trace"
62+
63+
# Verify MCP server span attributes
64+
assert server_span.attributes.get('traceloop.span.kind') == 'server', \
65+
"Server span should have server span kind"
66+
assert server_span.attributes.get('traceloop.entity.name') == 'mcp.server', \
67+
"Server span should have mcp.server entity name"
68+
69+
# Verify tool span attributes
70+
assert tool_span.attributes.get('traceloop.span.kind') == 'tool', \
71+
"Tool span should have tool span kind"
72+
assert tool_span.attributes.get('traceloop.entity.name') == 'test_tool', \
73+
"Tool span should have correct entity name"

0 commit comments

Comments
 (0)