diff --git a/mesh_tensorflow/transformer/attention.py b/mesh_tensorflow/transformer/attention.py index 80a3fe0b..005a21e9 100644 --- a/mesh_tensorflow/transformer/attention.py +++ b/mesh_tensorflow/transformer/attention.py @@ -94,16 +94,20 @@ def attention(q, "Dividing attention z-loss loss by num_microbatches={}".format( context.num_microbatches)) z_loss /= context.num_microbatches - if context.train: - mtf.scalar_summary("attention_z_loss", z_loss) + mtf.scalar_summary("attention_z_loss", z_loss) z_loss *= z_loss_coeff context.losses.append(mtf.cast(z_loss, v.dtype)) weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.cast(weights, v.dtype) - weights = mtf.dropout( - weights, context.train, 1.0 - dropout_rate, - noise_shape=weights.shape - dropout_broadcast_dims) + if context: + weights = mtf.dropout( + weights, context.train, 1.0 - dropout_rate, + noise_shape=weights.shape - dropout_broadcast_dims) + else: + weights = mtf.dropout( + weights, False, 1.0 - dropout_rate, + noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim)