@@ -631,11 +631,10 @@ def __init__(self,
631631 combine_dims = True ,
632632 ensemble_dim = None ,
633633 keep_query_heads_dims = False ,
634- fold_scaling_into_initializer = False ,
634+ fold_scaling_into_initializer = True ,
635635 context = None ,
636636 experts_hparams = None ,
637- expert_computation = "qkv" ,
638- is_encdec = False ):
637+ expert_computation = "qkv" ):
639638 super (ExpertsAttentionParams , self ).__init__ (
640639 mesh = mesh ,
641640 query_input_dim = query_input_dim ,
@@ -656,7 +655,6 @@ def __init__(self,
656655
657656 self .context = context
658657 self .expert_computation = expert_computation
659- self .is_encdec = is_encdec
660658
661659 # Unless we want to compute both q and kv, we can use the normal MoE
662660 # settings.
@@ -698,6 +696,9 @@ def __init__(self,
698696 # we want to partition both "experts_hidden" and "heads".
699697 moe_output_dims = mtf .Dimension ("d_model" , self .q_shape [- 1 ].size )
700698
699+ tf .logging .info ("ExpertsAttention moe_hidden_size: {}" .format (
700+ experts_hparams .hidden_size ))
701+ tf .logging .info ("moe_output_dims: {}" .format (moe_output_dims ))
701702 self .moe_layer = mtf .transformer .moe .MoE1D (
702703 moe_gating = experts_hparams .moe_gating ,
703704 num_experts = experts_hparams .num_experts ,
@@ -718,70 +719,55 @@ def __init__(self,
718719 activation = experts_hparams .activation ,
719720 z_loss = experts_hparams .z_loss )
720721
721- def _replace_d_model_dim (self , t ):
722- """Used to replace the `d_model` dim with `heads`."""
723- new_last_dim = mtf .Dimension (self .q_shape [- 1 ].name , t .shape [- 1 ].size )
724- return mtf .reshape (t , new_shape = mtf .Shape (t .shape [:- 1 ] + [new_last_dim ]))
725-
726- def _compute_q_with_experts (self , antecedent ):
727- q = self .moe_layer .call (self .context , antecedent )
728- q = self ._replace_d_model_dim (q )
729- return q
730-
731- def _compute_kv_with_experts (self , antecedent ):
732- kv = self .moe_layer .call (
733- self .context , antecedent , use_enc_nonpadding = self .is_encdec )
734- kv = self ._replace_d_model_dim (kv )
735- return kv
736-
737722 def _compute_merge_qkv (self , antecedent ):
738723 """Computes qkv all in one call using MoE layer."""
739- # This mode assumes query and memory antecedent are the same.
740- qkv = self .moe_layer .call (self .context , antecedent )
741- q , kv = qkv
742- q = self ._replace_d_model_dim (q )
743- kv = self ._replace_d_model_dim (kv )
744- self ._q = q
745- self ._kv = kv
746-
747- def compute_q (self , query_antecedent ):
724+ def _replace_d_model_dim (t ):
725+ """Used to replace the `d_model` dim with `heads`."""
726+ new_last_dim = mtf .Dimension (self .q_shape [- 1 ].name , t .shape [- 1 ].size )
727+ return mtf .reshape (
728+ t , new_shape = mtf .Shape (t .shape [:- 1 ] + [new_last_dim ]))
748729 if self .expert_computation == "qkv" :
749- self ._compute_merge_qkv (query_antecedent )
750- q = self ._q
730+ # NOTE: This assumes querty and memory antecedent are the same
731+ qk = self .moe_layer .call (self .context , antecedent )
732+ # Split qk here since they went through experts-layers
733+ q , k = qk
734+ q = _replace_d_model_dim (q )
735+ k = _replace_d_model_dim (k )
751736 elif self .expert_computation == "q" :
752- q = self ._compute_q_with_experts (query_antecedent )
753- # If computing "kv" with experts, then compute q normally.
737+ q = self .moe_layer .call (self .context , antecedent )
738+ q = _replace_d_model_dim (q )
739+ # Compute key/value normally
740+ k = mtf .layers .us_einsum (
741+ [antecedent , self .wkv ], reduced_dims = [self .memory_input_dim ])
754742 elif self .expert_computation == "kv" :
743+ k = self .moe_layer .call (self .context , antecedent )
744+ k = _replace_d_model_dim (k )
745+ # Compute query normally
755746 q = mtf .layers .us_einsum (
756- [query_antecedent , self .wq ], reduced_dims = [self .query_input_dim ])
747+ [antecedent , self .wq ], reduced_dims = [self .query_input_dim ])
748+ else :
749+ raise ValueError ("Invalid expert computation mode: {}" .format (
750+ self .expert_computation ))
751+
752+ # Scale query
757753 q *= self .key_dim .size ** - 0.5
758- return mtf .replace_dimensions (q , q .shape .dims [- 1 ], self .q_dims )
754+ self ._q = mtf .replace_dimensions (q , q .shape .dims [- 1 ], self .q_dims )
755+ self ._k = mtf .replace_dimensions (k , k .shape .dims [- 1 ], self .k_dims )
756+
757+ def compute_q (self , query_antecedent ):
758+ self ._compute_merge_qkv (query_antecedent )
759+ return self ._q
759760
760761 def compute_k (self , memory_antecedent ):
761- raise NotImplementedError ("ExpertsAttention uses shared_kv = True." )
762+ del memory_antecedent
763+ return self ._k
762764
763765 def compute_kv (self , memory_antecedent ):
764- if self .expert_computation == "qkv" :
765- # We have already computing "kv" with "q", so just return its value.
766- kv = self ._kv
767- # Check if the "length" dimension should be "memory_length" since both
768- # q and kv were computed using the same antecedent. This is why we must
769- # always have the same query and memory antecedent for the qkv mode.
770- if self .context .length_dim in kv .shape .dims :
771- memory_length = mtf .Dimension (
772- "memory_length" , self .context .length_dim .size )
773- kv = mtf .replace_dimensions (
774- kv , self .context .length_dim , memory_length )
775- # If computing "q" with experts, then compute "kv" normally.
776- elif self .expert_computation == "q" :
777- kv = mtf .layers .us_einsum (
778- [memory_antecedent , self .wkv ], reduced_dims = [self .memory_input_dim ])
779- elif self .expert_computation == "kv" :
780- kv = self ._compute_kv_with_experts (memory_antecedent )
781- kv = mtf .replace_dimensions (kv , kv .shape .dims [- 1 ], self .k_dims )
782- return kv
766+ del memory_antecedent
767+ return self ._k
783768
784769 def compute_v (self , memory_antecedent ):
770+ del memory_antecedent
785771 raise NotImplementedError ("ExpertsAttention uses shared_kv = True." )
786772
787773
0 commit comments