Skip to content

Commit 90a81a7

Browse files
committed
BUG: repair qwen3 model transformers random characters (#4148)
1 parent 8ac4e33 commit 90a81a7

File tree

4 files changed

+26
-61
lines changed

4 files changed

+26
-61
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ dev =
7979
anthropic
8080
langchain
8181
langchain-community
82+
langchain-openai
8283
orjson
8384
sphinx-tabs
8485
sphinx-design

xinference/core/tests/test_restful_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,13 +1178,13 @@ def test_lang_chain(setup):
11781178
model_uid_res = response_data["model_uid"]
11791179
assert model_uid_res == "test_restful_api"
11801180

1181-
from langchain.chat_models import ChatOpenAI
1182-
from langchain.prompts.chat import (
1181+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
1182+
from langchain_core.prompts import (
11831183
ChatPromptTemplate,
11841184
HumanMessagePromptTemplate,
11851185
SystemMessagePromptTemplate,
11861186
)
1187-
from langchain.schema import AIMessage, HumanMessage, SystemMessage
1187+
from langchain_openai import ChatOpenAI
11881188

11891189
inference_server_url = f"{endpoint}/v1"
11901190

@@ -1204,7 +1204,7 @@ def test_lang_chain(setup):
12041204
content="Translate the following sentence from English to Italian: I love programming."
12051205
),
12061206
]
1207-
r = chat(messages)
1207+
r = chat.invoke(messages)
12081208
assert type(r) == AIMessage
12091209
assert r.content
12101210

@@ -1218,7 +1218,7 @@ def test_lang_chain(setup):
12181218
)
12191219

12201220
# get a chat completion from the formatted messages
1221-
r = chat(
1221+
r = chat.invoke(
12221222
chat_prompt.format_prompt(
12231223
input_language="English",
12241224
output_language="Italian",

xinference/model/llm/transformers/core.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -549,46 +549,30 @@ def build_decode_attention_mask(
549549
So we need pad `0` on the left again.
550550
"""
551551
data = []
552-
# For decode phase, attention mask should match the full KV cache sequence length
553-
# All requests in batch should have attention mask of length `seq_length`
554-
for r in reqs:
555-
# Get the actual sequence length for this request from its tracking
556-
if "attention_mask_seq_len" not in r.extra_kwargs:
557-
# Initialize with the current sequence length (full KV cache length)
558-
r.extra_kwargs["attention_mask_seq_len"] = seq_length
559-
else:
560-
# Use the previously tracked length, but ensure it doesn't exceed current seq_length
561-
tracked_len = r.extra_kwargs["attention_mask_seq_len"]
562-
r.extra_kwargs["attention_mask_seq_len"] = min(tracked_len, seq_length)
563-
564-
# For decode phase after KV cache merge, all requests should have attention mask
565-
# that matches the merged sequence length
552+
max_len = max(r.extra_kwargs["attention_mask_seq_len"] for r in reqs) + 1
566553
for r in reqs:
554+
r.extra_kwargs["attention_mask_seq_len"] += 1
567555
real_len = r.extra_kwargs["attention_mask_seq_len"]
556+
pad_len = max_len - real_len
568557

569-
# The attention mask should cover the full sequence length
570-
if real_len < seq_length:
571-
# Pad with zeros on the left to reach full sequence length
572-
pad_len = seq_length - real_len
573-
574-
if self._tokenizer.padding_side == "left":
575-
x = torch.cat(
576-
[
577-
torch.full((pad_len,), 0, dtype=torch.long),
578-
torch.ones((real_len,), dtype=torch.long),
579-
]
580-
)
581-
else:
582-
x = torch.cat(
583-
[
584-
torch.ones((real_len,), dtype=torch.long),
585-
torch.full((pad_len,), 0, dtype=torch.long),
586-
]
587-
)
558+
if self._tokenizer.padding_side == "left":
559+
x = torch.cat(
560+
[
561+
(
562+
torch.full((pad_len,), 0, dtype=torch.long)
563+
if pad_len > 0
564+
else torch.tensor([], dtype=torch.long)
565+
),
566+
torch.ones((real_len,), dtype=torch.long),
567+
]
568+
)
588569
else:
589-
# Already at correct length
590-
x = torch.ones((real_len,), dtype=torch.long)
591-
570+
x = torch.cat(
571+
[
572+
torch.ones((real_len,), dtype=torch.long),
573+
torch.full((pad_len,), 0, dtype=torch.long),
574+
]
575+
)
592576
data.append(x)
593577

594578
return torch.stack(data).to(self._device)

xinference/model/llm/transformers/utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -285,30 +285,10 @@ def _batch_inference_one_step_internal(
285285
# This prevents batch size mismatches during merging
286286
decode_kv = decode_reqs[0].kv_cache
287287

288-
# Verify that all decode requests share the same kv_cache
289-
for req in decode_reqs[1:]:
290-
if req.kv_cache is not decode_kv:
291-
logger.warning(
292-
"Inconsistent kv_cache references detected in decode requests. "
293-
"This may indicate a batching synchronization issue."
294-
)
295-
# Use the first decode_kv as the reference to maintain consistency
296-
req.kv_cache = decode_kv
297-
298288
# prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
299289
merged_kv_cache = xinf_model_obj.merge_kv_cache(decode_kv, past_key_values)
300-
# Update sequence length information after KV cache merge
301-
_, merged_seq_len = get_batch_size_and_seq_len_from_kv_cache(
302-
merged_kv_cache, xinf_model_obj
303-
)
304290
for r in valid_req_list:
305291
r.kv_cache = merged_kv_cache
306-
# Update attention mask sequence length to match merged KV cache
307-
if "attention_mask_seq_len" in r.extra_kwargs:
308-
# Ensure the attention mask length doesn't exceed the merged sequence length
309-
r.extra_kwargs["attention_mask_seq_len"] = min(
310-
r.extra_kwargs["attention_mask_seq_len"], merged_seq_len - 1
311-
)
312292
empty_cache()
313293
else:
314294
for r in valid_req_list:

0 commit comments

Comments
 (0)