3636tf .flags .DEFINE_integer ('batch_size' , 64 , 'Training batch size.' )
3737tf .flags .DEFINE_integer ('io_size' , 2 , 'Number of channels per feature.' )
3838tf .flags .DEFINE_integer ('hidden_size' , 2 , 'Size of each hidden layer.' )
39+ tf .flags .DEFINE_integer ('num_hidden_layers' , 1 , 'Number of layers.' )
40+ tf .flags .DEFINE_string ('master_dtype' , 'bfloat16' , 'dtype for master vars.' )
41+ tf .flags .DEFINE_string ('slice_dtype' , 'float32' , 'dtype for slice vars.' )
42+ tf .flags .DEFINE_string ('activation_dtype' , 'float32' , 'dtype for activations.' )
43+ tf .flags .DEFINE_string ('optimizer' , 'SGD' , 'optimizer (SGD or Adafactor).' )
3944tf .flags .DEFINE_string ('mesh_shape' , 'all:8' , 'mesh shape' )
4045tf .flags .DEFINE_string ('layout' , 'hidden:all' , 'layout rules' )
4146tf .flags .DEFINE_integer ('iterations' , 100 ,
4853 'model_dir' ,
4954 default = '' ,
5055 help = 'The directory where the model will be stored.' )
56+ tf .flags .DEFINE_bool ('use_tpu' , True , 'use TPU' )
5157
5258# Cloud TPU Cluster Resolvers
5359tf .flags .DEFINE_string (
@@ -97,14 +103,31 @@ def __call__(self, params):
97103def toy_model (features , mesh ):
98104 """A toy model implemented by mesh tensorlfow."""
99105 batch_dim = mtf .Dimension ('batch' , FLAGS .batch_size )
100- hidden_dim = mtf .Dimension ('hidden' , FLAGS .hidden_size )
101106 io_dim = mtf .Dimension ('io' , FLAGS .io_size )
102107
103- x = mtf . import_tf_tensor ( mesh , features , mtf . Shape ([ batch_dim , io_dim ]) )
104- h = mtf . layers . dense ( x , hidden_dim , name = 'layer1' , use_bias = False )
105- y = mtf . layers . dense ( h , io_dim , name = 'layer2' , use_bias = False )
108+ master_dtype = tf . as_dtype ( FLAGS . master_dtype )
109+ slice_dtype = tf . as_dtype ( FLAGS . slice_dtype )
110+ activation_dtype = tf . as_dtype ( FLAGS . activation_dtype )
106111
107- loss = mtf .reduce_sum (mtf .square (y - x ))
112+ x = mtf .import_tf_tensor (mesh , features , mtf .Shape ([batch_dim , io_dim ]))
113+ x = mtf .cast (x , activation_dtype )
114+ h = x
115+ for lnum in xrange (FLAGS .num_hidden_layers + 1 ):
116+ if lnum + 1 == FLAGS .num_hidden_layers + 1 :
117+ dim = io_dim
118+ elif lnum % 2 == 0 :
119+ dim = mtf .Dimension ('hidden_even' , FLAGS .hidden_size )
120+ else :
121+ dim = mtf .Dimension ('hidden_odd' , FLAGS .hidden_size )
122+ h = mtf .layers .dense (
123+ h , dim ,
124+ use_bias = False ,
125+ master_dtype = master_dtype ,
126+ slice_dtype = slice_dtype ,
127+ name = 'layer_%d' % lnum )
128+ y = h
129+
130+ loss = mtf .reduce_mean (mtf .square (y - x ))
108131 return y , loss
109132
110133
@@ -113,20 +136,43 @@ def model_fn(features, labels, mode, params):
113136 del labels
114137 global_step = tf .train .get_global_step ()
115138 graph = mtf .Graph ()
116- mesh = mtf .Mesh (graph , 'my_mesh' )
117139 mesh_shape = mtf .convert_to_shape (FLAGS .mesh_shape )
118- mesh_devices = ['' ] * mesh_shape .size
119- mesh_impl = mtf .simd_mesh_impl .SimdMeshImpl (
120- mesh_shape , mtf .convert_to_layout_rules (FLAGS .layout ),
121- mesh_devices , params ['context' ].device_assignment )
140+ layout_rules = mtf .convert_to_layout_rules (FLAGS .layout )
141+ if FLAGS .use_tpu :
142+ ctx = params ['context' ]
143+ num_hosts = ctx .num_hosts
144+ host_placement_fn = ctx .tpu_host_placement_function
145+ device_list = [host_placement_fn (host_id = t ) for t in range (num_hosts )]
146+ tf .logging .info ('device_list = %s' % device_list ,)
147+ # TODO(ylc): Better estimation of replica cache size?
148+ replica_cache_size = 300 * 1000000 # 300M per replica
149+ # Worker 0 caches all the TPU binaries.
150+ worker0_mem = replica_cache_size * ctx .num_replicas
151+ devices_memeory_usage = [worker0_mem ] + [0 ] * (num_hosts - 1 )
152+ var_placer = mtf .utils .BalancedVariablePlacer (device_list ,
153+ devices_memeory_usage )
154+ mesh_devices = ['' ] * mesh_shape .size
155+ mesh_impl = mtf .simd_mesh_impl .SimdMeshImpl (
156+ mesh_shape , layout_rules , mesh_devices , ctx .device_assignment )
157+ else :
158+ var_placer = None
159+ mesh_devices = ['' ] * mesh_shape .size
160+ mesh_impl = mtf .placement_mesh_impl .PlacementMeshImpl (
161+ mesh_shape , layout_rules , mesh_devices )
162+ mesh = mtf .Mesh (graph , 'my_mesh' , var_placer )
163+
122164 with mtf .utils .outside_all_rewrites ():
123165 logits , loss = toy_model (features , mesh )
124166
125167 # TRAIN mode
126168 if mode == tf .estimator .ModeKeys .TRAIN :
127169 var_grads = mtf .gradients ([loss ],
128170 [v .outputs [0 ] for v in graph .trainable_variables ])
129- optimizer = mtf .optimize .AdafactorOptimizer ()
171+ if FLAGS .optimizer == 'Adafactor' :
172+ optimizer = mtf .optimize .AdafactorOptimizer ()
173+ else :
174+ assert FLAGS .optimizer == 'SGD'
175+ optimizer = mtf .optimize .SgdOptimizer (lr = 1e-4 )
130176 update_ops = []
131177 for grad , var in zip (var_grads , graph .trainable_variables ):
132178 update_ops .extend (optimizer .apply_grad (grad , var ))
@@ -136,7 +182,7 @@ def model_fn(features, labels, mode, params):
136182
137183 lowering = mtf .Lowering (graph , {mesh : mesh_impl })
138184
139- tf_loss = lowering .export_to_tf_tensor (loss )
185+ tf_loss = tf . to_float ( lowering .export_to_tf_tensor (loss ) )
140186
141187 if mode == tf .estimator .ModeKeys .TRAIN :
142188 tf_update_ops = [lowering .lowered_operation (op ) for op in update_ops ]
@@ -173,8 +219,8 @@ def model_fn(features, labels, mode, params):
173219 elif mode == tf .estimator .ModeKeys .EVAL :
174220
175221 def metric_fn (tf_logits ):
176- mean_logitss = tf .metrics .mean (tf_logits )
177- return {'mean_logitss ' : mean_logitss }
222+ mean_logits = tf .metrics .mean (tf_logits )
223+ return {'mean_logits ' : mean_logits }
178224
179225 eval_metrics = (metric_fn , [tf_logits ])
180226
0 commit comments