Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit be0632b

Browse files
author
vmpuri
committed
Include empty assistant message for chat
1 parent 26a99fc commit be0632b

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

torchchat/generate.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ def generate(
603603
if len(prompt.shape) > 1:
604604
prompt = prompt.squeeze(0)
605605
prompt_length = prompt.size(0)
606+
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
606607
# set up caches only if first inference
607608
if start_pos == 0:
608609
model = model.to(device=device)
@@ -825,6 +826,12 @@ def _gen_model_input(
825826
content=content,
826827
)
827828
)
829+
messages.append(
830+
Message(
831+
role="assistant",
832+
content="",
833+
)
834+
)
828835

829836
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
830837

@@ -849,7 +856,7 @@ def _gen_model_input(
849856
seq_len = encoded.size(0)
850857
batch = {}
851858

852-
total_response_length = max_seq_len + max_new_tokens
859+
total_response_length = seq_len + max_new_tokens
853860
batch["causal_mask"] = torch.nn.functional.pad(
854861
torch.tril(
855862
torch.ones(

torchchat/usages/browser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from openai import OpenAI
1212

13+
st.set_page_config(page_title="torchchat", page_icon="🤖")
1314
st.title("torchchat")
1415

1516

0 commit comments

Comments
 (0)