Skip to content

Commit 352579b

Browse files
authored
[Utils] Add skip_weights_download for developers and testing (#1334)
## Purpose ## * Follow up to #1188 * Add utilities which can be used by developers as well as used during testing of model architectures ## Prerequisites ## * #1187 ## Changes ## * Add `skip_weights_download` which allows a model to be initialized and dispatched without downloading the weights * Add `patch_transformers_logger_level` which is used by `skip_weights_download` to reduce warning verbosity --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 3c6204f commit 352579b

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

src/llmcompressor/utils/dev.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import contextlib
2+
import logging
3+
import os
4+
import tempfile
5+
from typing import Type
6+
7+
import torch
8+
from huggingface_hub import snapshot_download
9+
from safetensors.torch import save_file
10+
from transformers import AutoModelForCausalLM, PreTrainedModel
11+
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
12+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
13+
14+
from llmcompressor.utils.helpers import patch_attr
15+
16+
__all__ = ["skip_weights_download", "patch_transformers_logger_level"]
17+
18+
19+
@contextlib.contextmanager
20+
def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM):
21+
"""
22+
Context manager under which models are initialized without having to download
23+
the model weight files. This differs from `init_empty_weights` in that weights are
24+
allocated on to assigned devices with random values, as opposed to being on the meta
25+
device
26+
27+
:param model_class: class to patch, defaults to `AutoModelForCausalLM`
28+
"""
29+
original_fn = model_class.from_pretrained
30+
weights_files = [
31+
"*.bin",
32+
"*.safetensors",
33+
"*.pth",
34+
SAFE_WEIGHTS_INDEX_NAME,
35+
WEIGHTS_INDEX_NAME,
36+
"*.msgpack",
37+
]
38+
39+
@classmethod
40+
def patched(cls, *args, **kwargs):
41+
nonlocal tmp_dir
42+
43+
# intercept model stub
44+
model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path")
45+
46+
# download files into tmp dir
47+
os.makedirs(tmp_dir, exist_ok=True)
48+
snapshot_download(
49+
repo_id=model_stub, local_dir=tmp_dir, ignore_patterns=weights_files
50+
)
51+
52+
# make an empty weights file to avoid errors
53+
weights_file_path = os.path.join(tmp_dir, "model.safetensors")
54+
save_file({}, weights_file_path, metadata={"format": "pt"})
55+
56+
# load from tmp dir
57+
model = original_fn(tmp_dir, **kwargs)
58+
59+
# replace model_path
60+
model.name_or_path = model_stub
61+
model.config._name_or_path = model_stub
62+
63+
return model
64+
65+
with tempfile.TemporaryDirectory() as tmp_dir, patch_attr(
66+
model_class, "from_pretrained", patched
67+
), skip_weights_initialize(), patch_transformers_logger_level():
68+
yield
69+
70+
71+
@contextlib.contextmanager
72+
def skip_weights_initialize(use_zeros: bool = False):
73+
"""
74+
Very similar to `transformers.model_utils.no_init_weights`, except that torch.Tensor
75+
initialization functions are also patched to account for tensors which are
76+
initialized not on the meta device
77+
"""
78+
79+
def skip(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
80+
if use_zeros:
81+
return tensor.fill_(0)
82+
return tensor
83+
84+
with contextlib.ExitStack() as stack:
85+
for name in TORCH_INIT_FUNCTIONS.keys():
86+
stack.enter_context(patch_attr(torch.nn.init, name, skip))
87+
stack.enter_context(patch_attr(torch.Tensor, name, skip))
88+
yield
89+
90+
91+
@contextlib.contextmanager
92+
def patch_transformers_logger_level(level: int = logging.ERROR):
93+
"""
94+
Context under which the transformers logger's level is modified
95+
96+
This can be used with `skip_weights_download` to squelch warnings related to
97+
missing parameters in the checkpoint
98+
99+
:param level: new logging level for transformers logger. Logs whose level is below
100+
this level will not be logged
101+
"""
102+
transformers_logger = logging.getLogger("transformers.modeling_utils")
103+
restore_log_level = transformers_logger.getEffectiveLevel()
104+
105+
transformers_logger.setLevel(level=level)
106+
yield
107+
transformers_logger.setLevel(level=restore_log_level)

0 commit comments

Comments
 (0)