44# This source code is licensed under the license found in the 
55# LICENSE file in the root directory of this source tree. 
66import  argparse 
7- from  typing  import  Callable , Dict , List , Optional 
7+ from  typing  import  Callable , Dict , List , Optional ,  Literal 
88
99import  torch 
1010import  torch ._dynamo .config 
@@ -184,7 +184,12 @@ def _model_generate(self, context, max_length, eos_token_id):
184184
185185
186186class  VLMEvalWrapper (HFMultimodalLM ):
187-     """An EvalWrapper for EleutherAI's eval harness based on gpt-fast's 
187+     """ 
188+     This class is adapted from torchtune. 
189+     Source: https://github.com/pytorch/torchtune/blob/main/recipes/eleuther_eval.py 
190+     ------------------------------------------------------------------------------- 
191+ 
192+     An EvalWrapper for EleutherAI's eval harness based on gpt-fast's 
188193    EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py. 
189194
190195    Note: 
@@ -437,6 +442,7 @@ def eval(
437442    max_seq_length : Optional [int ] =  None ,
438443    device : str  =  "cpu" ,
439444    is_pte_model : bool  =  False ,
445+     modality : Literal ["text" , "text-image" ] =  "text" ,
440446) ->  dict :
441447    """ 
442448    Evaluates a language model on a specified task using the lm-evaluation-harness library. 
@@ -447,21 +453,33 @@ def eval(
447453        tasks (Optional[list]): The names of the evaluation tasks to perform. 
448454        limit (Optional[int]): The maximum number of samples to evaluate (None for all available). 
449455        max_seq_length (Optional[int]): The maximum sequence length allowed for input text. 
456+         modality (str): The modality of the model. Options: text, text-image 
450457
451458    Returns: 
452459        eval_results (dict): A dictionary of evaluation results for the specified task(s). 
453460    """ 
454461    if  tasks  is  None :
455-         tasks  =  ["wikitext" ]
456- 
457-     model_eval_wrapper  =  GPTFastEvalWrapper (
458-         model ,
459-         tokenizer ,
460-         model_forward = model_forward ,
461-         max_seq_length = max_seq_length ,
462-         device = device ,
463-         is_pte_model = is_pte_model ,
464-     )
462+         if  modality  ==  "text" :
463+             tasks  =  ["wikitext" ]
464+         elif  modality  ==  "text-image" :
465+             tasks  =  ["mmmu-val-art" ]
466+ 
467+     if  modality  ==  "text" :
468+         model_eval_wrapper  =  GPTFastEvalWrapper (
469+             model ,
470+             tokenizer ,
471+             model_forward = model_forward ,
472+             max_seq_length = max_seq_length ,
473+             device = device ,
474+             is_pte_model = is_pte_model ,
475+         )
476+     elif  modality  ==  "text-image" :
477+         model_eval_wrapper  =  VLMEvalWrapper (
478+             model ,
479+             transform = tokenizer , 
480+             max_seq_length  =  4096  if  max_seq_length  is  None  else  max_seq_length ,
481+             device  =  utils .get_device (device ) if  isinstance (device , str ) else  device ,
482+         )
465483
466484    try :
467485        lm_eval .tasks .initialize_tasks ()
@@ -482,57 +500,6 @@ def eval(
482500    return  eval_results 
483501
484502
485- def  multi_model_eval (
486-     model : Model ,
487-     model_forward : Callable ,
488-     tokenizer ,
489-     tasks : Optional [list ] =  None ,
490-     limit : Optional [int ] =  None ,
491-     max_seq_length : Optional [int ] =  None ,
492-     device : str  =  "cpu" ,
493-     is_pte_model : bool  =  False ,
494- ):
495-     """ 
496-     Evaluates a language model on a specified task using the lm-evaluation-harness library. 
497- 
498-     Args: 
499-         model (Model): The pre-trained language model to evaluate. 
500-         tokenizer: The tokenizer to use for encoding/decoding text. 
501-         tasks (Optional[list]): The names of the evaluation tasks to perform. 
502-         limit (Optional[int]): The maximum number of samples to evaluate (None for all available). 
503-         max_seq_length (Optional[int]): The maximum sequence length allowed for input text. 
504- 
505-     Returns: 
506-         eval_results (dict): A dictionary of evaluation results for the specified task(s). 
507-     """ 
508-     if  tasks  is  None :
509-         tasks  =  ["wikitext" ]
510-     max_seq_length  =  4096  if  max_seq_length  is  None  else  max_seq_length 
511-     device  =  utils .get_device (device ) if  isinstance (device , str ) else  device 
512- 
513-     model_eval_wrapper  =  VLMEvalWrapper (
514-         model ,
515-         transform = tokenizer ,  # tranform is the tokenizer for multimodal models 
516-         max_seq_length = max_seq_length ,
517-         device = device ,
518-     )
519- 
520-     try :
521-         lm_eval .tasks .initialize_tasks ()
522-     except :
523-         pass 
524- 
525-     task_dict  =  get_task_dict (tasks )
526- 
527-     eval_results  =  evaluate (
528-         model_eval_wrapper ,
529-         task_dict ,
530-         limit = limit ,
531-     )
532-     eval_results ["times" ] =  model_eval_wrapper .times 
533-     return  eval_results 
534- 
535- 
536503def  main (args ) ->  None :
537504    """Evaluates model on a task from the `lm-evaluation-harness` library. 
538505
@@ -553,13 +520,8 @@ def main(args) -> None:
553520    limit  =  args .limit 
554521    compile  =  args .compile 
555522    max_seq_length  =  args .max_seq_length 
523+     modality  =  args .modality 
556524
557-     modality  =  builder_args .modality 
558-     print (f"Modality of model={ modality }  )
559-     assert  modality  in  [
560-         "text" ,
561-         "text-image" ,
562-     ], "Only text and text-image modality is supported for evaluation" 
563525
564526    print (f"Using device={ device }  )
565527    set_precision (builder_args .precision )
@@ -588,30 +550,19 @@ def main(args) -> None:
588550            False  if  device  ==  "cpu"  else  True 
589551        )
590552
553+ 
591554    with  measure_time ("Time to run eval: {time:.02f}s." ):
592-         if  modality  ==  "text" :
593-             result  =  eval (
594-                 model .to (device ),
595-                 model_forward ,
596-                 tokenizer ,
597-                 tasks ,
598-                 limit ,
599-                 max_seq_length ,
600-                 device = builder_args .device ,
601-                 is_pte_model = builder_args .pte_path  is  not None ,
602-             )
603-         elif  modality  ==  "text-image" :
604-             result  =  multi_model_eval (
605-                 model .to (device ),
606-                 model_forward ,
607-                 tokenizer ,
608-                 tasks ,
609-                 limit ,
610-                 max_seq_length ,
611-                 device = builder_args .device ,
612-             )
613-         else :
614-             raise  ValueError (f"Unsupported modality: { modality }  )
555+         result  =  eval (
556+             model .to (device ),
557+             model_forward ,
558+             tokenizer ,
559+             tasks ,
560+             limit ,
561+             max_seq_length ,
562+             device = builder_args .device ,
563+             is_pte_model = builder_args .pte_path  is  not None ,
564+             modality = modality ,
565+         )
615566
616567    times  =  torch .tensor (result ["times" ])
617568    print (
0 commit comments