Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit b53f293

Browse files
author
Mesh TensorFlow Team
committed
Log expert_gating once it is been masked by the importance tensor to be sure no padded probabilities are being logged.
PiperOrigin-RevId: 378321668
1 parent 54b01b4 commit b53f293

File tree

1 file changed

+2
-4
lines changed
  • mesh_tensorflow/transformer

1 file changed

+2
-4
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -977,8 +977,6 @@ def _switch_max_gating(
977977

978978
if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
979979
expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
980-
if train:
981-
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
982980
elif policy == "sample":
983981
expert_index = mtf.sample_with_temperature(
984982
gate_logits, experts_dim, temperature=hparams.moe_switch_temperature)
@@ -1011,6 +1009,7 @@ def _switch_max_gating(
10111009
reduced_dim=experts_dim)
10121010
batch_entropy = mtf.reduce_mean(entropy)
10131011
mtf.scalar_summary(name + "/entropy", batch_entropy)
1012+
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
10141013

10151014
mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
10161015
total_routed = mtf.reduce_sum(mask_count_experts)
@@ -1209,8 +1208,6 @@ def _switch_gating(
12091208

12101209
if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
12111210
expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
1212-
if train:
1213-
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
12141211
elif policy == "sample":
12151212
expert_index = mtf.sample_with_temperature(
12161213
gate_logits, experts_dim, temperature=hparams.moe_switch_temperature)
@@ -1243,6 +1240,7 @@ def _switch_gating(
12431240
reduced_dim=experts_dim)
12441241
batch_entropy = mtf.reduce_mean(entropy)
12451242
mtf.scalar_summary(name + "/entropy", batch_entropy)
1243+
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
12461244

12471245
mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
12481246
total_routed = mtf.reduce_sum(mask_count_experts)

0 commit comments

Comments
 (0)