|
12 | 12 | from llmcompressor.recipe.args import RecipeArgs |
13 | 13 | from llmcompressor.recipe.base import RecipeBase |
14 | 14 | from llmcompressor.recipe.metadata import RecipeMetaData |
| 15 | +from llmcompressor.recipe.modifier import RecipeModifier |
15 | 16 | from llmcompressor.recipe.stage import RecipeStage |
16 | 17 |
|
17 | 18 | __all__ = [ |
@@ -61,20 +62,29 @@ def from_modifiers( |
61 | 62 | """ |
62 | 63 | logger.info("Creating recipe from modifiers") |
63 | 64 |
|
64 | | - # validate Modifiers |
65 | 65 | if isinstance(modifiers, Modifier): |
66 | | - modifiers: List[Modifier] = [modifiers] |
| 66 | + modifiers = [modifiers] |
67 | 67 |
|
68 | 68 | if any(not isinstance(modifier, Modifier) for modifier in modifiers): |
69 | 69 | raise ValueError("modifiers must be a list of Modifier instances") |
70 | 70 |
|
71 | | - recipe_string: str = create_recipe_string_from_modifiers( |
72 | | - modifiers=modifiers, |
73 | | - modifier_group_name=modifier_group_name, |
74 | | - ) |
| 71 | + group_name = modifier_group_name or "default" |
75 | 72 |
|
76 | | - # modifier group name already included in the recipe string |
77 | | - return cls.create_instance(path_or_modifiers=recipe_string) |
| 73 | + recipe_modifiers: List[RecipeModifier] = [ |
| 74 | + RecipeModifier( |
| 75 | + type=modifier.__class__.__name__, |
| 76 | + group=group_name, |
| 77 | + args=modifier.model_dump(exclude_unset=True), |
| 78 | + ) |
| 79 | + for modifier in modifiers |
| 80 | + ] |
| 81 | + # assume one stage for modifier instances |
| 82 | + stages: List[RecipeStage] = [ |
| 83 | + RecipeStage(group=group_name, modifiers=recipe_modifiers) |
| 84 | + ] |
| 85 | + recipe = cls() |
| 86 | + recipe.stages = stages |
| 87 | + return recipe |
78 | 88 |
|
79 | 89 | @classmethod |
80 | 90 | def create_instance( |
@@ -652,67 +662,6 @@ def _parse_recipe_from_md(file_path, yaml_str): |
652 | 662 | return yaml_str |
653 | 663 |
|
654 | 664 |
|
655 | | -def create_recipe_string_from_modifiers( |
656 | | - modifiers: List[Modifier], |
657 | | - modifier_group_name: Optional[str] = None, |
658 | | -) -> str: |
659 | | - """ |
660 | | - Create a recipe string from a list of Modifier instances |
661 | | -
|
662 | | - (Note: this pathway assumes there's only one stage in the recipe |
663 | | - associated by the modifier_group_name, if None, a dummy default |
664 | | - group_name will be assigned.) |
665 | | -
|
666 | | - :param modifiers: The list of Modifier instances |
667 | | - :param modifier_group_name: The stage_name of the recipe, |
668 | | - if `oneshot` or `train` the run_type of the recipe will be |
669 | | - inferred from the modifier_group_name, if None, a dummy default |
670 | | - group_name will be assigned. |
671 | | - :return: A string in yaml format from which the recipe can be created |
672 | | - """ |
673 | | - |
674 | | - # Recipe(s) are yaml/json strings of the following format: |
675 | | - # run_type_stage: # should contain oneshot/train |
676 | | - # modifiers: |
677 | | - # ModifierTypeOne: |
678 | | - # start: 0.0 |
679 | | - # end: 2.0 |
680 | | - # ... |
681 | | - # ModifierTypeTwo: |
682 | | - # ... |
683 | | - |
684 | | - # Create a recipe string from the modifiers |
685 | | - default_group_name: str = "DEFAULT" |
686 | | - modifier_group_name: str = modifier_group_name or default_group_name |
687 | | - |
688 | | - recipe_dict = { |
689 | | - f"{modifier_group_name}_stage": { |
690 | | - f"{default_group_name}_modifiers": { |
691 | | - modifier.__class__.__name__: modifier.model_dump(exclude_unset=True) |
692 | | - for modifier in modifiers |
693 | | - } |
694 | | - } |
695 | | - } |
696 | | - recipe_str: str = yaml.dump(recipe_dict, sort_keys=False) |
697 | | - return recipe_str |
698 | | - |
699 | | - |
700 | | -def get_modifiers_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]: |
701 | | - group_dict = {} |
702 | | - |
703 | | - for modifier in modifiers: |
704 | | - modifier_type = modifier["type"] |
705 | | - modifier_group = modifier["group"] |
706 | | - |
707 | | - if modifier_group not in group_dict: |
708 | | - group_dict[modifier_group] = [] |
709 | | - |
710 | | - modifier_dict = {modifier_type: modifier["args"]} |
711 | | - group_dict[modifier_group].append(modifier_dict) |
712 | | - |
713 | | - return group_dict |
714 | | - |
715 | | - |
716 | 665 | def get_yaml_serializable_stage_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]: |
717 | 666 | """ |
718 | 667 | This function is used to convert a list of modifiers into a dictionary |
|
0 commit comments