Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7721be9

Browse files
author
anirudh
committed
fixes from code review
1 parent ae66baf commit 7721be9

File tree

2 files changed

+51
-101
lines changed

2 files changed

+51
-101
lines changed

torchchat/cli/cli.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,6 @@ def _add_model_specification_args(parser) -> None:
137137
help=argparse.SUPPRESS,
138138
)
139139

140-
model_specification_parser.add_argument(
141-
"--modality",
142-
type=str,
143-
default="text",
144-
choices=["text", "text-image"],
145-
# help=argparse.SUPPRESS,
146-
help="Modality of the model. Options: text, text-image",
147-
)
148-
149140

150141
# Add CLI Args related to model configuration (compilation, quant, etc)
151142
# Excludes compile args if subcommand is export
@@ -441,6 +432,14 @@ def _add_evaluation_args(parser) -> None:
441432
help="Maximum length sequence to evaluate",
442433
)
443434

435+
eval_parser.add_argument(
436+
"--modality",
437+
type=str,
438+
default="text",
439+
choices=["text", "text-image"],
440+
help="Modality of the model. Options: text, text-image",
441+
)
442+
444443

445444
# Add CLI Args related to distributed inference
446445
# This feature is currently a [WIP] and hidden from --help

torchchat/usages/eval.py

Lines changed: 43 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7-
from typing import Callable, Dict, List, Optional
7+
from typing import Callable, Dict, List, Optional, Literal
88

99
import torch
1010
import torch._dynamo.config
@@ -184,7 +184,12 @@ def _model_generate(self, context, max_length, eos_token_id):
184184

185185

186186
class VLMEvalWrapper(HFMultimodalLM):
187-
"""An EvalWrapper for EleutherAI's eval harness based on gpt-fast's
187+
"""
188+
This class is adapted from torchtune.
189+
Source: https://github.com/pytorch/torchtune/blob/main/recipes/eleuther_eval.py
190+
-------------------------------------------------------------------------------
191+
192+
An EvalWrapper for EleutherAI's eval harness based on gpt-fast's
188193
EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py.
189194
190195
Note:
@@ -437,6 +442,7 @@ def eval(
437442
max_seq_length: Optional[int] = None,
438443
device: str = "cpu",
439444
is_pte_model: bool = False,
445+
modality: Literal["text", "text-image"] = "text",
440446
) -> dict:
441447
"""
442448
Evaluates a language model on a specified task using the lm-evaluation-harness library.
@@ -447,21 +453,33 @@ def eval(
447453
tasks (Optional[list]): The names of the evaluation tasks to perform.
448454
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
449455
max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
456+
modality (str): The modality of the model. Options: text, text-image
450457
451458
Returns:
452459
eval_results (dict): A dictionary of evaluation results for the specified task(s).
453460
"""
454461
if tasks is None:
455-
tasks = ["wikitext"]
456-
457-
model_eval_wrapper = GPTFastEvalWrapper(
458-
model,
459-
tokenizer,
460-
model_forward=model_forward,
461-
max_seq_length=max_seq_length,
462-
device=device,
463-
is_pte_model=is_pte_model,
464-
)
462+
if modality == "text":
463+
tasks = ["wikitext"]
464+
elif modality == "text-image":
465+
tasks = ["mmmu-val-art"]
466+
467+
if modality == "text":
468+
model_eval_wrapper = GPTFastEvalWrapper(
469+
model,
470+
tokenizer,
471+
model_forward=model_forward,
472+
max_seq_length=max_seq_length,
473+
device=device,
474+
is_pte_model=is_pte_model,
475+
)
476+
elif modality == "text-image":
477+
model_eval_wrapper = VLMEvalWrapper(
478+
model,
479+
transform=tokenizer,
480+
max_seq_length = 4096 if max_seq_length is None else max_seq_length,
481+
device = utils.get_device(device) if isinstance(device, str) else device,
482+
)
465483

466484
try:
467485
lm_eval.tasks.initialize_tasks()
@@ -482,57 +500,6 @@ def eval(
482500
return eval_results
483501

484502

485-
def multi_model_eval(
486-
model: Model,
487-
model_forward: Callable,
488-
tokenizer,
489-
tasks: Optional[list] = None,
490-
limit: Optional[int] = None,
491-
max_seq_length: Optional[int] = None,
492-
device: str = "cpu",
493-
is_pte_model: bool = False,
494-
):
495-
"""
496-
Evaluates a language model on a specified task using the lm-evaluation-harness library.
497-
498-
Args:
499-
model (Model): The pre-trained language model to evaluate.
500-
tokenizer: The tokenizer to use for encoding/decoding text.
501-
tasks (Optional[list]): The names of the evaluation tasks to perform.
502-
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
503-
max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
504-
505-
Returns:
506-
eval_results (dict): A dictionary of evaluation results for the specified task(s).
507-
"""
508-
if tasks is None:
509-
tasks = ["wikitext"]
510-
max_seq_length = 4096 if max_seq_length is None else max_seq_length
511-
device = utils.get_device(device) if isinstance(device, str) else device
512-
513-
model_eval_wrapper = VLMEvalWrapper(
514-
model,
515-
transform=tokenizer, # tranform is the tokenizer for multimodal models
516-
max_seq_length=max_seq_length,
517-
device=device,
518-
)
519-
520-
try:
521-
lm_eval.tasks.initialize_tasks()
522-
except:
523-
pass
524-
525-
task_dict = get_task_dict(tasks)
526-
527-
eval_results = evaluate(
528-
model_eval_wrapper,
529-
task_dict,
530-
limit=limit,
531-
)
532-
eval_results["times"] = model_eval_wrapper.times
533-
return eval_results
534-
535-
536503
def main(args) -> None:
537504
"""Evaluates model on a task from the `lm-evaluation-harness` library.
538505
@@ -553,13 +520,8 @@ def main(args) -> None:
553520
limit = args.limit
554521
compile = args.compile
555522
max_seq_length = args.max_seq_length
523+
modality = args.modality
556524

557-
modality = builder_args.modality
558-
print(f"Modality of model={modality}")
559-
assert modality in [
560-
"text",
561-
"text-image",
562-
], "Only text and text-image modality is supported for evaluation"
563525

564526
print(f"Using device={device}")
565527
set_precision(builder_args.precision)
@@ -588,30 +550,19 @@ def main(args) -> None:
588550
False if device == "cpu" else True
589551
)
590552

553+
591554
with measure_time("Time to run eval: {time:.02f}s."):
592-
if modality == "text":
593-
result = eval(
594-
model.to(device),
595-
model_forward,
596-
tokenizer,
597-
tasks,
598-
limit,
599-
max_seq_length,
600-
device=builder_args.device,
601-
is_pte_model=builder_args.pte_path is not None,
602-
)
603-
elif modality == "text-image":
604-
result = multi_model_eval(
605-
model.to(device),
606-
model_forward,
607-
tokenizer,
608-
tasks,
609-
limit,
610-
max_seq_length,
611-
device=builder_args.device,
612-
)
613-
else:
614-
raise ValueError(f"Unsupported modality: {modality}")
555+
result = eval(
556+
model.to(device),
557+
model_forward,
558+
tokenizer,
559+
tasks,
560+
limit,
561+
max_seq_length,
562+
device=builder_args.device,
563+
is_pte_model=builder_args.pte_path is not None,
564+
modality=modality,
565+
)
615566

616567
times = torch.tensor(result["times"])
617568
print(

0 commit comments

Comments
 (0)