Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/llmcompressor/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def from_modifiers(

(Note: all modifiers are wrapped into a single stage
with the modifier_group_name as the stage name. If modifier_group_name is None,
the default run type is `oneshot`)
the default run_type is `oneshot`)

Lfecycle:
| - Validate Modifiers
Expand Down Expand Up @@ -524,6 +524,33 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:

return dict_

def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
"""
Override the model_dump method to provide a dictionary representation that
is compatible with model_validate.

Unlike the standard model_dump, this transforms the stages list to a format
expected by the validation logic, ensuring round-trip compatibility with
model_validate.

:return: A dictionary representation of the recipe compatible with
model_validate
"""
# Get the base dictionary from parent class
base_dict = super().model_dump(*args, **kwargs)

# Transform stages into the expected format
if "stages" in base_dict:
stages_dict = {}
for stage in base_dict["stages"]:
group = stage["group"]
if group not in stages_dict:
stages_dict[group] = []
stages_dict[group].append(stage)
base_dict["stages"] = stages_dict

return base_dict

def yaml(self, file_path: Optional[str] = None) -> str:
"""
Return a yaml string representation of the recipe.
Expand Down
54 changes: 54 additions & 0 deletions tests/recipe/test_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from src.llmcompressor.recipe import Recipe


def test_recipe_model_dump():
"""Test that model_dump produces a format compatible with model_validate."""
# Create a recipe with multiple stages and modifiers
recipe_str = """
version: "1.0"
args:
learning_rate: 0.001
train_stage:
pruning_modifiers:
ConstantPruningModifier:
start: 0.0
end: 2.0
targets: ['re:.*weight']
quantization_modifiers:
QuantizationModifier:
bits: 8
targets: ['re:.*weight']
eval_stage:
pruning_modifiers:
ConstantPruningModifier:
start: 2.0
end: 4.0
targets: ['re:.*weight']
"""

# Create recipe instance
recipe = Recipe.create_instance(recipe_str)

# Get dictionary representation
recipe_dict = recipe.model_dump()

# Verify the structure is compatible with model_validate
# by creating a new recipe from the dictionary
new_recipe = Recipe.model_validate(recipe_dict)

# Verify version and args are preserved
assert new_recipe.version == recipe.version
assert new_recipe.args == recipe.args

# Verify stages are preserved
assert len(new_recipe.stages) == len(recipe.stages)

# Verify stage names and modifiers are preserved
for new_stage, orig_stage in zip(new_recipe.stages, recipe.stages):
assert new_stage.group == orig_stage.group
assert len(new_stage.modifiers) == len(orig_stage.modifiers)

# Verify modifier types and args are preserved
for new_mod, orig_mod in zip(new_stage.modifiers, orig_stage.modifiers):
assert new_mod.type == orig_mod.type
assert new_mod.args == orig_mod.args