forked from deepjavalibrary/djl-serving
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvllm_chat_utils.py
More file actions
100 lines (86 loc) · 3.7 KB
/
vllm_chat_utils.py
File metadata and controls
100 lines (86 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict, List, Optional, Union
from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from djl_python.rolling_batch.rolling_batch_vllm_utils import maybe_serialize_tool_calls
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
def parse_chat_completions_request_vllm(
input_map: Dict,
is_rolling_batch: bool,
rolling_batch,
tokenizer,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
# Chat completions can either be a rolling batch or no-batching .
if not (is_rolling_batch or configs.batch_size == 1):
raise ValueError(
"chat completions support is not currently available for dynamic batching. "
"You must enable rolling batch to use the chat completions format."
)
tool_parser = rolling_batch.get_tool_parser()
chat_params = ChatProperties(**input_map)
if chat_params.tool_choice == "required":
raise ValueError("tool_choice = \"required\" is not supported!")
if is_mistral_tokenizer:
maybe_serialize_tool_calls(chat_params)
elif chat_params.tool_choice == "auto" and tool_parser is None:
raise ValueError(
"\"auto\" tool choice requires tool_call_parser to be available")
should_parse_tools = tool_parser is not None and (hasattr(
chat_params, "tool_choice") and chat_params.tool_choice != "none")
if should_parse_tools:
chat_params = tool_parser.adjust_request(request=chat_params)
exclude = {"messages"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)
tool_dicts = None if chat_params.tools is None else [
tool.model_dump() for tool in chat_params.tools
]
# TODO - figure out what we need to pass for given format
content_format = resolve_chat_template_content_format(
chat_template=None,
given_format="auto",
tokenizer=tokenizer,
)
conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer,
content_format)
prompt_data: Union[str, List[int]]
if is_mistral_tokenizer:
text_inputs = apply_mistral_chat_template(
tokenizer,
chat_params.messages,
None,
tools=tool_dicts,
)
else:
text_inputs = apply_hf_chat_template(
tokenizer,
conversation,
None,
add_generation_prompt=True,
tools=tool_dicts,
)
param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"
param["tool_parser"] = tool_parser
param["chat_params"] = chat_params
if mm_data:
param["mm_data"] = mm_data
# In the case of mistral, text_inputs = List[TokenIds], else = str
return text_inputs, param