Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ jobs:
if: (success() || failure()) && steps.install.outcome == 'success'
run: |
pytest -v tests/llmcompressor/transformers/obcq
- name: Running Tracing Tests
if: (success() || failure()) && steps.install.outcome == 'success'
run: |
pytest -v tests/llmcompressor/transformers/tracing
14 changes: 7 additions & 7 deletions src/llmcompressor/transformers/finetune/data/peoples_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ class PeoplesSpeech(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "MLCommons/peoples_speech"
data_args.dataset_config_name = "test"
if not data_args.overwrite_cache:
def __init__(self, dataset_args: "DataArgs", split: str, processor: Processor):
dataset_args = deepcopy(dataset_args)
dataset_args.dataset = "MLCommons/peoples_speech"
dataset_args.dataset_config_name = "test"
if not dataset_args.overwrite_cache:
logger.warning(
"Because audio processors are more complex, dataset mapping functions "
"vary with model architecture and their results cannot be cached. "
"Setting overwrite_cache=True"
)
data_args.overwrite_cache = True
dataset_args.overwrite_cache = True
self.processor_type = processor.__class__.__name__

super().__init__(data_args=data_args, split=split, processor=processor)
super().__init__(dataset_args=dataset_args, split=split, processor=processor)

def dataset_template(self, example):
audio = example["audio"]["array"]
Expand Down
20 changes: 15 additions & 5 deletions src/llmcompressor/transformers/tracing/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from llmcompressor.transformers import TextGenerationDataset
from llmcompressor.args import DatasetArguments

from llmcompressor.utils.dev import skip_weights_download

__all__ = [
"get_model_class"
]
Expand All @@ -24,6 +26,7 @@ def parse_args():
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501
return parser.parse_args()


Expand All @@ -33,6 +36,7 @@ def trace(
sequential_targets: Optional[Union[List[str], str]] = None,
ignore: Union[List[str], str] = [],
modality: str = "text",
trust_remote_code: bool = True
):
"""
Debug traceability by tracing a pre-trained model into subgraphs
Expand All @@ -44,6 +48,7 @@ def trace(
inference
:param ignore: patterns to ignore during tracing
:param modality: data modality for dummy tracing data, defaults to 'text'
:param trust_remote_code: trust remote model code

Example usage from CLI
llmcompressor.trace \
Expand All @@ -54,12 +59,16 @@ def trace(
--modality text
"""
# Load model
model = model_class.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
with skip_weights_download(model_class):
model = model_class.from_pretrained(
model_id,
device_map="cpu",
torch_dtype="auto",
trust_remote_code=trust_remote_code,
)
processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print("Loaded model")

# Prepare sample data
Expand Down Expand Up @@ -138,6 +147,7 @@ def main():
sequential_targets=args.sequential_targets,
ignore=args.ignore,
modality=args.modality,
trust_remote_code=args.trust_remote_code
)


Expand Down
4 changes: 3 additions & 1 deletion src/llmcompressor/transformers/tracing/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(self, config: Idefics3Config):

def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
Expand All @@ -296,6 +296,7 @@ def forward(
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -394,6 +395,7 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
)

Expand Down
101 changes: 101 additions & 0 deletions src/llmcompressor/utils/dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import contextlib
import logging
import os
import tempfile
from typing import Type

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME

from llmcompressor.utils.helpers import patch_attr

__all__ = ["skip_weights_download", "patch_transformers_logger_level"]


@contextlib.contextmanager
def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM):
"""
Context manager under which models are initialized without having to download
the model weight files. This differs from `init_empty_weights` in that weights are
allocated on to assigned devices with random values, as opposed to being on the meta
device

:param model_class: class to patch, defaults to `AutoModelForCausalLM`
"""
original_fn = model_class.from_pretrained
weights_files = [
"*.bin",
"*.safetensors",
"*.pth",
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
"*.msgpack",
]

@classmethod
def patched(cls, *args, **kwargs):
nonlocal tmp_dir

# intercept model stub
model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path")

# download files into tmp dir
os.makedirs(tmp_dir, exist_ok=True)
snapshot_download(
repo_id=model_stub, local_dir=tmp_dir, ignore_patterns=weights_files
)

# make an empty weights file to avoid errors
weights_file_path = os.path.join(tmp_dir, "model.safetensors")
save_file({}, weights_file_path, metadata={"format": "pt"})

# load from tmp dir
model = original_fn(tmp_dir, **kwargs)

# replace model_path
model.name_or_path = model_stub
model.config._name_or_path = model_stub

return model

with tempfile.TemporaryDirectory() as tmp_dir, patch_attr(
model_class, "from_pretrained", patched
), skip_weights_initialize(), patch_transformers_logger_level():
yield


@contextlib.contextmanager
def skip_weights_initialize(use_zeros: bool = False):
def skip(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if use_zeros:
return tensor.fill_(0)
return tensor

with contextlib.ExitStack() as stack:
for name in TORCH_INIT_FUNCTIONS.keys():
stack.enter_context(patch_attr(torch.nn.init, name, skip))
stack.enter_context(patch_attr(torch.Tensor, name, skip))
yield


@contextlib.contextmanager
def patch_transformers_logger_level(level: int = logging.ERROR):
"""
Context under which the transformers logger's level is modified

This can be used with `skip_weights_download` to squelch warnings related to
missing parameters in the checkpoint

:param level: new logging level for transformers logger. Logs whose level is below
this level will not be logged
"""
transformers_logger = logging.getLogger("transformers.modeling_utils")
restore_log_level = transformers_logger.getEffectiveLevel()

transformers_logger.setLevel(level=level)
yield
transformers_logger.setLevel(level=restore_log_level)
98 changes: 98 additions & 0 deletions tests/llmcompressor/transformers/tracing/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytest
from transformers import AutoModelForCausalLM

from llmcompressor.transformers.tracing import (
TraceableIdefics3ForConditionalGeneration,
TraceableLlavaForConditionalGeneration,
TraceableMllamaForConditionalGeneration,
TraceableQwen2_5_VLForConditionalGeneration,
TraceableQwen2VLForConditionalGeneration,
TraceableWhisperForConditionalGeneration,
)
from llmcompressor.transformers.tracing.debug import trace


@pytest.mark.parametrize(
"model_id,model_class,targets",
[
("meta-llama/Meta-Llama-3-8B-Instruct", AutoModelForCausalLM, None),
],
)
def test_text_trace(model_id, model_class, targets):
trace(
model_id,
model_class,
targets,
ignore=[],
modality="text",
trust_remote_code=True,
)


@pytest.mark.parametrize(
"model_id,model_class,targets,ignore",
[
(
"Qwen/Qwen2-VL-2B-Instruct",
TraceableQwen2VLForConditionalGeneration,
None,
["lm_head", "re:visual.*"],
),
(
"Qwen/Qwen2.5-VL-7B-Instruct",
TraceableQwen2_5_VLForConditionalGeneration,
None,
["lm_head", "re:visual.*"],
),
(
"mgoin/pixtral-12b",
TraceableLlavaForConditionalGeneration,
["MistralDecoderLayer"],
["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
),
(
"meta-llama/Llama-3.2-11B-Vision-Instruct",
TraceableMllamaForConditionalGeneration,
None,
["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"],
),
(
"llava-hf/llava-1.5-7b-hf",
TraceableLlavaForConditionalGeneration,
["LlamaDecoderLayer"],
["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
),
(
"HuggingFaceM4/Idefics3-8B-Llama3",
TraceableIdefics3ForConditionalGeneration,
["Idefics3EncoderLayer", "LlamaDecoderLayer"],
["re:.*lm_head", "re:model.vision_model.*", "re:model.connector.*"],
),
],
)
def test_vision_trace(model_id, model_class, targets, ignore):
trace(
model_id,
model_class,
targets,
ignore=ignore,
modality="vision",
trust_remote_code=True,
)


@pytest.mark.parametrize(
"model_id,model_class,targets,ignore",
[
("openai/whisper-large-v3", TraceableWhisperForConditionalGeneration, None, []),
],
)
def test_audio_trace(model_id, model_class, targets, ignore):
trace(
model_id,
model_class,
targets,
ignore=ignore,
modality="audio",
trust_remote_code=True,
)
Loading