-
Notifications
You must be signed in to change notification settings - Fork 458
Expand file tree
/
Copy pathhelpers.py
More file actions
118 lines (88 loc) · 3.4 KB
/
helpers.py
File metadata and controls
118 lines (88 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import re
from collections import defaultdict
from typing import Mapping, TypeVar
import torch
from compressed_tensors.utils.match import match_name
from loguru import logger
from transformers.file_utils import CONFIG_NAME
__all__ = [
"gpu_if_available",
"find_safetensors_index_path",
"find_config_path",
"find_safetensors_index_file",
"match_names_set_eager",
"MatchedNamesSet",
"invert_mapping",
]
KeyType = TypeVar("K")
ValueType = TypeVar("V")
MatchedNamesSet = dict[str, str | None]
def gpu_if_available(device: torch.device | str | None) -> torch.device:
if device is not None:
return torch.device(device)
elif torch.cuda.is_available():
return torch.device("cuda:0")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu:0")
elif hasattr(torch, "npu") and torch.npu.is_available():
return torch.device("npu:0")
else:
logger.warning(
"CUDA/XPU/NPU is not available! Compressing model on CPU instead"
)
return torch.device("cpu")
def find_safetensors_index_path(save_directory: str | os.PathLike) -> str | None:
for file_name in os.listdir(save_directory):
if file_name.endswith("safetensors.index.json"):
return os.path.join(save_directory, file_name)
return None
def find_config_path(save_directory: str | os.PathLike) -> str | None:
for file_name in os.listdir(save_directory):
if file_name in (CONFIG_NAME, "params.json"):
return os.path.join(save_directory, file_name)
return None
def find_safetensors_index_file(model_files: dict[str, str]) -> str | None:
for file_path, resolved_path in model_files.items():
if file_path.endswith("safetensors.index.json"):
return resolved_path
return None
def match_names_set_eager(
names: set[str] | list[str],
targets: set[str] | list[str],
return_unmatched: bool = True,
) -> list[MatchedNamesSet] | tuple[list[MatchedNamesSet], MatchedNamesSet]:
matched_sets = []
matches = dict.fromkeys(targets, None)
def natural_key(s: str) -> list[str | int]:
return [int(p) if p.isdigit() else p for p in re.split(r"(\d+)", s)]
# natural sort for consistent grouping
names = sorted(names, key=natural_key)
for name in names:
# match until we get a full set
for target in targets:
if match_name(name, target):
if matches[target] is None:
matches[target] = name
else:
# matched target twice without completing a set
raise ValueError(
f"Matched a {target} twice before "
f"completing set ({matches[target]}, {name})"
)
# once we have a full set, yield and reset
if all((matches[target] is not None for target in targets)):
matched_sets.append(matches)
matches = dict.fromkeys(targets, None)
unmatched_set = matches if any((v is not None for v in matches.values())) else None
if return_unmatched:
return matched_sets, unmatched_set
else:
return matched_sets
def invert_mapping(
mapping: Mapping[KeyType, ValueType],
) -> dict[ValueType, list[KeyType]]:
inverse = defaultdict(list)
for key, value in mapping.items():
inverse[value].append(key)
return inverse