-
Notifications
You must be signed in to change notification settings - Fork 459
Expand file tree
/
Copy pathhelpers.py
More file actions
74 lines (60 loc) · 2.3 KB
/
helpers.py
File metadata and controls
74 lines (60 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Optional, Union
import torch
from loguru import logger
from torch.nn import Module
from llmcompressor.core import active_session
__all__ = [
"copy_python_files_from_model_cache",
"parse_dtype",
"get_session_model",
]
def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype:
"""
:param dtype_arg: dtype or string to parse
:return: torch.dtype parsed from input string
"""
dtype_arg = str(dtype_arg)
dtype = "auto" # get precision from model by default
if dtype_arg in ("half", "float16", "torch.float16"):
dtype = torch.float16
elif dtype_arg in ("torch.bfloat16", "bfloat16"):
dtype = torch.bfloat16
elif dtype_arg in ("full", "float32", "torch.float32"):
dtype = torch.float32
return dtype
def get_session_model() -> Optional[Module]:
"""
:return: pytorch module stored by the active CompressionSession,
or None if no session is active
"""
session = active_session()
if not session:
return None
active_model = session.state.model
return active_model
def copy_python_files_from_model_cache(model, save_path: str):
config = model.config
cache_path = None
if hasattr(config, "_name_or_path") and len(config._name_or_path.strip()) > 0:
import os
import shutil
from huggingface_hub import hf_hub_download
from transformers.utils import http_user_agent
cache_path = config._name_or_path
if not os.path.exists(cache_path):
user_agent = http_user_agent()
# Use cache_dir=None to respect HF_HOME, HF_HUB_CACHE, and other
# environment variables for cache location
config_file_path = hf_hub_download(
repo_id=cache_path,
filename="config.json",
cache_dir=None,
force_download=False,
user_agent=user_agent,
)
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])
for file in os.listdir(cache_path):
full_file_name = os.path.join(cache_path, file)
if file.endswith(".py") and os.path.isfile(full_file_name):
logger.debug(f"Transferring {full_file_name} to {save_path}")
shutil.copy(full_file_name, save_path)