Skip to content

Commit a4209da

Browse files
committed
优化lmdepoly 后端 logger prompt
1 parent 9eaf563 commit a4209da

File tree

4 files changed

+39
-5
lines changed

4 files changed

+39
-5
lines changed

gpt_server/model_backend/hf_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, tokenizer, model: torch.nn.Module) -> None:
5656

5757
async def stream_chat(self, params: Dict[str, Any]):
5858
prompt = params.get("prompt", "")
59-
logger.info(prompt)
59+
logger.info(f"prompt\n{prompt}")
6060
temperature = float(params.get("temperature", 0.8))
6161
top_p = float(params.get("top_p", 0.8))
6262
max_new_tokens = int(params.get("max_new_tokens", 512))

gpt_server/model_backend/lmdeploy_backend.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
PytorchEngineConfig,
77
)
88
from transformers import PreTrainedTokenizerBase
9-
from typing import Any, Dict, AsyncGenerator
9+
from typing import Any, Dict, AsyncGenerator, List, Optional
1010
from lmdeploy.archs import get_task
1111
from gpt_server.model_handler.reasoning_parser import ReasoningParserManager
1212
from lmdeploy.serve.async_engine import get_names_from_model
@@ -60,6 +60,39 @@ def is_messages_with_tool(messages: list):
6060
return flag
6161

6262

63+
from lmdeploy.logger import RequestLogger
64+
65+
66+
class CustomRequestLogger(RequestLogger):
67+
def log_prompt(self, session_id: int, prompt: str) -> None:
68+
if not isinstance(prompt, str):
69+
# Prompt may be a GPT4V message with base64 images;
70+
# logging might be impractical due to length
71+
return
72+
73+
def log_inputs(
74+
self,
75+
session_id: int,
76+
prompt: Optional[str],
77+
prompt_token_ids: Optional[List[int]],
78+
gen_config: GenerationConfig,
79+
adapter_name: str,
80+
) -> None:
81+
max_log_len = self.max_log_len
82+
input_tokens = len(prompt_token_ids)
83+
if max_log_len is not None:
84+
if prompt is not None:
85+
prompt = prompt[:max_log_len]
86+
87+
if prompt_token_ids is not None:
88+
prompt_token_ids = prompt_token_ids[:max_log_len]
89+
90+
logger.info(
91+
f"session_id={session_id} adapter_name={adapter_name} gen_config={gen_config}"
92+
)
93+
logger.info(f"prompt:\n{prompt}")
94+
95+
6396
class LMDeployBackend(ModelBackend):
6497
def __init__(self, model_path, tokenizer: PreTrainedTokenizerBase) -> None:
6598
model_config = get_model_config()
@@ -95,6 +128,8 @@ def __init__(self, model_path, tokenizer: PreTrainedTokenizerBase) -> None:
95128
self.chat_template_name = chat_template_name
96129
self.tokenizer = self.async_engine.tokenizer
97130
self.reasoning_parser_cache = {}
131+
# 自定义日志
132+
self.async_engine.request_logger = CustomRequestLogger(max_log_len=None)
98133

99134
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
100135
prompt = params.get("prompt", "")
@@ -141,7 +176,6 @@ async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
141176
messages = params["messages"]
142177
if isinstance(messages, str):
143178
logger.info(f"使用prompt模式")
144-
logger.info(prompt)
145179
else:
146180
logger.info(f"使用messages模式")
147181
results_generator = self.async_engine.generate(

gpt_server/model_backend/sglang_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, model_path, tokenizer: PreTrainedTokenizerBase) -> None:
6565
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
6666
prompt = params.get("prompt", "")
6767
messages = params["messages"]
68-
logger.info(prompt)
68+
logger.info(f"prompt\n{prompt}")
6969
request_id = params.get("request_id", "0")
7070
temperature = float(params.get("temperature", 0.8))
7171
top_p = float(params.get("top_p", 0.8))

gpt_server/model_backend/vllm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, model_path, tokenizer: AutoTokenizer) -> None:
6161
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
6262
prompt = params.get("prompt", "")
6363
messages = params["messages"]
64-
logger.info(prompt)
64+
logger.info(f"prompt\n{prompt}")
6565
request_id = params.get("request_id", "0")
6666
temperature = float(params.get("temperature", 0.8))
6767
top_p = float(params.get("top_p", 0.8))

0 commit comments

Comments
 (0)