Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ When `tp_size` is greater than 1, the script will automatically load the distrib

#### Customize Draft Model

If you want to change the draft model configuration, you can write your own configuration file and pass its path to the `--draft-model-config` argument. If you wish to serve your customized draft model with SGLang, make sure you implement the draft model in SGLang as well and the architecture name must match. To implement your own draft model, you can create a new class and inherit it from the `Eagle3DraftModel` class in the `specforge.modeling.draft.base.py` file.
If you want to change the draft model configuration, you can write your own configuration file and pass its path to the `--draft-model-config` argument. Or, if you do not provide the `--draft-model-config` argument, the script will automatically generate the draft model configuration based on the target model configuration. If you wish to serve your customized draft model with SGLang, make sure you implement the draft model in SGLang as well and the architecture name must match. To implement your own draft model, you can create a new class and inherit it from the `Eagle3DraftModel` class in the `specforge.modeling.draft.base.py` file.


```python
Expand Down
27 changes: 25 additions & 2 deletions scripts/train_eagle3_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker, get_tracker_class
from specforge.utils import (
create_draft_config_from_target,
get_last_checkpoint,
print_on_rank0,
print_with_rank,
Expand All @@ -37,7 +38,12 @@ def parse_args():

# add model-related arguments
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--draft-model-config", type=str, required=True)
parser.add_argument(
"--draft-model-config",
type=str,
required=False,
help="Draft model config path. If not provided, will auto-generate from target model.",
)
parser.add_argument(
"--embedding-key",
type=str,
Expand Down Expand Up @@ -213,7 +219,24 @@ def main():
target_head = target_head.eval().cuda().to(torch.bfloat16)
print_with_rank("Initialized target head")

draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)
# Handle draft model config
if args.draft_model_config is None:
print_with_rank(
"No draft model config provided, auto-generating from target model..."
)
# Auto-generate and save config file
auto_config_path = create_draft_config_from_target(
target_model_path=args.target_model_path, cache_dir=args.cache_dir
)
draft_model_config = AutoDraftModelConfig.from_file(auto_config_path)
print_with_rank(
f"Auto-generated draft model config saved to: {auto_config_path}"
)
else:
# Use provided config file
draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)
print_with_rank(f"Using provided draft model config: {args.draft_model_config}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for handling the draft model configuration is duplicated in scripts/train_eagle3_online.py (lines 170-185). To improve maintainability and follow the DRY (Don't Repeat Yourself) principle, consider refactoring this block into a shared utility function.

For example, you could create a function load_or_create_draft_config(args) in a utility module that encapsulates this logic and returns the draft_model_config.


if draft_model_last_checkpoint:
draft_model = (
AutoEagle3DraftModel.from_pretrained(
Expand Down
29 changes: 26 additions & 3 deletions scripts/train_eagle3_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker, get_tracker_class
from specforge.utils import (
create_draft_config_from_target,
get_last_checkpoint,
print_on_rank0,
print_with_rank,
Expand All @@ -41,7 +42,12 @@ def parse_args():

# add model-related arguments
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--draft-model-config", type=str, required=True)
parser.add_argument(
"--draft-model-config",
type=str,
required=False,
help="Draft model config path. If not provided, will auto-generate from target model.",
)
parser.add_argument(
"--embedding-key",
type=str,
Expand Down Expand Up @@ -205,8 +211,24 @@ def main():
parser.error(f"Unknown tracker: {args.report_to}")

tracker = create_tracker(args, args.output_dir)
# load draft model config
draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)

# Handle draft model config
if args.draft_model_config is None:
print_with_rank(
"No draft model config provided, auto-generating from target model..."
)
# Auto-generate and save config file
auto_config_path = create_draft_config_from_target(
target_model_path=args.target_model_path, cache_dir=args.cache_dir
)
draft_model_config = AutoDraftModelConfig.from_file(auto_config_path)
print_with_rank(
f"Auto-generated draft model config saved to: {auto_config_path}"
)
else:
# Use provided config file
draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)
print_with_rank(f"Using provided draft model config: {args.draft_model_config}")

# detecting last ckpt for draft model
draft_model_last_checkpoint = None
Expand Down Expand Up @@ -247,6 +269,7 @@ def main():
.cuda()
)
print_with_rank("Initialized target model")

# load model with resume
if draft_model_last_checkpoint:
draft_model = (
Expand Down
144 changes: 143 additions & 1 deletion specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import torch.distributed as dist
from transformers import PretrainedConfig
from transformers import AutoConfig, PretrainedConfig


@contextmanager
Expand Down Expand Up @@ -74,3 +74,145 @@ def get_last_checkpoint(folder):
folder,
max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])),
)


def generate_draft_model_config(
target_model_path: str, template_config_path: str = None, cache_dir: str = None
):
"""
Auto-generate draft model config based on target model parameters aligned with template config

Args:
target_model_path (str): Path to the target model
template_config_path (str, optional): Template config file path, defaults to llama3-8B-eagle3.json
cache_dir (str, optional): Cache directory

Returns:
dict: Generated draft model config dictionary
"""
# Get target model config
target_config = AutoConfig.from_pretrained(target_model_path, cache_dir=cache_dir)

# If no template specified, use default llama3-8B-eagle3.json
if template_config_path is None:
# Use the script execution directory as base
import sys

script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
project_root = os.path.dirname(script_dir) # Go up one level from scripts/
template_config_path = os.path.join(
project_root, "configs", "llama3-8B-eagle3.json"
)
Comment on lines +97 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using sys.argv[0] to determine the project root can be unreliable as it depends on how and from where the script is executed. A more robust approach is to use __file__ to get a path relative to the current module's location, which is independent of the working directory.

    if template_config_path is None:
        # Use the path of this file to robustly locate the project root
        # Assuming this file is in specforge/, two levels up is the project root.
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        template_config_path = os.path.join(
            project_root, "configs", "llama3-8B-eagle3.json"
        )


# Read template config
with open(template_config_path, "r") as f:
draft_config = json.load(f)

# Adjust architecture config based on target model type
if hasattr(target_config, "model_type"):
# Default to llama architecture
draft_config["model_type"] = "llama"

# Align key parameters
param_mappings = {
"vocab_size": "vocab_size",
"hidden_size": "hidden_size",
"num_attention_heads": "num_attention_heads",
"num_key_value_heads": "num_key_value_heads",
"intermediate_size": "intermediate_size",
"max_position_embeddings": "max_position_embeddings",
"rms_norm_eps": "rms_norm_eps",
"hidden_act": "hidden_act",
"bos_token_id": "bos_token_id",
"eos_token_id": "eos_token_id",
"torch_dtype": "torch_dtype",
}

# Copy parameters from target model to draft config
for target_param, draft_param in param_mappings.items():
if hasattr(target_config, target_param):
value = getattr(target_config, target_param)
# Special handling for torch_dtype to make it JSON serializable
if target_param == "torch_dtype":
if hasattr(value, "__name__") or "torch" in str(type(value)):
value_str = str(value)
if "torch." in value_str:
value = value_str.split("torch.")[
-1
] # Convert torch.float16 to 'float16'
else:
value = value_str.split(".")[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for converting a torch.dtype to its string representation is complex and can be simplified for better readability and maintainability. Using isinstance to check the type and str.replace for conversion is more direct.

            if target_param == "torch_dtype" and isinstance(value, torch.dtype):
                value = str(value).replace("torch.", "")

draft_config[draft_param] = value

# Special handling for some parameters
# Ensure num_hidden_layers is always 1 (EAGLE3 feature)
draft_config["num_hidden_layers"] = 1

# Keep some fixed draft model specific parameters
draft_config["tie_word_embeddings"] = False
draft_config["use_cache"] = True

# If template doesn't have draft_vocab_size, set default
if "draft_vocab_size" not in draft_config:
draft_config["draft_vocab_size"] = 32000 # Default value

return draft_config


def save_draft_model_config(config_dict: dict, output_path: str):
"""
Save draft model config to file

Args:
config_dict (dict): Config dictionary
output_path (str): Output file path
"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)

print(f"Draft model config saved to: {output_path}")


def create_draft_config_from_target(
target_model_path: str,
output_dir: str = None,
template_config_path: str = None,
cache_dir: str = None,
):
"""
Convenient function to create draft model config file from target model

Args:
target_model_path (str): Target model path
output_dir (str, optional): Output directory, defaults to configs folder in current directory
template_config_path (str, optional): Template config path
cache_dir (str, optional): Cache directory

Returns:
str: Generated config file path
"""
# Generate config
config_dict = generate_draft_model_config(
target_model_path, template_config_path, cache_dir
)

# Determine output path
if output_dir is None:
# Use the script execution directory as base
import sys

script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
project_root = os.path.dirname(script_dir) # Go up one level from scripts/
output_dir = os.path.join(project_root, "configs")
Comment on lines +202 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using sys.argv[0] to determine the project root can be unreliable as it depends on how and from where the script is executed. A more robust approach is to use __file__ to get a path relative to the current module's location, which is independent of the working directory.

    if output_dir is None:
        # Use the path of this file to robustly locate the project root
        # Assuming this file is in specforge/, two levels up is the project root.
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        output_dir = os.path.join(project_root, "configs")


# Extract model name from model path
model_name = target_model_path.split("/")[-1].lower()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using split('/')[-1] to extract the model name is not robust for all path formats, especially local paths or paths with trailing slashes. os.path.basename combined with os.path.normpath provides a more reliable way to handle different path structures.

    model_name = os.path.basename(os.path.normpath(target_model_path)).lower()

output_filename = f"{model_name}-eagle3-auto.json"
output_path = os.path.join(output_dir, output_filename)

# Save config
save_draft_model_config(config_dict, output_path)

return output_path