Skip to content

Commit da85235

Browse files
authored
[MLI-4665] Update vllm upgrade process (#713)
* initial changes * reverting some forwarder changes that aren't needed * remove some other unneeded stuff * not sure * adding cpu * add column * add file for db model change * update readme instructions * fix column name * reformat * remove unused commits * fix * fix readme * fix types * leave the existing variables for backwards compatibility * edit types * remove EXTRA ROUTES completely. its not used by the async or triton enhanced routes and we replaced it with ROUTES for the sync and streaming routes * adding FORWARDER_SYNC_ROUTES and FORWARDER_STREAMING_ROUTES to the tritonenhanced and lws ones that need it based on type. wont get used though * change to pass unit tests * update orm * test change * change test bundle * add debug logs * trying to fix * changes * cleanup debug code * reformat * remove 1 * remove 2 * revert 3 * reorder params
1 parent 93f1136 commit da85235

File tree

11 files changed

+105
-244
lines changed

11 files changed

+105
-244
lines changed

charts/model-engine/templates/service_template_config_map.yaml

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,6 @@ data:
181181
- --set
182182
- "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}"
183183
- --set
184-
- "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}"
185-
- --set
186-
- "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}"
187-
- --set
188184
- "forwarder.sync.forwarder_type=${FORWARDER_TYPE}"
189185
- --set
190186
- "forwarder.stream.forwarder_type=${FORWARDER_TYPE}"
@@ -370,7 +366,7 @@ data:
370366
name: {{ $service_template_aws_config_map_name }}
371367
{{- else }}
372368
name: {{ $aws_config_map_name }}
373-
{{- end }}
369+
{{- end }}
374370
{{- end }}
375371
- name: user-config
376372
configMap:
@@ -487,15 +483,15 @@ data:
487483
threshold: "${CONCURRENCY}"
488484
metricName: request_concurrency_average
489485
query: sum(rate(istio_request_duration_milliseconds_sum{destination_workload="${RESOURCE_NAME}"}[2m])) / 1000
490-
serverAddress: ${PROMETHEUS_SERVER_ADDRESS}
486+
serverAddress: ${PROMETHEUS_SERVER_ADDRESS}
491487
{{- range $device := tuple "gpu" }}
492488
{{- range $mode := tuple "streaming"}}
493489
leader-worker-set-{{ $mode }}-{{ $device }}.yaml: |-
494490
apiVersion: leaderworkerset.x-k8s.io/v1
495491
kind: LeaderWorkerSet
496492
metadata:
497-
name: ${RESOURCE_NAME}
498-
namespace: ${NAMESPACE}
493+
name: ${RESOURCE_NAME}
494+
namespace: ${NAMESPACE}
499495
labels:
500496
{{- $service_template_labels | nindent 8 }}
501497
spec:
@@ -617,10 +613,6 @@ data:
617613
- --set
618614
- "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}"
619615
- --set
620-
- "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}"
621-
- --set
622-
- "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}"
623-
- --set
624616
- "forwarder.sync.forwarder_type=${FORWARDER_TYPE}"
625617
- --set
626618
- "forwarder.stream.forwarder_type=${FORWARDER_TYPE}"
@@ -748,7 +740,7 @@ data:
748740
name: {{ $service_template_aws_config_map_name }}
749741
{{- else }}
750742
name: {{ $aws_config_map_name }}
751-
{{- end }}
743+
{{- end }}
752744
{{- end }}
753745
- name: user-config
754746
configMap:
@@ -856,7 +848,7 @@ data:
856848
name: {{ $service_template_aws_config_map_name }}
857849
{{- else }}
858850
name: {{ $aws_config_map_name }}
859-
{{- end }}
851+
{{- end }}
860852
{{- end }}
861853
- name: user-config
862854
configMap:

model-engine/model_engine_server/db/migrations/README

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ We introduce alembic by
44
1. dumping the current db schemas into 'initial.sql' via pg_dump
55

66
```
7-
pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql
7+
pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql
88
```
99

1010
2. writing an initial revision that reads and applies intial.sql script
@@ -19,6 +19,9 @@ alembic revision -m “initial”
1919
alembic stamp fa3267c80731
2020
```
2121

22+
# Steps to make generic database schema changes
23+
24+
Steps can be found here: https://alembic.sqlalchemy.org/en/latest/tutorial.html#running-our-second-migration
2225

2326
# Test db migration from scratch
2427

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""add routes column
2+
3+
Revision ID: 221aa19d3f32
4+
Revises: e580182d6bfd
5+
Create Date: 2025-09-25 19:40:24.927198
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = '221aa19d3f32'
13+
down_revision = 'e580182d6bfd'
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade() -> None:
19+
op.add_column(
20+
'bundles',
21+
sa.Column('runnable_image_routes', sa.ARRAY(sa.Text), nullable=True),
22+
schema='hosted_model_inference',
23+
)
24+
25+
26+
def downgrade() -> None:
27+
op.drop_column(
28+
'bundles',
29+
'runnable_image_routes',
30+
schema='hosted_model_inference',
31+
)

model-engine/model_engine_server/db/models/hosted_model_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class Bundle(Base):
146146
runnable_image_env = Column(JSON, nullable=True)
147147
runnable_image_protocol = Column(Text, nullable=True)
148148
runnable_image_readiness_initial_delay_seconds = Column(Integer, nullable=True)
149+
runnable_image_routes = Column(ARRAY(Text), nullable=True)
149150
runnable_image_extra_routes = Column(ARRAY(Text), nullable=True)
150151
runnable_image_forwarder_type = Column(Text, nullable=True)
151152
runnable_image_worker_command = Column(ARRAY(Text), nullable=True)
@@ -209,6 +210,7 @@ def __init__(
209210
runnable_image_env: Optional[Dict[str, Any]] = None,
210211
runnable_image_protocol: Optional[str] = None,
211212
runnable_image_readiness_initial_delay_seconds: Optional[int] = None,
213+
runnable_image_routes: Optional[List[str]] = None,
212214
runnable_image_extra_routes: Optional[List[str]] = None,
213215
runnable_image_forwarder_type: Optional[str] = None,
214216
runnable_image_worker_command: Optional[List[str]] = None,
@@ -268,6 +270,7 @@ def __init__(
268270
self.runnable_image_healthcheck_route = runnable_image_healthcheck_route
269271
self.runnable_image_env = runnable_image_env
270272
self.runnable_image_protocol = runnable_image_protocol
273+
self.runnable_image_routes = runnable_image_routes
271274
self.runnable_image_extra_routes = runnable_image_extra_routes
272275
self.runnable_image_forwarder_type = runnable_image_forwarder_type
273276
self.runnable_image_worker_command = runnable_image_worker_command

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ async def create_vllm_bundle(
10191019
healthcheck_route="/health",
10201020
predict_route="/predict",
10211021
streaming_predict_route="/stream",
1022-
extra_routes=[
1022+
routes=[
10231023
OPENAI_CHAT_COMPLETION_PATH,
10241024
OPENAI_COMPLETION_PATH,
10251025
],
@@ -1101,7 +1101,7 @@ async def create_vllm_multinode_bundle(
11011101
healthcheck_route="/health",
11021102
predict_route="/predict",
11031103
streaming_predict_route="/stream",
1104-
extra_routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH],
1104+
routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH],
11051105
env=common_vllm_envs,
11061106
worker_command=worker_command,
11071107
worker_env=common_vllm_envs,
Lines changed: 5 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,14 @@
11
import asyncio
22
import code
3-
import json
43
import os
5-
import signal
64
import subprocess
75
import traceback
86
from logging import Logger
9-
from typing import AsyncGenerator, Dict, List, Optional
107

11-
import vllm.envs as envs
12-
from fastapi import APIRouter, BackgroundTasks, Request
13-
from fastapi.responses import Response, StreamingResponse
14-
from vllm.engine.async_llm_engine import AsyncEngineDeadError
158
from vllm.engine.protocol import EngineClient
16-
from vllm.entrypoints.launcher import serve_http
17-
from vllm.entrypoints.openai.api_server import (
18-
build_app,
19-
build_async_engine_client,
20-
init_app_state,
21-
load_log_config,
22-
maybe_register_tokenizer_info_endpoint,
23-
setup_server,
24-
)
9+
from vllm.entrypoints.openai.api_server import run_server
2510
from vllm.entrypoints.openai.cli_args import make_arg_parser
26-
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
27-
from vllm.outputs import CompletionOutput
28-
from vllm.sampling_params import SamplingParams
29-
from vllm.sequence import Logprob
30-
from vllm.utils import FlexibleArgumentParser, random_uuid
11+
from vllm.utils import FlexibleArgumentParser
3112

3213
logger = Logger("vllm_server")
3314

@@ -36,88 +17,8 @@
3617
TIMEOUT_KEEP_ALIVE = 5 # seconds.
3718
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
3819

39-
router = APIRouter()
40-
41-
42-
@router.post("/predict")
43-
@router.post("/stream")
44-
async def generate(request: Request) -> Response:
45-
"""Generate completion for the request.
46-
47-
The request should be a JSON object with the following fields:
48-
- prompt: the prompt to use for the generation.
49-
- stream: whether to stream the results or not.
50-
- other fields: the sampling parameters (See `SamplingParams` for details).
51-
"""
52-
# check health before accepting request and fail fast if engine isn't healthy
53-
try:
54-
await engine_client.check_health()
55-
56-
request_dict = await request.json()
57-
prompt = request_dict.pop("prompt")
58-
stream = request_dict.pop("stream", False)
59-
60-
sampling_params = SamplingParams(**request_dict)
61-
62-
request_id = random_uuid()
63-
64-
results_generator = engine_client.generate(prompt, sampling_params, request_id)
65-
66-
async def abort_request() -> None:
67-
await engine_client.abort(request_id)
68-
69-
if stream:
70-
# Streaming case
71-
async def stream_results() -> AsyncGenerator[str, None]:
72-
last_output_text = ""
73-
async for request_output in results_generator:
74-
log_probs = format_logprobs(request_output)
75-
ret = {
76-
"text": request_output.outputs[-1].text[len(last_output_text) :],
77-
"count_prompt_tokens": len(request_output.prompt_token_ids),
78-
"count_output_tokens": len(request_output.outputs[0].token_ids),
79-
"log_probs": (
80-
log_probs[-1] if log_probs and sampling_params.logprobs else None
81-
),
82-
"finished": request_output.finished,
83-
}
84-
last_output_text = request_output.outputs[-1].text
85-
yield f"data:{json.dumps(ret)}\n\n"
86-
87-
background_tasks = BackgroundTasks()
88-
# Abort the request if the client disconnects.
89-
background_tasks.add_task(abort_request)
90-
91-
return StreamingResponse(stream_results(), background=background_tasks)
92-
93-
# Non-streaming case
94-
final_output = None
95-
tokens = []
96-
last_output_text = ""
97-
async for request_output in results_generator:
98-
tokens.append(request_output.outputs[-1].text[len(last_output_text) :])
99-
last_output_text = request_output.outputs[-1].text
100-
if await request.is_disconnected():
101-
# Abort the request if the client disconnects.
102-
await engine_client.abort(request_id)
103-
return Response(status_code=499)
104-
final_output = request_output
105-
106-
assert final_output is not None
107-
prompt = final_output.prompt
108-
ret = {
109-
"text": final_output.outputs[0].text,
110-
"count_prompt_tokens": len(final_output.prompt_token_ids),
111-
"count_output_tokens": len(final_output.outputs[0].token_ids),
112-
"log_probs": format_logprobs(final_output),
113-
"tokens": tokens,
114-
}
115-
return Response(content=json.dumps(ret))
116-
117-
except AsyncEngineDeadError as e:
118-
logger.error(f"The vllm engine is dead, exiting the pod: {e}")
119-
os.kill(os.getpid(), signal.SIGINT)
120-
raise e
20+
# Legacy endpoints /predit and /stream removed - using vLLM's native OpenAI-compatible endpoints instead
21+
# All requests now go through /v1/completions, /v1/chat/completions, etc.
12122

12223

12324
def get_gpu_free_memory():
@@ -171,90 +72,18 @@ def debug(sig, frame):
17172
i.interact(message)
17273

17374

174-
def format_logprobs(
175-
request_output: CompletionOutput,
176-
) -> Optional[List[Dict[int, float]]]:
177-
"""Given a request output, format the logprobs if they exist."""
178-
output_logprobs = request_output.outputs[0].logprobs
179-
if output_logprobs is None:
180-
return None
181-
182-
def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]:
183-
return {k: v.logprob for k, v in logprobs.items()}
184-
185-
return [extract_logprobs(logprobs) for logprobs in output_logprobs]
186-
187-
18875
def parse_args(parser: FlexibleArgumentParser):
18976
parser = make_arg_parser(parser)
19077
parser.add_argument("--attention-backend", type=str, help="The attention backend to use")
19178
return parser.parse_args()
19279

19380

194-
async def run_server(args, **uvicorn_kwargs) -> None:
195-
"""Run a single-worker API server."""
196-
listen_address, sock = setup_server(args)
197-
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
198-
199-
200-
async def run_server_worker(
201-
listen_address, sock, args, client_config=None, **uvicorn_kwargs
202-
) -> None:
203-
"""Run a single API server worker."""
204-
205-
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
206-
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
207-
208-
server_index = client_config.get("client_index", 0) if client_config else 0
209-
210-
# Load logging config for uvicorn if specified
211-
log_config = load_log_config(args.log_config_file)
212-
if log_config is not None:
213-
uvicorn_kwargs["log_config"] = log_config
214-
215-
global engine_client
216-
217-
async with build_async_engine_client(args, client_config=client_config) as engine_client:
218-
maybe_register_tokenizer_info_endpoint(args)
219-
app = build_app(args)
220-
221-
vllm_config = await engine_client.get_vllm_config()
222-
await init_app_state(engine_client, vllm_config, app.state, args)
223-
app.include_router(router)
224-
225-
logger.info("Starting vLLM API server %d on %s", server_index, listen_address)
226-
shutdown_task = await serve_http(
227-
app,
228-
sock=sock,
229-
enable_ssl_refresh=args.enable_ssl_refresh,
230-
host=args.host,
231-
port=args.port,
232-
log_level=args.uvicorn_log_level,
233-
# NOTE: When the 'disable_uvicorn_access_log' value is True,
234-
# no access log will be output.
235-
access_log=not args.disable_uvicorn_access_log,
236-
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
237-
ssl_keyfile=args.ssl_keyfile,
238-
ssl_certfile=args.ssl_certfile,
239-
ssl_ca_certs=args.ssl_ca_certs,
240-
ssl_cert_reqs=args.ssl_cert_reqs,
241-
h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
242-
h11_max_header_count=args.h11_max_header_count,
243-
**uvicorn_kwargs,
244-
)
245-
246-
# NB: Await server shutdown only after the backend context is exited
247-
try:
248-
await shutdown_task
249-
finally:
250-
sock.close()
251-
252-
25381
if __name__ == "__main__":
25482
check_unknown_startup_memory_usage()
25583

25684
parser = FlexibleArgumentParser()
25785
args = parse_args(parser)
25886
if args.attention_backend is not None:
25987
os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_backend
88+
# Using vllm's run_server
26089
asyncio.run(run_server(args))

model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ async def streaming_predict(
227227
if predict_request.num_retries is None
228228
else predict_request.num_retries
229229
)
230+
230231
response = self.make_request_with_retries(
231232
request_url=deployment_url,
232233
payload_json=predict_request.model_dump(exclude_none=True),

model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ async def predict(
253253
if predict_request.num_retries is None
254254
else predict_request.num_retries
255255
)
256+
256257
response = await self.make_request_with_retries(
257258
request_url=deployment_url,
258259
payload_json=predict_request.model_dump(exclude_none=True),

0 commit comments

Comments
 (0)