Skip to content

Commit e46cbd4

Browse files
authored
Update vllm batch job to work with vllm > 0.5.0 (#550)
* Update vllm batch job to work with vllm > 0.5.0 * Fix test * Add comments
1 parent dd4a0f9 commit e46cbd4

File tree

5 files changed

+211
-35
lines changed

5 files changed

+211
-35
lines changed

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
22
DTOs for LLM APIs.
3+
4+
Make sure to keep this in sync with inference/batch_inference/dto.py.
35
"""
46

57
from typing import Any, Dict, List, Optional
@@ -553,6 +555,14 @@ class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest):
553555
hidden from the DTO exposed to the client.
554556
"""
555557

558+
model_cfg: CreateBatchCompletionsModelConfig
559+
"""
560+
Model configuration for the batch inference. Hardware configurations are inferred.
561+
562+
We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which
563+
reserves model_config as a keyword.
564+
"""
565+
556566
max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0)
557567
"""
558568
Maximum GPU memory utilization for the batch inference. Default to 90%.
@@ -565,6 +575,7 @@ def from_api(request: CreateBatchCompletionsRequest) -> "CreateBatchCompletionsE
565575
output_data_path=request.output_data_path,
566576
content=request.content,
567577
model_config=request.model_config,
578+
model_cfg=request.model_config,
568579
data_parallelism=request.data_parallelism,
569580
max_runtime_sec=request.max_runtime_sec,
570581
tool_config=request.tool_config,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# This is a copy of model_engine_server.common.dtos.llm
2+
# This is done to decouple the pydantic requirements since vllm requires pydantic >2
3+
# while model engine is on 1.x
4+
from enum import Enum
5+
from typing import Dict, List, Optional
6+
7+
from pydantic import BaseModel, Field
8+
9+
10+
class TokenOutput(BaseModel):
11+
token: str
12+
log_prob: float
13+
14+
15+
class CompletionOutput(BaseModel):
16+
text: str
17+
num_prompt_tokens: int
18+
num_completion_tokens: int
19+
tokens: Optional[List[TokenOutput]] = None
20+
21+
22+
class CreateBatchCompletionsRequestContent(BaseModel):
23+
prompts: List[str]
24+
max_new_tokens: int
25+
temperature: float = Field(ge=0.0, le=1.0)
26+
"""
27+
Temperature of the sampling. Setting to 0 equals to greedy sampling.
28+
"""
29+
stop_sequences: Optional[List[str]] = None
30+
"""
31+
List of sequences to stop the completion at.
32+
"""
33+
return_token_log_probs: Optional[bool] = False
34+
"""
35+
Whether to return the log probabilities of the tokens.
36+
"""
37+
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
38+
"""
39+
Only supported in vllm, lightllm
40+
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
41+
"""
42+
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
43+
"""
44+
Only supported in vllm, lightllm
45+
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
46+
"""
47+
top_k: Optional[int] = Field(default=None, ge=-1)
48+
"""
49+
Controls the number of top tokens to consider. -1 means consider all tokens.
50+
"""
51+
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
52+
"""
53+
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
54+
"""
55+
skip_special_tokens: Optional[bool] = True
56+
"""
57+
Whether to skip special tokens in the output.
58+
"""
59+
60+
61+
class Quantization(str, Enum):
62+
BITSANDBYTES = "bitsandbytes"
63+
AWQ = "awq"
64+
65+
66+
class CreateBatchCompletionsModelConfig(BaseModel):
67+
model: str
68+
checkpoint_path: Optional[str] = None
69+
"""
70+
Path to the checkpoint to load the model from.
71+
"""
72+
labels: Dict[str, str]
73+
"""
74+
Labels to attach to the batch inference job.
75+
"""
76+
num_shards: Optional[int] = 1
77+
"""
78+
Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config.
79+
System may decide to use a different number than the given value.
80+
"""
81+
quantize: Optional[Quantization] = None
82+
"""
83+
Whether to quantize the model.
84+
"""
85+
seed: Optional[int] = None
86+
"""
87+
Random seed for the model.
88+
"""
89+
90+
91+
class ToolConfig(BaseModel):
92+
"""
93+
Configuration for tool use.
94+
NOTE: this config is highly experimental and signature will change significantly in future iterations.
95+
"""
96+
97+
name: str
98+
"""
99+
Name of the tool to use for the batch inference.
100+
"""
101+
max_iterations: Optional[int] = 10
102+
"""
103+
Maximum number of iterations to run the tool.
104+
"""
105+
execution_timeout_seconds: Optional[int] = 60
106+
"""
107+
Maximum runtime of the tool in seconds.
108+
"""
109+
should_retry_on_error: Optional[bool] = True
110+
"""
111+
Whether to retry the tool on error.
112+
"""
113+
114+
115+
class CreateBatchCompletionsRequest(BaseModel):
116+
"""
117+
Request object for batch completions.
118+
"""
119+
120+
input_data_path: Optional[str]
121+
output_data_path: str
122+
"""
123+
Path to the output file. The output file will be a JSON file of type List[CompletionOutput].
124+
"""
125+
content: Optional[CreateBatchCompletionsRequestContent] = None
126+
"""
127+
Either `input_data_path` or `content` needs to be provided.
128+
When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent.
129+
"""
130+
131+
data_parallelism: Optional[int] = Field(default=1, ge=1, le=64)
132+
"""
133+
Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.
134+
"""
135+
max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600)
136+
"""
137+
Maximum runtime of the batch inference in seconds. Default to one day.
138+
"""
139+
tool_config: Optional[ToolConfig] = None
140+
"""
141+
Configuration for tool use.
142+
NOTE: this config is highly experimental and signature will change significantly in future iterations.
143+
"""
144+
145+
146+
class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest):
147+
"""
148+
Internal model for representing request to the llm engine. This contains additional fields that we want
149+
hidden from the DTO exposed to the client.
150+
"""
151+
152+
model_cfg: CreateBatchCompletionsModelConfig = Field(alias="model_config")
153+
"""
154+
Model configuration for the batch inference. Hardware configurations are inferred.
155+
156+
We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which
157+
reserves model_config as a keyword.
158+
159+
We alias `model_config` for deserialization for backwards compatibility.
160+
"""
161+
162+
max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0)
163+
"""
164+
Maximum GPU memory utilization for the batch inference. Default to 90%.
165+
"""

model-engine/model_engine_server/inference/batch_inference/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
vllm==0.2.5
2-
pydantic==1.10.13
1+
vllm==0.5.0.post1
2+
pydantic>=2
33
boto3==1.34.15
44
smart-open==6.4.0
55
ddtrace==2.4.0

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import boto3
1414
import smart_open
1515
from func_timeout import FunctionTimedOut, func_set_timeout
16-
from model_engine_server.common.dtos.llms import (
16+
from model_engine_server.inference.batch_inference.dto import (
1717
CompletionOutput,
1818
CreateBatchCompletionsEngineRequest,
1919
CreateBatchCompletionsRequestContent,
@@ -150,9 +150,9 @@ def get_vllm_engine(model: str, request: CreateBatchCompletionsEngineRequest):
150150

151151
engine_args = AsyncEngineArgs(
152152
model=model,
153-
quantization=request.model_config.quantize,
154-
tensor_parallel_size=request.model_config.num_shards,
155-
seed=request.model_config.seed or 0,
153+
quantization=request.model_cfg.quantize,
154+
tensor_parallel_size=request.model_cfg.num_shards,
155+
seed=request.model_cfg.seed or 0,
156156
disable_log_requests=True,
157157
gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9,
158158
)
@@ -316,18 +316,16 @@ async def batch_inference():
316316

317317
request = CreateBatchCompletionsEngineRequest.parse_file(CONFIG_FILE)
318318

319-
if request.model_config.checkpoint_path is not None:
320-
download_model(request.model_config.checkpoint_path, MODEL_WEIGHTS_FOLDER)
319+
if request.model_cfg.checkpoint_path is not None:
320+
download_model(request.model_cfg.checkpoint_path, MODEL_WEIGHTS_FOLDER)
321321

322322
content = request.content
323323
if content is None:
324324
with smart_open.open(request.input_data_path, "r") as f:
325325
content = CreateBatchCompletionsRequestContent.parse_raw(f.read())
326326

327-
model = (
328-
MODEL_WEIGHTS_FOLDER if request.model_config.checkpoint_path else request.model_config.model
329-
)
330-
is_finetuned = request.model_config.checkpoint_path is not None
327+
model = MODEL_WEIGHTS_FOLDER if request.model_cfg.checkpoint_path else request.model_cfg.model
328+
is_finetuned = request.model_cfg.checkpoint_path is not None
331329

332330
llm = get_vllm_engine(model, request)
333331

@@ -352,7 +350,7 @@ async def batch_inference():
352350
prompts,
353351
tool,
354352
is_finetuned,
355-
request.model_config.model,
353+
request.model_cfg.model,
356354
)
357355
else:
358356
bar = tqdm(total=len(prompts), desc="Processed prompts")
@@ -372,7 +370,7 @@ async def batch_inference():
372370
bar,
373371
use_tool=False,
374372
is_finetuned=is_finetuned,
375-
model=request.model_config.model,
373+
model=request.model_cfg.model,
376374
)
377375

378376
bar.close()
@@ -430,27 +428,25 @@ async def generate_with_vllm(
430428
skip_special_tokens=skip_special_tokens if skip_special_tokens is not None else True,
431429
)
432430
results_generator = await engine.add_request(
433-
request_id, prompt, sampling_params, None, time.monotonic()
431+
request_id, prompt, sampling_params, time.monotonic(), None
434432
)
435433
results_generators.append(results_generator)
436434

437435
outputs = []
438436
for generator in results_generators:
439-
last_output_text = ""
440437
tokens = []
441438
async for request_output in generator:
442439
if request_output.finished:
443440
bar.update(1)
444441

445-
token_text = request_output.outputs[-1].text[len(last_output_text) :]
446-
log_probs = request_output.outputs[0].logprobs[-1] if return_token_log_probs else None
447-
last_output_text = request_output.outputs[-1].text
448-
449442
if return_token_log_probs:
443+
output = request_output.outputs[0]
444+
log_probs = output.logprobs[-1] if return_token_log_probs else None
445+
token_id = output.token_ids[-1]
450446
tokens.append(
451447
TokenOutput(
452-
token=token_text,
453-
log_prob=log_probs[request_output.outputs[0].token_ids[-1]],
448+
token=log_probs[token_id].decoded_token,
449+
log_prob=log_probs[token_id].logprob,
454450
)
455451
)
456452

model-engine/tests/unit/inference/conftest.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from unittest.mock import MagicMock
22

33
import pytest
4-
from model_engine_server.common.dtos.llms import (
4+
from model_engine_server.inference.batch_inference.dto import (
55
CompletionOutput,
66
CreateBatchCompletionsEngineRequest,
77
CreateBatchCompletionsModelConfig,
8-
CreateBatchCompletionsRequest,
98
CreateBatchCompletionsRequestContent,
109
TokenOutput,
1110
ToolConfig,
@@ -14,16 +13,18 @@
1413

1514
@pytest.fixture
1615
def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineRequest:
16+
model_config = CreateBatchCompletionsModelConfig(
17+
model="model",
18+
checkpoint_path="checkpoint_path",
19+
labels={},
20+
seed=123,
21+
num_shards=4,
22+
)
1723
return CreateBatchCompletionsEngineRequest(
1824
input_data_path="input_data_path",
1925
output_data_path="output_data_path",
20-
model_config=CreateBatchCompletionsModelConfig(
21-
model="model",
22-
checkpoint_path="checkpoint_path",
23-
labels={},
24-
seed=123,
25-
num_shards=4,
26-
),
26+
model_cfg=model_config,
27+
model_config=model_config,
2728
data_parallelism=1,
2829
max_runtime_sec=86400,
2930
max_gpu_memory_utilization=0.95,
@@ -32,10 +33,13 @@ def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineReq
3233

3334
@pytest.fixture
3435
def create_batch_completions_tool_completion_request():
35-
return CreateBatchCompletionsRequest(
36-
model_config=CreateBatchCompletionsModelConfig(
37-
checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={}
38-
),
36+
model_config = CreateBatchCompletionsModelConfig(
37+
checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={}
38+
)
39+
40+
return CreateBatchCompletionsEngineRequest(
41+
model_cfg=model_config,
42+
model_config=model_config,
3943
data_parallelism=1,
4044
input_data_path="input_data_path",
4145
output_data_path="output_data_path",

0 commit comments

Comments
 (0)