Skip to content

Commit 0980f72

Browse files
committed
add extract_prompt()
Signed-off-by: chickeyton <[email protected]>
1 parent f93d8c3 commit 0980f72

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed

src/vllm_router/routers/routing_logic.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,32 @@
4848
logger = init_logger(__name__)
4949

5050

51+
def extract_prompt(request_json: Dict):
52+
"""Extract prompt message from the request json object."""
53+
if "messages" in request_json:
54+
# Get the last message from the messages array
55+
messages = request_json["messages"]
56+
if messages:
57+
# Concatenate all message content
58+
prompt_parts = []
59+
for message in messages:
60+
content = message.get("content", "")
61+
if isinstance(content, list):
62+
# Handle multimodal messages
63+
text_content = " ".join(
64+
part.get("text", "")
65+
for part in content
66+
if part.get("type") == "text"
67+
)
68+
prompt_parts.append(text_content)
69+
elif content is not None:
70+
prompt_parts.append(content)
71+
return "\n".join(prompt_parts)
72+
return ""
73+
# Handle regular completions
74+
return request_json["prompt"]
75+
76+
5177
class RoutingLogic(str, enum.Enum):
5278
ROUND_ROBIN = "roundrobin"
5379
SESSION_BASED = "session"
@@ -299,7 +325,7 @@ async def route_request(
299325
self.tokenizer = AutoTokenizer.from_pretrained(endpoints[0].model_names[0])
300326
url = endpoints[0].url + "/tokenize"
301327
# TODO (Yuhan): Handle chat completions
302-
token_ids = self.tokenizer.encode(request_json["prompt"])
328+
token_ids = self.tokenizer.encode(extract_prompt(request_json))
303329
msg = LookupMsg(event_id="", tokens=token_ids)
304330
instance_id = await self.query_manager(msg)
305331
matched_tokens = math.inf
@@ -390,33 +416,7 @@ async def route_request(
390416
request_json (Dict): The request body (needed for finding the
391417
longest prefix match)
392418
"""
393-
394-
# Handle chat completions
395-
if "messages" in request_json:
396-
# Get the last message from the messages array
397-
messages = request_json["messages"]
398-
if messages:
399-
# Concatenate all message content
400-
prompt_parts = []
401-
for message in messages:
402-
content = message.get("content", "")
403-
if isinstance(content, list):
404-
# Handle multimodal messages
405-
text_content = " ".join(
406-
part.get("text", "")
407-
for part in content
408-
if part.get("type") == "text"
409-
)
410-
prompt_parts.append(text_content)
411-
elif content is not None:
412-
prompt_parts.append(content)
413-
prompt = "\n".join(prompt_parts)
414-
else:
415-
prompt = ""
416-
else:
417-
# Handle regular completions
418-
prompt = request_json["prompt"]
419-
419+
prompt = extract_prompt(request_json)
420420
available_endpoints = set(endpoint.url for endpoint in endpoints)
421421
_, matched_endpoint = await self.hashtrie.longest_prefix_match(
422422
prompt, available_endpoints
@@ -539,7 +539,7 @@ async def route_request(
539539
# fallback to use the model of the first endpoint as tokenizer
540540
self.tokenizer = AutoTokenizer.from_pretrained(endpoints[0].model_names[0])
541541

542-
token_ids = self.tokenizer.encode(request_json["prompt"])
542+
token_ids = self.tokenizer.encode(extract_prompt(request_json))
543543
try:
544544
if request_stats is None:
545545
ValueError("no request stats was provided")

0 commit comments

Comments
 (0)