Skip to content

Commit 546eeff

Browse files
authored
[MLI-4665] Update http forwarder for model engine (#714)
* initial code not clean * adding /predict and /stream like the other routes? * revisions * fix * comments * add default paths in passthrough also * fix and tested * revisions * reformat * fix for unit test * remove comments * revise
1 parent 369cb3f commit 546eeff

File tree

3 files changed

+80
-7
lines changed

3 files changed

+80
-7
lines changed

model-engine/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,41 @@ Run `mypy . --install-types` to set up mypy.
4242
Most of the business logic in Model Engine should contain unit tests, located in
4343
[`tests/unit`](./tests/unit). To run the tests, run `pytest`.
4444

45+
### Testing the http_forwarder
46+
47+
First have some endpoint running on port 5005
48+
```sh
49+
(llm-engine-vllm) ➜ vllm git:(dmchoi/vllm_batch_upgrade) ✗ export IMAGE=692474966980.dkr.ecr.us-west-2.amazonaws.com/vllm:0.10.1.1-rc2
50+
(llm-engine-vllm) ➜ vllm git:(dmchoi/vllm_batch_upgrade) ✗ export MODEL=meta-llama/Meta-Llama-3.1-8B-Instruct && export MODEL_PATH=/data/model_files/$MODEL
51+
(llm-engine-vllm) ➜ vllm git:(dmchoi/vllm_batch_upgrade) ✗ export REPO_PATH=/mnt/home/dmchoi/repos/scale
52+
(llm-engine-vllm) ➜ vllm git:(dmchoi/vllm_batch_upgrade) ✗ docker kill vlll; docker rm vllm; docker run \
53+
--runtime nvidia \
54+
--shm-size=16gb \
55+
--gpus '"device=0,1,2,3"' \
56+
-v $MODEL_PATH:/workspace/model_files:ro \
57+
-v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_server.py:/workspace/vllm_server.py \
58+
-p 5005:5005 \
59+
--name vllm \
60+
${IMAGE} \
61+
python -m vllm_server --model model_files --port 5005 --disable-log-requests --max-model-len 4096 --max-num-seqs 16 --enforce-eager
62+
```
63+
64+
Then you can run the forwarder locally like this
65+
```sh
66+
GIT_TAG=test python model_engine_server/inference/forwarding/http_forwarder.py \
67+
--config model_engine_server/inference/configs/service--http_forwarder.yaml \
68+
--num-workers 1 \
69+
--set "forwarder.sync.extra_routes=['/v1/chat/completions','/v1/completions']" \
70+
--set "forwarder.stream.extra_routes=['/v1/chat/completions','/v1/completions']" \
71+
--set "forwarder.sync.healthcheck_route=/health" \
72+
--set "forwarder.stream.healthcheck_route=/health"
73+
```
74+
75+
Then you can hit the forwarder like this
76+
```sh
77+
curl -X POST localhost:5000/v1/chat/completions -H "Content-Type: application/json" -d "{\"args\": {\"model\":\"$MODEL\", \"messages\":[{\"role\": \"systemr\", \"content\": \"Hey, what's the temperature in Paris right now?\"}],\"max_tokens\":100,\"temperature\":0.2,\"guided_regex\":\"Sean.*\"}}"
78+
```
79+
4580
## Generating OpenAI types
4681
We've decided to make our V2 APIs OpenAI compatible. We generate the
4782
corresponding Pydantic models:

model-engine/model_engine_server/domain/entities/model_bundle_entity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class RunnableImageLike(BaseModel, ABC):
163163
protocol: Literal["http"] # TODO: add support for other protocols (e.g. grpc)
164164
readiness_initial_delay_seconds: int = 120
165165
extra_routes: List[str] = Field(default_factory=list)
166+
routes: List[str] = Field(default_factory=list)
166167
forwarder_type: Optional[str] = ForwarderType.DEFAULT.value
167168
worker_command: Optional[List[str]] = None
168169
worker_env: Optional[Dict[str, str]] = None

model-engine/model_engine_server/inference/forwarding/http_forwarder.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def get_forwarder_loader(destination_path: Optional[str] = None) -> LoadForwarde
4141
config = get_config()["sync"]
4242
if "extra_routes" in config:
4343
del config["extra_routes"]
44+
if "routes" in config:
45+
del config["routes"]
4446
if destination_path:
4547
config["predict_route"] = destination_path
4648
if "forwarder_type" in config:
@@ -55,6 +57,8 @@ def get_streaming_forwarder_loader(
5557
config = get_config()["stream"]
5658
if "extra_routes" in config:
5759
del config["extra_routes"]
60+
if "routes" in config:
61+
del config["routes"]
5862
if destination_path:
5963
config["predict_route"] = destination_path
6064
if "forwarder_type" in config:
@@ -276,14 +280,34 @@ async def init_app():
276280
def healthcheck():
277281
return "OK"
278282

279-
def add_extra_sync_or_stream_routes(app: FastAPI):
280-
"""Read extra_routes from config and dynamically add routes to app"""
283+
def add_sync_or_stream_routes(app: FastAPI):
284+
"""Read routes from config (both old extra_routes and new routes field) and dynamically add routes to app"""
281285
config = get_config()
282286
sync_forwarders: Dict[str, Forwarder] = dict()
283287
stream_forwarders: Dict[str, StreamingForwarder] = dict()
284-
for route in config.get("sync", {}).get("extra_routes", []):
288+
289+
# Gather all sync routes from extra_routes and routes fields
290+
sync_routes_to_add = set()
291+
sync_routes_to_add.update(config.get("sync", {}).get("extra_routes", []))
292+
sync_routes_to_add.update(config.get("sync", {}).get("routes", []))
293+
294+
# predict_route = config.get("sync", {}).get("predict_route", None)
295+
# if predict_route:
296+
# sync_routes_to_add.add(predict_route)
297+
298+
# Gather all stream routes from extra_routes and routes fields
299+
stream_routes_to_add = set()
300+
stream_routes_to_add.update(config.get("stream", {}).get("extra_routes", []))
301+
stream_routes_to_add.update(config.get("stream", {}).get("routes", []))
302+
303+
# stream_predict_route = config.get("stream", {}).get("predict_route", None)
304+
# if stream_predict_route:
305+
# stream_routes_to_add.add(stream_predict_route)
306+
307+
# Load forwarders for all routes
308+
for route in sync_routes_to_add:
285309
sync_forwarders[route] = load_forwarder(route)
286-
for route in config.get("stream", {}).get("extra_routes", []):
310+
for route in stream_routes_to_add:
287311
stream_forwarders[route] = load_streaming_forwarder(route)
288312

289313
all_routes = set(list(sync_forwarders.keys()) + list(stream_forwarders.keys()))
@@ -327,7 +351,14 @@ def add_stream_passthrough_routes(app: FastAPI):
327351
config = get_config()
328352

329353
passthrough_forwarders: Dict[str, PassthroughForwarder] = dict()
330-
for route in config.get("stream", {}).get("extra_routes", []):
354+
355+
# Gather all routes from extra_routes and routes fields
356+
stream_passthrough_routes_to_add = set()
357+
stream_passthrough_routes_to_add.update(config.get("stream", {}).get("extra_routes", []))
358+
stream_passthrough_routes_to_add.update(config.get("stream", {}).get("routes", []))
359+
360+
# Load passthrough forwarders for all routes
361+
for route in stream_passthrough_routes_to_add:
331362
passthrough_forwarders[route] = load_stream_passthrough_forwarder(route)
332363

333364
for route in passthrough_forwarders:
@@ -352,7 +383,13 @@ def add_sync_passthrough_routes(app: FastAPI):
352383
config = get_config()
353384

354385
passthrough_forwarders: Dict[str, PassthroughForwarder] = dict()
355-
for route in config.get("sync", {}).get("extra_routes", []):
386+
387+
# Handle legacy extra_routes configuration (backwards compatibility)
388+
sync_passthrough_routes_to_add = set()
389+
sync_passthrough_routes_to_add.update(config.get("sync", {}).get("extra_routes", []))
390+
sync_passthrough_routes_to_add.update(config.get("sync", {}).get("routes", []))
391+
392+
for route in sync_passthrough_routes_to_add:
356393
passthrough_forwarders[route] = load_sync_passthrough_forwarder(route)
357394

358395
for route in passthrough_forwarders:
@@ -380,7 +417,7 @@ def add_extra_routes(app: FastAPI):
380417
elif config.get("sync", {}).get("forwarder_type") == "passthrough":
381418
add_sync_passthrough_routes(app)
382419
else:
383-
add_extra_sync_or_stream_routes(app)
420+
add_sync_or_stream_routes(app)
384421

385422
app.add_api_route(path="/healthz", endpoint=healthcheck, methods=["GET"])
386423
app.add_api_route(path="/readyz", endpoint=healthcheck, methods=["GET"])

0 commit comments

Comments
 (0)