@@ -1420,6 +1420,7 @@ def _square_grad(op, dy):
14201420 output_dtype: a dtype
14211421 splittable_dims: a list of Dimensions which are ok to split
14221422 grad_function: an optional python function. Default to using tf.gradients
1423+ pass in the number 0 to indicate no gradient
14231424 name: an optional string
14241425 """
14251426 super (SlicewiseOperation , self ).__init__ (inputs , name = name or "slicewise" )
@@ -1428,6 +1429,12 @@ def _square_grad(op, dy):
14281429 self ._splittable_dims = splittable_dims
14291430 self ._grad_function = grad_function
14301431
1432+ @property
1433+ def has_gradient (self ):
1434+ if self ._grad_function == 0 :
1435+ return False
1436+ return super (SlicewiseOperation , self ).has_gradient
1437+
14311438 def gradient (self , grad_ys ):
14321439 if self ._grad_function is not None :
14331440 return self ._grad_function (self , grad_ys [0 ])
@@ -1547,7 +1554,8 @@ def grad_function(op, dy):
15471554 return cwise (tf .tanh , [x ], name = name , grad_function = grad_function )
15481555
15491556
1550- def pow (x , y ): # pylint: disable=redefined-builtin
1557+ def mtf_pow (x , y ):
1558+ """Call externally as mtf.pow()."""
15511559 return exp (log (x ) * y )
15521560
15531561
@@ -1574,6 +1582,16 @@ def relu(x, name="relu"):
15741582 return cwise (tf .nn .relu , [x ], name = name , grad_function = _relu_grad )
15751583
15761584
1585+ def sign (x , name = "sign" ):
1586+ ret = cwise (tf .sign , [x ], name = name , grad_function = 0 )
1587+ return ret
1588+
1589+
1590+ def mtf_abs (x ):
1591+ """Call externally as mtf.abs()."""
1592+ return x * sign (x )
1593+
1594+
15771595def cast (x , dtype , name = "cast" ):
15781596 if dtype == x .dtype :
15791597 return x
@@ -2174,8 +2192,8 @@ def cumsum(x, dim, exclusive=False):
21742192 new_shape = x .shape .rename_dimension (dim .name , new_name )
21752193 comparator = less if exclusive else less_equal
21762194 m = cast (
2177- comparator (range (x .mesh , dim , dtype = tf .float32 ),
2178- range (x .mesh , new_dim , dtype = tf .float32 )), x .dtype )
2195+ comparator (mtf_range (x .mesh , dim , dtype = tf .float32 ),
2196+ mtf_range (x .mesh , new_dim , dtype = tf .float32 )), x .dtype )
21792197 ret = einsum ([x , m ], output_shape = new_shape )
21802198 return reshape (ret , x .shape )
21812199
@@ -3577,7 +3595,7 @@ def top_1(x, reduced_dim, dtype=tf.int32, name=None):
35773595 with tf .name_scope (name , default_name = "top_1" ):
35783596 max_val = reduce_max (x , reduced_dim = reduced_dim )
35793597 is_max = to_float (equal (x , max_val ))
3580- pos = range (x .mesh , reduced_dim , tf .float32 )
3598+ pos = mtf_range (x .mesh , reduced_dim , tf .float32 )
35813599 ret = reduce_max (is_max * pos , reduced_dim = reduced_dim )
35823600 ret = cast (ret , dtype )
35833601 return ret , max_val
@@ -3717,9 +3735,11 @@ def divide(x1, x2, output_shape=None, name=None):
37173735 return multiply (x1 , reciprocal (x2 ), output_shape = output_shape )
37183736
37193737
3720- def slice (x , begin , size , slice_dim_name , name = None ): # pylint: disable=redefined-builtin
3738+ def mtf_slice (x , begin , size , slice_dim_name , name = None ):
37213739 """Slice operation.
37223740
3741+ Call externally as mtf.slice()
3742+
37233743 Args:
37243744 x: a list of Tensors
37253745 begin: integer, where to begin slicing from along the axis
@@ -3754,7 +3774,7 @@ def one_hot(indices, output_dim, on_value=1.0,
37543774
37553775 TODO(noam): Is there a good reason we need a special mtf.Operation here?
37563776 We could just use some code like this:
3757- cast(equal(indices, range (indices.mesh, output_dim, dtype=indices.dtype)),
3777+ cast(equal(indices, mtf_range (indices.mesh, output_dim, dtype=indices.dtype)),
37583778 dtype)
37593779
37603780 Args:
@@ -4067,9 +4087,11 @@ def softmax(x, reduced_dim, extra_logit=None, name=None):
40674087 return exp (log_softmax (x , reduced_dim , extra_logit = extra_logit ))
40684088
40694089
4070- def range (mesh , dim , dtype , name = None ): # pylint: disable=redefined-builtin
4090+ def mtf_range (mesh , dim , dtype , name = None ):
40714091 """Create a 1d mesh tensor with a range from [0, dim.size).
40724092
4093+ Call externally as mtf.range()
4094+
40734095 Args:
40744096 mesh: a Mesh
40754097 dim: a Dimension
@@ -4563,9 +4585,10 @@ def halo_exchange(x, blocks_dim, block_size_dim, halo_size, wrap=False):
45634585 parts = ([shift (x , i , blocks_dim , wrap )] + parts +
45644586 [shift (x , - i , blocks_dim , wrap )])
45654587 if partial_size > 0 :
4566- left_margin = slice (x , 0 , partial_size , block_size_dim .name )
4567- right_margin = slice (x , block_size_dim .size - partial_size , partial_size ,
4568- block_size_dim .name )
4588+ left_margin = mtf_slice (x , 0 , partial_size , block_size_dim .name )
4589+ right_margin = mtf_slice (
4590+ x , block_size_dim .size - partial_size , partial_size ,
4591+ block_size_dim .name )
45694592 parts = (
45704593 [shift (right_margin , num_complete_blocks + 1 , blocks_dim , wrap )]
45714594 + parts +
@@ -4600,8 +4623,9 @@ def left_halo_exchange(x, blocks_dim, block_size_dim, halo_size, wrap=False):
46004623 for i in xrange (1 , num_complete_blocks + 1 ):
46014624 parts = ([shift (x , i , blocks_dim , wrap )] + parts )
46024625 if partial_size > 0 :
4603- right_margin = slice (x , block_size_dim .size - partial_size , partial_size ,
4604- block_size_dim .name )
4626+ right_margin = mtf_slice (
4627+ x , block_size_dim .size - partial_size , partial_size ,
4628+ block_size_dim .name )
46054629 parts = ([shift (right_margin , num_complete_blocks + 1 , blocks_dim , wrap )]
46064630 + parts )
46074631 return concat (parts , block_size_dim .name )
0 commit comments