@@ -197,24 +197,11 @@ class VLMEvalWrapper(HFMultimodalLM):
197197 the max number of images in MMMU.
198198 """
199199
200- # Having the imports here allow running other evals without installing torchtune
201- from torchtune import utils
202- from torchtune .data import (
203- format_content_with_images ,
204- left_pad_sequence ,
205- Message ,
206- padded_collate_tiled_images_and_mask ,
207- )
208- from torchtune .generation import generate , sample
209-
210- from torchtune .modules .common_utils import local_kv_cache
211- from torchtune .modules .model_fusion import DeepFusionModel
212- from torchtune .modules .transforms import Transform
213200
214201 def __init__ (
215202 self ,
216- model : DeepFusionModel ,
217- transform : Transform ,
203+ model : Model ,
204+ transform ,
218205 * ,
219206 device : torch .device ,
220207 max_seq_length : int = 4096 ,
@@ -226,6 +213,25 @@ def __init__(
226213 image_tag : str = "<image>" ,
227214 max_images_per_sample : int = 7 ,
228215 ):
216+ # Having the imports here allow running other evals without installing torchtune
217+ from torchtune .utils import batch_to_device
218+ from torchtune .data import (
219+ format_content_with_images ,
220+ left_pad_sequence ,
221+ Message ,
222+ padded_collate_tiled_images_and_mask ,
223+ )
224+ from torchtune .generation import generate , sample
225+ from torchtune .modules .common_utils import local_kv_cache
226+ self .batch_to_device = batch_to_device
227+ self .format_content_with_images = format_content_with_images
228+ self .left_pad_sequence = left_pad_sequence
229+ self .Message = Message
230+ self .padded_collate_tiled_images_and_mask = padded_collate_tiled_images_and_mask
231+ self .generate = generate
232+ self .sample = sample
233+ self .local_kv_cache = local_kv_cache
234+
229235 self ._model = model
230236 self ._transform = transform
231237 self ._device = device
@@ -326,24 +332,24 @@ def tok_batch_multimodal_encode(
326332
327333 # Construct the messages
328334 messages = []
329- content = format_content_with_images (
335+ content = self . format_content_with_images (
330336 text , image_tag = self ._image_tag , images = proper_images
331337 )
332- messages .append (Message (role = "user" , content = content ))
333- messages .append (Message (role = "assistant" , content = "" ))
338+ messages .append (self . Message (role = "user" , content = content ))
339+ messages .append (self . Message (role = "assistant" , content = "" ))
334340
335341 # Transform the messages
336342 tok_batch = self .model_transform ({"messages" : messages }, inference = True )
337343 all_encoded_messages .append (tok_batch )
338344
339345 # Pad the encoded messages
340- tok_batch = padded_collate_tiled_images_and_mask (
346+ tok_batch = self . padded_collate_tiled_images_and_mask (
341347 all_encoded_messages ,
342348 pad_direction = "left" ,
343349 pad_max_images = self ._max_images_per_sample ,
344350 pad_max_tiles = self ._transform .max_num_tiles ,
345351 )
346- utils .batch_to_device (tok_batch , self .device )
352+ self .batch_to_device (tok_batch , self .device )
347353
348354 # Convert the batch to the format expected by the HF
349355 tok_batch ["input_ids" ] = tok_batch .pop ("tokens" )
@@ -398,7 +404,7 @@ def _model_multimodal_generate(
398404
399405 with measure_time (message = None ) as measure :
400406 # 2. Setup KV cache
401- with local_kv_cache (
407+ with self . local_kv_cache (
402408 self .model ,
403409 batch_size = self .batch_size ,
404410 device = self .device ,
@@ -409,7 +415,7 @@ def _model_multimodal_generate(
409415 # 3. Prefill step
410416 generated_tokens = []
411417 logits = self .model (prompt , ** batch )[:, - 1 ]
412- token = sample (logits , temperature = 0.0 , top_k = None )
418+ token = self . sample (logits , temperature = 0.0 , top_k = None )
413419 generated_tokens .append (token .item ())
414420
415421 cache_mask = batch ["encoder_mask" ][:, - 1 :]
@@ -425,7 +431,7 @@ def _model_multimodal_generate(
425431 encoder_mask = cache_mask ,
426432 input_pos = input_pos [None , seq_len ],
427433 )[:, - 1 ]
428- token = sample (logits , temperature = 0.0 , top_k = None )
434+ token = self . sample (logits , temperature = 0.0 , top_k = None )
429435 generated_tokens .append (token .item ())
430436 seq_len += 1
431437 self .times .append (measure .get_time ())
@@ -460,6 +466,7 @@ def eval(
460466 Returns:
461467 eval_results (dict): A dictionary of evaluation results for the specified task(s).
462468 """
469+
463470 if tasks is None :
464471 if modality == "text" :
465472 tasks = ["wikitext" ]
@@ -478,11 +485,14 @@ def eval(
478485 # use eot_token_id as prefix_token_id.
479486 model_eval_wrapper .custom_prefix_token_id = model_eval_wrapper .eot_token_id
480487 elif modality == "text-image" :
488+ from torchtune .utils import get_device
489+ from torchtune .models .llama3_2_vision import llama3_2_vision_transform
490+
481491 model_eval_wrapper = VLMEvalWrapper (
482492 model ,
483- transform = tokenizer ,
493+ transform = llama3_2_vision_transform ( path = str ( tokenizer . tokenizer_path )) ,
484494 max_seq_length = 4096 if max_seq_length is None else max_seq_length ,
485- device = utils . get_device (device ) if isinstance (device , str ) else device ,
495+ device = get_device (device ) if isinstance (device , str ) else device ,
486496 )
487497
488498 try :
@@ -531,6 +541,7 @@ def main(args) -> None:
531541 set_precision (builder_args .precision )
532542
533543 tokenizer = _initialize_tokenizer (tokenizer_args )
544+ tokenizer .tokenizer_path = tokenizer_args .tokenizer_path
534545 builder_args .setup_caches = False
535546 model = _initialize_model (
536547 builder_args ,
0 commit comments