2525from __future__ import division
2626from __future__ import print_function
2727
28+ import math
2829import gin
2930
3031import mesh_tensorflow as mtf
@@ -65,7 +66,10 @@ def __init__(self,
6566 word_embed_mode = None ,
6667 use_second_place_expert_prob = None ,
6768 use_second_place_expert_prob_temp = None ,
68- top_n_num_experts_per_token = 3 ):
69+ top_n_num_experts_per_token = 3 ,
70+ rloo = False ,
71+ loss_type = "load_balance" ,
72+ p_dot_e = True ):
6973 self ._hparams = HParams (
7074 moe_gating = moe_gating ,
7175 moe_num_experts = num_experts ,
@@ -95,7 +99,10 @@ def __init__(self,
9599 use_second_place_expert_prob ),
96100 moe_use_second_place_expert_prob_temp = (
97101 use_second_place_expert_prob_temp ),
98- moe_top_n_num_experts_per_token = top_n_num_experts_per_token )
102+ moe_top_n_num_experts_per_token = top_n_num_experts_per_token ,
103+ moe_rloo = rloo ,
104+ loss_type = loss_type ,
105+ p_dot_e = p_dot_e )
99106 self ._activation = activation
100107
101108 def call (self , context , x , losses = None ):
@@ -127,7 +134,8 @@ def call(self, context, x, losses=None):
127134 nonpadding = context .nonpadding ,
128135 activation = self ._activation ,
129136 num_microbatches = context .num_microbatches ,
130- token_embeddings = context .input_embeddings )
137+ token_embeddings = context .input_embeddings ,
138+ context = context )
131139 if context .losses is not None :
132140 context .losses .append (loss )
133141 if not has_length_dim :
@@ -202,7 +210,7 @@ def call(self, context, x, losses=None):
202210def transformer_moe_layer_v1 (
203211 inputs , output_dim , hparams , train , variable_dtype ,
204212 layout = None , mesh_shape = None , nonpadding = None , activation = mtf .relu ,
205- num_microbatches = None , token_embeddings = None ):
213+ num_microbatches = None , token_embeddings = None , context = None ):
206214 """Local mixture of experts that works well on TPU.
207215
208216 Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -281,6 +289,8 @@ def transformer_moe_layer_v1(
281289 [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
282290 that correspond to the inputs. These can optionally be used to make
283291 routing decisions.
292+ context: a Context object contains extra information that layers need
293+ at call time, as defined in transformer.py.
284294
285295 Returns:
286296 outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -436,7 +446,8 @@ def transformer_moe_layer_v1(
436446 variable_dtype = variable_dtype ,
437447 importance = nonpadding ,
438448 num_microbatches = num_microbatches ,
439- token_embeddings = token_embeddings )
449+ token_embeddings = token_embeddings ,
450+ context = context )
440451 elif hparams .moe_gating == "ntlb" :
441452 dispatch_tensor , combine_tensor , loss = _ntlb_gating (
442453 inputs = inputs ,
@@ -1303,7 +1314,8 @@ def _expert_selection_gating(
13031314def _switch_gating (
13041315 inputs , outer_expert_dims , experts_dim , expert_capacity_dim ,
13051316 hparams , train , variable_dtype , importance = None , name = "switch_gating" ,
1306- num_microbatches = None , token_embeddings = None ):
1317+ num_microbatches = None , token_embeddings = None ,
1318+ context = None ):
13071319 """Compute Switch gating."""
13081320 # SELECT EXPERT
13091321 if train :
@@ -1351,6 +1363,11 @@ def _switch_gating(
13511363 expert_gate = mtf .gather (raw_gates , expert_index , dim = experts_dim )
13521364 else :
13531365 raise ValueError ("Unknown Switch gating policy %s" % policy )
1366+ full_expert_gate_log_probs = gate_logits / hparams .moe_switch_temperature
1367+ full_expert_gate_log_probs -= mtf .reduce_logsumexp (full_expert_gate_log_probs ,
1368+ reduced_dim = experts_dim )
1369+ expert_gate_log_probs = mtf .gather (full_expert_gate_log_probs , expert_index ,
1370+ dim = experts_dim )
13541371
13551372 expert_mask = mtf .one_hot (expert_index , experts_dim , dtype = raw_gates .dtype )
13561373
@@ -1363,21 +1380,40 @@ def _switch_gating(
13631380 expert_gate *= mtf .cast (mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
13641381 density_1_proxy *= mtf .cast (
13651382 mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
1366- loss = (
1383+ load_balance_loss = (
13671384 mtf .reduce_mean (density_1_proxy * density_1 ) *
13681385 float (experts_dim .size * experts_dim .size ))
1386+
1387+ kl_with_uniform = (
1388+ - math .log (float (experts_dim .size ))
1389+ - mtf .reduce_logsumexp (full_expert_gate_log_probs ,
1390+ reduced_dim = group_size_dim )
1391+ + math .log (float (group_size_dim .size )))
1392+ if importance :
1393+ kl_with_uniform *= mtf .cast (mtf .equal (importance , 1.0 ),
1394+ dtype = raw_gates .dtype )
1395+ kl_with_uniform = mtf .reduce_mean (kl_with_uniform )
1396+
1397+ if hparams .loss_type .lower () == "kl" :
1398+ loss = kl_with_uniform
1399+ else :
1400+ loss = load_balance_loss
1401+
13691402 if num_microbatches and num_microbatches > 1 :
13701403 tf .logging .info ("Dividing load-balance loss by num_microbatches={}" .format (
13711404 num_microbatches ))
13721405 loss /= num_microbatches
13731406
13741407 # Logging
13751408 if train :
1376- entropy = mtf .reduce_sum (- raw_gates * mtf .log (raw_gates + 1e-9 ),
1377- reduced_dim = experts_dim )
1409+ entropy = mtf .reduce_sum (
1410+ - mtf .exp (full_expert_gate_log_probs ) * full_expert_gate_log_probs ,
1411+ reduced_dim = experts_dim )
13781412 batch_entropy = mtf .reduce_mean (entropy )
13791413 mtf .scalar_summary (name + "/entropy" , batch_entropy )
13801414 mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
1415+ mtf .scalar_summary ("tempered_expert_gate" ,
1416+ mtf .reduce_mean (mtf .exp (expert_gate_log_probs )))
13811417
13821418 mask_count_experts = mtf .reduce_sum (expert_mask , output_shape = [experts_dim ])
13831419 total_routed = mtf .reduce_sum (mask_count_experts )
@@ -1389,7 +1425,25 @@ def _switch_gating(
13891425 for fraction in split_fractions :
13901426 mtf .scalar_summary ("experts/" + fraction .name .replace (":" , "/" ),
13911427 mtf .reduce_mean (fraction ))
1392- mtf .scalar_summary ("aux_loss" , mtf .reduce_mean (loss ))
1428+ dead_expert_fraction = mtf .reduce_mean (
1429+ mtf .cast (mtf .equal (mask_count_experts , 0. ),
1430+ dtype = raw_gates .dtype ))
1431+ mtf .scalar_summary ("dead_expert_fraction" ,
1432+ dead_expert_fraction )
1433+ mtf .scalar_summary ("load_balancing_loss" ,
1434+ mtf .reduce_mean (load_balance_loss ))
1435+ mtf .scalar_summary ("kl_with_uniform" ,
1436+ mtf .reduce_mean (kl_with_uniform ))
1437+
1438+ split_expert_index = mtf .rename_dimension (
1439+ expert_index , 'batch' , 'batch_split' )
1440+ first_expert_index , second_expert_index = mtf .split (
1441+ split_expert_index ,
1442+ split_expert_index .shape .get_dim_by_name ('batch_split' ), 2 )
1443+ duplicate_sample = mtf .reduce_mean (
1444+ mtf .cast (mtf .equal (first_expert_index , second_expert_index ),
1445+ dtype = raw_gates .dtype ))
1446+ mtf .scalar_summary ("duplicate_sample_fraction" , duplicate_sample )
13931447
13941448 # Add in the z_loss for router.
13951449 if train and hparams .moe_z_loss is not None :
@@ -1421,9 +1475,16 @@ def _switch_gating(
14211475 # Mask out the experts that have overflowed expert capacity. Sparsify the
14221476 # expert_gate.
14231477 expert_gate *= expert_mask_flat
1478+ if hparams .moe_rloo :
1479+ expert_gate_log_probs *= expert_mask_flat
1480+ context .expert_gate_log_probs .append (expert_gate_log_probs )
14241481
1425- combine_tensor = (
1426- expert_gate * expert_mask_flat *
1482+ if hparams .p_dot_e :
1483+ combine_tensor = expert_gate
1484+ else :
1485+ combine_tensor = expert_mask_flat
1486+
1487+ combine_tensor *= (
14271488 mtf .one_hot (expert_index , experts_dim , dtype = raw_gates .dtype ) *
14281489 mtf .one_hot (
14291490 mtf .to_int32 (position_in_expert ),
0 commit comments