@@ -591,6 +591,7 @@ def generate(
591591 Dict [str , Any ]
592592 ] = None , # List of Image prompt tensors for multimodal models
593593 start_pos : int = 0 ,
594+ skip_cache_setup : bool = False ,
594595 draft_model : Model ,
595596 speculate_k : Optional [int ] = 8 ,
596597 sequential_prefill = True ,
@@ -614,26 +615,27 @@ def generate(
614615 max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
615616 # set up caches only if first inference
616617 if start_pos == 0 :
617- model = model .to (device = device )
618- with torch .device (device ):
619- if (
620- self .is_torchtune_model
621- or self .model .config .model_type == ModelType .Flamingo
622- ):
623- # 6404 is one-gpu affordable max_seq_length for single image input
624- model .setup_caches (
625- batch_size = 1 ,
626- dtype = self .dtype ,
627- encoder_max_seq_len = 6404 ,
628- decoder_max_seq_len = max_seq_length ,
629- )
630- else :
631- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
632- if is_speculative and draft_model is not model :
633- draft_model .setup_caches (
634- max_batch_size = 1 ,
635- max_seq_length = max_seq_length ,
636- )
618+ if not skip_cache_setup :
619+ model = model .to (device = device )
620+ with torch .device (device ):
621+ if (
622+ self .is_torchtune_model
623+ or self .model .config .model_type == ModelType .Flamingo
624+ ):
625+ # 6404 is one-gpu affordable max_seq_length for single image input
626+ model .setup_caches (
627+ batch_size = 1 ,
628+ dtype = self .dtype ,
629+ encoder_max_seq_len = 6404 ,
630+ decoder_max_seq_len = max_seq_length ,
631+ )
632+ else :
633+ model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
634+ if is_speculative and draft_model is not model :
635+ draft_model .setup_caches (
636+ max_batch_size = 1 ,
637+ max_seq_length = max_seq_length ,
638+ )
637639 if model .config .model_type == ModelType .Flamingo :
638640 model .reset_caches ()
639641
@@ -915,13 +917,6 @@ def chat(
915917 ]
916918 )
917919 if generator_args .compile :
918- if (
919- self .is_speculative and self .builder_args .use_distributed
920- ): # and ("cuda" in builder_args.device):
921- torch ._inductor .config .triton .cudagraph_trees = (
922- False # Bug with cudagraph trees in this case
923- )
924-
925920 if self .builder_args .device == "cpu" :
926921 if generator_args .max_autotune :
927922 kwargs = {"mode" : "max-autotune" }
@@ -1020,6 +1015,7 @@ def chat(
10201015 )
10211016 for i in range (num_samples ):
10221017 device_sync (device = self .builder_args .device )
1018+ is_first_sample : bool = i == 0
10231019 if generator_args .chat_mode :
10241020 prompt = input ("User: " )
10251021 if prompt == "/bye" :
@@ -1045,7 +1041,7 @@ def chat(
10451041 ]
10461042 )
10471043 self .system_prompt = None
1048- elif i == 0 :
1044+ elif is_first_sample :
10491045 encoded = self .chat_formatter .encode_dialog_prompt (
10501046 [{"role" : "user" , "content" : prompt }]
10511047 )
@@ -1091,9 +1087,7 @@ def callback(x, *, done_generating=False):
10911087
10921088 torch ._inductor .config .profiler_mark_wrapper_call = True
10931089 torch ._inductor .config .cpp .enable_kernel_profile = True
1094- if (i != generator_args .num_samples - 1 or not self .profile ) or (
1095- self .builder_args .use_distributed and self .rank != 0
1096- ):
1090+ if i != generator_args .num_samples - 1 or not self .profile :
10971091 import contextlib
10981092
10991093 prof = contextlib .nullcontext ()
@@ -1116,6 +1110,7 @@ def callback(x, *, done_generating=False):
11161110 top_k = generator_args .top_k ,
11171111 sequential_prefill = generator_args .sequential_prefill ,
11181112 start_pos = start_pos ,
1113+ skip_cache_setup = not is_first_sample ,
11191114 max_seq_length = max_seq_length ,
11201115 )
11211116 for token_tensor , metrics in generator_func :
@@ -1125,7 +1120,7 @@ def callback(x, *, done_generating=False):
11251120 if metrics is not None :
11261121 aggregate_metrics .update (metrics )
11271122 yield token_tensor , metrics
1128- jit_compile = ( i == 0 ) and (
1123+ jit_compile = is_first_sample and (
11291124 generator_args .compile or generator_args .compile_prefill
11301125 )
11311126 compilation_time = time .perf_counter () - t0
@@ -1136,10 +1131,7 @@ def callback(x, *, done_generating=False):
11361131 print (prof .key_averages ().table (sort_by = "self_cpu_time_total" ))
11371132 else :
11381133 print (prof .key_averages ().table (sort_by = "self_cuda_time_total" ))
1139- if self .builder_args .use_distributed :
1140- prof .export_chrome_trace (f"{ self .profile } _rank_{ self .rank } .json" )
1141- else :
1142- prof .export_chrome_trace (f"{ self .profile } .json" )
1134+ prof .export_chrome_trace (f"{ self .profile } .json" )
11431135
11441136 if start_pos >= max_seq_length :
11451137 print (
0 commit comments