@@ -591,6 +591,7 @@ def generate(
591591 Dict [str , Any ]
592592 ] = None , # List of Image prompt tensors for multimodal models
593593 start_pos : int = 0 ,
594+ skip_cache_setup : bool = False ,
594595 draft_model : Model ,
595596 speculate_k : Optional [int ] = 8 ,
596597 sequential_prefill = True ,
@@ -614,26 +615,27 @@ def generate(
614615 max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
615616 # set up caches only if first inference
616617 if start_pos == 0 :
617- model = model .to (device = device )
618- with torch .device (device ):
619- if (
620- self .is_torchtune_model
621- or self .model .config .model_type == ModelType .Flamingo
622- ):
623- # 6404 is one-gpu affordable max_seq_length for single image input
624- model .setup_caches (
625- batch_size = 1 ,
626- dtype = self .dtype ,
627- encoder_max_seq_len = 6404 ,
628- decoder_max_seq_len = max_seq_length ,
629- )
630- else :
631- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
632- if is_speculative and draft_model is not model :
633- draft_model .setup_caches (
634- max_batch_size = 1 ,
635- max_seq_length = max_seq_length ,
636- )
618+ if not skip_cache_setup :
619+ model = model .to (device = device )
620+ with torch .device (device ):
621+ if (
622+ self .is_torchtune_model
623+ or self .model .config .model_type == ModelType .Flamingo
624+ ):
625+ # 6404 is one-gpu affordable max_seq_length for single image input
626+ model .setup_caches (
627+ batch_size = 1 ,
628+ dtype = self .dtype ,
629+ encoder_max_seq_len = 6404 ,
630+ decoder_max_seq_len = max_seq_length ,
631+ )
632+ else :
633+ model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
634+ if is_speculative and draft_model is not model :
635+ draft_model .setup_caches (
636+ max_batch_size = 1 ,
637+ max_seq_length = max_seq_length ,
638+ )
637639 if model .config .model_type == ModelType .Flamingo :
638640 model .reset_caches ()
639641
@@ -1013,6 +1015,7 @@ def chat(
10131015 )
10141016 for i in range (num_samples ):
10151017 device_sync (device = self .builder_args .device )
1018+ is_first_sample : bool = i == 0
10161019 if generator_args .chat_mode :
10171020 prompt = input ("User: " )
10181021 if prompt == "/bye" :
@@ -1038,7 +1041,7 @@ def chat(
10381041 ]
10391042 )
10401043 self .system_prompt = None
1041- elif i == 0 :
1044+ elif is_first_sample :
10421045 encoded = self .chat_formatter .encode_dialog_prompt (
10431046 [{"role" : "user" , "content" : prompt }]
10441047 )
@@ -1107,6 +1110,7 @@ def callback(x, *, done_generating=False):
11071110 top_k = generator_args .top_k ,
11081111 sequential_prefill = generator_args .sequential_prefill ,
11091112 start_pos = start_pos ,
1113+ skip_cache_setup = not is_first_sample ,
11101114 max_seq_length = max_seq_length ,
11111115 )
11121116 for token_tensor , metrics in generator_func :
@@ -1116,7 +1120,7 @@ def callback(x, *, done_generating=False):
11161120 if metrics is not None :
11171121 aggregate_metrics .update (metrics )
11181122 yield token_tensor , metrics
1119- jit_compile = ( i == 0 ) and (
1123+ jit_compile = is_first_sample and (
11201124 generator_args .compile or generator_args .compile_prefill
11211125 )
11221126 compilation_time = time .perf_counter () - t0
0 commit comments