@@ -614,27 +614,28 @@ def generate(
614614 prompt_length = prompt .size (0 )
615615 max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
616616 # set up caches only if first inference
617- if start_pos == 0 and not skip_cache_setup :
618- model = model .to (device = device )
619- with torch .device (device ):
620- if (
621- self .is_torchtune_model
622- or self .model .config .model_type == ModelType .Flamingo
623- ):
624- # 6404 is one-gpu affordable max_seq_length for single image input
625- model .setup_caches (
626- batch_size = 1 ,
627- dtype = self .dtype ,
628- encoder_max_seq_len = 6404 ,
629- decoder_max_seq_len = max_seq_length ,
630- )
631- else :
632- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
633- if is_speculative and draft_model is not model :
634- draft_model .setup_caches (
635- max_batch_size = 1 ,
636- max_seq_length = max_seq_length ,
637- )
617+ if start_pos == 0 :
618+ if not skip_cache_setup :
619+ model = model .to (device = device )
620+ with torch .device (device ):
621+ if (
622+ self .is_torchtune_model
623+ or self .model .config .model_type == ModelType .Flamingo
624+ ):
625+ # 6404 is one-gpu affordable max_seq_length for single image input
626+ model .setup_caches (
627+ batch_size = 1 ,
628+ dtype = self .dtype ,
629+ encoder_max_seq_len = 6404 ,
630+ decoder_max_seq_len = max_seq_length ,
631+ )
632+ else :
633+ model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
634+ if is_speculative and draft_model is not model :
635+ draft_model .setup_caches (
636+ max_batch_size = 1 ,
637+ max_seq_length = max_seq_length ,
638+ )
638639 if model .config .model_type == ModelType .Flamingo :
639640 model .reset_caches ()
640641
0 commit comments