@@ -61,7 +61,8 @@ def __init__(self,
6161 ntlb_top_k = 4 ,
6262 output_dim = None ,
6363 use_experts_attention = False ,
64- z_loss = None ):
64+ z_loss = None ,
65+ token_logging = False ):
6566 self ._hparams = HParams (
6667 moe_gating = moe_gating ,
6768 moe_num_experts = num_experts ,
@@ -87,6 +88,7 @@ def __init__(self,
8788 moe_use_experts_attention = use_experts_attention ,
8889 moe_z_loss = z_loss )
8990 self ._activation = activation
91+ self .token_logging = token_logging
9092
9193 def call (self , context , x , losses = None ):
9294 """Call the layer."""
@@ -106,7 +108,13 @@ def call(self, context, x, losses=None):
106108 output_dim = self ._hparams .moe_output_dim
107109 else :
108110 output_dim = context .model .model_dim
109- y , loss = transformer_moe_layer_v1 (
111+ if self .token_logging :
112+ tokens = _detokenize (context .inputs , context .model .vocabulary )
113+ x = mtf .Print (x , [tokens ], "tokens" , summarize = 1000 )
114+ extras = _windows (context .inputs , context .length_dim )
115+ else :
116+ extras = None
117+ y , loss , extras = transformer_moe_layer_v1 (
110118 x ,
111119 output_dim ,
112120 self ._hparams ,
@@ -116,7 +124,16 @@ def call(self, context, x, losses=None):
116124 mesh_shape = context .model .mesh_shape ,
117125 nonpadding = context .nonpadding ,
118126 activation = self ._activation ,
119- num_microbatches = context .num_microbatches )
127+ num_microbatches = context .num_microbatches ,
128+ extras = extras )
129+
130+ if extras :
131+ extras = _detokenize (extras , context .model .vocabulary )
132+ experts_dim = mtf .Dimension ("experts" , self ._hparams .moe_num_experts )
133+ extras = mtf .unstack (extras , experts_dim )
134+ for i , t in enumerate (extras ):
135+ y = mtf .Print (y , [t ], "EXPERT %s" % i , summarize = 1000 )
136+
120137 if context .losses is not None :
121138 context .losses .append (loss )
122139 if not has_length_dim :
@@ -128,6 +145,23 @@ def call(self, context, x, losses=None):
128145 return y
129146
130147
148+ @gin .configurable
149+ def _windows (ids , length_dim , window_start = 0 , window_end = 0 ):
150+ to_stack = []
151+ for offset in range (window_start , window_end + 1 ):
152+ to_stack .append (mtf .shift (ids , - offset , length_dim , wrap = False ))
153+ return mtf .stack (to_stack , "window" , axis = ids .shape .ndims )
154+
155+
156+ def _detokenize (ids , vocabulary ):
157+ return mtf .slicewise (
158+ vocabulary .decode_tf ,
159+ [ids ],
160+ output_shape = mtf .Shape (ids .shape .dims [:- 1 ]),
161+ output_dtype = tf .string ,
162+ splittable_dims = ids .shape .dims [:- 1 ])
163+
164+
131165class MoE2D (transformer .TransformerLayer ):
132166 """Mixture of Experts Layer."""
133167
@@ -191,7 +225,7 @@ def call(self, context, x, losses=None):
191225def transformer_moe_layer_v1 (
192226 inputs , output_dim , hparams , train , variable_dtype ,
193227 layout = None , mesh_shape = None , nonpadding = None , activation = mtf .relu ,
194- num_microbatches = None ):
228+ num_microbatches = None , extras = None ):
195229 """Local mixture of experts that works well on TPU.
196230
197231 Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -266,6 +300,7 @@ def transformer_moe_layer_v1(
266300 and zeros(padding).
267301 activation: a function.
268302 num_microbatches: number of microbatches.
303+ extras: a tensor to dispatch (for debugging purposes)
269304
270305 Returns:
271306 outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -329,6 +364,10 @@ def transformer_moe_layer_v1(
329364 # over which those groups are split.
330365 batch_and_length_dims , input_dim = (orig_inputs .shape .dims [:- 1 ],
331366 orig_inputs .shape .dims [- 1 ])
367+
368+ if extras :
369+ extras_dims = extras .shape .dims [len (batch_and_length_dims ):]
370+
332371 # Hack: we assume that
333372 # "outer_batch" == replication of experts
334373 # mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -360,6 +399,11 @@ def transformer_moe_layer_v1(
360399 # OGSM Tensor
361400 inputs = mtf .reshape (inputs , moe_input_dims )
362401
402+ if extras :
403+ extras = mtf .reshape (
404+ extras ,
405+ [outer_batch_dim , num_groups_dim , group_size_dim ] + extras_dims )
406+
363407 # Each sequence sends expert_capacity positions to each expert.
364408 if train :
365409 capacity_factor = hparams .moe_capacity_factor_train
@@ -465,6 +509,17 @@ def transformer_moe_layer_v1(
465509 input_dim
466510 ]))
467511
512+ if extras :
513+ extras = mtf .einsum ([extras , mtf .cast (dispatch_tensor , extras .dtype )],
514+ mtf .Shape ([
515+ outer_batch_dim , experts_dim_unsplit ,
516+ num_groups_dim , expert_capacity_dim ] + extras_dims ))
517+ extras = mtf .reshape (
518+ extras ,
519+ mtf .Shape ([
520+ outer_batch_dim , experts_dim , batch_dim_unsplit ,
521+ expert_capacity_dim ] + extras_dims ))
522+
468523 # Now feed the expert inputs through the experts.
469524 h = mtf .layers .dense_product (
470525 expert_inputs ,
@@ -519,10 +574,15 @@ def _compute_output(hidden, layer_name):
519574 k = _compute_output (k_h , layer_name = "k_wo" )
520575 outputs .append (q )
521576 outputs .append (k )
522- return outputs , loss * hparams .moe_loss_coef
577+ return outputs , loss * hparams .moe_loss_coef , None
523578 else :
524579 output = _compute_output (h , layer_name = "wo" )
525- return output , loss * hparams .moe_loss_coef
580+ loss *= hparams .moe_loss_coef
581+
582+ if extras :
583+ return output , loss , extras
584+ else :
585+ return output , loss , None
526586
527587
528588def transformer_moe_layer_v2 (
0 commit comments