77
88import orjson
99import uvicorn
10- from fastapi import BackgroundTasks , Depends , FastAPI
10+ from fastapi import BackgroundTasks , Depends , FastAPI , Request
11+ from fastapi .responses import Response , StreamingResponse
1112from model_engine_server .common .concurrency_limiter import MultiprocessingConcurrencyLimiter
1213from model_engine_server .common .dtos .tasks import EndpointPredictV1Request
1314from model_engine_server .core .loggers import logger_name , make_logger
1415from model_engine_server .inference .forwarding .forwarding import (
1516 Forwarder ,
1617 LoadForwarder ,
18+ LoadPassthroughForwarder ,
1719 LoadStreamingForwarder ,
20+ PassthroughForwarder ,
1821 StreamingForwarder ,
1922 load_named_config ,
2023)
@@ -40,6 +43,8 @@ def get_forwarder_loader(destination_path: Optional[str] = None) -> LoadForwarde
4043 del config ["extra_routes" ]
4144 if destination_path :
4245 config ["predict_route" ] = destination_path
46+ if "forwarder_type" in config :
47+ del config ["forwarder_type" ]
4348 forwarder_loader = LoadForwarder (** config )
4449 return forwarder_loader
4550
@@ -52,10 +57,40 @@ def get_streaming_forwarder_loader(
5257 del config ["extra_routes" ]
5358 if destination_path :
5459 config ["predict_route" ] = destination_path
60+ if "forwarder_type" in config :
61+ del config ["forwarder_type" ]
5562 streaming_forwarder_loader = LoadStreamingForwarder (** config )
5663 return streaming_forwarder_loader
5764
5865
66+ def get_stream_passthrough_forwarder_loader (
67+ destination_path : Optional [str ] = None ,
68+ ) -> LoadPassthroughForwarder :
69+ config = {}
70+ stream_config = get_config ().get ("stream" , {})
71+ for key in ["user_port" , "user_hostname" , "healthcheck_route" ]:
72+ config [key ] = stream_config [key ]
73+ if destination_path :
74+ config ["passthrough_route" ] = destination_path
75+
76+ passthrough_forwarder_loader = LoadPassthroughForwarder (** config )
77+ return passthrough_forwarder_loader
78+
79+
80+ def get_sync_passthrough_forwarder_loader (
81+ destination_path : Optional [str ] = None ,
82+ ) -> LoadPassthroughForwarder :
83+ config = {}
84+ sync_config = get_config ().get ("sync" , {})
85+ for key in ["user_port" , "user_hostname" , "healthcheck_route" ]:
86+ config [key ] = sync_config [key ]
87+ if destination_path :
88+ config ["passthrough_route" ] = destination_path
89+
90+ passthrough_forwarder_loader = LoadPassthroughForwarder (** config )
91+ return passthrough_forwarder_loader
92+
93+
5994@lru_cache ()
6095def get_concurrency_limiter () -> MultiprocessingConcurrencyLimiter :
6196 config = get_config ()
@@ -75,6 +110,41 @@ def load_streaming_forwarder(destination_path: Optional[str] = None) -> Streamin
75110 return get_streaming_forwarder_loader (destination_path ).load (None , None )
76111
77112
113+ @lru_cache ()
114+ def load_stream_passthrough_forwarder (
115+ destination_path : Optional [str ] = None ,
116+ ) -> PassthroughForwarder :
117+ return get_stream_passthrough_forwarder_loader (destination_path ).load (None , None )
118+
119+
120+ @lru_cache ()
121+ def load_sync_passthrough_forwarder (destination_path : Optional [str ] = None ) -> PassthroughForwarder :
122+ return get_sync_passthrough_forwarder_loader (destination_path ).load (None , None )
123+
124+
125+ HOP_BY_HOP_HEADERS : list [str ] = [
126+ "proxy-authenticate" ,
127+ "proxy-authorization" ,
128+ "content-length" ,
129+ "content-encoding" ,
130+ ]
131+
132+
133+ def sanitize_response_headers (headers : dict , force_cache_bust : bool = False ) -> dict :
134+ lower_headers = {k .lower (): v for k , v in headers .items ()}
135+ # Delete hop by hop headers that should not be forwarded
136+ for header in HOP_BY_HOP_HEADERS :
137+ if header in lower_headers :
138+ del lower_headers [header ]
139+
140+ if force_cache_bust :
141+ # force clients to refetch resources
142+ lower_headers ["cache-control" ] = "no-store"
143+ if "etag" in lower_headers :
144+ del lower_headers ["etag" ]
145+ return lower_headers
146+
147+
78148async def predict (
79149 request : EndpointPredictV1Request ,
80150 background_tasks : BackgroundTasks ,
@@ -123,6 +193,35 @@ async def event_generator():
123193 return EventSourceResponse (event_generator ())
124194
125195
196+ async def passthrough_stream (
197+ request : Request ,
198+ forwarder : PassthroughForwarder = Depends (get_stream_passthrough_forwarder_loader ),
199+ limiter : MultiprocessingConcurrencyLimiter = Depends (get_concurrency_limiter ),
200+ ):
201+ with limiter :
202+ response = forwarder .forward_stream (request )
203+ headers , status = await anext (response )
204+ headers = sanitize_response_headers (headers )
205+
206+ async def content_generator ():
207+ async for chunk in response :
208+ yield chunk
209+
210+ return StreamingResponse (content_generator (), headers = headers , status_code = status )
211+
212+
213+ async def passthrough_sync (
214+ request : Request ,
215+ forwarder : PassthroughForwarder = Depends (get_sync_passthrough_forwarder_loader ),
216+ limiter : MultiprocessingConcurrencyLimiter = Depends (get_concurrency_limiter ),
217+ ):
218+ with limiter :
219+ response = await forwarder .forward_sync (request )
220+ headers = sanitize_response_headers (response .headers )
221+ content = await response .read ()
222+ return Response (content = content , status_code = response .status , headers = headers )
223+
224+
126225async def serve_http (app : FastAPI , ** uvicorn_kwargs : Any ): # pragma: no cover
127226 logger .info ("Available routes are:" )
128227 for route in app .routes :
@@ -177,7 +276,7 @@ async def init_app():
177276 def healthcheck ():
178277 return "OK"
179278
180- def add_extra_routes (app : FastAPI ):
279+ def add_extra_sync_or_stream_routes (app : FastAPI ):
181280 """Read extra_routes from config and dynamically add routes to app"""
182281 config = get_config ()
183282 sync_forwarders : Dict [str , Forwarder ] = dict ()
@@ -224,6 +323,65 @@ async def predict_or_stream(
224323 methods = ["POST" ],
225324 )
226325
326+ def add_stream_passthrough_routes (app : FastAPI ):
327+ config = get_config ()
328+
329+ passthrough_forwarders : Dict [str , PassthroughForwarder ] = dict ()
330+ for route in config .get ("stream" , {}).get ("extra_routes" , []):
331+ passthrough_forwarders [route ] = load_stream_passthrough_forwarder (route )
332+
333+ for route in passthrough_forwarders :
334+
335+ def get_passthrough_forwarder (route = route ):
336+ return passthrough_forwarders .get (route )
337+
338+ async def passthrough_route (
339+ request : Request ,
340+ passthrough_forwarder : PassthroughForwarder = Depends (get_passthrough_forwarder ),
341+ limiter = Depends (get_concurrency_limiter ),
342+ ):
343+ return await passthrough_stream (request , passthrough_forwarder , limiter )
344+
345+ app .add_api_route (
346+ path = route ,
347+ endpoint = passthrough_route ,
348+ methods = ["GET" , "POST" , "PUT" , "DELETE" , "PATCH" , "HEAD" , "OPTIONS" ],
349+ )
350+
351+ def add_sync_passthrough_routes (app : FastAPI ):
352+ config = get_config ()
353+
354+ passthrough_forwarders : Dict [str , PassthroughForwarder ] = dict ()
355+ for route in config .get ("sync" , {}).get ("extra_routes" , []):
356+ passthrough_forwarders [route ] = load_sync_passthrough_forwarder (route )
357+
358+ for route in passthrough_forwarders :
359+
360+ def get_passthrough_forwarder (route = route ):
361+ return passthrough_forwarders .get (route )
362+
363+ async def passthrough_route (
364+ request : Request ,
365+ passthrough_forwarder : PassthroughForwarder = Depends (get_passthrough_forwarder ),
366+ limiter = Depends (get_concurrency_limiter ),
367+ ):
368+ return await passthrough_sync (request , passthrough_forwarder , limiter )
369+
370+ app .add_api_route (
371+ path = route ,
372+ endpoint = passthrough_route ,
373+ methods = ["GET" , "POST" , "PUT" , "DELETE" , "PATCH" , "HEAD" , "OPTIONS" ],
374+ )
375+
376+ def add_extra_routes (app : FastAPI ):
377+ config = get_config ()
378+ if config .get ("stream" , {}).get ("forwarder_type" ) == "passthrough" :
379+ add_stream_passthrough_routes (app )
380+ elif config .get ("sync" , {}).get ("forwarder_type" ) == "passthrough" :
381+ add_sync_passthrough_routes (app )
382+ else :
383+ add_extra_sync_or_stream_routes (app )
384+
227385 app .add_api_route (path = "/healthz" , endpoint = healthcheck , methods = ["GET" ])
228386 app .add_api_route (path = "/readyz" , endpoint = healthcheck , methods = ["GET" ])
229387 app .add_api_route (path = "/predict" , endpoint = predict , methods = ["POST" ])
0 commit comments