@@ -474,8 +474,9 @@ def transformer_moe_layer_v1(
474474 activation_functions = activation , use_bias = False ,
475475 variable_dtype = variable_dtype , name = "wi" )
476476
477- if train and hparams .moe_dropout_rate != 0.0 :
478- h = mtf .dropout (h , 1.0 - hparams .moe_dropout_rate )
477+ if hparams .moe_dropout_rate != 0.0 :
478+ h = mtf .dropout (h , is_training = train ,
479+ keep_prob = 1.0 - hparams .moe_dropout_rate )
479480
480481 def _compute_output (hidden , layer_name ):
481482 """Compute the output of the attention layer from the hidden vector."""
@@ -957,8 +958,10 @@ def _switch_max_gating(
957958 gate_inputs = mtf .to_float (inputs )
958959
959960 # Input perturbations
960- if train and policy == "input_dropout" :
961- gate_inputs = mtf .dropout (gate_inputs , 1.0 - hparams .moe_switch_dropout )
961+ if policy == "input_dropout" :
962+ gate_inputs = mtf .dropout (
963+ gate_inputs , is_training = train ,
964+ keep_prob = 1.0 - hparams .moe_switch_dropout )
962965 elif train and policy == "input_jitter" :
963966 gate_inputs = mtf .layers .multiplicative_jitter (gate_inputs ,
964967 hparams .moe_switch_jitter )
@@ -1068,8 +1071,9 @@ def _expert_selection_gating(
10681071 gate_inputs = mtf .to_float (inputs )
10691072
10701073 # Input perturbations for exploration.
1071- if train and policy == "input_dropout" :
1072- gate_inputs = mtf .dropout (gate_inputs , 1.0 - hparams .moe_switch_dropout )
1074+ if policy == "input_dropout" :
1075+ gate_inputs = mtf .dropout (gate_inputs , is_training = train ,
1076+ keep_prob = 1.0 - hparams .moe_switch_dropout )
10731077 elif train and policy == "input_jitter" :
10741078 gate_inputs = mtf .layers .multiplicative_jitter (gate_inputs ,
10751079 hparams .moe_switch_jitter )
@@ -1185,8 +1189,11 @@ def _switch_gating(
11851189 gate_inputs = mtf .to_float (inputs )
11861190
11871191 # Input perturbations
1188- if train and policy == "input_dropout" :
1189- gate_inputs = mtf .dropout (gate_inputs , 1.0 - hparams .moe_switch_dropout )
1192+ if policy == "input_dropout" :
1193+ gate_inputs = mtf .dropout (
1194+ gate_inputs ,
1195+ is_training = train ,
1196+ keep_prob = 1.0 - hparams .moe_switch_dropout )
11901197 elif train and policy == "input_jitter" :
11911198 gate_inputs = mtf .layers .multiplicative_jitter (gate_inputs ,
11921199 hparams .moe_switch_jitter )
0 commit comments