Skip to content

Commit f8e7add

Browse files
Fix/async chat serving (#2727)
1 parent 7e65477 commit f8e7add

File tree

5 files changed

+73
-21
lines changed

5 files changed

+73
-21
lines changed

tests/async_engine/test_chat_template.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ class MockServingChat:
6060
tokenizer: MockTokenizer
6161

6262

63-
def test_load_chat_template():
63+
@pytest.mark.asyncio
64+
async def test_load_chat_template():
6465
# Testing chatml template
6566
tokenizer = MockTokenizer()
6667
mock_serving_chat = MockServingChat(tokenizer)
67-
OpenAIServingChat._load_chat_template(mock_serving_chat,
68-
chat_template=chatml_jinja_path)
68+
await OpenAIServingChat._load_chat_template(
69+
mock_serving_chat, chat_template=chatml_jinja_path)
6970

7071
template_content = tokenizer.chat_template
7172

@@ -76,26 +77,28 @@ def test_load_chat_template():
7677
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
7778

7879

79-
def test_no_load_chat_template_filelike():
80+
@pytest.mark.asyncio
81+
async def test_no_load_chat_template_filelike():
8082
# Testing chatml template
8183
template = "../../examples/does_not_exist"
8284
tokenizer = MockTokenizer()
8385

8486
mock_serving_chat = MockServingChat(tokenizer)
8587

8688
with pytest.raises(ValueError, match="looks like a file path"):
87-
OpenAIServingChat._load_chat_template(mock_serving_chat,
88-
chat_template=template)
89+
await OpenAIServingChat._load_chat_template(mock_serving_chat,
90+
chat_template=template)
8991

9092

91-
def test_no_load_chat_template_literallike():
93+
@pytest.mark.asyncio
94+
async def test_no_load_chat_template_literallike():
9295
# Testing chatml template
9396
template = "{{ messages }}"
9497
tokenizer = MockTokenizer()
9598

9699
mock_serving_chat = MockServingChat(tokenizer)
97-
OpenAIServingChat._load_chat_template(mock_serving_chat,
98-
chat_template=template)
100+
await OpenAIServingChat._load_chat_template(mock_serving_chat,
101+
chat_template=template)
99102
template_content = tokenizer.chat_template
100103

101104
assert template_content == template
@@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
110113
# Initialize the tokenizer
111114
tokenizer = get_tokenizer(tokenizer_name=model)
112115
mock_serving_chat = MockServingChat(tokenizer)
113-
OpenAIServingChat._load_chat_template(mock_serving_chat,
114-
chat_template=template)
116+
await OpenAIServingChat._load_chat_template(mock_serving_chat,
117+
chat_template=template)
115118

116119
# Create a mock request object using keyword arguments
117120
mock_request = ChatCompletionRequest(
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import asyncio
2+
from dataclasses import dataclass
3+
4+
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
5+
6+
MODEL_NAME = "openai-community/gpt2"
7+
CHAT_TEMPLATE = "Dummy chat template for testing {}"
8+
9+
10+
@dataclass
11+
class MockModelConfig:
12+
tokenizer = MODEL_NAME
13+
trust_remote_code = False
14+
tokenizer_mode = "auto"
15+
max_model_len = 100
16+
tokenizer_revision = None
17+
18+
19+
@dataclass
20+
class MockEngine:
21+
22+
async def get_model_config(self):
23+
return MockModelConfig
24+
25+
26+
async def _async_serving_chat_init():
27+
serving_completion = OpenAIServingChat(MockEngine(),
28+
served_model_names=[MODEL_NAME],
29+
response_role="assistant",
30+
chat_template=CHAT_TEMPLATE)
31+
return serving_completion
32+
33+
34+
def test_async_serving_chat_init():
35+
serving_completion = asyncio.run(_async_serving_chat_init())
36+
assert serving_completion.tokenizer is not None
37+
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE

tests/entrypoints/test_openai_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def server(zephyr_lora_files):
150150
ray.shutdown()
151151

152152

153-
@pytest.fixture(scope="session")
153+
@pytest.fixture(scope="module")
154154
def client():
155155
client = openai.AsyncOpenAI(
156156
base_url="http://localhost:8000/v1",

vllm/entrypoints/openai/serving_chat.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import codecs
23
import time
34
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
@@ -40,9 +41,11 @@ def __init__(self,
4041
chat_template: Optional[str] = None):
4142
super().__init__(engine=engine,
4243
served_model_names=served_model_names,
43-
lora_modules=lora_modules)
44+
lora_modules=lora_modules,
45+
await_post_init=self._load_chat_template(
46+
chat_template=chat_template))
47+
4448
self.response_role = response_role
45-
self._load_chat_template(chat_template)
4649

4750
def _parse_chat_message_content(
4851
self,
@@ -356,7 +359,10 @@ async def chat_completion_full_generator(
356359

357360
return response
358361

359-
def _load_chat_template(self, chat_template: Optional[str]):
362+
async def _load_chat_template(self, chat_template: Optional[str]):
363+
while self.tokenizer is None:
364+
# Give the parent class time to load the tokenizer
365+
await asyncio.sleep(0.1)
360366
tokenizer = self.tokenizer
361367

362368
if chat_template is not None:

vllm/entrypoints/openai/serving_engine.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
from dataclasses import dataclass
44
from http import HTTPStatus
5-
from typing import Dict, List, Optional, Tuple, Union
5+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
66

77
from pydantic import Field
88
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -29,8 +29,11 @@ class LoRAModulePath:
2929

3030
class OpenAIServing:
3131

32-
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
33-
lora_modules: Optional[List[LoRAModulePath]]):
32+
def __init__(self,
33+
engine: AsyncLLMEngine,
34+
served_model_names: List[str],
35+
lora_modules: Optional[List[LoRAModulePath]],
36+
await_post_init: Optional[Awaitable[Any]] = None):
3437
self.engine = engine
3538
self.served_model_names = served_model_names
3639
if lora_modules is None:
@@ -56,12 +59,12 @@ def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
5659
if event_loop is not None and event_loop.is_running():
5760
# If the current is instanced by Ray Serve,
5861
# there is already a running event loop
59-
event_loop.create_task(self._post_init())
62+
event_loop.create_task(self._post_init(await_post_init))
6063
else:
6164
# When using single vLLM without engine_use_ray
62-
asyncio.run(self._post_init())
65+
asyncio.run(self._post_init(await_post_init))
6366

64-
async def _post_init(self):
67+
async def _post_init(self, await_post_init):
6568
engine_model_config = await self.engine.get_model_config()
6669
self.max_model_len = engine_model_config.max_model_len
6770

@@ -73,6 +76,9 @@ async def _post_init(self):
7376
trust_remote_code=engine_model_config.trust_remote_code,
7477
truncation_side="left")
7578

79+
if await_post_init is not None:
80+
await await_post_init
81+
7682
async def show_available_models(self) -> ModelList:
7783
"""Show available models. Right now we only have one model."""
7884
model_cards = [

0 commit comments

Comments
 (0)