Skip to content

Commit 027caa4

Browse files
authored
[BugFix] Directly Convert Modifiers to Recipe Instance (#1271)
Currently, the process of recipe creation follows this sequence: **Modifiers → String (Serialization) → Recipe Instance (Deserialization)** This intermediate serialization and deserialization step introduces issues when dealing with more complex objects, such as **SmoothQuant mappings**, which can lead to parsing errors. ### Solution This PR refactors the flow to directly construct the **Recipe Instance** from **Modifiers**, thereby **removing an unnecessary conversion step** and eliminating a potential source of error. ### Issue Tracking This issue was originally surfaced in [[#37](https://github.com/vllm-project/llm-compressor/issues/37)](https://github.com/vllm-project/llm-compressor/issues/37) and is formally tracked under **[[INFERENG-358](https://issues.redhat.com/browse/INFERENG-358)](https://issues.redhat.com/browse/INFERENG-358)**. ### Testing The issue was reproduced using the following script, which previously errored out but now runs **successfully** with this fix: ```python from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.modifiers.smoothquant import SmoothQuantModifier DATASET_ID = "HuggingFaceH4/ultrachat_200k" MODEL_ID = "bigscience/bloom-3b" DATASET_SPLIT = "train_sft" NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype="auto" ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Define quantization recipe recipe = [ SmoothQuantModifier( smoothing_strength=0.8, mappings=[ (["re:.*query_key_value"], "re:.*input_layernorm"), (["re:.*dense_h_to_4h"], "re:.*post_attention_layernorm"), ], ), GPTQModifier( scheme="W8A8", targets="Linear", ignore=["lm_head"], dampening_frac=0.003, ), ] # Load and preprocess dataset dataset = load_dataset(DATASET_ID, split=DATASET_SPLIT) dataset = dataset.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) def preprocess(example): """Formats the messages into a simple dialogue format.""" text = "\n".join([msg["content"] for msg in example["messages"]]) return {"text": text} dataset = dataset.map(preprocess) # Apply quantization oneshot( model=model, dataset=dataset, recipe=recipe, output_dir="bloom-3b-gptq-w8a8", max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) ``` With this fix, the script now runs **to completion** without errors. Automated tests have also been added to test new changes --------- Signed-off-by: Rahul Tuli <[email protected]>
1 parent 4f35d48 commit 027caa4

File tree

7 files changed

+178
-141
lines changed

7 files changed

+178
-141
lines changed

src/llmcompressor/recipe/modifier.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,20 @@ def create_modifier(self) -> "Modifier":
8383
@model_validator(mode="before")
8484
@classmethod
8585
def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]:
86-
modifier = {"group": values.pop("group")}
87-
assert len(values) == 1, "multiple key pairs found for modifier"
88-
modifier_type, args = list(values.items())[0]
89-
90-
modifier["type"] = modifier_type
91-
modifier["args"] = args
92-
return modifier
86+
if len(values) == 2:
87+
if "group" not in values:
88+
raise ValueError(
89+
"Invalid format: expected keys 'group' and one modifier "
90+
f"type, but got keys: {list(values.keys())}"
91+
)
92+
93+
# values contains only group and the Modifier type as keys
94+
group = values.pop("group")
95+
modifier_type, args = values.popitem()
96+
return {"group": group, "type": modifier_type, "args": args}
97+
98+
# values already in the correct format
99+
return values
93100

94101
def dict(self, *args, **kwargs) -> Dict[str, Any]:
95102
"""

src/llmcompressor/recipe/recipe.py

Lines changed: 18 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from llmcompressor.recipe.args import RecipeArgs
1313
from llmcompressor.recipe.base import RecipeBase
1414
from llmcompressor.recipe.metadata import RecipeMetaData
15+
from llmcompressor.recipe.modifier import RecipeModifier
1516
from llmcompressor.recipe.stage import RecipeStage
1617

1718
__all__ = [
@@ -61,20 +62,29 @@ def from_modifiers(
6162
"""
6263
logger.info("Creating recipe from modifiers")
6364

64-
# validate Modifiers
6565
if isinstance(modifiers, Modifier):
66-
modifiers: List[Modifier] = [modifiers]
66+
modifiers = [modifiers]
6767

6868
if any(not isinstance(modifier, Modifier) for modifier in modifiers):
6969
raise ValueError("modifiers must be a list of Modifier instances")
7070

71-
recipe_string: str = create_recipe_string_from_modifiers(
72-
modifiers=modifiers,
73-
modifier_group_name=modifier_group_name,
74-
)
71+
group_name = modifier_group_name or "default"
7572

76-
# modifier group name already included in the recipe string
77-
return cls.create_instance(path_or_modifiers=recipe_string)
73+
recipe_modifiers: List[RecipeModifier] = [
74+
RecipeModifier(
75+
type=modifier.__class__.__name__,
76+
group=group_name,
77+
args=modifier.model_dump(exclude_unset=True),
78+
)
79+
for modifier in modifiers
80+
]
81+
# assume one stage for modifier instances
82+
stages: List[RecipeStage] = [
83+
RecipeStage(group=group_name, modifiers=recipe_modifiers)
84+
]
85+
recipe = cls()
86+
recipe.stages = stages
87+
return recipe
7888

7989
@classmethod
8090
def create_instance(
@@ -652,67 +662,6 @@ def _parse_recipe_from_md(file_path, yaml_str):
652662
return yaml_str
653663

654664

655-
def create_recipe_string_from_modifiers(
656-
modifiers: List[Modifier],
657-
modifier_group_name: Optional[str] = None,
658-
) -> str:
659-
"""
660-
Create a recipe string from a list of Modifier instances
661-
662-
(Note: this pathway assumes there's only one stage in the recipe
663-
associated by the modifier_group_name, if None, a dummy default
664-
group_name will be assigned.)
665-
666-
:param modifiers: The list of Modifier instances
667-
:param modifier_group_name: The stage_name of the recipe,
668-
if `oneshot` or `train` the run_type of the recipe will be
669-
inferred from the modifier_group_name, if None, a dummy default
670-
group_name will be assigned.
671-
:return: A string in yaml format from which the recipe can be created
672-
"""
673-
674-
# Recipe(s) are yaml/json strings of the following format:
675-
# run_type_stage: # should contain oneshot/train
676-
# modifiers:
677-
# ModifierTypeOne:
678-
# start: 0.0
679-
# end: 2.0
680-
# ...
681-
# ModifierTypeTwo:
682-
# ...
683-
684-
# Create a recipe string from the modifiers
685-
default_group_name: str = "DEFAULT"
686-
modifier_group_name: str = modifier_group_name or default_group_name
687-
688-
recipe_dict = {
689-
f"{modifier_group_name}_stage": {
690-
f"{default_group_name}_modifiers": {
691-
modifier.__class__.__name__: modifier.model_dump(exclude_unset=True)
692-
for modifier in modifiers
693-
}
694-
}
695-
}
696-
recipe_str: str = yaml.dump(recipe_dict, sort_keys=False)
697-
return recipe_str
698-
699-
700-
def get_modifiers_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:
701-
group_dict = {}
702-
703-
for modifier in modifiers:
704-
modifier_type = modifier["type"]
705-
modifier_group = modifier["group"]
706-
707-
if modifier_group not in group_dict:
708-
group_dict[modifier_group] = []
709-
710-
modifier_dict = {modifier_type: modifier["args"]}
711-
group_dict[modifier_group].append(modifier_dict)
712-
713-
return group_dict
714-
715-
716665
def get_yaml_serializable_stage_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:
717666
"""
718667
This function is used to convert a list of modifiers into a dictionary

src/llmcompressor/recipe/stage.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -139,26 +139,26 @@ def extract_dict_modifiers(values: Dict[str, Any]) -> List[Dict[str, Any]]:
139139
"""
140140

141141
modifiers = []
142-
remove_keys = []
143-
144-
if "modifiers" in values and values["modifiers"]:
145-
remove_keys.append("modifiers")
146-
for mod_key, mod_value in values["stages"].items():
147-
modifier = {mod_key: mod_value}
148-
modifier["group"] = "default"
149-
modifiers.append(modifier)
150-
151-
for key, value in list(values.items()):
152-
if key.endswith("_modifiers"):
153-
remove_keys.append(key)
154-
group = key.rsplit("_modifiers", 1)[0]
155-
for mod_key, mod_value in value.items():
156-
modifier = {mod_key: mod_value}
157-
modifier["group"] = group
158-
modifiers.append(modifier)
159-
160-
for key in remove_keys:
161-
del values[key]
142+
143+
if "modifiers" in values:
144+
modifier_values = values.pop("modifiers")
145+
if "stages" in values:
146+
for mod_key, mod_value in values.pop("stages").items():
147+
modifiers.append({mod_key: mod_value, "group": "default"})
148+
else:
149+
values["default_stage"] = {
150+
"default_modifiers": {mod.type: mod.args for mod in modifier_values}
151+
}
152+
modifiers.extend(
153+
{mod.type: mod.args, "group": "default"} for mod in modifier_values
154+
)
155+
156+
for key in [k for k in values if k.endswith("_modifiers")]:
157+
group = key.rsplit("_modifiers", 1)[0]
158+
modifiers.extend(
159+
{mod_key: mod_value, "group": group}
160+
for mod_key, mod_value in values.pop(key).items()
161+
)
162162

163163
return modifiers
164164

tests/e2e/recipe.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
quant_stage:
2+
quant_modifiers:
3+
SmoothQuantModifier:
4+
smoothing_strength: 0.8
5+
mappings:
6+
- - ['re:.*q_proj', 're:.*k_proj', 're:.*v_proj']
7+
- re:.*input_layernorm
8+
- - ['re:.*gate_proj', 're:.*up_proj']
9+
- re:.*post_attention_layernorm
10+
GPTQModifier:
11+
sequential_update: false
12+
ignore: [lm_head]
13+
config_groups:
14+
group_0:
15+
weights: {num_bits: 8, type: int, symmetric: true, strategy: channel}
16+
input_activations: {num_bits: 8, symmetric: false}
17+
targets: [Linear]

tests/e2e/test_recipe_parsing.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
from transformers import AutoModelForCausalLM
5+
6+
from llmcompressor.core.session_functions import reset_session
7+
from llmcompressor.modifiers.quantization.gptq import GPTQModifier
8+
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
9+
from llmcompressor.modifiers.smoothquant.utils import DEFAULT_SMOOTHQUANT_MAPPINGS
10+
from llmcompressor.transformers import oneshot
11+
from tests.testing_utils import requires_gpu
12+
13+
14+
@pytest.fixture
15+
def common_setup():
16+
model_stub = "Xenova/llama2.c-stories110M"
17+
model = AutoModelForCausalLM.from_pretrained(
18+
model_stub, device_map="auto", torch_dtype="auto"
19+
)
20+
21+
dataset = "ultrachat-200k"
22+
output_dir = "./test_output"
23+
splits = {"calibration": "train_gen[:5%]"}
24+
max_seq_length = 2048
25+
pad_to_max_length = False
26+
num_calibration_samples = 8
27+
28+
return (
29+
model,
30+
dataset,
31+
output_dir,
32+
splits,
33+
max_seq_length,
34+
pad_to_max_length,
35+
num_calibration_samples,
36+
)
37+
38+
39+
def recipes():
40+
modifier_objects = [
41+
SmoothQuantModifier(
42+
smoothing_strength=0.8, mappings=DEFAULT_SMOOTHQUANT_MAPPINGS
43+
),
44+
GPTQModifier(
45+
targets="Linear", scheme="W8A8", ignore=["lm_head"], sequential_update=False
46+
),
47+
]
48+
49+
recipe_str = """
50+
DEFAULT_stage:
51+
DEFAULT_modifiers:
52+
SmoothQuantModifier:
53+
smoothing_strength: 0.8
54+
mappings:
55+
- - ['re:.*q_proj', 're:.*k_proj', 're:.*v_proj']
56+
- re:.*input_layernorm
57+
- - ['re:.*gate_proj', 're:.*up_proj']
58+
- re:.*post_attention_layernorm
59+
GPTQModifier:
60+
sequential_update: false
61+
targets: Linear
62+
scheme: W8A8
63+
"""
64+
65+
recipe_file = str(Path(__file__).parent / "recipe.yaml")
66+
67+
return [modifier_objects, recipe_str, recipe_file]
68+
69+
70+
@requires_gpu
71+
@pytest.mark.parametrize("recipe", recipes())
72+
def test_oneshot(common_setup, recipe):
73+
(
74+
model,
75+
dataset,
76+
output_dir,
77+
splits,
78+
max_seq_length,
79+
pad_to_max_length,
80+
num_calibration_samples,
81+
) = common_setup
82+
83+
oneshot(
84+
model=model,
85+
dataset=dataset,
86+
recipe=recipe,
87+
output_dir=output_dir,
88+
splits=splits,
89+
max_seq_length=max_seq_length,
90+
pad_to_max_length=pad_to_max_length,
91+
num_calibration_samples=num_calibration_samples,
92+
save_compressed=True,
93+
)
94+
95+
reset_session()

tests/llmcompressor/helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# flake8: noqa
2+
3+
14
def valid_recipe_strings():
25
return [
36
"""
@@ -52,4 +55,14 @@ def valid_recipe_strings():
5255
final_sparsity: 0.5
5356
targets: __ALL_PRUNABLE__
5457
""",
58+
"""
59+
test1_stage:
60+
smoothquant_modifiers:
61+
SmoothQuantModifier:
62+
smoothing_strength: 0.5
63+
mappings: [
64+
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
65+
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
66+
]
67+
""",
5568
]

tests/llmcompressor/recipe/test_recipe.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import pytest
44
import yaml
55

6-
from llmcompressor.modifiers import Modifier
76
from llmcompressor.modifiers.obcq.base import SparseGPTModifier
87
from llmcompressor.recipe import Recipe
9-
from llmcompressor.recipe.recipe import create_recipe_string_from_modifiers
108
from tests.llmcompressor.helpers import valid_recipe_strings
119

1210

@@ -97,46 +95,4 @@ def test_recipe_can_be_created_from_modifier_instances():
9795
actual_modifiers[0].modifiers, expected_modifiers[0].modifiers
9896
):
9997
assert isinstance(actual_modifier, type(expected_modifier))
100-
assert actual_modifier.dict() == expected_modifier.dict()
101-
102-
103-
class A_FirstDummyModifier(Modifier):
104-
def on_initialize(self, *args, **kwargs) -> bool:
105-
return True
106-
107-
108-
class B_SecondDummyModifier(Modifier):
109-
def on_initialize(self, *args, **kwargs) -> bool:
110-
return True
111-
112-
113-
def test_create_recipe_string_from_modifiers_with_default_group_name():
114-
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
115-
expected_recipe_str = (
116-
"DEFAULT_stage:\n"
117-
" DEFAULT_modifiers:\n"
118-
" B_SecondDummyModifier: {}\n"
119-
" A_FirstDummyModifier: {}\n"
120-
)
121-
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
122-
assert actual_recipe_str == expected_recipe_str
123-
124-
125-
def test_create_recipe_string_from_modifiers_with_custom_group_name():
126-
modifiers = [B_SecondDummyModifier(), A_FirstDummyModifier()]
127-
group_name = "custom"
128-
expected_recipe_str = (
129-
"custom_stage:\n"
130-
" DEFAULT_modifiers:\n"
131-
" B_SecondDummyModifier: {}\n"
132-
" A_FirstDummyModifier: {}\n"
133-
)
134-
actual_recipe_str = create_recipe_string_from_modifiers(modifiers, group_name)
135-
assert actual_recipe_str == expected_recipe_str
136-
137-
138-
def test_create_recipe_string_from_modifiers_with_empty_modifiers():
139-
modifiers = []
140-
expected_recipe_str = "DEFAULT_stage:\n" " DEFAULT_modifiers: {}\n"
141-
actual_recipe_str = create_recipe_string_from_modifiers(modifiers)
142-
assert actual_recipe_str == expected_recipe_str
98+
assert actual_modifier.model_dump() == expected_modifier.model_dump()

0 commit comments

Comments
 (0)