-
Notifications
You must be signed in to change notification settings - Fork 182
Add automatic configuration generation for the draft model. #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
2d51e95
7d5d607
c1bccee
f712564
0be2736
48ad58c
cef0654
1fa5ca7
6e63c7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from transformers import PretrainedConfig | ||
| from transformers import AutoConfig, PretrainedConfig | ||
|
|
||
|
|
||
| @contextmanager | ||
|
|
@@ -74,3 +74,145 @@ def get_last_checkpoint(folder): | |
| folder, | ||
| max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])), | ||
| ) | ||
|
|
||
|
|
||
| def generate_draft_model_config( | ||
| target_model_path: str, template_config_path: str = None, cache_dir: str = None | ||
| ): | ||
| """ | ||
| Auto-generate draft model config based on target model parameters aligned with template config | ||
|
|
||
| Args: | ||
| target_model_path (str): Path to the target model | ||
| template_config_path (str, optional): Template config file path, defaults to llama3-8B-eagle3.json | ||
| cache_dir (str, optional): Cache directory | ||
|
|
||
| Returns: | ||
| dict: Generated draft model config dictionary | ||
| """ | ||
| # Get target model config | ||
| target_config = AutoConfig.from_pretrained(target_model_path, cache_dir=cache_dir) | ||
|
|
||
| # If no template specified, use default llama3-8B-eagle3.json | ||
| if template_config_path is None: | ||
| # Use the script execution directory as base | ||
| import sys | ||
|
|
||
| script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) | ||
| project_root = os.path.dirname(script_dir) # Go up one level from scripts/ | ||
| template_config_path = os.path.join( | ||
| project_root, "configs", "llama3-8B-eagle3.json" | ||
| ) | ||
|
Comment on lines
+97
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using if template_config_path is None:
# Use the path of this file to robustly locate the project root
# Assuming this file is in specforge/, two levels up is the project root.
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
template_config_path = os.path.join(
project_root, "configs", "llama3-8B-eagle3.json"
) |
||
|
|
||
| # Read template config | ||
| with open(template_config_path, "r") as f: | ||
| draft_config = json.load(f) | ||
|
|
||
| # Adjust architecture config based on target model type | ||
| if hasattr(target_config, "model_type"): | ||
| # Default to llama architecture | ||
| draft_config["model_type"] = "llama" | ||
|
|
||
| # Align key parameters | ||
| param_mappings = { | ||
| "vocab_size": "vocab_size", | ||
| "hidden_size": "hidden_size", | ||
| "num_attention_heads": "num_attention_heads", | ||
| "num_key_value_heads": "num_key_value_heads", | ||
| "intermediate_size": "intermediate_size", | ||
| "max_position_embeddings": "max_position_embeddings", | ||
| "rms_norm_eps": "rms_norm_eps", | ||
| "hidden_act": "hidden_act", | ||
| "bos_token_id": "bos_token_id", | ||
| "eos_token_id": "eos_token_id", | ||
| "torch_dtype": "torch_dtype", | ||
| } | ||
|
|
||
| # Copy parameters from target model to draft config | ||
| for target_param, draft_param in param_mappings.items(): | ||
| if hasattr(target_config, target_param): | ||
| value = getattr(target_config, target_param) | ||
| # Special handling for torch_dtype to make it JSON serializable | ||
| if target_param == "torch_dtype": | ||
| if hasattr(value, "__name__") or "torch" in str(type(value)): | ||
| value_str = str(value) | ||
| if "torch." in value_str: | ||
| value = value_str.split("torch.")[ | ||
| -1 | ||
| ] # Convert torch.float16 to 'float16' | ||
| else: | ||
| value = value_str.split(".")[-1] | ||
|
||
| draft_config[draft_param] = value | ||
|
|
||
| # Special handling for some parameters | ||
| # Ensure num_hidden_layers is always 1 (EAGLE3 feature) | ||
| draft_config["num_hidden_layers"] = 1 | ||
|
|
||
| # Keep some fixed draft model specific parameters | ||
| draft_config["tie_word_embeddings"] = False | ||
| draft_config["use_cache"] = True | ||
|
|
||
| # If template doesn't have draft_vocab_size, set default | ||
| if "draft_vocab_size" not in draft_config: | ||
| draft_config["draft_vocab_size"] = 32000 # Default value | ||
|
|
||
| return draft_config | ||
|
|
||
|
|
||
| def save_draft_model_config(config_dict: dict, output_path: str): | ||
| """ | ||
| Save draft model config to file | ||
|
|
||
| Args: | ||
| config_dict (dict): Config dictionary | ||
| output_path (str): Output file path | ||
| """ | ||
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | ||
|
|
||
| with open(output_path, "w", encoding="utf-8") as f: | ||
| json.dump(config_dict, f, indent=2, ensure_ascii=False) | ||
|
|
||
| print(f"Draft model config saved to: {output_path}") | ||
|
|
||
|
|
||
| def create_draft_config_from_target( | ||
| target_model_path: str, | ||
| output_dir: str = None, | ||
| template_config_path: str = None, | ||
| cache_dir: str = None, | ||
| ): | ||
| """ | ||
| Convenient function to create draft model config file from target model | ||
|
|
||
| Args: | ||
| target_model_path (str): Target model path | ||
| output_dir (str, optional): Output directory, defaults to configs folder in current directory | ||
| template_config_path (str, optional): Template config path | ||
| cache_dir (str, optional): Cache directory | ||
|
|
||
| Returns: | ||
| str: Generated config file path | ||
| """ | ||
| # Generate config | ||
| config_dict = generate_draft_model_config( | ||
| target_model_path, template_config_path, cache_dir | ||
| ) | ||
|
|
||
| # Determine output path | ||
| if output_dir is None: | ||
| # Use the script execution directory as base | ||
| import sys | ||
|
|
||
| script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) | ||
| project_root = os.path.dirname(script_dir) # Go up one level from scripts/ | ||
| output_dir = os.path.join(project_root, "configs") | ||
|
Comment on lines
+202
to
+208
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using if output_dir is None:
# Use the path of this file to robustly locate the project root
# Assuming this file is in specforge/, two levels up is the project root.
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
output_dir = os.path.join(project_root, "configs") |
||
|
|
||
| # Extract model name from model path | ||
| model_name = target_model_path.split("/")[-1].lower() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using model_name = os.path.basename(os.path.normpath(target_model_path)).lower() |
||
| output_filename = f"{model_name}-eagle3-auto.json" | ||
| output_path = os.path.join(output_dir, output_filename) | ||
|
|
||
| # Save config | ||
| save_draft_model_config(config_dict, output_path) | ||
|
|
||
| return output_path | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for handling the draft model configuration is duplicated in
scripts/train_eagle3_online.py(lines 170-185). To improve maintainability and follow the DRY (Don't Repeat Yourself) principle, consider refactoring this block into a shared utility function.For example, you could create a function
load_or_create_draft_config(args)in a utility module that encapsulates this logic and returns thedraft_model_config.