@@ -977,8 +977,6 @@ def _switch_max_gating(
977977
978978 if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter" :
979979 expert_gate , expert_index = mtf .top_1 (raw_gates , reduced_dim = experts_dim )
980- if train :
981- mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
982980 elif policy == "sample" :
983981 expert_index = mtf .sample_with_temperature (
984982 gate_logits , experts_dim , temperature = hparams .moe_switch_temperature )
@@ -1011,6 +1009,7 @@ def _switch_max_gating(
10111009 reduced_dim = experts_dim )
10121010 batch_entropy = mtf .reduce_mean (entropy )
10131011 mtf .scalar_summary (name + "/entropy" , batch_entropy )
1012+ mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
10141013
10151014 mask_count_experts = mtf .reduce_sum (expert_mask , output_shape = [experts_dim ])
10161015 total_routed = mtf .reduce_sum (mask_count_experts )
@@ -1209,8 +1208,6 @@ def _switch_gating(
12091208
12101209 if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter" :
12111210 expert_gate , expert_index = mtf .top_1 (raw_gates , reduced_dim = experts_dim )
1212- if train :
1213- mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
12141211 elif policy == "sample" :
12151212 expert_index = mtf .sample_with_temperature (
12161213 gate_logits , experts_dim , temperature = hparams .moe_switch_temperature )
@@ -1243,6 +1240,7 @@ def _switch_gating(
12431240 reduced_dim = experts_dim )
12441241 batch_entropy = mtf .reduce_mean (entropy )
12451242 mtf .scalar_summary (name + "/entropy" , batch_entropy )
1243+ mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
12461244
12471245 mask_count_experts = mtf .reduce_sum (expert_mask , output_shape = [experts_dim ])
12481246 total_routed = mtf .reduce_sum (mask_count_experts )
0 commit comments