@@ -591,6 +591,7 @@ def generate(
591591 Dict [str , Any ]
592592 ] = None , # List of Image prompt tensors for multimodal models
593593 start_pos : int = 0 ,
594+ skip_cache_setup : bool = False ,
594595 draft_model : Model ,
595596 speculate_k : Optional [int ] = 8 ,
596597 sequential_prefill = True ,
@@ -613,7 +614,7 @@ def generate(
613614 prompt_length = prompt .size (0 )
614615 max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
615616 # set up caches only if first inference
616- if start_pos == 0 :
617+ if start_pos == 0 and not skip_cache_setup :
617618 model = model .to (device = device )
618619 with torch .device (device ):
619620 if (
@@ -1020,6 +1021,7 @@ def chat(
10201021 )
10211022 for i in range (num_samples ):
10221023 device_sync (device = self .builder_args .device )
1024+ is_first_sample : bool = i == 0
10231025 if generator_args .chat_mode :
10241026 prompt = input ("User: " )
10251027 if prompt == "/bye" :
@@ -1045,7 +1047,7 @@ def chat(
10451047 ]
10461048 )
10471049 self .system_prompt = None
1048- elif i == 0 :
1050+ elif is_first_sample :
10491051 encoded = self .chat_formatter .encode_dialog_prompt (
10501052 [{"role" : "user" , "content" : prompt }]
10511053 )
@@ -1116,6 +1118,7 @@ def callback(x, *, done_generating=False):
11161118 top_k = generator_args .top_k ,
11171119 sequential_prefill = generator_args .sequential_prefill ,
11181120 start_pos = start_pos ,
1121+ skip_cache_setup = not is_first_sample ,
11191122 max_seq_length = max_seq_length ,
11201123 )
11211124 for token_tensor , metrics in generator_func :
@@ -1125,7 +1128,7 @@ def callback(x, *, done_generating=False):
11251128 if metrics is not None :
11261129 aggregate_metrics .update (metrics )
11271130 yield token_tensor , metrics
1128- jit_compile = ( i == 0 ) and (
1131+ jit_compile = is_first_sample and (
11291132 generator_args .compile or generator_args .compile_prefill
11301133 )
11311134 compilation_time = time .perf_counter () - t0
0 commit comments