|
34 | 34 | FLAGS = flags.FLAGS |
35 | 35 |
|
36 | 36 | tf.flags.DEFINE_integer('batch_size', 64, 'Training batch size.') |
37 | | -tf.flags.DEFINE_integer('io_size', 2, 'Number of channels per feature.') |
38 | | -tf.flags.DEFINE_integer('hidden_size', 2, 'Size of each hidden layer.') |
| 37 | +tf.flags.DEFINE_integer('io_size', 16, 'Number of channels per feature.') |
| 38 | +tf.flags.DEFINE_integer('hidden_size', 16, 'Size of each hidden layer.') |
39 | 39 | tf.flags.DEFINE_integer('num_hidden_layers', 1, 'Number of layers.') |
40 | 40 | tf.flags.DEFINE_string('master_dtype', 'bfloat16', 'dtype for master vars.') |
41 | 41 | tf.flags.DEFINE_string('slice_dtype', 'float32', 'dtype for slice vars.') |
42 | 42 | tf.flags.DEFINE_string('activation_dtype', 'float32', 'dtype for activations.') |
43 | 43 | tf.flags.DEFINE_string('optimizer', 'SGD', 'optimizer (SGD or Adafactor).') |
44 | 44 | tf.flags.DEFINE_string('mesh_shape', 'all:8', 'mesh shape') |
45 | | -tf.flags.DEFINE_string('layout', 'hidden:all', 'layout rules') |
| 45 | +tf.flags.DEFINE_string('layout', 'hidden_odd:all', 'layout rules') |
46 | 46 | tf.flags.DEFINE_integer('iterations', 100, |
47 | 47 | 'Number of iterations per training loop.') |
48 | 48 | tf.flags.DEFINE_integer('train_steps', 10000, 'max steps') |
@@ -112,8 +112,9 @@ def toy_model(features, mesh): |
112 | 112 | x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) |
113 | 113 | x = mtf.cast(x, activation_dtype) |
114 | 114 | h = x |
115 | | - for lnum in xrange(FLAGS.num_hidden_layers + 1): |
116 | | - if lnum + 1 == FLAGS.num_hidden_layers + 1: |
| 115 | + for lnum in xrange(1, FLAGS.num_hidden_layers + 2): |
| 116 | + if lnum + 1 == FLAGS.num_hidden_layers + 2: |
| 117 | + # output layer |
117 | 118 | dim = io_dim |
118 | 119 | elif lnum % 2 == 0: |
119 | 120 | dim = mtf.Dimension('hidden_even', FLAGS.hidden_size) |
|
0 commit comments