Skip to content

Commit 8189c86

Browse files
authored
Adjust chat loop routing for fastchat/transformers (#178)
1 parent 2fdebf6 commit 8189c86

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

qllm/plugin/chatcli/generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def generate_stream(model, tokenizer, prompt: str, device, max_new_tokens: int,
1010

1111
lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0)
1212

13-
past_kvs = transformers.DynamicCache()
13+
past_kvs = None
1414
output_ids = list(inputs.input_ids)
1515
input_echo_len = len(output_ids)
1616

@@ -70,7 +70,7 @@ def generate(model, tokenizer, prompt: str, max_new_tokens:int, context_len: int
7070

7171
lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0)
7272

73-
past_kvs = transformers.DynamicCache()
73+
past_kvs = None
7474
output_ids = list(inputs.input_ids)
7575
input_echo_len = len(output_ids)
7676

qllm/plugin/chatcli/inference.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import time
22
import torch
3+
import transformers
4+
from packaging.version import Version, InvalidVersion
35
try:
46
import fastchat
57
#from fastchat.conversation import Conversation, SeparatorStyle
@@ -28,9 +30,16 @@ def chat_loop(
2830
debug: bool = True,
2931
echo: bool = False,
3032
):
33+
model_type = str(type(model)).lower()
34+
use_fastchat_v2 = False
3135
if _fastchat_available:
36+
try:
37+
use_fastchat_v2 = Version(transformers.__version__) < Version("4.3") and "llama" not in model_type
38+
except InvalidVersion:
39+
use_fastchat_v2 = False
40+
41+
if use_fastchat_v2:
3242
return chat_loop_v2(model, tokenizer)
33-
model_type = str(type(model)).lower()
3443
if "llama" not in model_type and hasattr(tokenizer, 'apply_chat_template'):
3544
return chat_loop_v3(model, tokenizer)
3645
assert "llama" in model_type, 'have you installed fschat? please run `pip install fschat` and try again.'

0 commit comments

Comments
 (0)