Skip to content

Commit de62faa

Browse files
authored
PSP-245 MCP/Passthrough Route Forwarder (#706)
* PSP-245 MCP Route Forwarder Adding a forwarder for /mcp routes. This allows us to forward MCP requests to containers running running a service with MCP compatible requests. https://linear.app/scale-epd/issue/PSP-245/mcp-model-engine-integration * Renaming to make Forwarder more generic * Make Passthrough Routes an array * Rename helper function * Add ability to pass forwarder type. * Fix mock objects causing failing tests. * Add sync forwarder * Passing through destination path * Fix typo in config * Add back status * Sanitizing Headers
1 parent fe36be6 commit de62faa

File tree

13 files changed

+956
-12
lines changed

13 files changed

+956
-12
lines changed

.ruff.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Same as Black.
22
line-length = 100
3-
4-
ignore = ["E501"]
3+
target-version = "py310"
4+
lint.ignore = ["E501"]
55
exclude = ["gen", "alembic"]

charts/model-engine/Chart.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type: application
1515
# This is the chart version. This version number should be incremented each time you make changes
1616
# to the chart and its templates, including the app version.
1717
# Versions are expected to follow Semantic Versioning (https://semver.org/)
18-
version: 0.1.9
18+
version: 0.1.10
1919

2020
# This is the version number of the application being deployed. This version number should be
2121
# incremented each time you make changes to the application. Versions are not expected to

charts/model-engine/templates/service_template_config_map.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ data:
184184
- "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}"
185185
- --set
186186
- "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}"
187+
- --set
188+
- "forwarder.sync.forwarder_type=${FORWARDER_TYPE}"
189+
- --set
190+
- "forwarder.stream.forwarder_type=${FORWARDER_TYPE}"
187191
{{- $sync_forwarder_template_env | nindent 14 }}
188192
readinessProbe:
189193
httpGet:
@@ -616,6 +620,10 @@ data:
616620
- "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}"
617621
- --set
618622
- "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}"
623+
- --set
624+
- "forwarder.sync.forwarder_type=${FORWARDER_TYPE}"
625+
- --set
626+
- "forwarder.stream.forwarder_type=${FORWARDER_TYPE}"
619627
{{- $sync_forwarder_template_env | nindent 16 }}
620628
readinessProbe:
621629
httpGet:

model-engine/model_engine_server/domain/entities/model_bundle_entity.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ class ModelBundleFrameworkType(str, Enum):
3131
CUSTOM = "custom_base_image"
3232

3333

34+
class ForwarderType(str, Enum):
35+
PASSTHROUGH = "passthrough"
36+
DEFAULT = "default"
37+
38+
3439
class ModelBundleEnvironmentParams(BaseModel):
3540
"""
3641
This is the entity-layer class for the Model Bundle environment parameters. Being an
@@ -158,6 +163,7 @@ class RunnableImageLike(BaseModel, ABC):
158163
protocol: Literal["http"] # TODO: add support for other protocols (e.g. grpc)
159164
readiness_initial_delay_seconds: int = 120
160165
extra_routes: List[str] = Field(default_factory=list)
166+
forwarder_type: Optional[ForwarderType] = ForwarderType.DEFAULT
161167
worker_command: Optional[List[str]] = None
162168
worker_env: Optional[Dict[str, str]] = None
163169

model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ forwarder:
1919
model_engine_unwrap: true
2020
serialize_results_as_string: false
2121
extra_routes: []
22+
2223
max_concurrency: 100

model-engine/model_engine_server/inference/forwarding/forwarding.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,103 @@ def endpoint(route: str) -> str:
627627
)
628628

629629

630+
@dataclass
631+
class PassthroughForwarder(ModelEngineSerializationMixin):
632+
passthrough_endpoint: str
633+
634+
async def _make_request(
635+
self, request: Any, aioclient: aiohttp.ClientSession
636+
) -> aiohttp.ClientResponse:
637+
headers: dict[str, str] = dict(request.headers)
638+
excluded_headers: set[str] = {
639+
"host",
640+
"content-length",
641+
"connection",
642+
}
643+
headers = {k: v for k, v in headers.items() if k.lower() not in excluded_headers}
644+
url = request.url
645+
target_url: str = f"{self.passthrough_endpoint.rstrip('/')}"
646+
647+
if url.query:
648+
target_url = f"{target_url}?{url.query}"
649+
650+
return await aioclient.request(
651+
method=request.method,
652+
url=target_url,
653+
data=await request.body() if request.method in ["POST", "PUT", "PATCH"] else None,
654+
headers=headers,
655+
)
656+
657+
async def forward_stream(self, request: Any):
658+
async with aiohttp.ClientSession() as aioclient:
659+
response = await self._make_request(request, aioclient)
660+
response_headers = response.headers
661+
yield (response_headers, response.status)
662+
663+
if response.status != 200:
664+
yield await response.read()
665+
666+
async for chunk in response.content.iter_chunks():
667+
yield chunk[0]
668+
669+
yield await response.read()
670+
671+
async def forward_sync(self, request: Any):
672+
async with aiohttp.ClientSession() as aioclient:
673+
response = await self._make_request(request, aioclient)
674+
return response
675+
676+
677+
@dataclass(frozen=True)
678+
class LoadPassthroughForwarder:
679+
user_port: int = DEFAULT_PORT
680+
user_hostname: str = "localhost"
681+
healthcheck_route: str = "/health"
682+
passthrough_route: str = ""
683+
684+
def load(self, resources: Optional[Path], cache: Any) -> PassthroughForwarder:
685+
if len(self.healthcheck_route) == 0:
686+
raise ValueError("healthcheck route must be non-empty!")
687+
688+
if not self.healthcheck_route.startswith("/"):
689+
raise ValueError(f"healthcheck route must start with /: {self.healthcheck_route=}")
690+
691+
if not (1 <= self.user_port <= 65535):
692+
raise ValueError(f"Invalid port value: {self.user_port=}")
693+
694+
if len(self.user_hostname) == 0:
695+
raise ValueError("hostname must be non-empty!")
696+
697+
if self.user_hostname != "localhost":
698+
raise NotImplementedError(
699+
"Currently only localhost-based user-code services are supported with forwarders! "
700+
f"Cannot handle {self.user_hostname=}"
701+
)
702+
703+
def endpoint(route: str) -> str:
704+
return f"http://{self.user_hostname}:{self.user_port}{route}"
705+
706+
passthrough_endpoint: str = endpoint(self.passthrough_route)
707+
hc: str = endpoint(self.healthcheck_route)
708+
709+
logger.info(f"Forwarding to user-defined service at: {self.user_hostname}:{self.user_port}")
710+
logger.info(f"Passthrough endpoint: {passthrough_endpoint}")
711+
logger.info(f"Healthcheck endpoint: {hc}")
712+
713+
while True:
714+
try:
715+
if requests.get(hc).status_code == 200:
716+
break
717+
except requests.exceptions.ConnectionError:
718+
pass
719+
720+
logger.info(f"Waiting for user-defined service to be ready at {hc}...")
721+
time.sleep(1)
722+
723+
logger.info(f"Creating PassthroughForwarder with endpoint: {passthrough_endpoint}")
724+
return PassthroughForwarder(passthrough_endpoint=passthrough_endpoint)
725+
726+
630727
def load_named_config(config_uri, config_overrides=None):
631728
with open(config_uri, "rt") as rt:
632729
if config_uri.endswith(".json"):

model-engine/model_engine_server/inference/forwarding/http_forwarder.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77

88
import orjson
99
import uvicorn
10-
from fastapi import BackgroundTasks, Depends, FastAPI
10+
from fastapi import BackgroundTasks, Depends, FastAPI, Request
11+
from fastapi.responses import Response, StreamingResponse
1112
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
1213
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
1314
from model_engine_server.core.loggers import logger_name, make_logger
1415
from model_engine_server.inference.forwarding.forwarding import (
1516
Forwarder,
1617
LoadForwarder,
18+
LoadPassthroughForwarder,
1719
LoadStreamingForwarder,
20+
PassthroughForwarder,
1821
StreamingForwarder,
1922
load_named_config,
2023
)
@@ -40,6 +43,8 @@ def get_forwarder_loader(destination_path: Optional[str] = None) -> LoadForwarde
4043
del config["extra_routes"]
4144
if destination_path:
4245
config["predict_route"] = destination_path
46+
if "forwarder_type" in config:
47+
del config["forwarder_type"]
4348
forwarder_loader = LoadForwarder(**config)
4449
return forwarder_loader
4550

@@ -52,10 +57,40 @@ def get_streaming_forwarder_loader(
5257
del config["extra_routes"]
5358
if destination_path:
5459
config["predict_route"] = destination_path
60+
if "forwarder_type" in config:
61+
del config["forwarder_type"]
5562
streaming_forwarder_loader = LoadStreamingForwarder(**config)
5663
return streaming_forwarder_loader
5764

5865

66+
def get_stream_passthrough_forwarder_loader(
67+
destination_path: Optional[str] = None,
68+
) -> LoadPassthroughForwarder:
69+
config = {}
70+
stream_config = get_config().get("stream", {})
71+
for key in ["user_port", "user_hostname", "healthcheck_route"]:
72+
config[key] = stream_config[key]
73+
if destination_path:
74+
config["passthrough_route"] = destination_path
75+
76+
passthrough_forwarder_loader = LoadPassthroughForwarder(**config)
77+
return passthrough_forwarder_loader
78+
79+
80+
def get_sync_passthrough_forwarder_loader(
81+
destination_path: Optional[str] = None,
82+
) -> LoadPassthroughForwarder:
83+
config = {}
84+
sync_config = get_config().get("sync", {})
85+
for key in ["user_port", "user_hostname", "healthcheck_route"]:
86+
config[key] = sync_config[key]
87+
if destination_path:
88+
config["passthrough_route"] = destination_path
89+
90+
passthrough_forwarder_loader = LoadPassthroughForwarder(**config)
91+
return passthrough_forwarder_loader
92+
93+
5994
@lru_cache()
6095
def get_concurrency_limiter() -> MultiprocessingConcurrencyLimiter:
6196
config = get_config()
@@ -75,6 +110,41 @@ def load_streaming_forwarder(destination_path: Optional[str] = None) -> Streamin
75110
return get_streaming_forwarder_loader(destination_path).load(None, None)
76111

77112

113+
@lru_cache()
114+
def load_stream_passthrough_forwarder(
115+
destination_path: Optional[str] = None,
116+
) -> PassthroughForwarder:
117+
return get_stream_passthrough_forwarder_loader(destination_path).load(None, None)
118+
119+
120+
@lru_cache()
121+
def load_sync_passthrough_forwarder(destination_path: Optional[str] = None) -> PassthroughForwarder:
122+
return get_sync_passthrough_forwarder_loader(destination_path).load(None, None)
123+
124+
125+
HOP_BY_HOP_HEADERS: list[str] = [
126+
"proxy-authenticate",
127+
"proxy-authorization",
128+
"content-length",
129+
"content-encoding",
130+
]
131+
132+
133+
def sanitize_response_headers(headers: dict, force_cache_bust: bool = False) -> dict:
134+
lower_headers = {k.lower(): v for k, v in headers.items()}
135+
# Delete hop by hop headers that should not be forwarded
136+
for header in HOP_BY_HOP_HEADERS:
137+
if header in lower_headers:
138+
del lower_headers[header]
139+
140+
if force_cache_bust:
141+
# force clients to refetch resources
142+
lower_headers["cache-control"] = "no-store"
143+
if "etag" in lower_headers:
144+
del lower_headers["etag"]
145+
return lower_headers
146+
147+
78148
async def predict(
79149
request: EndpointPredictV1Request,
80150
background_tasks: BackgroundTasks,
@@ -123,6 +193,35 @@ async def event_generator():
123193
return EventSourceResponse(event_generator())
124194

125195

196+
async def passthrough_stream(
197+
request: Request,
198+
forwarder: PassthroughForwarder = Depends(get_stream_passthrough_forwarder_loader),
199+
limiter: MultiprocessingConcurrencyLimiter = Depends(get_concurrency_limiter),
200+
):
201+
with limiter:
202+
response = forwarder.forward_stream(request)
203+
headers, status = await anext(response)
204+
headers = sanitize_response_headers(headers)
205+
206+
async def content_generator():
207+
async for chunk in response:
208+
yield chunk
209+
210+
return StreamingResponse(content_generator(), headers=headers, status_code=status)
211+
212+
213+
async def passthrough_sync(
214+
request: Request,
215+
forwarder: PassthroughForwarder = Depends(get_sync_passthrough_forwarder_loader),
216+
limiter: MultiprocessingConcurrencyLimiter = Depends(get_concurrency_limiter),
217+
):
218+
with limiter:
219+
response = await forwarder.forward_sync(request)
220+
headers = sanitize_response_headers(response.headers)
221+
content = await response.read()
222+
return Response(content=content, status_code=response.status, headers=headers)
223+
224+
126225
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): # pragma: no cover
127226
logger.info("Available routes are:")
128227
for route in app.routes:
@@ -177,7 +276,7 @@ async def init_app():
177276
def healthcheck():
178277
return "OK"
179278

180-
def add_extra_routes(app: FastAPI):
279+
def add_extra_sync_or_stream_routes(app: FastAPI):
181280
"""Read extra_routes from config and dynamically add routes to app"""
182281
config = get_config()
183282
sync_forwarders: Dict[str, Forwarder] = dict()
@@ -224,6 +323,65 @@ async def predict_or_stream(
224323
methods=["POST"],
225324
)
226325

326+
def add_stream_passthrough_routes(app: FastAPI):
327+
config = get_config()
328+
329+
passthrough_forwarders: Dict[str, PassthroughForwarder] = dict()
330+
for route in config.get("stream", {}).get("extra_routes", []):
331+
passthrough_forwarders[route] = load_stream_passthrough_forwarder(route)
332+
333+
for route in passthrough_forwarders:
334+
335+
def get_passthrough_forwarder(route=route):
336+
return passthrough_forwarders.get(route)
337+
338+
async def passthrough_route(
339+
request: Request,
340+
passthrough_forwarder: PassthroughForwarder = Depends(get_passthrough_forwarder),
341+
limiter=Depends(get_concurrency_limiter),
342+
):
343+
return await passthrough_stream(request, passthrough_forwarder, limiter)
344+
345+
app.add_api_route(
346+
path=route,
347+
endpoint=passthrough_route,
348+
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
349+
)
350+
351+
def add_sync_passthrough_routes(app: FastAPI):
352+
config = get_config()
353+
354+
passthrough_forwarders: Dict[str, PassthroughForwarder] = dict()
355+
for route in config.get("sync", {}).get("extra_routes", []):
356+
passthrough_forwarders[route] = load_sync_passthrough_forwarder(route)
357+
358+
for route in passthrough_forwarders:
359+
360+
def get_passthrough_forwarder(route=route):
361+
return passthrough_forwarders.get(route)
362+
363+
async def passthrough_route(
364+
request: Request,
365+
passthrough_forwarder: PassthroughForwarder = Depends(get_passthrough_forwarder),
366+
limiter=Depends(get_concurrency_limiter),
367+
):
368+
return await passthrough_sync(request, passthrough_forwarder, limiter)
369+
370+
app.add_api_route(
371+
path=route,
372+
endpoint=passthrough_route,
373+
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
374+
)
375+
376+
def add_extra_routes(app: FastAPI):
377+
config = get_config()
378+
if config.get("stream", {}).get("forwarder_type") == "passthrough":
379+
add_stream_passthrough_routes(app)
380+
elif config.get("sync", {}).get("forwarder_type") == "passthrough":
381+
add_sync_passthrough_routes(app)
382+
else:
383+
add_extra_sync_or_stream_routes(app)
384+
227385
app.add_api_route(path="/healthz", endpoint=healthcheck, methods=["GET"])
228386
app.add_api_route(path="/readyz", endpoint=healthcheck, methods=["GET"])
229387
app.add_api_route(path="/predict", endpoint=predict, methods=["POST"])

0 commit comments

Comments
 (0)