Skip to content

Commit 48d8dcd

Browse files
authored
make mx recipe name more generic (#1512)
Summary: Instead of maintaining a mapping in torchtitan with valid mx recipe names, just pass the string recipe directly to torchao. This way torchao can iterate on recipes without any changes to torchtitan to use those recipes. Note that appropriate error messages will be thrown from torchao if user specifies an invalid config name, so there is no need to duplicate them in torchtitan. Test Plan: ```bash with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.print_after_conversion --training.compile --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8_cublas_rceil" ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 43fa980 commit 48d8dcd

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222

2323
from .utils import module_filter_fn
2424

25-
# Maps titan recipe names to torchao mx recipe names
26-
NAME_MAP = {"mxfp8": "mxfp8_cublas"}
27-
2825

2926
class MXConverter(ModelConverter):
3027
"""Converts the linear layers of `model` to `MXLinear`."""
@@ -76,7 +73,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
7673
)
7774

7875
mx_job_config: MX = job_config.mx
79-
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
76+
config = MXLinearConfig.from_recipe_name(mx_job_config.recipe_name)
8077
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[
8178
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
8279
]

torchtitan/config/job_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,11 @@ class MX:
562562
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
563563
"""Temp work around for inductor performance gap"""
564564

565-
recipe_name: Literal["mxfp8"] = "mxfp8"
566-
"""If specified, creates float8 config from recipe name"""
565+
recipe_name: str = "mxfp8_cublas"
566+
"""
567+
If specified, creates MX config from recipe name. See
568+
https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats for more information.
569+
"""
567570

568571
filter_fqns: list[str] = field(default_factory=lambda: ["output"])
569572
"""

0 commit comments

Comments
 (0)