3636from lm_eval .models .hf_vlms import HFMultimodalLM
3737from lm_eval .evaluator import evaluate
3838
39+ from torchtune .modules .common_utils import local_kv_cache
3940from torchtune .modules .model_fusion import DeepFusionModel
4041from torchtune .modules .transforms import Transform
4142from torchtune .data import (
@@ -183,12 +184,7 @@ def _model_generate(self, context, max_length, eos_token_id):
183184 raise Exception ("unimplemented" )
184185
185186
186- # Dummy class which _VLMEvalWrapper can inherit from when the imports don't work
187- # class HFMultimodalLM():
188- # def __init__(self):
189- # return
190-
191- class _VLMEvalWrapper (HFMultimodalLM ):
187+ class VLMEvalWrapper (HFMultimodalLM ):
192188 """An EvalWrapper for EleutherAI's eval harness based on gpt-fast's
193189 EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py.
194190
@@ -234,6 +230,7 @@ def __init__(
234230 self ._enable_kv_cache = True
235231 self ._image_tag = image_tag
236232 self ._max_images_per_sample = max_images_per_sample
233+ self .times = []
237234
238235 @property
239236 def model (self ):
@@ -338,6 +335,7 @@ def tok_batch_multimodal_encode(
338335 all_encoded_messages ,
339336 pad_direction = "left" ,
340337 pad_max_images = self ._max_images_per_sample ,
338+ pad_max_tiles = self ._transform .max_num_tiles ,
341339 )
342340 utils .batch_to_device (tok_batch , self .device )
343341
@@ -376,15 +374,11 @@ def _model_multimodal_generate(
376374 "multimodal generation."
377375 )
378376
379- # 2. Setup KV cache and masks for bsz 1
377+ encoder_max_seq_len = (
378+ self .model_transform .image_seq_len * self ._max_images_per_sample
379+ )
380+ # Setup masks for bsz 1
380381 with self .device :
381- self .model .setup_caches (
382- batch_size = 1 ,
383- dtype = self ._dtype ,
384- encoder_max_seq_len = self .model_transform .image_seq_len
385- * self ._max_images_per_sample ,
386- decoder_max_seq_len = self .max_length ,
387- )
388382 causal_mask = torch .tril (
389383 torch .ones (
390384 size = (self .max_length , self .max_length ),
@@ -396,28 +390,39 @@ def _model_multimodal_generate(
396390 batch ["input_pos" ] = input_pos [None , :seq_len ]
397391 batch ["mask" ] = causal_mask [None , :seq_len ]
398392
399- # 3. Prefill step
400- generated_tokens = []
401- logits = self .model (prompt , ** batch )[:, - 1 ]
402- token = sample (logits , temperature = 0.0 , top_k = None )
403- generated_tokens .append (token .item ())
404-
405- cache_mask = batch ["encoder_mask" ][:, - 1 :]
406-
407- # 4. Continue generating
408- for _ in range (max_length ):
409- if token .item () in self .model_transform .stop_tokens :
410- break
411- logits = self .model (
412- token ,
413- mask = causal_mask [None , seq_len , None , :],
414- encoder_input = None ,
415- encoder_mask = cache_mask ,
416- input_pos = input_pos [None , seq_len ],
417- )[:, - 1 ]
418- token = sample (logits , temperature = 0.0 , top_k = None )
419- generated_tokens .append (token .item ())
420- seq_len += 1
393+ with measure_time (message = None ) as measure :
394+ # 2. Setup KV cache
395+ with local_kv_cache (
396+ self .model ,
397+ batch_size = self .batch_size ,
398+ device = self .device ,
399+ dtype = self ._dtype ,
400+ encoder_max_seq_len = encoder_max_seq_len ,
401+ decoder_max_seq_len = self .max_length ,
402+ ):
403+ # 3. Prefill step
404+ generated_tokens = []
405+ logits = self .model (prompt , ** batch )[:, - 1 ]
406+ token = sample (logits , temperature = 0.0 , top_k = None )
407+ generated_tokens .append (token .item ())
408+
409+ cache_mask = batch ["encoder_mask" ][:, - 1 :]
410+
411+ # 4. Continue generating
412+ for _ in range (max_length ):
413+ if token .item () in self .model_transform .stop_tokens :
414+ break
415+ logits = self .model (
416+ token ,
417+ mask = causal_mask [None , seq_len , None , :],
418+ encoder_input = None ,
419+ encoder_mask = cache_mask ,
420+ input_pos = input_pos [None , seq_len ],
421+ )[:, - 1 ]
422+ token = sample (logits , temperature = 0.0 , top_k = None )
423+ generated_tokens .append (token .item ())
424+ seq_len += 1
425+ self .times .append (measure .get_time ())
421426
422427 # 5. Return generated tokens
423428 return torch .tensor (generated_tokens , dtype = torch .int32 ).unsqueeze (0 )
@@ -506,7 +511,7 @@ def multi_model_eval(
506511 max_seq_length = 4096 if max_seq_length is None else max_seq_length
507512 device = utils .get_device (device ) if isinstance (device , str ) else device
508513
509- model_eval_wrapper = _VLMEvalWrapper (
514+ model_eval_wrapper = VLMEvalWrapper (
510515 model ,
511516 transform = tokenizer , # tranform is the tokenizer for multimodal models
512517 max_seq_length = max_seq_length ,
0 commit comments