This repository was archived by the owner on Jan 21, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed
mesh_tensorflow/transformer Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -946,6 +946,8 @@ def _rand_1_gating(
946946
947947 if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter" :
948948 expert_gate , expert_index = mtf .top_1 (raw_gates , reduced_dim = experts_dim )
949+ if train :
950+ mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
949951 elif policy == "sample" :
950952 expert_index = mtf .sample_with_temperature (
951953 gate_logits , experts_dim , temperature = hparams .moe_rand_1_temperature )
@@ -1005,6 +1007,12 @@ def _rand_1_gating(
10051007 dtype = raw_gates .dtype )
10061008 expert_mask_flat = mtf .reduce_sum (expert_mask , reduced_dim = experts_dim )
10071009
1010+ if train :
1011+ total_routed = mtf .reduce_sum (expert_mask_flat )
1012+ importance = mtf .cast (importance , dtype = total_routed .dtype )
1013+ mtf .scalar_summary ("fraction_routed" ,
1014+ total_routed / mtf .reduce_sum (importance ))
1015+
10081016 # Mask out the experts that have overflowed expert capacity. Sparsify the
10091017 # expert_gate.
10101018 expert_gate *= expert_mask_flat
Original file line number Diff line number Diff line change @@ -665,9 +665,6 @@ def serialized_fn(mtf_features):
665665
666666 if tpu_summaries :
667667 mtf .scalar_summary ("loss" , loss )
668- for g in var_grads :
669- grad_norm = mtf .sqrt (mtf .reduce_sum (mtf .square (g )))
670- mtf .scalar_summary ("grads/norm" + g .name [:- 2 ], grad_norm )
671668
672669 if callable (learning_rate_schedule ):
673670 # the following happens on CPU since TPU can't handle summaries.
You can’t perform that action at this time.
0 commit comments