Skip to content

Commit c354898

Browse files
committed
add gemma
1 parent a12f7e8 commit c354898

File tree

3 files changed

+125
-8
lines changed

3 files changed

+125
-8
lines changed

gpt_server/model_backend/lmdeploy_backend.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from typing import Any, Dict, AsyncGenerator
99
from lmdeploy.archs import get_task
10+
# from lmdeploy.serve.openai.reasoning_parser import ReasoningParserManager
1011
from lmdeploy.serve.async_engine import get_names_from_model
1112
from loguru import logger
1213
from gpt_server.model_backend.base import ModelBackend
@@ -87,6 +88,8 @@ def __init__(self, model_path) -> None:
8788
self.messages_type_select = (
8889
model_type[1] == "base"
8990
) # 如果为True 则使用 prompt:str 否则: messages:list
91+
# self.reasoning_parser = False
92+
# self.tokenizer = self.async_engine.tokenizer
9093

9194
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
9295
prompt = params.get("prompt", "")
@@ -131,12 +134,16 @@ async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
131134
results_generator = self.async_engine.generate(
132135
messages=messages, session_id=int(request_id), gen_config=gen_config
133136
)
134-
text_outputs = ""
137+
previous_text = ""
138+
current_text = ""
139+
previous_token_ids = []
140+
current_token_ids = []
141+
delta_token_ids = []
135142
async for request_output in results_generator:
136143
if await request.is_disconnected():
137144
# Abort the request if the client disconnects.
138145
await self.async_engine.stop_session(session_id=request_id)
139-
text_outputs += request_output.response
146+
current_text = current_text + request_output.response
140147

141148
usage = {
142149
"prompt_tokens": request_output.input_token_len,
@@ -145,16 +152,39 @@ async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
145152
+ request_output.generate_token_len,
146153
}
147154
ret = {
148-
"text": text_outputs,
155+
"text": current_text,
149156
"error_code": 0,
150157
"usage": usage,
151158
"finish_reason": request_output.finish_reason,
152159
}
160+
# if self.reasoning_parser is not None:
161+
# delta_token_ids = (
162+
# request_output.token_ids
163+
# if request_output.token_ids is not None
164+
# else []
165+
# )
166+
# current_token_ids = current_token_ids + delta_token_ids
167+
# reasoning_parser = ReasoningParserManager.get("deepseek-r1")(
168+
# self.tokenizer
169+
# )
170+
# reasoning_delta = reasoning_parser.reasoning_parser.extract_reasoning_content_streaming(
171+
# previous_text=previous_text,
172+
# current_text=current_text,
173+
# delta_text=request_output.response,
174+
# previous_token_ids=previous_token_ids,
175+
# current_token_ids=current_token_ids,
176+
# delta_token_ids=delta_token_ids,
177+
# )
178+
# if reasoning_delta is not None:
179+
# ret["text"] = reasoning_delta.content
180+
# ret["reasoning_content"] = reasoning_delta.reasoning_content
181+
# previous_text = current_text
182+
# previous_token_ids = current_token_ids
153183
# TODO -------------------------------------------------------------------
154184
output_info_list = []
155185
for stop_str in list(stop):
156186
if stop_str:
157-
text, bool_value = is_stop(output=text_outputs, stop_str=stop_str)
187+
text, bool_value = is_stop(output=current_text, stop_str=stop_str)
158188
output_info_list.append(
159189
{"text": text, "bool_value": bool_value, "text_len": len(text)}
160190
)
@@ -167,5 +197,5 @@ async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
167197
break
168198
# TODO -------------------------------------------------------------------
169199
yield ret
170-
logger.info(text_outputs)
200+
logger.info(current_text)
171201
logger.info(usage)

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,13 @@ def __init__(
5353
model_type: str = "AutoModel",
5454
multimodal: bool = False,
5555
):
56-
self.model_config = AutoConfig.from_pretrained(
57-
model_path, trust_remote_code=True
58-
)
56+
try:
57+
self.model_config = AutoConfig.from_pretrained(
58+
model_path, trust_remote_code=True
59+
)
60+
except ValueError as e:
61+
logger.warning(e)
62+
self.model_config = {}
5963
# logger.info(f"模型配置:{self.model_config}")
6064
self.vision_config = getattr(self.model_config, "vision_config", None)
6165
is_vision = self.vision_config is not None

gpt_server/model_worker/gemma.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import json
2+
from typing import List
3+
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
4+
import torch
5+
from loguru import logger
6+
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
7+
8+
9+
class GemmaWorker(ModelWorkerBase):
10+
def __init__(
11+
self,
12+
controller_addr: str,
13+
worker_addr: str,
14+
worker_id: str,
15+
model_path: str,
16+
model_names: List[str],
17+
limit_worker_concurrency: int,
18+
conv_template: str = None, # type: ignore
19+
):
20+
super().__init__(
21+
controller_addr,
22+
worker_addr,
23+
worker_id,
24+
model_path,
25+
model_names,
26+
limit_worker_concurrency,
27+
conv_template,
28+
model_type="AutoModelForCausalLM",
29+
)
30+
self.stop_words_ids = [1, 106]
31+
self.stop = [
32+
self.tokenizer.decode(skip_word) for skip_word in self.stop_words_ids
33+
]
34+
logger.info(f"{model_names[0]} 停用词: {self.stop}")
35+
36+
async def generate_stream_gate(self, params):
37+
self.call_ct += 1
38+
logger.info(f"params {params}")
39+
logger.info(f"worker_id: {self.worker_id}")
40+
try:
41+
messages = params["messages"]
42+
if isinstance(messages, list):
43+
task = "chat"
44+
elif isinstance(messages, str):
45+
task = "completion"
46+
if task == "chat":
47+
text = self.tokenizer.apply_chat_template(
48+
conversation=messages,
49+
tokenize=True,
50+
add_generation_prompt=True,
51+
)
52+
elif task == "completion":
53+
text = messages
54+
55+
input_ids = self.tokenizer([text], return_tensors="pt").input_ids
56+
params["messages"] = messages
57+
params["prompt"] = text
58+
params["stop"].extend(self.stop)
59+
params["stop_words_ids"] = self.stop_words_ids
60+
params["input_ids"] = input_ids
61+
62+
async for ret in self.backend.stream_chat(params=params):
63+
response = ret["text"]
64+
65+
yield json.dumps(ret).encode() + b"\0"
66+
67+
except torch.cuda.OutOfMemoryError as e:
68+
ret = {
69+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
70+
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
71+
}
72+
yield json.dumps(ret).encode() + b"\0"
73+
except (ValueError, RuntimeError) as e:
74+
logger.info(e)
75+
ret = {
76+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
77+
"error_code": ErrorCode.INTERNAL_ERROR,
78+
}
79+
yield json.dumps(ret).encode() + b"\0"
80+
81+
82+
if __name__ == "__main__":
83+
GemmaWorker.run()

0 commit comments

Comments
 (0)