Skip to content

Commit 1233902

Browse files
committed
[dsv3] patch graph break fix, works up until sharding rules
1 parent 91c5639 commit 1233902

File tree

2 files changed

+81
-83
lines changed

2 files changed

+81
-83
lines changed

torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def input_fn():
4444
return (
4545
torch.randint(
4646
0,
47-
# job_config.training.vocab_size,
48-
model.vocab_size,
47+
model.model_args.vocab_size,
4948
(global_batch_size, job_config.training.seq_len),
5049
device=torch.device("cuda"),
5150
),
@@ -63,23 +62,24 @@ def input_fn():
6362
# lambda bucket_idx: 1000 / parallel_dims.tp
6463
# )
6564

66-
# bail out
67-
return model
68-
6965
# if job_config.experimental.autop_force_bf16:
7066
# logger.info("Forcing bf16 on model")
7167
# model = model.bfloat16()
7268

7369
# param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
7470
# reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce]
7571
# mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
76-
# with AutoParallel(
77-
# model,
78-
# input_fn,
79-
# world_mesh,
80-
# mp_policy=mp_policy,
81-
# compile=job_config.training.compile,
82-
# ) as autop:
72+
mp_policy = None
73+
with AutoParallel(
74+
model,
75+
input_fn,
76+
world_mesh,
77+
mp_policy=mp_policy,
78+
compile=job_config.training.compile,
79+
) as autop:
80+
# currently errors due to missing sharding prop rules
81+
torch.distributed.breakpoint()
82+
8383
# autop.add_parameter_memory_constraint(low=None, high=None)
8484

8585
# possible_input_shardings = {

torchtitan/models/deepseek_v3/model/moe.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,73 @@ def init_weights(self, init_std: float = 0.02):
4848
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
4949

5050

51+
# TODO: keeping this for-loop implementation for comparison
52+
# and readability, may remove later
53+
@expert_parallel
54+
def _run_experts_for_loop(
55+
w1: torch.Tensor,
56+
w2: torch.Tensor,
57+
w3: torch.Tensor,
58+
x: torch.Tensor,
59+
num_tokens_per_expert: torch.Tensor | None = None,
60+
) -> torch.Tensor:
61+
if num_tokens_per_expert is not None:
62+
# NOTE: this would incur a synchronization between device and host
63+
num_tokens_per_expert = num_tokens_per_expert.tolist()
64+
65+
# side-effect code due to the usage of generate_permute_indices
66+
num_padding = x.shape[0] - sum(num_tokens_per_expert)
67+
68+
# a tuple of tensors indexed by experts
69+
# each with shape (tokens_per_expert(varying), dim)
70+
x = torch.split(
71+
x[: sum(num_tokens_per_expert)],
72+
split_size_or_sections=num_tokens_per_expert,
73+
dim=0,
74+
)
75+
out_experts_splits = []
76+
for expert_idx, x_expert in enumerate(x):
77+
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
78+
h = h * torch.matmul(x_expert, w3[expert_idx])
79+
h = torch.matmul(h, w2[expert_idx])
80+
# h shape (tokens_per_expert(varying), dim)
81+
out_experts_splits.append(h)
82+
out = torch.cat(out_experts_splits, dim=0)
83+
84+
# side-effect code due to the usage of generate_permute_indices
85+
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
86+
else:
87+
# x shape (num_experts, tokens_per_expert, dim)
88+
h = F.silu(torch.bmm(x, w1))
89+
h = h * torch.bmm(x, w3)
90+
# out shape (num_experts, tokens_per_expert, dim)
91+
out = torch.bmm(h, w2)
92+
93+
return out
94+
95+
@expert_parallel
96+
def _run_experts_grouped_mm(
97+
w1: torch.Tensor,
98+
w2: torch.Tensor,
99+
w3: torch.Tensor,
100+
x: torch.Tensor,
101+
num_tokens_per_expert: torch.Tensor | None = None,
102+
) -> torch.Tensor:
103+
if num_tokens_per_expert is not None:
104+
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
105+
# grouped mm between a 2D tensor and a 3D tensor
106+
assert x.dim() == 2
107+
else:
108+
offsets = None
109+
# fall back to regular bmm between 3D tensors
110+
assert x.dim() == 3
111+
112+
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
113+
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
114+
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
115+
116+
return out
117+
51118
class GroupedExperts(nn.Module):
52119
def __init__(
53120
self,
@@ -69,83 +136,14 @@ def forward(
69136
num_tokens_per_expert: torch.Tensor | None = None,
70137
) -> torch.Tensor:
71138
if self.use_grouped_mm:
72-
return GroupedExperts._run_experts_grouped_mm(
139+
return _run_experts_grouped_mm(
73140
self.w1, self.w2, self.w3, x, num_tokens_per_expert
74141
)
75142
else:
76-
return GroupedExperts._run_experts_for_loop(
143+
return _run_experts_for_loop(
77144
self.w1, self.w2, self.w3, x, num_tokens_per_expert
78145
)
79146

80-
# TODO: keeping this for-loop implementation for comparison
81-
# and readability, may remove later
82-
@expert_parallel
83-
@staticmethod
84-
def _run_experts_for_loop(
85-
w1: torch.Tensor,
86-
w2: torch.Tensor,
87-
w3: torch.Tensor,
88-
x: torch.Tensor,
89-
num_tokens_per_expert: torch.Tensor | None = None,
90-
) -> torch.Tensor:
91-
if num_tokens_per_expert is not None:
92-
# NOTE: this would incur a synchronization between device and host
93-
num_tokens_per_expert = num_tokens_per_expert.tolist()
94-
95-
# side-effect code due to the usage of generate_permute_indices
96-
num_padding = x.shape[0] - sum(num_tokens_per_expert)
97-
98-
# a tuple of tensors indexed by experts
99-
# each with shape (tokens_per_expert(varying), dim)
100-
x = torch.split(
101-
x[: sum(num_tokens_per_expert)],
102-
split_size_or_sections=num_tokens_per_expert,
103-
dim=0,
104-
)
105-
out_experts_splits = []
106-
for expert_idx, x_expert in enumerate(x):
107-
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
108-
h = h * torch.matmul(x_expert, w3[expert_idx])
109-
h = torch.matmul(h, w2[expert_idx])
110-
# h shape (tokens_per_expert(varying), dim)
111-
out_experts_splits.append(h)
112-
out = torch.cat(out_experts_splits, dim=0)
113-
114-
# side-effect code due to the usage of generate_permute_indices
115-
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
116-
else:
117-
# x shape (num_experts, tokens_per_expert, dim)
118-
h = F.silu(torch.bmm(x, w1))
119-
h = h * torch.bmm(x, w3)
120-
# out shape (num_experts, tokens_per_expert, dim)
121-
out = torch.bmm(h, w2)
122-
123-
return out
124-
125-
@expert_parallel
126-
@staticmethod
127-
def _run_experts_grouped_mm(
128-
w1: torch.Tensor,
129-
w2: torch.Tensor,
130-
w3: torch.Tensor,
131-
x: torch.Tensor,
132-
num_tokens_per_expert: torch.Tensor | None = None,
133-
) -> torch.Tensor:
134-
if num_tokens_per_expert is not None:
135-
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
136-
# grouped mm between a 2D tensor and a 3D tensor
137-
assert x.dim() == 2
138-
else:
139-
offsets = None
140-
# fall back to regular bmm between 3D tensors
141-
assert x.dim() == 3
142-
143-
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
144-
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
145-
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
146-
147-
return out
148-
149147
def init_weights(self, init_std: float):
150148
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
151149
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)

0 commit comments

Comments
 (0)