Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1abd632

Browse files
author
vmpuri
committed
Include empty assistant message for chat
1 parent 26a99fc commit 1abd632

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

torchchat/generate.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ def generate(
603603
if len(prompt.shape) > 1:
604604
prompt = prompt.squeeze(0)
605605
prompt_length = prompt.size(0)
606+
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
606607
# set up caches only if first inference
607608
if start_pos == 0:
608609
model = model.to(device=device)
@@ -825,6 +826,12 @@ def _gen_model_input(
825826
content=content,
826827
)
827828
)
829+
messages.append(
830+
Message(
831+
role="assistant",
832+
content="",
833+
)
834+
)
828835

829836
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
830837

@@ -849,7 +856,7 @@ def _gen_model_input(
849856
seq_len = encoded.size(0)
850857
batch = {}
851858

852-
total_response_length = max_seq_len + max_new_tokens
859+
total_response_length = seq_len + max_new_tokens
853860
batch["causal_mask"] = torch.nn.functional.pad(
854861
torch.tril(
855862
torch.ones(

torchchat/usages/browser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from openai import OpenAI
1212

13+
st.set_page_config(page_title="torchchat", page_icon="🤖")
1314
st.title("torchchat")
1415

1516

@@ -26,15 +27,18 @@ def reset_chat():
2627
st.session_state["messages"] = start_state
2728
st.session_state["conversation_images"] = []
2829

30+
2931
if "messages" not in st.session_state:
3032
st.session_state.messages = start_state
3133
if "conversation_images" not in st.session_state:
3234
st.session_state.conversation_images = []
3335

36+
3437
def _upload_image_prompts(file_uploads):
3538
for file in file_uploads:
3639
st.session_state.conversation_images.append(file)
3740

41+
3842
with st.sidebar:
3943
if st.button("Reset Chat", type="primary"):
4044
reset_chat()
@@ -105,7 +109,6 @@ def _upload_image_prompts(file_uploads):
105109
for img in st.session_state.conversation_images:
106110
st.image(img)
107111
st.session_state.conversation_images = []
108-
109112

110113
with st.chat_message("assistant"), st.status(
111114
"Generating... ", expanded=True

0 commit comments

Comments
 (0)