Skip to content

Add automatic configuration generation for the draft model.#167

Merged
sleepcoo merged 9 commits intosgl-project:mainfrom
ZhengHSI:add_auto_draft_config
Aug 27, 2025
Merged

Add automatic configuration generation for the draft model.#167
sleepcoo merged 9 commits intosgl-project:mainfrom
ZhengHSI:add_auto_draft_config

Conversation

@ZhengHSI
Copy link
Collaborator

Motivation

Add automatic configuration generation for the draft model.

Modifications

When the --draft-model-config argument is not provided, the script will automatically generate the draft model configuration based on the target model’s configuration.

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

Co-authored-by: lukec 118525388+sleepcoo@users.noreply.github.com

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ZhengHSI, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly streamlines the setup process for training Eagle3 models by introducing automatic draft model configuration generation. Previously, users were required to manually provide a draft model configuration file. Now, if this file is not specified, the system intelligently creates one by aligning parameters from the target model, thereby reducing manual overhead and potential configuration errors. This enhancement is reflected in both offline and online training scripts and is supported by new utility functions for configuration management.

Highlights

  • Automatic Draft Model Configuration: The primary change introduces an automatic configuration generation mechanism for draft models. If the --draft-model-config argument is not explicitly provided during training, the system will now automatically derive and create a suitable draft model configuration based on the target model's configuration.
  • Optional --draft-model-config Argument: The --draft-model-config argument in both train_eagle3_offline.py and train_eagle3_online.py scripts has been updated to be optional. This enhances usability by reducing the need for manual configuration file creation.
  • New Configuration Utility Functions: New utility functions (generate_draft_model_config, save_draft_model_config, create_draft_config_from_target) have been added to specforge/utils.py to encapsulate the logic for creating, saving, and managing these auto-generated configurations.
  • Documentation Update: The README.md documentation has been updated to reflect the new behavior, guiding users on how to leverage the automatic configuration generation feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a convenient feature for automatically generating draft model configurations. The implementation is mostly solid, but there are opportunities to improve robustness and maintainability. Specifically, I've suggested changes to make file path handling more reliable by using __file__ instead of sys.argv[0], simplifying some logic, and refactoring duplicated code into a shared utility function. These changes will make the new feature more robust and the codebase easier to maintain.

Comment on lines +92 to +100
if template_config_path is None:
# Use the script execution directory as base
import sys

script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
project_root = os.path.dirname(script_dir) # Go up one level from scripts/
template_config_path = os.path.join(
project_root, "configs", "llama3-8B-eagle3.json"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using sys.argv[0] to determine the project root can be unreliable as it depends on how and from where the script is executed. A more robust approach is to use __file__ to get a path relative to the current module's location, which is independent of the working directory.

    if template_config_path is None:
        # Use the path of this file to robustly locate the project root
        # Assuming this file is in specforge/, two levels up is the project root.
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        template_config_path = os.path.join(
            project_root, "configs", "llama3-8B-eagle3.json"
        )

Comment on lines +197 to +203
if output_dir is None:
# Use the script execution directory as base
import sys

script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
project_root = os.path.dirname(script_dir) # Go up one level from scripts/
output_dir = os.path.join(project_root, "configs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using sys.argv[0] to determine the project root can be unreliable as it depends on how and from where the script is executed. A more robust approach is to use __file__ to get a path relative to the current module's location, which is independent of the working directory.

    if output_dir is None:
        # Use the path of this file to robustly locate the project root
        # Assuming this file is in specforge/, two levels up is the project root.
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        output_dir = os.path.join(project_root, "configs")

Comment on lines 177 to 192
if args.draft_model_config is None:
print_with_rank(
"No draft model config provided, auto-generating from target model..."
)
# Auto-generate and save config file
auto_config_path = create_draft_config_from_target(
target_model_path=args.target_model_path, cache_dir=args.cache_dir
)
draft_model_config = AutoDraftModelConfig.from_file(auto_config_path)
print_with_rank(
f"Auto-generated draft model config saved to: {auto_config_path}"
)
else:
# Use provided config file
draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)
print_with_rank(f"Using provided draft model config: {args.draft_model_config}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for handling the draft model configuration is duplicated in scripts/train_eagle3_online.py (lines 170-185). To improve maintainability and follow the DRY (Don't Repeat Yourself) principle, consider refactoring this block into a shared utility function.

For example, you could create a function load_or_create_draft_config(args) in a utility module that encapsulates this logic and returns the draft_model_config.

Comment on lines 131 to 139
if target_param == "torch_dtype":
if hasattr(value, "__name__") or "torch" in str(type(value)):
value_str = str(value)
if "torch." in value_str:
value = value_str.split("torch.")[
-1
] # Convert torch.float16 to 'float16'
else:
value = value_str.split(".")[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for converting a torch.dtype to its string representation is complex and can be simplified for better readability and maintainability. Using isinstance to check the type and str.replace for conversion is more direct.

            if target_param == "torch_dtype" and isinstance(value, torch.dtype):
                value = str(value).replace("torch.", "")

output_dir = os.path.join(project_root, "configs")

# Extract model name from model path
model_name = target_model_path.split("/")[-1].lower()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using split('/')[-1] to extract the model name is not robust for all path formats, especially local paths or paths with trailing slashes. os.path.basename combined with os.path.normpath provides a more reliable way to handle different path structures.

    model_name = os.path.basename(os.path.normpath(target_model_path)).lower()

@sleepcoo sleepcoo merged commit f346bb9 into sgl-project:main Aug 27, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants