Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 8aa9b39

Browse files
nshazeerCopybara-Service
authored andcommitted
Mesh-TensorFlow:
Remove redefined-builtins from ops.py and move them to __init__.py Add a few potentially useful operations: mtf.sign mtf.abs mtf.layers.sigmoid_cross_entropy_with_logits mtf Transformer implementation: Remove logit-jittering and replace it with "z_loss", which seems to work better. Hard-code the broadcast dimensions for the dropout layers. PiperOrigin-RevId: 224581601
1 parent 8b721b0 commit 8aa9b39

File tree

9 files changed

+134
-50
lines changed

9 files changed

+134
-50
lines changed

mesh_tensorflow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from mesh_tensorflow import simd_mesh_impl
2727
from mesh_tensorflow import tpu_variables
2828
from mesh_tensorflow import utils
29-
from mesh_tensorflow.ops import * # pylint: disable=wildcard-import
29+
from mesh_tensorflow.ops_with_redefined_builtins import * # pylint: disable=wildcard-import
30+
3031

3132
# TODO(trandustin): Seal module.
3233
# from tensorflow.python.util.all_util import remove_undocumented # pylint: disable=line-too-long

mesh_tensorflow/layers.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
from mesh_tensorflow import ops as mtf
22+
from mesh_tensorflow import ops_with_redefined_builtins as mtf
2323

2424
import tensorflow as tf
2525

@@ -51,9 +51,6 @@ def dense(x, output_dim, reduced_dims=None, expert_dims=None,
5151
"""
5252
if variable_dtype is None:
5353
variable_dtype = mtf.VariableDType(master_dtype, slice_dtype, x.dtype)
54-
if variable_dtype.activation_dtype != x.dtype:
55-
raise ValueError("variable_dtype.activation_dtype must match x.dtype "
56-
"variable_dtype=%s x=%s" % (variable_dtype, x))
5754
if expert_dims is None:
5855
expert_dims = []
5956
if reduced_dims is None:
@@ -70,6 +67,7 @@ def dense(x, output_dim, reduced_dims=None, expert_dims=None,
7067
w_shape,
7168
initializer=tf.random_normal_initializer(stddev=stddev),
7269
dtype=variable_dtype)
70+
w = mtf.cast(w, x.dtype)
7371
y = mtf.einsum([x, w], output_shape)
7472
if use_bias:
7573
b = mtf.get_variable(
@@ -186,13 +184,20 @@ def batch_norm(x, is_training, momentum, epsilon=1e-9,
186184
return (norm_x * scale) + bias
187185

188186

189-
def softmax_cross_entropy_with_logits(logits, targets, vocab_dim):
187+
def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
190188
"""Per-example softmax loss.
191189
190+
if z_loss is nonzero, we add a loss equal to z_loss*log(z)^2, where z is the
191+
partition function. Example value: z_loss=1e-4. Two uses of z_loss are:
192+
- To keep the logits from drifting too far from zero, which can cause
193+
unacceptable roundoff errors in bfloat16.
194+
- To encourage the logits to be normalized log-probabilities.
195+
192196
Args:
193197
logits: a mtf.Tensor whose shape contains vocab_dim
194198
targets: a mtf.Tensor with the same shape as logits
195199
vocab_dim: a mtf.Dimension
200+
z_loss: a float
196201
197202
Returns:
198203
a mtf.Tensor whose shape is equal to logits.shape - vocab_dim
@@ -206,9 +211,35 @@ def softmax_cross_entropy_with_logits(logits, targets, vocab_dim):
206211
"logits=%s targets=%s" % (logits.to_string, targets.to_string))
207212
if vocab_dim not in logits.shape.dims:
208213
raise ValueError("vocab_dim must be in logits.shape.dims")
209-
log_softmax = mtf.log_softmax(logits, vocab_dim)
210-
return mtf.negative(
214+
log_z = mtf.reduce_logsumexp(logits, vocab_dim)
215+
log_softmax = logits - log_z
216+
loss = mtf.negative(
211217
mtf.reduce_sum(log_softmax * targets, reduced_dim=vocab_dim))
218+
if z_loss != 0:
219+
loss += z_loss * mtf.square(log_z)
220+
return loss
221+
222+
223+
def sigmoid_cross_entropy_with_logits(logits, targets):
224+
"""Sigmoid cross-entropy loss.
225+
226+
Args:
227+
logits: a mtf.Tensor
228+
targets: a mtf.Tensor with the same shape as logits
229+
230+
Returns:
231+
a mtf.Tensor whose shape is equal to logits.shape
232+
233+
Raises:
234+
ValueError: if the shapes do not match.
235+
"""
236+
if logits.shape != targets.shape:
237+
raise ValueError(
238+
"logits shape must equal targets shape"
239+
"logits=%s targets=%s" % (logits.to_string, targets.to_string))
240+
x = logits
241+
z = targets
242+
return mtf.relu(x) - x * z + mtf.log(1 + mtf.exp(-mtf.abs(x)))
212243

213244

214245
def weights_nonzero(targets, dtype=tf.float32):

mesh_tensorflow/ops.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,7 @@ def _square_grad(op, dy):
14201420
output_dtype: a dtype
14211421
splittable_dims: a list of Dimensions which are ok to split
14221422
grad_function: an optional python function. Default to using tf.gradients
1423+
pass in the number 0 to indicate no gradient
14231424
name: an optional string
14241425
"""
14251426
super(SlicewiseOperation, self).__init__(inputs, name=name or "slicewise")
@@ -1428,6 +1429,12 @@ def _square_grad(op, dy):
14281429
self._splittable_dims = splittable_dims
14291430
self._grad_function = grad_function
14301431

1432+
@property
1433+
def has_gradient(self):
1434+
if self._grad_function == 0:
1435+
return False
1436+
return super(SlicewiseOperation, self).has_gradient
1437+
14311438
def gradient(self, grad_ys):
14321439
if self._grad_function is not None:
14331440
return self._grad_function(self, grad_ys[0])
@@ -1547,7 +1554,8 @@ def grad_function(op, dy):
15471554
return cwise(tf.tanh, [x], name=name, grad_function=grad_function)
15481555

15491556

1550-
def pow(x, y): # pylint: disable=redefined-builtin
1557+
def mtf_pow(x, y):
1558+
"""Call externally as mtf.pow()."""
15511559
return exp(log(x) * y)
15521560

15531561

@@ -1574,6 +1582,16 @@ def relu(x, name="relu"):
15741582
return cwise(tf.nn.relu, [x], name=name, grad_function=_relu_grad)
15751583

15761584

1585+
def sign(x, name="sign"):
1586+
ret = cwise(tf.sign, [x], name=name, grad_function=0)
1587+
return ret
1588+
1589+
1590+
def mtf_abs(x):
1591+
"""Call externally as mtf.abs()."""
1592+
return x * sign(x)
1593+
1594+
15771595
def cast(x, dtype, name="cast"):
15781596
if dtype == x.dtype:
15791597
return x
@@ -2174,8 +2192,8 @@ def cumsum(x, dim, exclusive=False):
21742192
new_shape = x.shape.rename_dimension(dim.name, new_name)
21752193
comparator = less if exclusive else less_equal
21762194
m = cast(
2177-
comparator(range(x.mesh, dim, dtype=tf.float32),
2178-
range(x.mesh, new_dim, dtype=tf.float32)), x.dtype)
2195+
comparator(mtf_range(x.mesh, dim, dtype=tf.float32),
2196+
mtf_range(x.mesh, new_dim, dtype=tf.float32)), x.dtype)
21792197
ret = einsum([x, m], output_shape=new_shape)
21802198
return reshape(ret, x.shape)
21812199

@@ -3577,7 +3595,7 @@ def top_1(x, reduced_dim, dtype=tf.int32, name=None):
35773595
with tf.name_scope(name, default_name="top_1"):
35783596
max_val = reduce_max(x, reduced_dim=reduced_dim)
35793597
is_max = to_float(equal(x, max_val))
3580-
pos = range(x.mesh, reduced_dim, tf.float32)
3598+
pos = mtf_range(x.mesh, reduced_dim, tf.float32)
35813599
ret = reduce_max(is_max * pos, reduced_dim=reduced_dim)
35823600
ret = cast(ret, dtype)
35833601
return ret, max_val
@@ -3717,9 +3735,11 @@ def divide(x1, x2, output_shape=None, name=None):
37173735
return multiply(x1, reciprocal(x2), output_shape=output_shape)
37183736

37193737

3720-
def slice(x, begin, size, slice_dim_name, name=None): # pylint: disable=redefined-builtin
3738+
def mtf_slice(x, begin, size, slice_dim_name, name=None):
37213739
"""Slice operation.
37223740
3741+
Call externally as mtf.slice()
3742+
37233743
Args:
37243744
x: a list of Tensors
37253745
begin: integer, where to begin slicing from along the axis
@@ -3754,7 +3774,7 @@ def one_hot(indices, output_dim, on_value=1.0,
37543774
37553775
TODO(noam): Is there a good reason we need a special mtf.Operation here?
37563776
We could just use some code like this:
3757-
cast(equal(indices, range(indices.mesh, output_dim, dtype=indices.dtype)),
3777+
cast(equal(indices, mtf_range(indices.mesh, output_dim, dtype=indices.dtype)),
37583778
dtype)
37593779
37603780
Args:
@@ -4067,9 +4087,11 @@ def softmax(x, reduced_dim, extra_logit=None, name=None):
40674087
return exp(log_softmax(x, reduced_dim, extra_logit=extra_logit))
40684088

40694089

4070-
def range(mesh, dim, dtype, name=None): # pylint: disable=redefined-builtin
4090+
def mtf_range(mesh, dim, dtype, name=None):
40714091
"""Create a 1d mesh tensor with a range from [0, dim.size).
40724092
4093+
Call externally as mtf.range()
4094+
40734095
Args:
40744096
mesh: a Mesh
40754097
dim: a Dimension
@@ -4563,9 +4585,10 @@ def halo_exchange(x, blocks_dim, block_size_dim, halo_size, wrap=False):
45634585
parts = ([shift(x, i, blocks_dim, wrap)] + parts +
45644586
[shift(x, -i, blocks_dim, wrap)])
45654587
if partial_size > 0:
4566-
left_margin = slice(x, 0, partial_size, block_size_dim.name)
4567-
right_margin = slice(x, block_size_dim.size - partial_size, partial_size,
4568-
block_size_dim.name)
4588+
left_margin = mtf_slice(x, 0, partial_size, block_size_dim.name)
4589+
right_margin = mtf_slice(
4590+
x, block_size_dim.size - partial_size, partial_size,
4591+
block_size_dim.name)
45694592
parts = (
45704593
[shift(right_margin, num_complete_blocks + 1, blocks_dim, wrap)]
45714594
+ parts +
@@ -4600,8 +4623,9 @@ def left_halo_exchange(x, blocks_dim, block_size_dim, halo_size, wrap=False):
46004623
for i in xrange(1, num_complete_blocks + 1):
46014624
parts = ([shift(x, i, blocks_dim, wrap)] + parts)
46024625
if partial_size > 0:
4603-
right_margin = slice(x, block_size_dim.size - partial_size, partial_size,
4604-
block_size_dim.name)
4626+
right_margin = mtf_slice(
4627+
x, block_size_dim.size - partial_size, partial_size,
4628+
block_size_dim.name)
46054629
parts = ([shift(right_margin, num_complete_blocks + 1, blocks_dim, wrap)]
46064630
+ parts)
46074631
return concat(parts, block_size_dim.name)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Mesh TensorFlow Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Mesh TensorFlow."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from mesh_tensorflow.ops import * # pylint: disable=wildcard-import
23+
from mesh_tensorflow.ops import mtf_abs as abs # pylint: disable=redefined-builtin,unused-import
24+
from mesh_tensorflow.ops import mtf_pow as pow # pylint: disable=redefined-builtin,unused-import
25+
from mesh_tensorflow.ops import mtf_range as range # pylint: disable=redefined-builtin,unused-import
26+
from mesh_tensorflow.ops import mtf_slice as slice # pylint: disable=redefined-builtin,unused-import
27+
28+
29+
30+
# TODO(trandustin): Seal module.
31+
# from tensorflow.python.util.all_util import remove_undocumented # pylint: disable=line-too-long
32+
#
33+
# _allowed_symbols = None
34+
#
35+
# remove_undocumented(__name__, _allowed_symbols)

mesh_tensorflow/optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from __future__ import division
2121
from __future__ import print_function
2222

23-
from mesh_tensorflow import ops as mtf
23+
from mesh_tensorflow import ops_with_redefined_builtins as mtf
2424
import tensorflow as tf
2525

2626

mesh_tensorflow/placement_mesh_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import functools
2222

23-
from mesh_tensorflow import ops as mtf
23+
from mesh_tensorflow import ops_with_redefined_builtins as mtf
2424
from six.moves import xrange # pylint: disable=redefined-builtin
2525

2626
import tensorflow as tf

mesh_tensorflow/simd_mesh_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
from mesh_tensorflow import ops as mtf
22+
from mesh_tensorflow import ops_with_redefined_builtins as mtf
2323
from mesh_tensorflow import tpu_variables
2424
from mesh_tensorflow import utils
2525
from six.moves import xrange # pylint: disable=redefined-builtin

mesh_tensorflow/transformer/transformer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def __init__(self,
323323
max_length,
324324
shared_embedding_and_softmax_weights=False,
325325
label_smoothing=0.0,
326+
z_loss=1e-4,
326327
name="transformer"):
327328
self.layer_stack = layer_stack
328329
self.model_dim = mtf.Dimension("d_model", d_model)
@@ -338,6 +339,7 @@ def __init__(self,
338339
self.shared_embedding_and_softmax_weights = (
339340
shared_embedding_and_softmax_weights)
340341
self.label_smoothing = label_smoothing
342+
self.z_loss = z_loss
341343
self.name = name
342344

343345
def _call_internal(self, context, inputs, targets=None):
@@ -381,27 +383,25 @@ def _call_internal(self, context, inputs, targets=None):
381383
if self.output_vocab_dim is None:
382384
return x
383385
if self.shared_embedding_and_softmax_weights:
384-
logits = tf.einsum(
386+
logits = mtf.einsum(
385387
[x * (self.model_dim ** -0.5), embedding_weights],
386388
reduced_dims=[self.model_dim])
387389
else:
388390
logits = mtf.layers.dense(
389391
x, self.output_vocab_dim, use_bias=False,
390392
variable_dtype=context.variable_dtype,
391393
name="logits")
392-
if context.train:
393-
logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
394394
if targets is not None and context.losses is not None:
395395
off_value = self.label_smoothing / self.output_vocab_dim.size
396396
on_value = 1.0 - self.label_smoothing + off_value
397+
soft_targets = mtf.one_hot(
398+
targets, self.output_vocab_dim,
399+
dtype=context.activation_dtype,
400+
on_value=on_value,
401+
off_value=off_value)
397402
loss = mtf.layers.softmax_cross_entropy_with_logits(
398-
logits,
399-
mtf.one_hot(
400-
targets, self.output_vocab_dim,
401-
dtype=context.activation_dtype,
402-
on_value=on_value,
403-
off_value=off_value),
404-
self.output_vocab_dim)
403+
logits, soft_targets, self.output_vocab_dim,
404+
z_loss=self.z_loss if context.train else 0.0)
405405
weights = mtf.layers.weights_nonzero(
406406
targets, dtype=context.activation_dtype)
407407
loss = mtf.reduce_mean(loss * weights)
@@ -674,6 +674,7 @@ def __init__(self,
674674
max_length,
675675
shared_embedding=True,
676676
label_smoothing=0.0,
677+
z_loss=1e-4,
677678
encoder_name="encoder",
678679
decoder_name="decoder"):
679680
self.encoder = Unitransformer(
@@ -692,6 +693,7 @@ def __init__(self,
692693
autoregressive=True,
693694
max_length=max_length,
694695
label_smoothing=label_smoothing,
696+
z_loss=z_loss,
695697
name=decoder_name)
696698
self.shared_embedding = shared_embedding
697699

0 commit comments

Comments
 (0)