|
48 | 48 | logger = init_logger(__name__)
|
49 | 49 |
|
50 | 50 |
|
| 51 | +def extract_prompt(request_json: Dict): |
| 52 | + """Extract prompt message from the request json object.""" |
| 53 | + if "messages" in request_json: |
| 54 | + # Get the last message from the messages array |
| 55 | + messages = request_json["messages"] |
| 56 | + if messages: |
| 57 | + # Concatenate all message content |
| 58 | + prompt_parts = [] |
| 59 | + for message in messages: |
| 60 | + content = message.get("content", "") |
| 61 | + if isinstance(content, list): |
| 62 | + # Handle multimodal messages |
| 63 | + text_content = " ".join( |
| 64 | + part.get("text", "") |
| 65 | + for part in content |
| 66 | + if part.get("type") == "text" |
| 67 | + ) |
| 68 | + prompt_parts.append(text_content) |
| 69 | + elif content is not None: |
| 70 | + prompt_parts.append(content) |
| 71 | + return "\n".join(prompt_parts) |
| 72 | + return "" |
| 73 | + # Handle regular completions |
| 74 | + return request_json["prompt"] |
| 75 | + |
| 76 | + |
51 | 77 | class RoutingLogic(str, enum.Enum):
|
52 | 78 | ROUND_ROBIN = "roundrobin"
|
53 | 79 | SESSION_BASED = "session"
|
@@ -299,7 +325,7 @@ async def route_request(
|
299 | 325 | self.tokenizer = AutoTokenizer.from_pretrained(endpoints[0].model_names[0])
|
300 | 326 | url = endpoints[0].url + "/tokenize"
|
301 | 327 | # TODO (Yuhan): Handle chat completions
|
302 |
| - token_ids = self.tokenizer.encode(request_json["prompt"]) |
| 328 | + token_ids = self.tokenizer.encode(extract_prompt(request_json)) |
303 | 329 | msg = LookupMsg(event_id="", tokens=token_ids)
|
304 | 330 | instance_id = await self.query_manager(msg)
|
305 | 331 | matched_tokens = math.inf
|
@@ -390,33 +416,7 @@ async def route_request(
|
390 | 416 | request_json (Dict): The request body (needed for finding the
|
391 | 417 | longest prefix match)
|
392 | 418 | """
|
393 |
| - |
394 |
| - # Handle chat completions |
395 |
| - if "messages" in request_json: |
396 |
| - # Get the last message from the messages array |
397 |
| - messages = request_json["messages"] |
398 |
| - if messages: |
399 |
| - # Concatenate all message content |
400 |
| - prompt_parts = [] |
401 |
| - for message in messages: |
402 |
| - content = message.get("content", "") |
403 |
| - if isinstance(content, list): |
404 |
| - # Handle multimodal messages |
405 |
| - text_content = " ".join( |
406 |
| - part.get("text", "") |
407 |
| - for part in content |
408 |
| - if part.get("type") == "text" |
409 |
| - ) |
410 |
| - prompt_parts.append(text_content) |
411 |
| - elif content is not None: |
412 |
| - prompt_parts.append(content) |
413 |
| - prompt = "\n".join(prompt_parts) |
414 |
| - else: |
415 |
| - prompt = "" |
416 |
| - else: |
417 |
| - # Handle regular completions |
418 |
| - prompt = request_json["prompt"] |
419 |
| - |
| 419 | + prompt = extract_prompt(request_json) |
420 | 420 | available_endpoints = set(endpoint.url for endpoint in endpoints)
|
421 | 421 | _, matched_endpoint = await self.hashtrie.longest_prefix_match(
|
422 | 422 | prompt, available_endpoints
|
@@ -539,7 +539,7 @@ async def route_request(
|
539 | 539 | # fallback to use the model of the first endpoint as tokenizer
|
540 | 540 | self.tokenizer = AutoTokenizer.from_pretrained(endpoints[0].model_names[0])
|
541 | 541 |
|
542 |
| - token_ids = self.tokenizer.encode(request_json["prompt"]) |
| 542 | + token_ids = self.tokenizer.encode(extract_prompt(request_json)) |
543 | 543 | try:
|
544 | 544 | if request_stats is None:
|
545 | 545 | ValueError("no request stats was provided")
|
|
0 commit comments