Skip to content

Commit 998be99

Browse files
ved1betadsikkakylesayrsrahul-tuli
authored
fix: Make Recipe.model_dump() output compatible with model_validate() (#1328)
SUMMARY: Fixed issue #1319 where Recipe.model_dump() output couldn't be used with Recipe.model_validate(). Implemented an override of the model_dump() method to ensure it produces output in the format expected by validation, enabling proper round-trip serialization using standard Pydantic methods. TEST PLAN: Created test cases to verify fix works with both simple and complex recipes Confirmed Recipe.model_validate(recipe.model_dump()) succeeds with various recipe formats Validated that recipes with multiple stages having the same group name serialize/deserialize correctly Ensured existing YAML serialization pathways continue to work as expected --------- Signed-off-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com>
1 parent 006899c commit 998be99

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

src/llmcompressor/recipe/recipe.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def from_modifiers(
4646
4747
(Note: all modifiers are wrapped into a single stage
4848
with the modifier_group_name as the stage name. If modifier_group_name is None,
49-
the default run type is `oneshot`)
49+
the default run_type is `oneshot`)
5050
5151
Lfecycle:
5252
| - Validate Modifiers
@@ -524,6 +524,33 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
524524

525525
return dict_
526526

527+
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
528+
"""
529+
Override the model_dump method to provide a dictionary representation that
530+
is compatible with model_validate.
531+
532+
Unlike the standard model_dump, this transforms the stages list to a format
533+
expected by the validation logic, ensuring round-trip compatibility with
534+
model_validate.
535+
536+
:return: A dictionary representation of the recipe compatible with
537+
model_validate
538+
"""
539+
# Get the base dictionary from parent class
540+
base_dict = super().model_dump(*args, **kwargs)
541+
542+
# Transform stages into the expected format
543+
if "stages" in base_dict:
544+
stages_dict = {}
545+
for stage in base_dict["stages"]:
546+
group = stage["group"]
547+
if group not in stages_dict:
548+
stages_dict[group] = []
549+
stages_dict[group].append(stage)
550+
base_dict["stages"] = stages_dict
551+
552+
return base_dict
553+
527554
def yaml(self, file_path: Optional[str] = None) -> str:
528555
"""
529556
Return a yaml string representation of the recipe.

tests/recipe/test_recipe.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from src.llmcompressor.recipe import Recipe
2+
3+
4+
def test_recipe_model_dump():
5+
"""Test that model_dump produces a format compatible with model_validate."""
6+
# Create a recipe with multiple stages and modifiers
7+
recipe_str = """
8+
version: "1.0"
9+
args:
10+
learning_rate: 0.001
11+
train_stage:
12+
pruning_modifiers:
13+
ConstantPruningModifier:
14+
start: 0.0
15+
end: 2.0
16+
targets: ['re:.*weight']
17+
quantization_modifiers:
18+
QuantizationModifier:
19+
bits: 8
20+
targets: ['re:.*weight']
21+
eval_stage:
22+
pruning_modifiers:
23+
ConstantPruningModifier:
24+
start: 2.0
25+
end: 4.0
26+
targets: ['re:.*weight']
27+
"""
28+
29+
# Create recipe instance
30+
recipe = Recipe.create_instance(recipe_str)
31+
32+
# Get dictionary representation
33+
recipe_dict = recipe.model_dump()
34+
35+
# Verify the structure is compatible with model_validate
36+
# by creating a new recipe from the dictionary
37+
new_recipe = Recipe.model_validate(recipe_dict)
38+
39+
# Verify version and args are preserved
40+
assert new_recipe.version == recipe.version
41+
assert new_recipe.args == recipe.args
42+
43+
# Verify stages are preserved
44+
assert len(new_recipe.stages) == len(recipe.stages)
45+
46+
# Verify stage names and modifiers are preserved
47+
for new_stage, orig_stage in zip(new_recipe.stages, recipe.stages):
48+
assert new_stage.group == orig_stage.group
49+
assert len(new_stage.modifiers) == len(orig_stage.modifiers)
50+
51+
# Verify modifier types and args are preserved
52+
for new_mod, orig_mod in zip(new_stage.modifiers, orig_stage.modifiers):
53+
assert new_mod.type == orig_mod.type
54+
assert new_mod.args == orig_mod.args

0 commit comments

Comments
 (0)