|
| 1 | +import contextlib |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import tempfile |
| 5 | +from typing import Type |
| 6 | + |
| 7 | +import torch |
| 8 | +from huggingface_hub import snapshot_download |
| 9 | +from safetensors.torch import save_file |
| 10 | +from transformers import AutoModelForCausalLM, PreTrainedModel |
| 11 | +from transformers.modeling_utils import TORCH_INIT_FUNCTIONS |
| 12 | +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME |
| 13 | + |
| 14 | +from llmcompressor.utils.helpers import patch_attr |
| 15 | + |
| 16 | +__all__ = ["skip_weights_download", "patch_transformers_logger_level"] |
| 17 | + |
| 18 | + |
| 19 | +@contextlib.contextmanager |
| 20 | +def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM): |
| 21 | + """ |
| 22 | + Context manager under which models are initialized without having to download |
| 23 | + the model weight files. This differs from `init_empty_weights` in that weights are |
| 24 | + allocated on to assigned devices with random values, as opposed to being on the meta |
| 25 | + device |
| 26 | +
|
| 27 | + :param model_class: class to patch, defaults to `AutoModelForCausalLM` |
| 28 | + """ |
| 29 | + original_fn = model_class.from_pretrained |
| 30 | + weights_files = [ |
| 31 | + "*.bin", |
| 32 | + "*.safetensors", |
| 33 | + "*.pth", |
| 34 | + SAFE_WEIGHTS_INDEX_NAME, |
| 35 | + WEIGHTS_INDEX_NAME, |
| 36 | + "*.msgpack", |
| 37 | + ] |
| 38 | + |
| 39 | + @classmethod |
| 40 | + def patched(cls, *args, **kwargs): |
| 41 | + nonlocal tmp_dir |
| 42 | + |
| 43 | + # intercept model stub |
| 44 | + model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path") |
| 45 | + |
| 46 | + # download files into tmp dir |
| 47 | + os.makedirs(tmp_dir, exist_ok=True) |
| 48 | + snapshot_download( |
| 49 | + repo_id=model_stub, local_dir=tmp_dir, ignore_patterns=weights_files |
| 50 | + ) |
| 51 | + |
| 52 | + # make an empty weights file to avoid errors |
| 53 | + weights_file_path = os.path.join(tmp_dir, "model.safetensors") |
| 54 | + save_file({}, weights_file_path, metadata={"format": "pt"}) |
| 55 | + |
| 56 | + # load from tmp dir |
| 57 | + model = original_fn(tmp_dir, **kwargs) |
| 58 | + |
| 59 | + # replace model_path |
| 60 | + model.name_or_path = model_stub |
| 61 | + model.config._name_or_path = model_stub |
| 62 | + |
| 63 | + return model |
| 64 | + |
| 65 | + with tempfile.TemporaryDirectory() as tmp_dir, patch_attr( |
| 66 | + model_class, "from_pretrained", patched |
| 67 | + ), skip_weights_initialize(), patch_transformers_logger_level(): |
| 68 | + yield |
| 69 | + |
| 70 | + |
| 71 | +@contextlib.contextmanager |
| 72 | +def skip_weights_initialize(use_zeros: bool = False): |
| 73 | + """ |
| 74 | + Very similar to `transformers.model_utils.no_init_weights`, except that torch.Tensor |
| 75 | + initialization functions are also patched to account for tensors which are |
| 76 | + initialized not on the meta device |
| 77 | + """ |
| 78 | + |
| 79 | + def skip(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
| 80 | + if use_zeros: |
| 81 | + return tensor.fill_(0) |
| 82 | + return tensor |
| 83 | + |
| 84 | + with contextlib.ExitStack() as stack: |
| 85 | + for name in TORCH_INIT_FUNCTIONS.keys(): |
| 86 | + stack.enter_context(patch_attr(torch.nn.init, name, skip)) |
| 87 | + stack.enter_context(patch_attr(torch.Tensor, name, skip)) |
| 88 | + yield |
| 89 | + |
| 90 | + |
| 91 | +@contextlib.contextmanager |
| 92 | +def patch_transformers_logger_level(level: int = logging.ERROR): |
| 93 | + """ |
| 94 | + Context under which the transformers logger's level is modified |
| 95 | +
|
| 96 | + This can be used with `skip_weights_download` to squelch warnings related to |
| 97 | + missing parameters in the checkpoint |
| 98 | +
|
| 99 | + :param level: new logging level for transformers logger. Logs whose level is below |
| 100 | + this level will not be logged |
| 101 | + """ |
| 102 | + transformers_logger = logging.getLogger("transformers.modeling_utils") |
| 103 | + restore_log_level = transformers_logger.getEffectiveLevel() |
| 104 | + |
| 105 | + transformers_logger.setLevel(level=level) |
| 106 | + yield |
| 107 | + transformers_logger.setLevel(level=restore_log_level) |
0 commit comments