You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[llama4] store expert weights such that we can transpose before grouped mm to have col-major memory layout (#1517)
# Summary
Rather than store experts weights pre-transposed (E, in_dim, out_dim),
we should store the expert weights non-transposed (E, out_dim, in_dim)
then transpose before grouped gemm for (1) compatible dims for gemm, and
(2) column-major memory layout required for right operand in grouped
gemm. Doing this simple transpose (metadata change only) is must more
efficient than doing this [inefficient memory layout transformation
before every GEMM in
fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96).
# Eager Performance
Llama4 debug model with FSDP=8, using config:
```python
"debugmodel": TransformerModelArgs(
dim=5120,
n_layers=4,
n_heads=40,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=2048,
rope_theta=500000,
max_seq_len=10485760,
num_experts=16,
interleave_moe_layer_step=1,
),
```
### bfloat16
With change:
```
=====================================================
Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2147.0
Max Memory Usage: 92.67 GiB
```
Without change:
```
=====================================================
Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 1711.0
Max Memory Usage: 92.67 GiB
```
### fp8 rowwise
With change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh
=====================================================
Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2675.0
Max Memory Usage: 90.35 GiB
```
Without change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh
=====================================================
Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2360.0
Max Memory Usage: 90.35 GiB
```
0 commit comments