Skip to content

Commit bb1912c

Browse files
committed
add tracing tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 547e68f commit bb1912c

File tree

5 files changed

+47
-15
lines changed

5 files changed

+47
-15
lines changed

.github/workflows/test-check-transformers.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@ jobs:
8888
if: (success() || failure()) && steps.install.outcome == 'success'
8989
run: |
9090
pytest -v tests/llmcompressor/transformers/obcq
91+
- name: Running Tracing Tests
92+
if: (success() || failure()) && steps.install.outcome == 'success'
93+
run: |
94+
pytest -v tests/llmcompressor/transformers/tracing

src/llmcompressor/transformers/finetune/data/peoples_speech.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ class PeoplesSpeech(TextGenerationDataset):
2626
:param processor: processor or tokenizer to use on dataset
2727
"""
2828

29-
def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
30-
data_args = deepcopy(data_args)
31-
data_args.dataset = "MLCommons/peoples_speech"
32-
data_args.dataset_config_name = "test"
33-
if not data_args.overwrite_cache:
29+
def __init__(self, dataset_args: "DataArgs", split: str, processor: Processor):
30+
dataset_args = deepcopy(dataset_args)
31+
dataset_args.dataset = "MLCommons/peoples_speech"
32+
dataset_args.dataset_config_name = "test"
33+
if not dataset_args.overwrite_cache:
3434
logger.warning(
3535
"Because audio processors are more complex, dataset mapping functions "
3636
"vary with model architecture and their results cannot be cached. "
3737
"Setting overwrite_cache=True"
3838
)
39-
data_args.overwrite_cache = True
39+
dataset_args.overwrite_cache = True
4040
self.processor_type = processor.__class__.__name__
4141

42-
super().__init__(data_args=data_args, split=split, processor=processor)
42+
super().__init__(dataset_args=dataset_args, split=split, processor=processor)
4343

4444
def dataset_template(self, example):
4545
audio = example["audio"]["array"]

src/llmcompressor/transformers/tracing/debug.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from llmcompressor.transformers import TextGenerationDataset
1313
from llmcompressor.args import DatasetArguments
1414

15+
from llmcompressor.utils.dev import skip_weights_download
16+
1517
__all__ = [
1618
"get_model_class"
1719
]
@@ -24,6 +26,7 @@ def parse_args():
2426
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
2527
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
2628
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
29+
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501
2730
return parser.parse_args()
2831

2932

@@ -33,6 +36,7 @@ def trace(
3336
sequential_targets: Optional[Union[List[str], str]] = None,
3437
ignore: Union[List[str], str] = [],
3538
modality: str = "text",
39+
trust_remote_code: bool = True
3640
):
3741
"""
3842
Debug traceability by tracing a pre-trained model into subgraphs
@@ -44,6 +48,7 @@ def trace(
4448
inference
4549
:param ignore: patterns to ignore during tracing
4650
:param modality: data modality for dummy tracing data, defaults to 'text'
51+
:param trust_remote_code: trust remote model code
4752
4853
Example usage from CLI
4954
llmcompressor.trace \
@@ -54,12 +59,16 @@ def trace(
5459
--modality text
5560
"""
5661
# Load model
57-
model = model_class.from_pretrained(
58-
model_id,
59-
device_map="auto",
60-
torch_dtype="auto",
62+
with skip_weights_download(model_class):
63+
model = model_class.from_pretrained(
64+
model_id,
65+
device_map="cpu",
66+
torch_dtype="auto",
67+
trust_remote_code=trust_remote_code,
68+
)
69+
processor = AutoProcessor.from_pretrained(
70+
model_id, trust_remote_code=trust_remote_code
6171
)
62-
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
6372
print("Loaded model")
6473

6574
# Prepare sample data
@@ -138,6 +147,7 @@ def main():
138147
sequential_targets=args.sequential_targets,
139148
ignore=args.ignore,
140149
modality=args.modality,
150+
trust_remote_code=args.trust_remote_code
141151
)
142152

143153

src/llmcompressor/transformers/tracing/idefics3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def __init__(self, config: Idefics3Config):
285285

286286
def forward(
287287
self,
288-
input_ids: torch.LongTensor = None,
288+
input_ids: Optional[torch.LongTensor] = None,
289289
attention_mask: Optional[torch.Tensor] = None,
290290
position_ids: Optional[torch.LongTensor] = None,
291291
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -296,6 +296,7 @@ def forward(
296296
use_cache: Optional[bool] = None,
297297
output_attentions: Optional[bool] = None,
298298
output_hidden_states: Optional[bool] = None,
299+
cache_position: Optional[torch.LongTensor] = None,
299300
return_dict: Optional[bool] = None,
300301
) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
301302
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -394,6 +395,7 @@ def forward(
394395
use_cache=use_cache,
395396
output_attentions=output_attentions,
396397
output_hidden_states=output_hidden_states,
398+
cache_position=cache_position,
397399
return_dict=return_dict,
398400
)
399401

src/llmcompressor/utils/dev.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import tempfile
55
from typing import Type
66

7+
import torch
78
from huggingface_hub import snapshot_download
89
from safetensors.torch import save_file
910
from transformers import AutoModelForCausalLM, PreTrainedModel
10-
from transformers.modeling_utils import no_init_weights
11+
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
1112
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
1213

1314
from llmcompressor.utils.helpers import patch_attr
@@ -32,6 +33,7 @@ def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausa
3233
"*.pth",
3334
SAFE_WEIGHTS_INDEX_NAME,
3435
WEIGHTS_INDEX_NAME,
36+
"*.msgpack",
3537
]
3638

3739
@classmethod
@@ -62,7 +64,21 @@ def patched(cls, *args, **kwargs):
6264

6365
with tempfile.TemporaryDirectory() as tmp_dir, patch_attr(
6466
model_class, "from_pretrained", patched
65-
), no_init_weights(), patch_transformers_logger_level():
67+
), skip_weights_initialize(), patch_transformers_logger_level():
68+
yield
69+
70+
71+
@contextlib.contextmanager
72+
def skip_weights_initialize(use_zeros: bool = False):
73+
def skip(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
74+
if use_zeros:
75+
return tensor.fill_(0)
76+
return tensor
77+
78+
with contextlib.ExitStack() as stack:
79+
for name in TORCH_INIT_FUNCTIONS.keys():
80+
stack.enter_context(patch_attr(torch.nn.init, name, skip))
81+
stack.enter_context(patch_attr(torch.Tensor, name, skip))
6682
yield
6783

6884

0 commit comments

Comments
 (0)