Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 3016312

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
Explicitly pass named-arg to mtf.dropout
PiperOrigin-RevId: 370775492
1 parent 0ae4a03 commit 3016312

File tree

1 file changed

+15
-8
lines changed
  • mesh_tensorflow/transformer

1 file changed

+15
-8
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,9 @@ def transformer_moe_layer_v1(
474474
activation_functions=activation, use_bias=False,
475475
variable_dtype=variable_dtype, name="wi")
476476

477-
if train and hparams.moe_dropout_rate != 0.0:
478-
h = mtf.dropout(h, 1.0 - hparams.moe_dropout_rate)
477+
if hparams.moe_dropout_rate != 0.0:
478+
h = mtf.dropout(h, is_training=train,
479+
keep_prob=1.0 - hparams.moe_dropout_rate)
479480

480481
def _compute_output(hidden, layer_name):
481482
"""Compute the output of the attention layer from the hidden vector."""
@@ -957,8 +958,10 @@ def _switch_max_gating(
957958
gate_inputs = mtf.to_float(inputs)
958959

959960
# Input perturbations
960-
if train and policy == "input_dropout":
961-
gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_switch_dropout)
961+
if policy == "input_dropout":
962+
gate_inputs = mtf.dropout(
963+
gate_inputs, is_training=train,
964+
keep_prob=1.0 - hparams.moe_switch_dropout)
962965
elif train and policy == "input_jitter":
963966
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
964967
hparams.moe_switch_jitter)
@@ -1068,8 +1071,9 @@ def _expert_selection_gating(
10681071
gate_inputs = mtf.to_float(inputs)
10691072

10701073
# Input perturbations for exploration.
1071-
if train and policy == "input_dropout":
1072-
gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_switch_dropout)
1074+
if policy == "input_dropout":
1075+
gate_inputs = mtf.dropout(gate_inputs, is_training=train,
1076+
keep_prob=1.0 - hparams.moe_switch_dropout)
10731077
elif train and policy == "input_jitter":
10741078
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
10751079
hparams.moe_switch_jitter)
@@ -1185,8 +1189,11 @@ def _switch_gating(
11851189
gate_inputs = mtf.to_float(inputs)
11861190

11871191
# Input perturbations
1188-
if train and policy == "input_dropout":
1189-
gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_switch_dropout)
1192+
if policy == "input_dropout":
1193+
gate_inputs = mtf.dropout(
1194+
gate_inputs,
1195+
is_training=train,
1196+
keep_prob=1.0 - hparams.moe_switch_dropout)
11901197
elif train and policy == "input_jitter":
11911198
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
11921199
hparams.moe_switch_jitter)

0 commit comments

Comments
 (0)