This repository was archived by the owner on Sep 10, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -603,6 +603,7 @@ def generate(
603603 if len (prompt .shape ) > 1 :
604604 prompt = prompt .squeeze (0 )
605605 prompt_length = prompt .size (0 )
606+ max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
606607 # set up caches only if first inference
607608 if start_pos == 0 :
608609 model = model .to (device = device )
@@ -825,6 +826,12 @@ def _gen_model_input(
825826 content = content ,
826827 )
827828 )
829+ messages .append (
830+ Message (
831+ role = "assistant" ,
832+ content = "" ,
833+ )
834+ )
828835
829836 transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
830837
@@ -849,7 +856,7 @@ def _gen_model_input(
849856 seq_len = encoded .size (0 )
850857 batch = {}
851858
852- total_response_length = max_seq_len + max_new_tokens
859+ total_response_length = seq_len + max_new_tokens
853860 batch ["causal_mask" ] = torch .nn .functional .pad (
854861 torch .tril (
855862 torch .ones (
Original file line number Diff line number Diff line change 1010
1111from openai import OpenAI
1212
13+ st .set_page_config (page_title = "torchchat" , page_icon = "🤖" )
1314st .title ("torchchat" )
1415
1516
You can’t perform that action at this time.
0 commit comments