Skip to content

Commit 0e86cc9

Browse files
committed
test
1 parent d621e44 commit 0e86cc9

File tree

1 file changed

+1
-6
lines changed
  • torchtitan/models/common/moe

1 file changed

+1
-6
lines changed

torchtitan/models/common/moe/moe.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def __init__(self, config: Config):
100100
)
101101
self.use_grouped_mm = config.use_grouped_mm
102102
self._local_map_fn: Callable | None = None
103-
self._local_map_run_experts_fn: Callable | None = None
104103

105104
def forward(
106105
self,
@@ -127,10 +126,7 @@ def forward(
127126
# tensors. The output has a dynamic token dimension that cannot be
128127
# wrapped as a DTensor, so we use None out_placements to keep it
129128
# as a plain tensor.
130-
if (
131-
self._local_map_fn is None
132-
or self._local_map_run_experts_fn is not run_experts_fn
133-
):
129+
if self._local_map_fn is None:
134130
self._local_map_fn = local_map(
135131
run_experts_fn,
136132
in_placements=(
@@ -143,7 +139,6 @@ def forward(
143139
out_placements=None, # output stays as plain tensor
144140
device_mesh=self.w1.device_mesh,
145141
)
146-
self._local_map_run_experts_fn = run_experts_fn
147142
return self._local_map_fn(
148143
self.w1, self.w2, self.w3, x, num_tokens_per_expert
149144
)

0 commit comments

Comments
 (0)