Skip to content

Conversation

rakkit
Copy link
Contributor

@rakkit rakkit commented Aug 25, 2025

This is a distributed version of Scion or Modular Norm, muon is considered to be a variant of this by using explicit AdamW for LLM's embedding/output.

Works:

  • Embedding/head
  • FSDP/DP/TP/EP/CP/PP parameters
  • Bias (mainly for norm)
  • weight decay

Missing

  • Conv

Need some extra work to adjust the EP changes for EP-[shard(1)] and ETP? And it's not working for multiple shared_experts.

CC @janEbert @ofivite

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 25, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR on cutting-edge features!

I didn't read the papers so please forgive me if what I comment doesn't make sense.

I guess for "core" changes such as this one on optimizers, the recommended path is to first land in pytorch/pytorch, and then expose minimal interfaces to torchtitan. torchtitan shouldn't be a place to host core features.

cc @janeyx99 on interesting optimizer work

@rakkit
Copy link
Contributor Author

rakkit commented Sep 1, 2025

update: the init refactor is done, you can check the diff here that this optimizer is not aggressive at all, and not too much hack on 'components.optimizer'. Though we need add the configs.

i added the debug configs, so can try it now. CONFIG_FILE="./torchtitan/experiments/distributed_scion/train_configs/debug_model.toml" NGPU=4 ./run_train.sh
--compile.enable`

there is a "clean" version where I removed the code for logging, which can make the code easier to read and understand.

a random test (be aware scion here using a higher LR, muon/scion allows us to train a model with high LR)

Distributed-Scion (LR=0.1, disable gradient norm clipping)
[rank0]:[titan] 2025-09-01 18:32:41,146 - root - INFO - step:  1  loss:  8.1319  grad_norm:  2.8994  memory:  1.85GiB(1.99%)  tps: 1,205  tflops: 0.12  mfu: 0.01%
[rank0]:[titan] 2025-09-01 18:32:41,146 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-01 18:32:41,221 - root - INFO - step:  2  loss:  7.2448  grad_norm:  2.8807  memory:  1.86GiB(2.00%)  tps: 219,132  tflops: 22.46  mfu: 2.27%
[rank0]:[titan] 2025-09-01 18:32:41,291 - root - INFO - step:  3  loss:  5.1976  grad_norm:  3.2225  memory:  1.86GiB(2.00%)  tps: 232,809  tflops: 23.86  mfu: 2.41%
[rank0]:[titan] 2025-09-01 18:32:41,363 - root - INFO - step:  4  loss:  4.4230  grad_norm:  2.7686  memory:  1.86GiB(2.00%)  tps: 230,201  tflops: 23.60  mfu: 2.39%
[rank0]:[titan] 2025-09-01 18:32:41,432 - root - INFO - step:  5  loss:  4.1285  grad_norm:  4.4248  memory:  1.86GiB(2.00%)  tps: 235,659  tflops: 24.16  mfu: 2.44%
[rank0]:[titan] 2025-09-01 18:32:41,507 - root - INFO - step:  6  loss:  3.9301  grad_norm:  4.0992  memory:  1.86GiB(2.00%)  tps: 219,604  tflops: 22.51  mfu: 2.28%
[rank0]:[titan] 2025-09-01 18:32:41,577 - root - INFO - step:  7  loss:  3.7481  grad_norm:  4.3226  memory:  1.86GiB(2.00%)  tps: 237,191  tflops: 24.31  mfu: 2.46%
[rank0]:[titan] 2025-09-01 18:32:41,645 - root - INFO - step:  8  loss:  3.4784  grad_norm:  3.3010  memory:  1.86GiB(2.00%)  tps: 238,348  tflops: 24.43  mfu: 2.47%
[rank0]:[titan] 2025-09-01 18:32:41,715 - root - INFO - step:  9  loss:  3.5601  grad_norm:  2.4235  memory:  1.86GiB(2.00%)  tps: 235,945  tflops: 24.18  mfu: 2.45%
[rank0]:[titan] 2025-09-01 18:32:41,790 - root - INFO - step: 10  loss:  3.3342  grad_norm:  2.3503  memory:  1.86GiB(2.00%)  tps: 219,221  tflops: 22.47  mfu: 2.27%


Adamw (LR=8e-4)
[rank0]:[titan] 2025-09-01 18:28:22,954 - root - INFO - step:  1  loss:  8.1141  grad_norm:  2.7299  memory:  1.86GiB(2.00%)  tps: 4,072  tflops: 0.42  mfu: 0.04%
[rank0]:[titan] 2025-09-01 18:28:22,954 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-01 18:28:23,015 - root - INFO - step:  2  loss:  6.8200  grad_norm:  3.5221  memory:  1.87GiB(2.01%)  tps: 268,831  tflops: 27.56  mfu: 2.79%
[rank0]:[titan] 2025-09-01 18:28:23,073 - root - INFO - step:  3  loss:  5.0779  grad_norm:  3.7820  memory:  1.87GiB(2.01%)  tps: 282,964  tflops: 29.00  mfu: 2.93%
[rank0]:[titan] 2025-09-01 18:28:23,133 - root - INFO - step:  4  loss:  4.8205  grad_norm:  2.8645  memory:  1.87GiB(2.01%)  tps: 277,407  tflops: 28.43  mfu: 2.88%
[rank0]:[titan] 2025-09-01 18:28:23,189 - root - INFO - step:  5  loss:  4.4854  grad_norm:  2.4403  memory:  1.87GiB(2.01%)  tps: 292,630  tflops: 30.00  mfu: 3.03%
[rank0]:[titan] 2025-09-01 18:28:23,250 - root - INFO - step:  6  loss:  4.2748  grad_norm:  2.1404  memory:  1.87GiB(2.01%)  tps: 268,434  tflops: 27.52  mfu: 2.78%
[rank0]:[titan] 2025-09-01 18:28:23,305 - root - INFO - step:  7  loss:  4.1164  grad_norm:  1.8821  memory:  1.87GiB(2.01%)  tps: 297,987  tflops: 30.54  mfu: 3.09%
[rank0]:[titan] 2025-09-01 18:28:23,362 - root - INFO - step:  8  loss:  4.0172  grad_norm:  1.9915  memory:  1.87GiB(2.01%)  tps: 291,299  tflops: 29.86  mfu: 3.02%
[rank0]:[titan] 2025-09-01 18:28:23,418 - root - INFO - step:  9  loss:  4.0967  grad_norm:  1.7458  memory:  1.87GiB(2.01%)  tps: 292,345  tflops: 29.97  mfu: 3.03%
[rank0]:[titan] 2025-09-01 18:28:23,478 - root - INFO - step: 10  loss:  3.9069  grad_norm:  1.5619  memory:  1.87GiB(2.01%)  tps: 272,612  tflops: 27.94  mfu: 2.83%```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants