Skip to content

Commit ed288bc

Browse files
[llama4] store expert weights such that we can transpose before grouped mm to have col-major memory layout (#1517)
# Summary Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this [inefficient memory layout transformation before every GEMM in fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96). # Eager Performance Llama4 debug model with FSDP=8, using config: ```python "debugmodel": TransformerModelArgs( dim=5120, n_layers=4, n_heads=40, n_kv_heads=8, ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, ), ``` ### bfloat16 With change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2147.0 Max Memory Usage: 92.67 GiB ``` Without change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 1711.0 Max Memory Usage: 92.67 GiB ``` ### fp8 rowwise With change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2675.0 Max Memory Usage: 90.35 GiB ``` Without change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2360.0 Max Memory Usage: 90.35 GiB ```
1 parent 004162a commit ed288bc

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,21 @@ def set_token_group_alignment_size_m(
5454
# implementation of Tensor Parallel for the GroupedExperts in MoE
5555
class TensorParallel(ParallelStyle):
5656
def _partition_fn(self, name, module, device_mesh):
57+
# w1 shape = (experts, out_dim, in_dim)
5758
module.register_parameter(
58-
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
59+
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)]))
5960
) # Column-wise sharding
61+
62+
# w2 shape = (experts, in_dim, out_dim)
6063
module.register_parameter(
6164
"w2",
62-
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
65+
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(2)])),
6366
) # Row-wise sharding
67+
68+
# w3 shape = (experts, out_dim, in_dim)
6469
module.register_parameter(
6570
"w3",
66-
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])),
71+
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(1)])),
6772
) # Column-wise sharding
6873

6974
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
@@ -223,17 +228,22 @@ def _token_dispatch(self, mod, inputs, device_mesh):
223228
return super()._token_dispatch(mod, inputs, self.ep_mesh)
224229

225230
def _partition_fn_2d(self, name, mod, ep_tp_mesh):
231+
# w1 shape = (experts, out_dim, in_dim)
226232
mod.register_parameter(
227233
"w1",
228-
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])),
234+
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])),
229235
) # Column-wise sharding
236+
237+
# w2 shape = (experts, in_dim, out_dim)
230238
mod.register_parameter(
231239
"w2",
232-
nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(1)])),
240+
nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])),
233241
) # Row-wise sharding
242+
243+
# w3 shape = (experts, out_dim, in_dim)
234244
mod.register_parameter(
235245
"w3",
236-
nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(2)])),
246+
nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])),
237247
) # Column-wise sharding
238248

239249
def _token_combine(self, mod, routed_output, device_mesh):

torchtitan/experiments/llama4/model/moe.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def __init__(
2323
):
2424
super().__init__()
2525
self.num_experts = num_experts
26-
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
27-
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
28-
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26+
self.w1 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
27+
self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
28+
self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
2929
self.use_grouped_mm = use_grouped_mm
3030

3131
def forward(
@@ -69,9 +69,9 @@ def _run_experts_for_loop(
6969
)
7070
out_experts_splits = []
7171
for expert_idx, x_expert in enumerate(x):
72-
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
73-
h = h * torch.matmul(x_expert, w3[expert_idx])
74-
h = torch.matmul(h, w2[expert_idx])
72+
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
73+
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
74+
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
7575
# h shape (tokens_per_expert(varying), dim)
7676
out_experts_splits.append(h)
7777
out = torch.cat(out_experts_splits, dim=0)
@@ -80,10 +80,10 @@ def _run_experts_for_loop(
8080
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
8181
else:
8282
# x shape (num_experts, tokens_per_expert, dim)
83-
h = F.silu(torch.bmm(x, w1))
84-
h = h * torch.bmm(x, w3)
83+
h = F.silu(torch.bmm(x, w1.transpose(-2, -1)))
84+
h = h * torch.bmm(x, w3.transpose(-2, -1))
8585
# out shape (num_experts, tokens_per_expert, dim)
86-
out = torch.bmm(h, w2)
86+
out = torch.bmm(h, w2.transpose(-2, -1))
8787

8888
return out
8989

@@ -105,9 +105,17 @@ def _run_experts_grouped_mm(
105105
# fall back to regular bmm between 3D tensors
106106
assert x.dim() == 3
107107

108-
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
109-
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
110-
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
108+
h = F.silu(
109+
torch._grouped_mm(
110+
x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets
111+
)
112+
)
113+
h = h * torch._grouped_mm(
114+
x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets
115+
)
116+
out = torch._grouped_mm(
117+
h, w2.bfloat16().transpose(-2, -1), offs=offsets
118+
).type_as(x)
111119

112120
return out
113121

0 commit comments

Comments
 (0)