Skip to content

Commit 6377dce

Browse files
authored
llama4: Avoid staticmethod nested graph break for MoE compile (#1565)
This nested graph break is particularly bad, it is falling back the scaled grouped mm ops to eager Test plan: `NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" tlp ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_d egree=2 --training.compile`
1 parent aeb3a4b commit 6377dce

File tree

1 file changed

+75
-79
lines changed

1 file changed

+75
-79
lines changed

torchtitan/models/moe.py

Lines changed: 75 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,79 @@ class MoEArgs:
3131
load_balance_coeff: float | None = 1e-3
3232

3333

34+
# TODO: keeping this for-loop implementation for comparison
35+
# and readability, may remove later
36+
@expert_parallel
37+
def _run_experts_for_loop(
38+
w1: torch.Tensor,
39+
w2: torch.Tensor,
40+
w3: torch.Tensor,
41+
x: torch.Tensor,
42+
num_tokens_per_expert: torch.Tensor | None = None,
43+
) -> torch.Tensor:
44+
if num_tokens_per_expert is not None:
45+
# NOTE: this would incur a synchronization between device and host
46+
num_tokens_per_expert = num_tokens_per_expert.tolist()
47+
48+
# side-effect code due to the usage of generate_permute_indices
49+
num_padding = x.shape[0] - sum(num_tokens_per_expert)
50+
51+
# a tuple of tensors indexed by experts
52+
# each with shape (tokens_per_expert(varying), dim)
53+
x = torch.split(
54+
x[: sum(num_tokens_per_expert)],
55+
split_size_or_sections=num_tokens_per_expert,
56+
dim=0,
57+
)
58+
out_experts_splits = []
59+
for expert_idx, x_expert in enumerate(x):
60+
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
61+
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
62+
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
63+
# h shape (tokens_per_expert(varying), dim)
64+
out_experts_splits.append(h)
65+
out = torch.cat(out_experts_splits, dim=0)
66+
67+
# side-effect code due to the usage of generate_permute_indices
68+
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
69+
else:
70+
# x shape (num_experts, tokens_per_expert, dim)
71+
h = F.silu(torch.bmm(x, w1.transpose(-2, -1)))
72+
h = h * torch.bmm(x, w3.transpose(-2, -1))
73+
# out shape (num_experts, tokens_per_expert, dim)
74+
out = torch.bmm(h, w2.transpose(-2, -1))
75+
76+
return out
77+
78+
79+
@expert_parallel
80+
def _run_experts_grouped_mm(
81+
w1: torch.Tensor,
82+
w2: torch.Tensor,
83+
w3: torch.Tensor,
84+
x: torch.Tensor,
85+
num_tokens_per_expert: torch.Tensor | None = None,
86+
) -> torch.Tensor:
87+
if num_tokens_per_expert is not None:
88+
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
89+
# grouped mm between a 2D tensor and a 3D tensor
90+
assert x.dim() == 2
91+
else:
92+
offsets = None
93+
# fall back to regular bmm between 3D tensors
94+
assert x.dim() == 3
95+
96+
h = F.silu(
97+
torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets)
98+
)
99+
h = h * torch._grouped_mm(
100+
x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets
101+
)
102+
out = torch._grouped_mm(h, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x)
103+
104+
return out
105+
106+
34107
class GroupedExperts(nn.Module):
35108
def __init__(
36109
self,
@@ -52,91 +125,14 @@ def forward(
52125
num_tokens_per_expert: torch.Tensor | None = None,
53126
) -> torch.Tensor:
54127
if self.use_grouped_mm:
55-
return GroupedExperts._run_experts_grouped_mm(
128+
return _run_experts_grouped_mm(
56129
self.w1, self.w2, self.w3, x, num_tokens_per_expert
57130
)
58131
else:
59-
return GroupedExperts._run_experts_for_loop(
132+
return _run_experts_for_loop(
60133
self.w1, self.w2, self.w3, x, num_tokens_per_expert
61134
)
62135

63-
# TODO: keeping this for-loop implementation for comparison
64-
# and readability, may remove later
65-
@expert_parallel
66-
@staticmethod
67-
def _run_experts_for_loop(
68-
w1: torch.Tensor,
69-
w2: torch.Tensor,
70-
w3: torch.Tensor,
71-
x: torch.Tensor,
72-
num_tokens_per_expert: torch.Tensor | None = None,
73-
) -> torch.Tensor:
74-
if num_tokens_per_expert is not None:
75-
# NOTE: this would incur a synchronization between device and host
76-
num_tokens_per_expert = num_tokens_per_expert.tolist()
77-
78-
# side-effect code due to the usage of generate_permute_indices
79-
num_padding = x.shape[0] - sum(num_tokens_per_expert)
80-
81-
# a tuple of tensors indexed by experts
82-
# each with shape (tokens_per_expert(varying), dim)
83-
x = torch.split(
84-
x[: sum(num_tokens_per_expert)],
85-
split_size_or_sections=num_tokens_per_expert,
86-
dim=0,
87-
)
88-
out_experts_splits = []
89-
for expert_idx, x_expert in enumerate(x):
90-
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
91-
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
92-
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
93-
# h shape (tokens_per_expert(varying), dim)
94-
out_experts_splits.append(h)
95-
out = torch.cat(out_experts_splits, dim=0)
96-
97-
# side-effect code due to the usage of generate_permute_indices
98-
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
99-
else:
100-
# x shape (num_experts, tokens_per_expert, dim)
101-
h = F.silu(torch.bmm(x, w1.transpose(-2, -1)))
102-
h = h * torch.bmm(x, w3.transpose(-2, -1))
103-
# out shape (num_experts, tokens_per_expert, dim)
104-
out = torch.bmm(h, w2.transpose(-2, -1))
105-
106-
return out
107-
108-
@expert_parallel
109-
@staticmethod
110-
def _run_experts_grouped_mm(
111-
w1: torch.Tensor,
112-
w2: torch.Tensor,
113-
w3: torch.Tensor,
114-
x: torch.Tensor,
115-
num_tokens_per_expert: torch.Tensor | None = None,
116-
) -> torch.Tensor:
117-
if num_tokens_per_expert is not None:
118-
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
119-
# grouped mm between a 2D tensor and a 3D tensor
120-
assert x.dim() == 2
121-
else:
122-
offsets = None
123-
# fall back to regular bmm between 3D tensors
124-
assert x.dim() == 3
125-
126-
h = F.silu(
127-
torch._grouped_mm(
128-
x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets
129-
)
130-
)
131-
h = h * torch._grouped_mm(
132-
x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets
133-
)
134-
out = torch._grouped_mm(
135-
h, w2.bfloat16().transpose(-2, -1), offs=offsets
136-
).type_as(x)
137-
138-
return out
139-
140136
def init_weights(self, init_std: float):
141137
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
142138
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)

0 commit comments

Comments
 (0)