Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 3457056

Browse files
nshazeerCopybara-Service
authored andcommitted
Fix toy_model_tpu.
PiperOrigin-RevId: 219566591
1 parent b2756dd commit 3457056

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

examples/toy_model_tpu.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@
3434
FLAGS = flags.FLAGS
3535

3636
tf.flags.DEFINE_integer('batch_size', 64, 'Training batch size.')
37-
tf.flags.DEFINE_integer('io_size', 2, 'Number of channels per feature.')
38-
tf.flags.DEFINE_integer('hidden_size', 2, 'Size of each hidden layer.')
37+
tf.flags.DEFINE_integer('io_size', 16, 'Number of channels per feature.')
38+
tf.flags.DEFINE_integer('hidden_size', 16, 'Size of each hidden layer.')
3939
tf.flags.DEFINE_integer('num_hidden_layers', 1, 'Number of layers.')
4040
tf.flags.DEFINE_string('master_dtype', 'bfloat16', 'dtype for master vars.')
4141
tf.flags.DEFINE_string('slice_dtype', 'float32', 'dtype for slice vars.')
4242
tf.flags.DEFINE_string('activation_dtype', 'float32', 'dtype for activations.')
4343
tf.flags.DEFINE_string('optimizer', 'SGD', 'optimizer (SGD or Adafactor).')
4444
tf.flags.DEFINE_string('mesh_shape', 'all:8', 'mesh shape')
45-
tf.flags.DEFINE_string('layout', 'hidden:all', 'layout rules')
45+
tf.flags.DEFINE_string('layout', 'hidden_odd:all', 'layout rules')
4646
tf.flags.DEFINE_integer('iterations', 100,
4747
'Number of iterations per training loop.')
4848
tf.flags.DEFINE_integer('train_steps', 10000, 'max steps')
@@ -112,8 +112,9 @@ def toy_model(features, mesh):
112112
x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
113113
x = mtf.cast(x, activation_dtype)
114114
h = x
115-
for lnum in xrange(FLAGS.num_hidden_layers + 1):
116-
if lnum + 1 == FLAGS.num_hidden_layers + 1:
115+
for lnum in xrange(1, FLAGS.num_hidden_layers + 2):
116+
if lnum + 1 == FLAGS.num_hidden_layers + 2:
117+
# output layer
117118
dim = io_dim
118119
elif lnum % 2 == 0:
119120
dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)

0 commit comments

Comments
 (0)