@@ -115,12 +115,11 @@ def rdd_generator():
115115 return
116116
117117 ds = tf .data .Dataset .from_generator (rdd_generator , (tf .float32 , tf .float32 ), (tf .TensorShape ([2 ]), tf .TensorShape ([1 ])))
118- ds = ds .batch (args .batch_size )
119-
120- # disable auto-sharding dataset
118+ # disable auto-sharding since we're feeding from an RDD generator
121119 options = tf .data .Options ()
122- options .experimental_distribute .auto_shard = False
120+ options .experimental_distribute .auto_shard_policy = tf . data . experimental . AutoShardPolicy . OFF
123121 ds = ds .with_options (options )
122+ ds = ds .batch (args .batch_size )
124123
125124 # only train 90% of each epoch to account for uneven RDD partition sizes
126125 steps_per_epoch = 1000 * 0.9 // (args .batch_size * ctx .num_workers )
@@ -134,9 +133,8 @@ def rdd_generator():
134133 # This fails with: "NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy"
135134 # model.fit_generator(ds, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
136135
137- if ctx .job_name == 'chief' and args .export_dir :
138- print ("exporting model to: {}" .format (args .export_dir ))
139- tf .keras .experimental .export_saved_model (model , args .export_dir )
136+ print ("exporting model to: {}" .format (args .export_dir ))
137+ model .save (args .export_dir , save_format = 'tf' )
140138
141139 tf_feed .terminate ()
142140
0 commit comments