Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit f3075ff

Browse files
Dustin TranCopybara-Service
authored andcommitted
Remove "mtf_" prefix to module and use mtf namespace.
PiperOrigin-RevId: 217225608
1 parent 4119b79 commit f3075ff

15 files changed

+49
-65
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1111

1212
Transformer for EN-FR WMT with model splitting | Transformer for EN-FR WMT with data splitting
1313
:-------------------------:|:-------------------------:
14-
![model_splitting](./mtf_transformer_model_splitting.png) | ![data_splitting](./mtf_transformer_data_splitting.png)
14+
![model_splitting](./transformer_model_splitting.png) | ![data_splitting](./transformer_data_splitting.png)
1515

1616
# Introduction
1717

@@ -116,7 +116,7 @@ w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim])
116116
# einsum is a generalization of matrix multiplication (see numpy.einsum)
117117
hidden = mtf.relu(mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim]))
118118
logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim])
119-
loss = mtf.reduce_mean(mtf_layers.softmax_cross_entropy_with_logits(
119+
loss = mtf.reduce_mean(mtf.layers.softmax_cross_entropy_with_logits(
120120
logits, mtf.one_hot(labels, classes_dim), classes_dim))
121121
w1_grad, w2_grad = mtf.gradients([loss], [w1, w2])
122122
update_w1_op = mtf.assign(w1, w1 - w1_grad * 0.001)
@@ -132,7 +132,7 @@ computation.
132132
devices = ["gpu:0", "gpu:1", "gpu:2", "gpu:3"]
133133
mesh_shape = [("all_processors", 4)]
134134
layout_rules = [("batch", "all_processors")]
135-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
135+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
136136
mesh_shape, layout_rules, devices)
137137
lowering = mtf.Lowering(graph, {mesh:mesh_impl})
138138
tf_update_ops = [lowering.lowered_operation(update_w1_op),
@@ -371,7 +371,7 @@ TPU_NAME=ylc-mtf-donut
371371
# 2 ways data-parallelism and 4 ways model-parallelism.
372372
# In this configuration, we split the batch dimension into 2 cores and the
373373
# hidden dimension into 4 cores.
374-
python examples/mtf_toy_model_tpu.py \
374+
python examples/toy_model_tpu.py \
375375
--tpu=$TPU \
376376
--model_dir=$MODEL_DIR \
377377
--io_size=8 \
@@ -381,7 +381,7 @@ python examples/mtf_toy_model_tpu.py \
381381

382382
# 8 ways model-parallelism.
383383
# In this configuration, We split the hidden dimension into 8 cores.
384-
python examples/mtf_toy_model_tpu.py \
384+
python examples/toy_model_tpu.py \
385385
--tpu=$TPU \
386386
--model_dir=$MODEL_DIR \
387387
--io_size=8 \

examples/mnist.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@
2323
from __future__ import print_function
2424

2525
import mesh_tensorflow as mtf
26-
27-
from mesh_tensorflow import mtf_layers
28-
from mesh_tensorflow import mtf_optimize
29-
from mesh_tensorflow import placement_mesh_impl
30-
3126
import mnist_dataset as dataset # local file import
3227
import tensorflow as tf
3328

@@ -104,20 +99,20 @@ def mnist_model(image, labels, mesh):
10499
hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
105100
hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
106101

107-
h1 = mtf_layers.dense(
102+
h1 = mtf.layers.dense(
108103
x, hidden_dim1,
109104
reduced_dims=x.shape.dims[-4:],
110105
activation=mtf.relu, name="hidden1")
111-
h2 = mtf_layers.dense(
106+
h2 = mtf.layers.dense(
112107
h1, hidden_dim2,
113108
activation=mtf.relu, name="hidden2")
114-
logits = mtf_layers.dense(h2, classes_dim, name="logits")
109+
logits = mtf.layers.dense(h2, classes_dim, name="logits")
115110
if labels is None:
116111
loss = None
117112
else:
118113
labels = mtf.import_tf_tensor(
119114
mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
120-
loss = mtf_layers.softmax_cross_entropy_with_logits(
115+
loss = mtf.layers.softmax_cross_entropy_with_logits(
121116
logits, mtf.one_hot(labels, classes_dim), classes_dim)
122117
loss = mtf.reduce_mean(loss)
123118
return logits, loss
@@ -135,13 +130,13 @@ def model_fn(features, labels, mode, params):
135130
layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
136131
mesh_size = mesh_shape.size
137132
mesh_devices = [""] * mesh_size
138-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
133+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
139134
mesh_shape, layout_rules, mesh_devices)
140135

141136
if mode == tf.estimator.ModeKeys.TRAIN:
142137
var_grads = mtf.gradients(
143138
[loss], [v.outputs[0] for v in graph.trainable_variables])
144-
optimizer = mtf_optimize.AdafactorOptimizer()
139+
optimizer = mtf.optimize.AdafactorOptimizer()
145140
update_ops = []
146141
for grad, var in zip(var_grads, graph.trainable_variables):
147142
update_ops.extend(optimizer.apply_grad(grad, var))

examples/mtf_toy_model_tpu.py renamed to examples/toy_model_tpu.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@
2020
from __future__ import print_function
2121

2222
import mesh_tensorflow as mtf
23-
24-
from mesh_tensorflow import mtf_layers
25-
from mesh_tensorflow import mtf_optimize
26-
from mesh_tensorflow import mtf_utils
27-
from mesh_tensorflow.simd_mesh_impl import SimdMeshImpl
28-
2923
import numpy
3024
import tensorflow as tf
3125

@@ -107,8 +101,8 @@ def toy_model(features, mesh):
107101
io_dim = mtf.Dimension('io', FLAGS.io_size)
108102

109103
x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
110-
h = mtf_layers.dense(x, hidden_dim, name='layer1', use_bias=False)
111-
y = mtf_layers.dense(h, io_dim, name='layer2', use_bias=False)
104+
h = mtf.layers.dense(x, hidden_dim, name='layer1', use_bias=False)
105+
y = mtf.layers.dense(h, io_dim, name='layer2', use_bias=False)
112106

113107
loss = mtf.reduce_sum(mtf.square(y - x))
114108
return y, loss
@@ -122,17 +116,17 @@ def model_fn(features, labels, mode, params):
122116
mesh = mtf.Mesh(graph, 'my_mesh')
123117
mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
124118
mesh_devices = [''] * mesh_shape.size
125-
mesh_impl = SimdMeshImpl(
119+
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
126120
mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout),
127121
mesh_devices, params['context'].device_assignment)
128-
with mtf_utils.outside_all_rewrites():
122+
with mtf.utils.outside_all_rewrites():
129123
logits, loss = toy_model(features, mesh)
130124

131125
# TRAIN mode
132126
if mode == tf.estimator.ModeKeys.TRAIN:
133127
var_grads = mtf.gradients([loss],
134128
[v.outputs[0] for v in graph.trainable_variables])
135-
optimizer = mtf_optimize.AdafactorOptimizer()
129+
optimizer = mtf.optimize.AdafactorOptimizer()
136130
update_ops = []
137131
for grad, var in zip(var_grads, graph.trainable_variables):
138132
update_ops.extend(optimizer.apply_grad(grad, var))
@@ -152,7 +146,7 @@ def model_fn(features, labels, mode, params):
152146
else:
153147
tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)
154148

155-
with mtf_utils.outside_all_rewrites():
149+
with mtf.utils.outside_all_rewrites():
156150
# Copy master variables to slices. Must be called first.
157151
restore_hook = mtf.MtfRestoreHook(lowering)
158152
if mode == tf.estimator.ModeKeys.TRAIN:

mesh_tensorflow/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
from mesh_tensorflow import mtf_beam_search
23-
from mesh_tensorflow import mtf_layers
24-
from mesh_tensorflow import mtf_optimize
25-
from mesh_tensorflow import mtf_utils
22+
from mesh_tensorflow import beam_search
23+
from mesh_tensorflow import layers
24+
from mesh_tensorflow import optimize
2625
from mesh_tensorflow import placement_mesh_impl
2726
from mesh_tensorflow import simd_mesh_impl
2827
from mesh_tensorflow import tpu_variables
28+
from mesh_tensorflow import utils
2929
from mesh_tensorflow.ops import * # pylint: disable=wildcard-import
3030

3131
# TODO(trandustin): Seal module.
File renamed without changes.

mesh_tensorflow/mtf_layers.py renamed to mesh_tensorflow/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Layers for mesh tensorflow."""
16+
"""Layers implemented in Mesh TensorFlow."""
1717

1818
from __future__ import absolute_import
1919
from __future__ import division

mesh_tensorflow/mtf_layers_test.py renamed to mesh_tensorflow/layers_test.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@
2323

2424
import mesh_tensorflow as mtf
2525

26-
from mesh_tensorflow import mtf_layers
27-
from mesh_tensorflow import placement_mesh_impl
2826
from tensor2tensor.layers import common_layers
2927

3028
import tensorflow as tf
3129

3230

33-
class MtfLayersTest(parameterized.TestCase, tf.test.TestCase):
31+
class LayersTest(parameterized.TestCase, tf.test.TestCase):
3432

3533
@parameterized.parameters(
3634
(4, True),
@@ -49,12 +47,12 @@ def testDense(self, units, use_bias):
4947

5048
mtf_inputs = mtf.import_tf_tensor(
5149
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
52-
mtf_outputs = mtf_layers.dense(mtf_inputs,
50+
mtf_outputs = mtf.layers.dense(mtf_inputs,
5351
output_dim=depth_dim,
5452
reduced_dims=[channels_dim],
5553
activation=mtf.relu,
5654
use_bias=use_bias)
57-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
55+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
5856
shape=[], layout={}, devices=[""])
5957
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
6058
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
@@ -83,9 +81,9 @@ def testLayerNorm(self):
8381

8482
mtf_inputs = mtf.import_tf_tensor(
8583
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
86-
mtf_outputs = mtf_layers.layer_norm(mtf_inputs,
84+
mtf_outputs = mtf.layers.layer_norm(mtf_inputs,
8785
dim=channels_dim)
88-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
86+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
8987
shape=[], layout={}, devices=[""])
9088
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
9189
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
@@ -110,8 +108,8 @@ def testWeightsNonzero(self):
110108

111109
mtf_inputs = mtf.import_tf_tensor(
112110
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
113-
mtf_outputs = mtf_layers.weights_nonzero(mtf_inputs)
114-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
111+
mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs)
112+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
115113
shape=[], layout={}, devices=[""])
116114
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
117115
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
@@ -138,9 +136,9 @@ def testDenseReluDense(self):
138136

139137
mtf_inputs = mtf.import_tf_tensor(
140138
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
141-
mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs,
139+
mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs,
142140
hidden_channels=hidden_dim)
143-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
141+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
144142
shape=[], layout={}, devices=[""])
145143
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
146144
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
@@ -179,13 +177,13 @@ def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels,
179177
mtf_memory = mtf.import_tf_tensor(
180178
mesh, memory,
181179
shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim]))
182-
mtf_outputs = mtf_layers.masked_local_attention_1d(
180+
mtf_outputs = mtf.layers.masked_local_attention_1d(
183181
mtf_query,
184182
mtf_memory,
185183
kv_channels=kv_channels_dim,
186184
heads=heads_dim,
187185
block_length=block_length)
188-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
186+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
189187
shape=[], layout={}, devices=[""])
190188
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
191189
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
@@ -228,12 +226,12 @@ def testDotProductAttention(
228226
mesh, value,
229227
shape=mtf.Shape(
230228
[batch_dim, heads_dim, length_kv_dim, depth_v_dim]))
231-
mtf_outputs = mtf_layers.dot_product_attention(
229+
mtf_outputs = mtf.layers.dot_product_attention(
232230
mtf_query,
233231
mtf_key,
234232
mtf_value,
235233
mask=None)
236-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
234+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
237235
shape=[], layout={}, devices=[""])
238236
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
239237
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
@@ -267,13 +265,13 @@ def testMultiheadAttention(self, kv_channels, heads):
267265
mtf_query = mtf.import_tf_tensor(
268266
mesh, query,
269267
shape=mtf.Shape([batch_dim, length_dim, channels_dim]))
270-
mtf_outputs = mtf_layers.multihead_attention(
268+
mtf_outputs = mtf.layers.multihead_attention(
271269
mtf_query,
272270
memory_antecedent=None,
273271
mask=None,
274272
kv_channels=kv_channels_dim,
275273
heads=heads_dim)
276-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
274+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
277275
shape=[], layout={}, devices=[""])
278276
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
279277
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

mesh_tensorflow/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from operator import mul
2424
import re
2525

26-
from mesh_tensorflow import mtf_utils
26+
from mesh_tensorflow import utils
2727
import six
2828
from six.moves import xrange # pylint: disable=redefined-builtin
2929

@@ -2526,7 +2526,7 @@ def __init__(self, mesh, name, shape, dtype, initializer,
25262526
trainable, **kwargs):
25272527
super(Variable, self).__init__([], mesh, name="name_will_be_set_later")
25282528
self._trainable = trainable
2529-
with tf.device(mesh.variable_placer_fn), mtf_utils.outside_all_rewrites():
2529+
with tf.device(mesh.variable_placer_fn), utils.outside_all_rewrites():
25302530
self.master = tf.get_variable(
25312531
name, shape.to_integer_list, dtype=dtype, initializer=initializer,
25322532
**kwargs)
@@ -2538,7 +2538,7 @@ def __init__(self, mesh, name, shape, dtype, initializer,
25382538

25392539
def lower(self, lowering):
25402540
mesh_impl = lowering.mesh_impl(self)
2541-
with mtf_utils.outside_all_rewrites():
2541+
with utils.outside_all_rewrites():
25422542
sv = mesh_impl.LaidOutVariable(self, mesh_impl)
25432543
lowering.variables[self] = sv
25442544
lowering.set_tensor_lowering(self.outputs[0], sv.laid_out_tensor)

mesh_tensorflow/ops_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from absl.testing import parameterized
2323

2424
import mesh_tensorflow as mtf
25-
from mesh_tensorflow import placement_mesh_impl
26-
2725
import tensorflow as tf
2826

2927

@@ -126,7 +124,7 @@ def testLowering(self):
126124
mtf_inputs = mtf.import_tf_tensor(mesh,
127125
tf_tensor=inputs,
128126
shape=mtf.Shape([]))
129-
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
127+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
130128
shape=[], layout={}, devices=[""])
131129
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
132130

mesh_tensorflow/mtf_optimize.py renamed to mesh_tensorflow/optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Mesh-Tensorflow Optimizers."""
16+
"""Mesh Tensorflow Optimizers."""
1717

1818

1919
from __future__ import absolute_import

0 commit comments

Comments
 (0)