Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit b2756dd

Browse files
nshazeerCopybara-Service
authored andcommitted
Many changes to mesh-tensorflow - breaks existing mtf model checkpoints.
PiperOrigin-RevId: 219522764
1 parent 58a1f2c commit b2756dd

File tree

7 files changed

+362
-110
lines changed

7 files changed

+362
-110
lines changed

examples/toy_model_tpu.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
tf.flags.DEFINE_integer('batch_size', 64, 'Training batch size.')
3737
tf.flags.DEFINE_integer('io_size', 2, 'Number of channels per feature.')
3838
tf.flags.DEFINE_integer('hidden_size', 2, 'Size of each hidden layer.')
39+
tf.flags.DEFINE_integer('num_hidden_layers', 1, 'Number of layers.')
40+
tf.flags.DEFINE_string('master_dtype', 'bfloat16', 'dtype for master vars.')
41+
tf.flags.DEFINE_string('slice_dtype', 'float32', 'dtype for slice vars.')
42+
tf.flags.DEFINE_string('activation_dtype', 'float32', 'dtype for activations.')
43+
tf.flags.DEFINE_string('optimizer', 'SGD', 'optimizer (SGD or Adafactor).')
3944
tf.flags.DEFINE_string('mesh_shape', 'all:8', 'mesh shape')
4045
tf.flags.DEFINE_string('layout', 'hidden:all', 'layout rules')
4146
tf.flags.DEFINE_integer('iterations', 100,
@@ -48,6 +53,7 @@
4853
'model_dir',
4954
default='',
5055
help='The directory where the model will be stored.')
56+
tf.flags.DEFINE_bool('use_tpu', True, 'use TPU')
5157

5258
# Cloud TPU Cluster Resolvers
5359
tf.flags.DEFINE_string(
@@ -97,14 +103,31 @@ def __call__(self, params):
97103
def toy_model(features, mesh):
98104
"""A toy model implemented by mesh tensorlfow."""
99105
batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
100-
hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size)
101106
io_dim = mtf.Dimension('io', FLAGS.io_size)
102107

103-
x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
104-
h = mtf.layers.dense(x, hidden_dim, name='layer1', use_bias=False)
105-
y = mtf.layers.dense(h, io_dim, name='layer2', use_bias=False)
108+
master_dtype = tf.as_dtype(FLAGS.master_dtype)
109+
slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
110+
activation_dtype = tf.as_dtype(FLAGS.activation_dtype)
106111

107-
loss = mtf.reduce_sum(mtf.square(y - x))
112+
x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
113+
x = mtf.cast(x, activation_dtype)
114+
h = x
115+
for lnum in xrange(FLAGS.num_hidden_layers + 1):
116+
if lnum + 1 == FLAGS.num_hidden_layers + 1:
117+
dim = io_dim
118+
elif lnum % 2 == 0:
119+
dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)
120+
else:
121+
dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size)
122+
h = mtf.layers.dense(
123+
h, dim,
124+
use_bias=False,
125+
master_dtype=master_dtype,
126+
slice_dtype=slice_dtype,
127+
name='layer_%d' % lnum)
128+
y = h
129+
130+
loss = mtf.reduce_mean(mtf.square(y - x))
108131
return y, loss
109132

110133

@@ -113,20 +136,43 @@ def model_fn(features, labels, mode, params):
113136
del labels
114137
global_step = tf.train.get_global_step()
115138
graph = mtf.Graph()
116-
mesh = mtf.Mesh(graph, 'my_mesh')
117139
mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
118-
mesh_devices = [''] * mesh_shape.size
119-
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
120-
mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout),
121-
mesh_devices, params['context'].device_assignment)
140+
layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
141+
if FLAGS.use_tpu:
142+
ctx = params['context']
143+
num_hosts = ctx.num_hosts
144+
host_placement_fn = ctx.tpu_host_placement_function
145+
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
146+
tf.logging.info('device_list = %s' % device_list,)
147+
# TODO(ylc): Better estimation of replica cache size?
148+
replica_cache_size = 300 * 1000000 # 300M per replica
149+
# Worker 0 caches all the TPU binaries.
150+
worker0_mem = replica_cache_size * ctx.num_replicas
151+
devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
152+
var_placer = mtf.utils.BalancedVariablePlacer(device_list,
153+
devices_memeory_usage)
154+
mesh_devices = [''] * mesh_shape.size
155+
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
156+
mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
157+
else:
158+
var_placer = None
159+
mesh_devices = [''] * mesh_shape.size
160+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
161+
mesh_shape, layout_rules, mesh_devices)
162+
mesh = mtf.Mesh(graph, 'my_mesh', var_placer)
163+
122164
with mtf.utils.outside_all_rewrites():
123165
logits, loss = toy_model(features, mesh)
124166

125167
# TRAIN mode
126168
if mode == tf.estimator.ModeKeys.TRAIN:
127169
var_grads = mtf.gradients([loss],
128170
[v.outputs[0] for v in graph.trainable_variables])
129-
optimizer = mtf.optimize.AdafactorOptimizer()
171+
if FLAGS.optimizer == 'Adafactor':
172+
optimizer = mtf.optimize.AdafactorOptimizer()
173+
else:
174+
assert FLAGS.optimizer == 'SGD'
175+
optimizer = mtf.optimize.SgdOptimizer(lr=1e-4)
130176
update_ops = []
131177
for grad, var in zip(var_grads, graph.trainable_variables):
132178
update_ops.extend(optimizer.apply_grad(grad, var))
@@ -136,7 +182,7 @@ def model_fn(features, labels, mode, params):
136182

137183
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
138184

139-
tf_loss = lowering.export_to_tf_tensor(loss)
185+
tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))
140186

141187
if mode == tf.estimator.ModeKeys.TRAIN:
142188
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
@@ -173,8 +219,8 @@ def model_fn(features, labels, mode, params):
173219
elif mode == tf.estimator.ModeKeys.EVAL:
174220

175221
def metric_fn(tf_logits):
176-
mean_logitss = tf.metrics.mean(tf_logits)
177-
return {'mean_logitss': mean_logitss}
222+
mean_logits = tf.metrics.mean(tf_logits)
223+
return {'mean_logits': mean_logits}
178224

179225
eval_metrics = (metric_fn, [tf_logits])
180226

mesh_tensorflow/beam_search.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def beam_search(logits_fn,
9595
eos_id=EOS_ID,
9696
stop_early=True,
9797
decode_length=None,
98-
use_tpu=True):
98+
use_tpu=True,
99+
dtype=tf.float32):
99100
"""Beam search with length penalties.
100101
101102
Requires a function that can take the currently decoded symbols and return
@@ -128,14 +129,15 @@ def beam_search(logits_fn,
128129
step_num - mtf Scalar
129130
ids - mtf Tensor with shape [batch, beam, length]
130131
Should return:
131-
logits - [batch, beam, vocab_size]
132+
logits - [batch, beam, vocab_size], dtype=dtype
132133
initial_ids: a mtf.Tensor with shape [batch_dim, beam_dim, length_dim])
133134
alpha: alpha for length penalty.
134135
states: list of mtf.Tensor
135136
eos_id: ID for end of sentence.
136137
stop_early: a boolean - stop once best sequence is provably determined.
137138
decode_length: a mtf Scalar of dtype tf.int32 - maximum length of decodes
138139
use_tpu: a boolean
140+
dtype: a tf.dtype
139141
Returns:
140142
Tuple of
141143
(decoded beams [batch, beam, length]
@@ -150,7 +152,8 @@ def beam_search(logits_fn,
150152
mtf.constant(mesh, 0, dtype=tf.int32),
151153
beam_dim,
152154
on_value=0.0,
153-
off_value=-INF),
155+
off_value=-INF,
156+
dtype=dtype),
154157
batch_by_beam)
155158

156159
length_scalar = mtf.constant(mesh, length_dim.size, dtype=tf.int32)
@@ -166,7 +169,7 @@ def beam_search(logits_fn,
166169
# Finished log probs will be negative infinity in the beginning
167170
# finished_flags will keep track of booleans
168171
finished_seq = initial_ids
169-
finished_scores = mtf.constant(mesh, -INF, batch_by_beam)
172+
finished_scores = mtf.constant(mesh, -INF, batch_by_beam, dtype=dtype)
170173

171174
# Setting the scores of the initial to negative infinity.
172175
finished_flags = mtf.constant(mesh, False, batch_by_beam, tf.bool)
@@ -197,7 +200,7 @@ def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq,
197200

198201
# Set the scores of the unfinished seq in curr_seq to large negative
199202
# values
200-
curr_scores += (1. - mtf.to_float(curr_finished)) * -INF
203+
curr_scores += (1. - mtf.cast(curr_finished, curr_scores.dtype)) * -INF
201204
unused_batch_dim, beam_dim, unused_length_dim = finished_seq.shape.dims
202205
# concatenating the sequences and scores along beam axis
203206
def _my_concat(a, b):
@@ -232,7 +235,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states):
232235
"""
233236
# Set the scores of the finished seq in curr_seq to large negative
234237
# values
235-
curr_scores += mtf.to_float(curr_finished) * -INF
238+
curr_scores += mtf.cast(curr_finished, curr_scores.dtype) * -INF
236239
return compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs,
237240
curr_finished, beam_dim,
238241
"grow_alive", states)
@@ -273,7 +276,7 @@ def grow_topk(i, alive_seq, alive_log_probs, states=None):
273276
# (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
274277
log_probs = candidate_log_probs + alive_log_probs
275278

276-
length_penalty = mtf.pow(((5. + mtf.to_float(i + 1)) / 6.), alpha)
279+
length_penalty = mtf.pow(((5. + mtf.cast(i + 1, logits.dtype)) / 6.), alpha)
277280

278281
curr_scores = log_probs / length_penalty
279282

@@ -401,7 +404,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
401404
if not stop_early:
402405
return mtf.less(i, decode_length)
403406
max_length_penalty = mtf.pow(
404-
((5. + mtf.to_float(decode_length)) / 6.), alpha)
407+
((5. + mtf.cast(decode_length, finished_scores.dtype)) / 6.), alpha)
405408
# The best possible score of the most likely alive sequence.
406409
lower_bound_alive_scores = mtf.gather(
407410
alive_log_probs, mtf.constant(mesh, 0, dtype=tf.int32),
@@ -412,16 +415,17 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
412415
# scores are all -ve, taking the min will give us the score of the lowest
413416
# finished item.
414417
lowest_score_of_finished_in_finished = mtf.reduce_min(
415-
finished_scores * mtf.to_float(finished_in_finished),
418+
finished_scores * mtf.cast(finished_in_finished, finished_scores.dtype),
416419
reduced_dim=beam_dim)
417420

418421
# If none of the sequences have finished, then the min will be 0 and
419422
# we have to replace it by -ve INF if it is. The score of any seq in alive
420423
# will be much higher than -ve INF and the termination condition will not
421424
# be met.
422425
lowest_score_of_finished_in_finished += (
423-
(1. - mtf.to_float(mtf.reduce_any(
424-
finished_in_finished, reduced_dim=beam_dim))) * -INF)
426+
(1. - mtf.cast(mtf.reduce_any(
427+
finished_in_finished, reduced_dim=beam_dim),
428+
finished_scores.dtype)) * -INF)
425429

426430
bound_is_met = mtf.reduce_all(
427431
mtf.greater(lowest_score_of_finished_in_finished,

0 commit comments

Comments
 (0)