2
2
import multiprocessing
3
3
import socket
4
4
import time
5
+ import os
6
+ import signal
7
+ import atexit
8
+ import sys
9
+ import threading
10
+ import coverage
5
11
from typing import AsyncGenerator , Generator
6
12
from mcp .client .session import ClientSession
7
13
from mcp .client .sse import sse_client
@@ -32,6 +38,40 @@ def server_url(server_port: int) -> str:
32
38
33
39
34
40
def run_server (server_port : int ) -> None :
41
+ # Initialize coverage for subprocesses
42
+ cov = None
43
+ if "COVERAGE_PROCESS_START" in os .environ :
44
+ cov = coverage .Coverage (source = ["fastapi_mcp" ])
45
+ cov .start ()
46
+
47
+ # Create a function to save coverage data at exit
48
+ def cleanup ():
49
+ if cov :
50
+ cov .stop ()
51
+ cov .save ()
52
+
53
+ # Register multiple cleanup mechanisms to ensure coverage data is saved
54
+ atexit .register (cleanup )
55
+
56
+ # Setup signal handler for clean termination
57
+ def handle_signal (signum , frame ):
58
+ cleanup ()
59
+ sys .exit (0 )
60
+
61
+ signal .signal (signal .SIGTERM , handle_signal )
62
+
63
+ # Backup thread to ensure coverage is written if process is terminated abruptly
64
+ def periodic_save ():
65
+ while True :
66
+ time .sleep (1.0 )
67
+ if cov :
68
+ cov .save ()
69
+
70
+ save_thread = threading .Thread (target = periodic_save )
71
+ save_thread .daemon = True
72
+ save_thread .start ()
73
+
74
+ # Configure the server
35
75
fastapi = make_simple_fastapi_app ()
36
76
mcp = FastApiMCP (
37
77
fastapi ,
@@ -40,16 +80,26 @@ def run_server(server_port: int) -> None:
40
80
)
41
81
mcp .mount ()
42
82
83
+ # Start the server
43
84
server = uvicorn .Server (config = uvicorn .Config (app = fastapi , host = HOST , port = server_port , log_level = "error" ))
44
85
server .run ()
45
86
46
87
# Give server time to start
47
88
while not server .started :
48
89
time .sleep (0.5 )
49
90
91
+ # Ensure coverage is saved if exiting the normal way
92
+ if cov :
93
+ cov .stop ()
94
+ cov .save ()
95
+
50
96
51
97
@pytest .fixture ()
52
98
def server (server_port : int ) -> Generator [None , None , None ]:
99
+ # Ensure COVERAGE_PROCESS_START is set in the environment for subprocesses
100
+ coverage_rc = os .path .abspath (".coveragerc" )
101
+ os .environ ["COVERAGE_PROCESS_START" ] = coverage_rc
102
+
53
103
proc = multiprocessing .Process (target = run_server , kwargs = {"server_port" : server_port }, daemon = True )
54
104
proc .start ()
55
105
@@ -69,11 +119,18 @@ def server(server_port: int) -> Generator[None, None, None]:
69
119
70
120
yield
71
121
72
- # Signal the server to stop
73
- proc .kill ()
74
- proc .join (timeout = 2 )
122
+ # Signal the server to stop - added graceful shutdown before kill
123
+ try :
124
+ proc .terminate ()
125
+ proc .join (timeout = 2 )
126
+ except (OSError , AttributeError ):
127
+ pass
128
+
75
129
if proc .is_alive ():
76
- raise RuntimeError ("server process failed to terminate" )
130
+ proc .kill ()
131
+ proc .join (timeout = 2 )
132
+ if proc .is_alive ():
133
+ raise RuntimeError ("server process failed to terminate" )
77
134
78
135
79
136
@pytest .fixture ()
0 commit comments