Skip to content

Commit 5be6f79

Browse files
committed
reindex_fused_weights.py script
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0677428 commit 5be6f79

File tree

7 files changed

+317
-106
lines changed

7 files changed

+317
-106
lines changed

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
from loguru import logger
1313
from safetensors.torch import load_file, save_file
1414

15-
from llmcompressor.entrypoints.model_free.helpers import (
16-
gpu_if_available,
17-
validate_safetensors_index,
18-
validate_scheme,
19-
)
15+
from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
2016
from llmcompressor.entrypoints.model_free.lifecycle import (
2117
calibrate_global_scale,
2218
calibrate_scale_zp,
@@ -35,6 +31,10 @@
3531
update_config,
3632
update_safetensors_index,
3733
)
34+
from llmcompressor.entrypoints.model_free.validate import (
35+
validate_safetensors_index,
36+
validate_scheme,
37+
)
3838

3939
__all__ = ["model_free_ptq"]
4040

@@ -71,15 +71,15 @@ def model_free_ptq(
7171
if not is_microscale_scheme(scheme)
7272
else _process_file_microscale_scheme
7373
)
74-
for file_path, resolved_path in model_files:
74+
for file_path, resolved_path in model_files.items():
7575
save_path = Path(save_directory) / file_path
7676

7777
if file_path.endswith("safetensors"):
7878
jobs.append((job_fn, resolved_path, save_path, scheme, ignore, device))
7979

8080
else:
8181
if is_weights_file(file_path):
82-
logger.warning(f"Skipping weights file {file_path}")
82+
logger.warning(f"Skip processing for weights file {file_path}")
8383
save_path.parent.mkdir(parents=True, exist_ok=True)
8484
logger.info(f"Copying {file_path} {save_path}")
8585
shutil.copyfile(resolved_path, save_path)
Lines changed: 50 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,18 @@
1-
import json
1+
import os
2+
from collections import defaultdict
3+
from typing import Mapping, TypeVar
24

35
import torch
4-
from compressed_tensors.quantization import (
5-
QuantizationScheme,
6-
preset_name_to_scheme,
7-
)
8-
from compressed_tensors.utils import getattr_chain
96
from loguru import logger
7+
from transformers.file_utils import CONFIG_NAME
108

11-
from .microscale import get_fused_names, is_microscale_scheme
12-
13-
__all__ = ["validate_scheme", "gpu_if_available"]
14-
15-
16-
def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme]:
17-
# treat strings as preset schemes
18-
if isinstance(scheme, str):
19-
scheme_name, scheme = scheme, preset_name_to_scheme(scheme, [])
20-
else:
21-
scheme_name = "config_group_0"
22-
23-
# weight quantization must be provided
24-
if scheme.weights is None:
25-
raise ValueError(
26-
"Must provide a weights quanitization scheme to perform weights-only PTQ"
27-
)
28-
29-
# activation quantization must be dynamic
30-
input_dynamic = getattr_chain(scheme, "input_activations.dynamic", True)
31-
output_dynamic = getattr_chain(scheme, "output_activations.dynamic", True)
32-
if input_dynamic is not True or output_dynamic is not True:
33-
raise ValueError(
34-
"Model Free PTQ cannot calibrate activations. "
35-
"Please use `oneshot` instead."
36-
)
37-
38-
# override with static observers
39-
# Remove after https://github.com/vllm-project/compressed-tensors/pull/489
40-
if scheme.weights.observer in ("minmax", "mse"):
41-
new_observer = f"static_{scheme.weights.observer}"
42-
logger.warning(
43-
f"Scheme uses {scheme.weights.observer} weight observer. "
44-
f"Using {new_observer} instead"
45-
)
46-
scheme.weights.observer = new_observer
47-
48-
# target all modules; filter by ignore list
49-
# technically this should be "re:.*", but vllm's
50-
# ct moe layer has a hard coded check for "Linear"
51-
scheme.targets = ["Linear"]
52-
return scheme_name, scheme
53-
54-
55-
def validate_safetensors_index(
56-
model_files: list[tuple[str, str]], scheme: QuantizationScheme
57-
):
58-
resolved_paths = [
59-
resolved_path
60-
for file_path, resolved_path in model_files
61-
if file_path.endswith("safetensors.index.json")
62-
]
63-
if len(resolved_paths) <= 0:
64-
return
65-
resolved_path = resolved_paths[0]
66-
67-
if is_microscale_scheme(scheme):
68-
with open(resolved_path, "r") as file:
69-
weight_map: dict[str, str] = json.load(file)["weight_map"]
70-
71-
fused_names = get_fused_names(weight_map)
72-
for submodule_names in fused_names.values():
73-
file_names = [weight_map[name] for name in submodule_names]
74-
if not all(file_name == file_names[0] for file_name in file_names):
75-
raise NotImplementedError(
76-
"When using a microscale scheme (NVFP4, MXFP4), global scales "
77-
"will be fused. Current implmentation requires that all fused "
78-
"modules (attention and non-moe mlp) be stored in the same file. "
79-
f"Instead, got {submodule_names}\n\n {file_names}"
80-
)
9+
__all__ = [
10+
"gpu_if_available",
11+
"find_safetensors_index_path",
12+
"find_config_path",
13+
"find_safetensors_index_file",
14+
"invert_mapping",
15+
]
8116

8217

8318
def gpu_if_available(device: torch.device | str | None) -> torch.device:
@@ -93,3 +28,42 @@ def gpu_if_available(device: torch.device | str | None) -> torch.device:
9328
else:
9429
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
9530
return torch.device("cpu")
31+
32+
33+
def find_safetensors_index_path(save_directory: str | os.PathLike) -> str | None:
34+
for file_name in os.listdir(save_directory):
35+
if file_name.endswith("safetensors.index.json"):
36+
return os.path.join(save_directory, file_name)
37+
38+
return None
39+
40+
41+
def find_config_path(save_directory: str | os.PathLike) -> str | None:
42+
for file_name in os.listdir(save_directory):
43+
if file_name in (CONFIG_NAME, "params.json"):
44+
return os.path.join(save_directory, file_name)
45+
46+
return None
47+
48+
49+
def find_safetensors_index_file(model_files: dict[str, str]) -> str | None:
50+
for file_path, resolved_path in model_files.items():
51+
if file_path.endswith("safetensors.index.json"):
52+
return resolved_path
53+
54+
return None
55+
56+
57+
KeyType = TypeVar("K")
58+
ValueType = TypeVar("V")
59+
60+
61+
def invert_mapping(
62+
mapping: Mapping[KeyType, ValueType],
63+
) -> dict[ValueType, list[KeyType]]:
64+
inverse = defaultdict(list)
65+
66+
for key, value in mapping.items():
67+
inverse[value].append(key)
68+
69+
return inverse

src/llmcompressor/entrypoints/model_free/microscale.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,52 @@
11
import torch
22
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
3+
from compressed_tensors.utils.match import _match_name
34

4-
__all__ = ["get_fused_names", "is_microscale_scheme"]
5+
__all__ = ["get_fused_names", "is_microscale_scheme", "match_names_set_eager"]
6+
7+
8+
MatchedNamesSet = dict[str, str | None]
59

610

711
def is_microscale_scheme(scheme: QuantizationScheme) -> bool:
812
assert scheme.weights is not None
913
return scheme.weights.strategy == QuantizationStrategy.TENSOR_GROUP
1014

1115

16+
def match_names_set_eager(
17+
tensor_names: set[str] | list[str],
18+
targets: set[str] | list[str],
19+
return_unmatched: bool = True,
20+
) -> list[MatchedNamesSet] | tuple[list[MatchedNamesSet], MatchedNamesSet]:
21+
matched_sets = []
22+
matches = dict.fromkeys(targets, None)
23+
24+
for name in tensor_names:
25+
# match until we get a full set
26+
for target in targets:
27+
if _match_name(name, target):
28+
if matches[target] is None:
29+
matches[target] = name
30+
else:
31+
# matched target twice without completing a set
32+
raise ValueError(
33+
f"Matched a {target} twice before "
34+
f"completing set ({matches[target]}, {name})"
35+
)
36+
37+
# once we have a full set, yield and reset
38+
if all((matches[target] is not None for target in targets)):
39+
matched_sets.append(matches)
40+
matches = dict.fromkeys(targets, None)
41+
42+
unmatched_set = matches if any((v is not None for v in matches.values())) else None
43+
44+
if return_unmatched:
45+
return matched_sets, unmatched_set
46+
else:
47+
return matched_sets
48+
49+
1250
def get_fused_names(tensors: dict[str, torch.Tensor]) -> dict[str, list[str]]:
1351
fused_names = {}
1452

src/llmcompressor/entrypoints/model_free/model_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ def is_weights_file(file_name: str) -> bool:
1818
return any(file_name.endswith(suffix) for suffix in weights_files)
1919

2020

21-
def get_checkpoint_files(model_stub: str | os.PathLike) -> list[tuple[str, str]]:
21+
def get_checkpoint_files(model_stub: str | os.PathLike) -> dict[str, str]:
2222
# In the future, this function can accept and pass download kwargs to cached_file
2323

2424
if os.path.exists(model_stub):
2525
file_paths = walk_file_paths(model_stub, ignore=".cache")
2626
else:
2727
file_paths = list_repo_files(model_stub)
2828

29-
return [(file_path, cached_file(model_stub, file_path)) for file_path in file_paths]
29+
return {file_path: cached_file(model_stub, file_path) for file_path in file_paths}
3030

3131

3232
def walk_file_paths(root_dir: str, ignore: str | None = None) -> list[str]:
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import json
2+
import os
3+
import shutil
4+
from concurrent.futures import ThreadPoolExecutor
5+
from pathlib import Path
6+
7+
import torch
8+
import tqdm
9+
from loguru import logger
10+
from safetensors.torch import load_file, save_file
11+
12+
from llmcompressor.entrypoints.model_free.helpers import (
13+
find_safetensors_index_file,
14+
invert_mapping,
15+
)
16+
from llmcompressor.entrypoints.model_free.microscale import match_names_set_eager
17+
from llmcompressor.entrypoints.model_free.model_utils import (
18+
get_checkpoint_files,
19+
is_weights_file,
20+
)
21+
from llmcompressor.entrypoints.model_free.save_utils import update_safetensors_index
22+
23+
# very naive script
24+
# assumes weight locality, meaning that if a set of fused weights are not in a file,
25+
# 1. the incomplete set is the last set of weights (sorted alphabetically)
26+
# 2. the remainder of the incomplete set is the next file (sorted alphabetically)
27+
28+
model_stub = ""
29+
fused_mappings: list[list[str]] = []
30+
31+
DEFAULT_FUSED_MAPPINGS = [
32+
[
33+
"re:.*(attn|attention)\.q_proj\.weight$",
34+
"re:.*(attn|attention)\.k_proj\.weight$",
35+
"re:.*(attn|attention)\.v_proj\.weight$",
36+
],
37+
[
38+
"re:.*(attn|attention)\.wq_a\.weight$",
39+
"re:.*(attn|attention)\.wkv_a_with_mqa\.weight$",
40+
],
41+
["re:.*mlp\.gate_proj\.weight$", "re:.*attn\.up_proj\.weight$"],
42+
["re:.*w1\.weight$", "re:.*w3\.weight$"],
43+
]
44+
45+
46+
def main(
47+
model_stub: str,
48+
save_directory: str,
49+
fused_mappings: list[list[str]] = DEFAULT_FUSED_MAPPINGS,
50+
):
51+
# read files
52+
model_files = get_checkpoint_files(model_stub)
53+
index_file = find_safetensors_index_file(model_files)
54+
if index_file is None:
55+
raise ValueError(
56+
"This script is used to modify safetensor file shards, "
57+
"but was unable to find safetenors index file"
58+
)
59+
60+
# copy non-weight files
61+
for file_path, resolved_path in model_files.items():
62+
save_path = Path(save_directory) / file_path
63+
64+
if file_path.endswith("safetensors"):
65+
continue
66+
else:
67+
if is_weights_file(file_path):
68+
logger.warning(f"Skip processing for weights file {file_path}")
69+
save_path.parent.mkdir(parents=True, exist_ok=True)
70+
logger.info(f"Copying {file_path} {save_path}")
71+
shutil.copyfile(resolved_path, save_path)
72+
73+
# read index file
74+
with open(index_file, "r") as file:
75+
index_file_data = json.load(file)
76+
77+
weight_map: dict[str, str] = index_file_data["weight_map"]
78+
final_weight_map: dict[str, str] = {}
79+
80+
# set up copy executor and carry over
81+
executor = ThreadPoolExecutor(max_workers=10)
82+
carry_over_tensors: dict[str, torch.Tensor] = {}
83+
84+
# iterate in alphabetical order on assumption of weight-file locality
85+
file_map = invert_mapping(weight_map)
86+
file_map = sorted(file_map)
87+
progress = tqdm.tqdm(total=len(file_map))
88+
for file_name in file_map:
89+
file_path = model_files[file_name]
90+
save_path = os.path.join(save_directory, file_name)
91+
tensors = load_file(file_path)
92+
93+
if len(carry_over_tensors) > 0:
94+
# add carryover
95+
tensors.update(carry_over_tensors)
96+
carry_over_tensors = {}
97+
98+
tensor_names = sorted(list(tensors.keys()))
99+
for mapping in fused_mappings:
100+
_matches, unmatched = match_names_set_eager(tensor_names, mapping)
101+
102+
if unmatched is not None:
103+
# move to carry over
104+
unmatched_tensors = {
105+
key: tensors[key] for key in unmatched.values() if key is not None
106+
}
107+
carry_over_tensors.update(unmatched_tensors)
108+
109+
# delete from current file
110+
for key in unmatched_tensors:
111+
tensor_names.remove(key)
112+
del tensors[key]
113+
114+
# save tensors after modification
115+
executor.submit(with_progress(save_file, tensors, save_path, progress=progress))
116+
final_weight_map.update({name: file_name for name in tensor_names})
117+
118+
update_safetensors_index(
119+
save_directory, index_file_data["metadata"]["total_size"], final_weight_map
120+
)
121+
122+
executor.shutdown()
123+
124+
125+
def with_progress(fn: callable, *args, progress: tqdm.tqdm):
126+
# ret = fn(*args)
127+
# print(args[0].keys())
128+
# print(args[1])
129+
ret = None
130+
progress.update(1)
131+
return ret
132+
133+
134+
if __name__ == "__main__":
135+
main(
136+
# "mistralai/mistral-large-3",
137+
"/raid/engine/hub_cache/mistral-fp8-block",
138+
"temp",
139+
)

0 commit comments

Comments
 (0)