@@ -61,7 +61,9 @@ def __init__(self,
6161 ntlb_top_k = 4 ,
6262 output_dim = None ,
6363 use_experts_attention = False ,
64- z_loss = None ):
64+ z_loss = None ,
65+ num_hidden_splits = None ,
66+ split_hidden_before_routing = False ):
6567 self ._hparams = HParams (
6668 moe_gating = moe_gating ,
6769 moe_num_experts = num_experts ,
@@ -85,7 +87,9 @@ def __init__(self,
8587 moe_output_dim = output_dim ,
8688 moe_ntlb_top_k = ntlb_top_k ,
8789 moe_use_experts_attention = use_experts_attention ,
88- moe_z_loss = z_loss )
90+ moe_z_loss = z_loss ,
91+ moe_num_hidden_splits = num_hidden_splits ,
92+ moe_split_hidden_before_routing = split_hidden_before_routing )
8993 self ._activation = activation
9094
9195 def call (self , context , x , losses = None ):
@@ -327,8 +331,8 @@ def transformer_moe_layer_v1(
327331 # We "cheat" here and look at the mesh shape and layout. This is to ensure
328332 # that the number of groups is a multiple of the mesh dimension
329333 # over which those groups are split.
330- batch_and_length_dims , input_dim = (orig_inputs . shape . dims [: - 1 ],
331- orig_inputs .shape .dims [- 1 ])
334+ batch_and_length_dims , orig_input_dim = (
335+ orig_inputs . shape . dims [: - 1 ], orig_inputs .shape .dims [- 1 ])
332336 # Hack: we assume that
333337 # "outer_batch" == replication of experts
334338 # mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -348,16 +352,57 @@ def transformer_moe_layer_v1(
348352
349353 n = n // outer_batch_dim .size
350354
351- mesh_dim_size = mtf .tensor_dim_to_mesh_dim_size (layout , mesh_shape ,
352- orig_batch_dim )
353- num_groups , group_size = _split_into_groups (n , hparams .moe_group_size ,
354- mesh_dim_size )
355+ # Create num_groups and group_size dimensions
356+ mesh_dim_size = mtf .tensor_dim_to_mesh_dim_size (
357+ layout , mesh_shape , orig_batch_dim )
358+ num_groups , group_size = _split_into_groups (
359+ n , hparams .moe_group_size , mesh_dim_size )
360+ orig_group_size_dim = mtf .Dimension ("group" , group_size )
361+ orig_num_groups_dim = mtf .Dimension (orig_batch_dim .name , num_groups )
362+
363+ # The original dimensions correspond to those before splitting tokens
364+ # into subtokens
365+ group_size_dim = orig_group_size_dim
366+ num_groups_dim = orig_num_groups_dim
367+ input_dim = orig_input_dim
368+
369+ split_hidden_before_routing = False
370+ split_hidden_after_routing = False
371+ if hparams .moe_num_hidden_splits is not None :
372+ if orig_input_dim .size % hparams .moe_num_hidden_splits :
373+ raise ValueError ("num_hidden_splits {} must divide input_dim {}" .format (
374+ hparams .moe_num_hidden_splits , input_dim .size ))
375+ if output_dim .size % hparams .moe_num_hidden_splits :
376+ raise ValueError ("num_hidden_splits {} must divide input_dim {}" .format (
377+ hparams .moe_num_hidden_splits , input_dim .size ))
378+ split_hidden_before_routing = hparams .moe_split_hidden_before_routing
379+ split_hidden_after_routing = not hparams .moe_split_hidden_before_routing
380+ hidden_dim = mtf .Dimension (
381+ "expert_hidden" ,
382+ hparams .moe_hidden_size // hparams .moe_num_hidden_splits )
383+ sub_output_dim = mtf .Dimension (
384+ output_dim .name , output_dim .size // hparams .moe_num_hidden_splits )
385+ num_splits_dim = mtf .Dimension (
386+ "num_splits" , hparams .moe_num_hidden_splits )
387+
388+ if split_hidden_before_routing :
389+ input_dim = mtf .Dimension (
390+ input_dim .name , input_dim .size // hparams .moe_num_hidden_splits )
391+
392+ # Split into groups and subtokens
393+ inputs = mtf .reshape (
394+ inputs , [outer_batch_dim , num_groups_dim , group_size_dim ,
395+ num_splits_dim , input_dim ])
355396
356- group_size_dim = mtf .Dimension ("group" , group_size )
357- num_groups_dim = mtf .Dimension (orig_batch_dim .name , num_groups )
397+ inputs = mtf .transpose (
398+ inputs , [outer_batch_dim , num_groups_dim , num_splits_dim ,
399+ group_size_dim , input_dim ])
358400
401+ num_groups_dim = mtf .Dimension (
402+ orig_batch_dim .name , num_groups * hparams .moe_num_hidden_splits )
403+
404+ # [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
359405 moe_input_dims = [outer_batch_dim , num_groups_dim , group_size_dim , input_dim ]
360- # OGSM Tensor
361406 inputs = mtf .reshape (inputs , moe_input_dims )
362407
363408 # Each sequence sends expert_capacity positions to each expert.
@@ -373,156 +418,138 @@ def transformer_moe_layer_v1(
373418 expert_capacity_dim = mtf .Dimension ("expert_capacity" , expert_capacity )
374419 experts_dim_unsplit = mtf .Dimension ("expert_unsplit" , experts_dim .size )
375420 batch_dim_unsplit = mtf .Dimension ("batch_unsplit" , num_groups_dim .size )
421+
376422 if nonpadding is not None :
377423 nonpadding = mtf .zeros (
378424 inputs .mesh , batch_and_length_dims , dtype = inputs .dtype ) + nonpadding
425+
426+ if split_hidden_before_routing :
427+ nonpadding = mtf .reshape (
428+ nonpadding ,
429+ [outer_batch_dim , orig_num_groups_dim , orig_group_size_dim ])
430+
431+ # Tile num_hidden_splits times with an einsum
432+ tiling_tensor = mtf .ones (inputs .mesh , [num_splits_dim ])
433+ nonpadding = mtf .einsum (
434+ [nonpadding , tiling_tensor ],
435+ output_shape = [outer_batch_dim , orig_num_groups_dim , num_splits_dim ,
436+ orig_group_size_dim ])
437+
379438 nonpadding = mtf .reshape (nonpadding , moe_input_dims [:- 1 ])
380- if hparams .moe_gating == "top_2" :
381- # combine_tensor,
382- # dispatch_tensor OG`SEC Tensors
383- # (G is generally split along mesh dim)
384- dispatch_tensor , combine_tensor , loss = _top_2_gating (
385- inputs = inputs ,
386- outer_expert_dims = None ,
387- experts_dim = experts_dim_unsplit ,
388- expert_capacity_dim = expert_capacity_dim ,
389- hparams = hparams ,
390- train = train ,
391- variable_dtype = variable_dtype ,
392- importance = nonpadding ,
393- num_microbatches = num_microbatches )
394- elif hparams .moe_gating == "switch" :
395- dispatch_tensor , combine_tensor , loss = _switch_gating (
396- inputs = inputs ,
397- outer_expert_dims = None ,
398- experts_dim = experts_dim_unsplit ,
399- expert_capacity_dim = expert_capacity_dim ,
400- hparams = hparams ,
401- train = train ,
402- variable_dtype = variable_dtype ,
403- importance = nonpadding ,
404- num_microbatches = num_microbatches )
405- elif hparams .moe_gating == "ntlb" :
406- dispatch_tensor , combine_tensor , loss = _ntlb_gating (
407- inputs = inputs ,
408- outer_expert_dims = None ,
409- experts_dim = experts_dim_unsplit ,
410- expert_capacity_dim = expert_capacity_dim ,
411- hparams = hparams ,
412- train = train ,
413- variable_dtype = variable_dtype ,
414- importance = nonpadding ,
415- num_microbatches = num_microbatches )
416- elif hparams .moe_gating == "switch_max" :
417- dispatch_tensor , combine_tensor , loss = _switch_max_gating (
418- inputs = inputs ,
419- outer_expert_dims = None ,
420- experts_dim = experts_dim_unsplit ,
421- expert_capacity_dim = expert_capacity_dim ,
422- hparams = hparams ,
423- train = train ,
424- variable_dtype = variable_dtype ,
425- importance = nonpadding ,
426- num_microbatches = num_microbatches )
427- elif hparams .moe_gating == "expert_selection" :
428- dispatch_tensor , combine_tensor , loss = _expert_selection_gating (
429- inputs = inputs ,
430- outer_expert_dims = None ,
431- experts_dim = experts_dim_unsplit ,
432- group_size_dim = group_size_dim ,
433- expert_capacity_dim = expert_capacity_dim ,
434- hparams = hparams ,
435- train = train ,
436- variable_dtype = variable_dtype ,
437- importance = nonpadding ,
438- name = "expert_selection_gating" ,
439- num_microbatches = num_microbatches )
440- else :
441- raise ValueError ("unknown hparams.moe_gating=%s" % hparams .moe_gating )
442439
443- expert_inputs = mtf .einsum ([inputs , dispatch_tensor ],
444- mtf .Shape ([
445- outer_batch_dim , experts_dim_unsplit ,
446- num_groups_dim , expert_capacity_dim , input_dim
447- ]))
440+ # [outer_batch_dim, num_groups_dim.B, group_size_dim,
441+ # experts_dim_unsplit, expert_capacity_dim]
442+ gating_fn = get_gating_fn (hparams .moe_gating )
443+ dispatch_tensor , combine_tensor , loss = gating_fn (
444+ inputs = inputs ,
445+ outer_expert_dims = None ,
446+ experts_dim = experts_dim_unsplit ,
447+ expert_capacity_dim = expert_capacity_dim ,
448+ hparams = hparams ,
449+ train = train ,
450+ variable_dtype = variable_dtype ,
451+ importance = nonpadding ,
452+ num_microbatches = num_microbatches )
453+
454+ # Dispatch to the experts by reducing group_size_dim
455+ # inputs: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
456+ # dispatch_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
457+ # experts_dim_unsplit, expert_capacity_dim]
458+ # expert_inputs: [outer_batch_dim, experts_dim_unsplit, num_groups_dim.B,
459+ # expert_capacity_dim, input_dim]
460+ expert_inputs_shape = [
461+ outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
462+ expert_capacity_dim , input_dim ]
463+ expert_inputs = mtf .einsum ([inputs , dispatch_tensor ], expert_inputs_shape )
448464
465+ # Split over batch -> split over experts
449466 # Extra reshape reduces communication cost for model-parallel versions.
450467 # For model-parallel versions, this reshape causes an mtf.slice and for non-
451468 # model-parallel versions, this has no effect.
469+ # expert_inputs: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
470+ # expert_capacity_dim, input_dim or input_dim.M]
452471 d_model_split_dim = mtf .Dimension ("d_model_split" , input_dim .size )
453- expert_inputs = mtf .reshape (
454- expert_inputs ,
455- mtf .Shape ([
456- outer_batch_dim , experts_dim , batch_dim_unsplit , expert_capacity_dim ,
457- d_model_split_dim
458- ]))
459-
460- # Split over batch -> split over experts
461- expert_inputs = mtf .reshape (
462- expert_inputs ,
463- mtf .Shape ([
464- outer_batch_dim , experts_dim , batch_dim_unsplit , expert_capacity_dim ,
465- input_dim
466- ]))
467-
468- # Now feed the expert inputs through the experts.
469- h = mtf .layers .dense_product (
470- expert_inputs ,
471- reduced_dims = expert_inputs .shape .dims [- 1 :],
472- new_dims = [hidden_dim ],
473- expert_dims = [experts_dim ],
474- activation_functions = activation , use_bias = False ,
475- variable_dtype = variable_dtype , name = "wi" )
476-
477- if hparams .moe_dropout_rate != 0.0 :
478- h = mtf .dropout (h , is_training = train ,
479- keep_prob = 1.0 - hparams .moe_dropout_rate )
480-
481- def _compute_output (hidden , layer_name ):
482- """Compute the output of the attention layer from the hidden vector."""
472+ expert_inputs_shape = [
473+ outer_batch_dim , experts_dim , batch_dim_unsplit ,
474+ expert_capacity_dim , d_model_split_dim ]
475+ expert_inputs = mtf .reshape (expert_inputs , expert_inputs_shape )
476+
477+ expert_inputs_shape = [
478+ outer_batch_dim , experts_dim , batch_dim_unsplit ,
479+ expert_capacity_dim , input_dim ]
480+ expert_inputs = mtf .reshape (expert_inputs , expert_inputs_shape )
481+
482+ def _apply_experts (x , output_dim , hidden_dim ):
483+ # x: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
484+ # expert_capacity_dim, input_dim]
485+ h = mtf .layers .dense_product (
486+ x ,
487+ reduced_dims = x .shape .dims [- 1 :],
488+ new_dims = [hidden_dim ],
489+ expert_dims = [experts_dim ],
490+ activation_functions = activation , use_bias = False ,
491+ variable_dtype = variable_dtype , name = "wi" )
492+
493+ if hparams .moe_dropout_rate != 0.0 :
494+ h = mtf .dropout (h , is_training = train ,
495+ keep_prob = 1.0 - hparams .moe_dropout_rate )
483496 expert_output = mtf .layers .dense (
484- hidden , output_dim , expert_dims = [experts_dim ], use_bias = False ,
485- reduced_dims = hidden .shape .dims [- 1 :], variable_dtype = variable_dtype ,
486- name = layer_name )
487-
488- # Extra reshape reduces communication cost for model-parallel versions.
489- # For model-parallel versions, this reshape causes an mtf.slice and for non-
490- # model-parallel versions, this has no effect.
491- expert_output = mtf .reshape (
492- expert_output ,
493- mtf .Shape ([
494- outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
495- expert_capacity_dim , d_model_split_dim
496- ]))
497-
498- # Split over experts -> split over batch
497+ h , output_dim , expert_dims = [experts_dim ], use_bias = False ,
498+ reduced_dims = h .shape .dims [- 1 :], variable_dtype = variable_dtype ,
499+ name = "wo" )
500+
501+ return expert_output
502+
503+ if split_hidden_after_routing :
504+ input_dim = mtf .Dimension (
505+ input_dim .name , input_dim .size // hparams .moe_num_hidden_splits )
506+ expert_inputs = mtf .reshape (
507+ expert_inputs , expert_inputs .shape [:- 1 ] + [num_splits_dim , input_dim ])
508+ expert_output = _apply_experts (expert_inputs , sub_output_dim , hidden_dim )
509+ # Concat sub_tokens into tokens
499510 expert_output = mtf .reshape (
500- expert_output ,
501- mtf .Shape ([
502- outer_batch_dim ,
503- experts_dim_unsplit ,
504- num_groups_dim ,
505- expert_capacity_dim ,
506- output_dim ,
507- ]))
508- moe_output_dims = moe_input_dims [:- 1 ] + [output_dim ]
509- output = mtf .einsum ([expert_output , combine_tensor ],
510- mtf .Shape (moe_output_dims ))
511- output = mtf .reshape (output , batch_and_length_dims + [output_dim ])
512- return output
513-
514- if hparams .moe_use_experts_attention :
515- # We share k_h and v_h with no degradation in performance
516- q_h , k_h = h , h
517- outputs = []
518- q = _compute_output (q_h , layer_name = "q_wo" )
519- k = _compute_output (k_h , layer_name = "k_wo" )
520- outputs .append (q )
521- outputs .append (k )
522- return outputs , loss * hparams .moe_loss_coef
511+ expert_output , expert_output .shape [:- 2 ] + [output_dim ])
512+ elif split_hidden_before_routing :
513+ expert_output = _apply_experts (expert_inputs , sub_output_dim , hidden_dim )
523514 else :
524- output = _compute_output (h , layer_name = "wo" )
525- return output , loss * hparams .moe_loss_coef
515+ expert_output = _apply_experts (expert_inputs , output_dim , hidden_dim )
516+
517+ # Extra reshape reduces communication cost for model-parallel versions.
518+ # For model-parallel versions, this reshape causes an mtf.slice and for non-
519+ # model-parallel versions, this has no effect.
520+ expert_output_shape = [
521+ outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
522+ expert_capacity_dim , d_model_split_dim ]
523+ expert_output = mtf .reshape (expert_output , expert_output_shape )
524+
525+ # Split over experts -> split over batch
526+ expert_output_shape = [
527+ outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
528+ expert_capacity_dim , expert_output .shape [- 1 ]]
529+ expert_output = mtf .reshape (expert_output , expert_output_shape )
530+
531+ # Combine by reducing experts_dim_unsplit and expert_capacity_dim
532+ # expert_output: [outer_batch_dim, experts_dim_unsplit, num_groups_dim,
533+ # expert_capacity_dim, output_dim]
534+ # combine_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
535+ # experts_dim_unsplit, expert_capacity_dim]
536+ # output: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
537+ moe_output_dims = moe_input_dims [:- 1 ] + [expert_output .shape [- 1 ]]
538+ output = mtf .einsum ([expert_output , combine_tensor ], moe_output_dims )
539+ # import pdb; pdb.set_trace() # pylint:disable=g-import-not-at-top
540+
541+ if split_hidden_before_routing :
542+ output = mtf .reshape (
543+ output , [output .shape [0 ], orig_num_groups_dim , num_splits_dim ] + (
544+ output .shape [- 2 :]))
545+ output = mtf .transpose (
546+ output , output .shape [:2 ] + [
547+ group_size_dim , num_splits_dim , output .shape [- 1 ]])
548+ output = mtf .reshape (output , output .shape [:3 ] + [output_dim ])
549+
550+ output = mtf .reshape (output , batch_and_length_dims + [output_dim ])
551+
552+ return output , loss * hparams .moe_loss_coef
526553
527554
528555def transformer_moe_layer_v2 (
@@ -801,6 +828,22 @@ def transformer_moe_layer_v2(
801828 return output , (loss_outer + loss_inner ) * hparams .moe_loss_coef
802829
803830
831+ def get_gating_fn (moe_gating ):
832+ """Factory for gating functions."""
833+ if moe_gating == "top_2" :
834+ return _top_2_gating
835+ elif moe_gating == "switch" :
836+ return _switch_gating
837+ elif moe_gating == "ntlb" :
838+ return _ntlb_gating
839+ elif moe_gating == "switch_max" :
840+ return _switch_max_gating
841+ elif moe_gating == "expert_selection" :
842+ return _expert_selection_gating
843+ else :
844+ raise ValueError ("unknown hparams.moe_gating=%s" % moe_gating )
845+
846+
804847def _ntlb_gating (inputs ,
805848 outer_expert_dims ,
806849 experts_dim ,
0 commit comments