This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Couldn't load subscription status.
- Fork 249
Added support for Multimodal eval #1499
Merged
Merged
Changes from 7 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
2aa67b4
[wip] Added cli args and other changes to eval multi-modal models
78bdacf
remove redundant comment
bfc62dc
Added Llama3VisionTransform in TokenizerArgs and other changes
8900f8a
use kv caching and other minor fixes
59ce657
default batch size 1
afdb3ce
lint eval.py and builder.py
ae66baf
lm-eval 0.4.2->0.4.7 in install_requirements.sh
7721be9
fixes from code review
e9c0d34
Merge branch 'main' into multimodal-eval-2
Jack-Khuu 96ab799
remove modality from builder args
1e609d8
Merge branch 'main' into multimodal-eval-2
Jack-Khuu 51b0e83
use custom prefix token
842be23
Merge branch 'main' into multimodal-eval-2
anirudhs001 51135fd
move torchtune imports inside VLMEvalWrapper
14502bf
revert changes from builder.py
815966c
instantiate transform in eval()
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,4 +34,4 @@ streamlit | |
| flask | ||
|
|
||
| # eval | ||
| lm_eval==0.4.2 | ||
| lm_eval==0.4.7 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,13 +16,22 @@ | |
| import torch._inductor.config | ||
| import torch.distributed as dist | ||
|
|
||
| from torchchat.distributed.utils import( | ||
| from torchtune.models.convert_weights import meta_to_tune | ||
|
|
||
| from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
|
|
||
| from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
|
||
| from torchtune.training import set_default_dtype | ||
|
|
||
| from torchchat.distributed.logging_utils import SingletonLogger | ||
|
|
||
| from torchchat.distributed.utils import ( | ||
| Color as color, | ||
| CUDATrackTime, | ||
| init_distributed, | ||
| GPUMemoryMonitor, | ||
| init_distributed, | ||
| ) | ||
| from torchchat.distributed.logging_utils import SingletonLogger | ||
|
|
||
| from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs | ||
| from torchchat.model_config.model_config import resolve_model_config | ||
|
|
@@ -36,15 +45,6 @@ | |
| from torchchat.utils.quantize import quantize_model | ||
|
|
||
|
|
||
| from torchtune.models.convert_weights import meta_to_tune | ||
|
|
||
| from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
|
|
||
| from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
|
||
| from torchtune.training import set_default_dtype | ||
|
|
||
|
|
||
| @dataclass | ||
| class BuilderArgs: | ||
| checkpoint_path: Optional[Union[Path, str]] = None | ||
|
|
@@ -71,6 +71,7 @@ class BuilderArgs: | |
| dynamic_shapes: bool = False | ||
| max_seq_length: Optional[int] = None | ||
| attention_backend: str = "math" | ||
| modality: Optional[str] = "text" | ||
|
||
|
|
||
| def __post_init__(self): | ||
| if self.device is None: | ||
|
|
@@ -146,6 +147,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
| aoti_package_path = getattr(args, "aoti_package_path", None) | ||
| snapshot_path = getattr(args, "snapshot_path", None) | ||
|
|
||
| modality = "text" | ||
| if args.modality: | ||
| modality = args.modality | ||
|
|
||
| is_chat_model = False | ||
| if args.is_chat_model: | ||
| is_chat_model = True | ||
|
|
@@ -189,15 +194,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
| tp = getattr(args, "tp", 1) | ||
| chpt_from = getattr(args, "chpt_from", "hf") | ||
| sdp_backend_dict = { | ||
| 'math': torch.nn.attention.SDPBackend.MATH, | ||
| 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, | ||
| 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, | ||
| 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, | ||
| "math": torch.nn.attention.SDPBackend.MATH, | ||
| "flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, | ||
| "efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, | ||
| "cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, | ||
| } | ||
| attention_backend = sdp_backend_dict[args.attention_backend] | ||
| if args.device == "cpu" and (args.attention_backend == "efficient_attention" | ||
| or args.attention_backend == "cudnn_attention"): | ||
| print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") | ||
| if args.device == "cpu" and ( | ||
| args.attention_backend == "efficient_attention" | ||
| or args.attention_backend == "cudnn_attention" | ||
| ): | ||
| print( | ||
| f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." | ||
| ) | ||
| attention_backend = torch.nn.attention.SDPBackend.MATH | ||
| return cls( | ||
| checkpoint_dir=checkpoint_dir, | ||
|
|
@@ -222,6 +231,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
| chpt_from=chpt_from, | ||
| distribution_path=distribution_path, | ||
| is_chat_model=is_chat_model, | ||
| modality=modality, | ||
| dynamic_shapes=getattr(args, "dynamic_shapes", False), | ||
| max_seq_length=getattr(args, "max_seq_length", None), | ||
| attention_backend=attention_backend, | ||
|
|
@@ -246,13 +256,29 @@ class TokenizerArgs: | |
| is_sentencepiece: bool = False | ||
| is_tiktoken: bool = False | ||
| is_hf_tokenizer: bool = False | ||
| is_llama_3_2_mm: bool = False | ||
| t: Optional[Any] = None | ||
|
|
||
| def __post_init__(self): | ||
| # special handling for llama-3.2-mm | ||
| if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower(): | ||
| try: | ||
| from torchtune.models.llama3_2_vision import llama3_2_vision_transform | ||
|
|
||
| self.t = llama3_2_vision_transform(path=str(self.tokenizer_path)) | ||
| self.is_llama_3_2_mm = True | ||
| self.is_tiktoken = False | ||
| self.is_sentencepiece = False | ||
| self.is_hf_tokenizer = False | ||
| return | ||
| except: | ||
| pass | ||
|
|
||
| try: | ||
| from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer | ||
|
|
||
| self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) | ||
| self.is_llama_3_2_mm = False | ||
| self.is_tiktoken = True | ||
| self.is_sentencepiece = False | ||
| self.is_hf_tokenizer = False | ||
|
|
@@ -264,6 +290,7 @@ def __post_init__(self): | |
| from sentencepiece import SentencePieceProcessor | ||
|
|
||
| self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) | ||
| self.is_llama_3_2_mm = False | ||
| self.is_tiktoken = False | ||
| self.is_sentencepiece = True | ||
| self.is_hf_tokenizer = False | ||
|
|
@@ -275,13 +302,15 @@ def __post_init__(self): | |
| from tokenizer.hf_tokenizer import HFTokenizer | ||
|
|
||
| self.t = HFTokenizer(str(self.tokenizer_path)) | ||
| self.is_llama_3_2_mm = False | ||
| self.is_tiktoken = False | ||
| self.is_sentencepiece = False | ||
| self.is_hf_tokenizer = True | ||
| return | ||
| except: | ||
| pass | ||
|
|
||
| self.is_llama_3_2_mm = False | ||
| self.is_tiktoken = False | ||
| self.is_sentencepiece = False | ||
| self.is_hf_tokenizer = False | ||
|
|
@@ -296,20 +325,32 @@ def validate_model( | |
| if model is None: | ||
| return | ||
|
|
||
| if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: | ||
| if ( | ||
| sum( | ||
| [ | ||
| self.is_tiktoken, | ||
| self.is_hf_tokenizer, | ||
| self.is_sentencepiece, | ||
| self.is_llama_3_2_mm, | ||
| ] | ||
| ) | ||
| != 1 | ||
| ): | ||
| raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") | ||
|
|
||
| is_tiktoken = self.is_tiktoken | ||
| is_sentencepiece = self.is_sentencepiece | ||
| is_hf_tokenizer = self.is_hf_tokenizer | ||
| is_llama_3_2_mm = self.is_llama_3_2_mm | ||
|
|
||
| use_tiktoken = model.config.use_tiktoken | ||
| use_hf_tokenizer = model.config.use_hf_tokenizer | ||
| use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) | ||
|
|
||
| use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer) | ||
| if ( | ||
| (is_tiktoken and not use_tiktoken) or | ||
| (is_hf_tokenizer and not use_hf_tokenizer) or | ||
| (is_sentencepiece and not use_sentencepiece) | ||
| (is_tiktoken and not use_tiktoken) | ||
| or (is_hf_tokenizer and not use_hf_tokenizer) | ||
| or (is_sentencepiece and not use_other_tokenizer) | ||
| or (is_llama_3_2_mm and not use_other_tokenizer) | ||
| ): | ||
| raise RuntimeError( | ||
| "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( | ||
|
|
@@ -507,6 +548,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: | |
| # AOTI-compoiled model will load its own weights. | ||
| # Release weights here to avoid OOM | ||
| import gc | ||
|
|
||
| if hasattr(model, "model"): | ||
| model.model = None | ||
| gc.collect() | ||
|
|
@@ -564,6 +606,7 @@ def _initialize_model( | |
|
|
||
| def do_nothing(max_batch_size, max_seq_length): | ||
| pass | ||
|
|
||
| model.setup_caches = do_nothing | ||
|
|
||
| model.forward = torch._export.aot_load( | ||
|
|
@@ -601,6 +644,7 @@ def do_nothing(max_batch_size, max_seq_length): | |
|
|
||
| def do_nothing(max_batch_size, max_seq_length): | ||
| pass | ||
|
|
||
| model.setup_caches = do_nothing | ||
|
|
||
| model.forward = aoti_compiled_model | ||
|
|
@@ -675,7 +719,9 @@ def do_nothing(max_batch_size, max_seq_length): | |
| logger = SingletonLogger.get_logger() | ||
|
|
||
| gpu_memory_monitor = GPUMemoryMonitor("cuda") | ||
| logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") | ||
| logger.info( | ||
| f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" | ||
| ) | ||
|
|
||
| # Model-level config | ||
| if builder_args.params_table: | ||
|
|
@@ -686,20 +732,16 @@ def do_nothing(max_batch_size, max_seq_length): | |
| config = TransformerArgs.from_params(model_config.transformer_args["text"]) | ||
| logger.info(f"Transformer Config: {config}") | ||
|
|
||
| #TODO: Move into head of file after solving circular import | ||
| from torchchat.distributed.checkpoint_utils import ( | ||
| load_model_weights, | ||
| ) | ||
| # TODO: Move into head of file after solving circular import | ||
| from torchchat.distributed.checkpoint_utils import load_model_weights | ||
|
|
||
| # Validate pipeline degree | ||
| assert config.n_layers % pp_degree == 0 | ||
|
|
||
| # Create device mesh | ||
| device_mesh = dist.init_device_mesh( | ||
| "cuda", | ||
| (pp_degree, tp_degree), | ||
| mesh_dim_names=("pp", "tp") | ||
| ) | ||
| "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") | ||
| ) | ||
| tp_mesh = device_mesh["tp"] | ||
| pp_mesh = device_mesh["pp"] | ||
| logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") | ||
|
|
@@ -728,7 +770,13 @@ def do_nothing(max_batch_size, max_seq_length): | |
| # Load weights | ||
| logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
| with CUDATrackTime() as timer: | ||
| load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) | ||
| load_model_weights( | ||
| model, | ||
| builder_args.distribution_path, | ||
| device, | ||
| config, | ||
| builder_args.chpt_from, | ||
| ) | ||
|
|
||
| logger.info( | ||
| f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
|
|
@@ -742,7 +790,7 @@ def do_nothing(max_batch_size, max_seq_length): | |
| # lanes. | ||
| # TODO: bump up the lane count | ||
| pipeline_lanes = 1 | ||
| seqlen_prefill=1024 | ||
| seqlen_prefill = 1024 | ||
| with device: | ||
| model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -137,6 +137,15 @@ def _add_model_specification_args(parser) -> None: | |||
| help=argparse.SUPPRESS, | ||||
| ) | ||||
|
|
||||
| model_specification_parser.add_argument( | ||||
| "--modality", | ||||
| type=str, | ||||
| default="text", | ||||
| choices=["text", "text-image"], | ||||
| # help=argparse.SUPPRESS, | ||||
|
||||
| # help=argparse.SUPPRESS, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this arg is only used for evaluation, let's bump it into _add_evaluation_args() below
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beyond the scope of this PR, but the duplicated requirements in here vs requirements.txt will be collapsed when we introduce packaging