Skip to content

Commit 699508c

Browse files
author
Lee Yang
committed
compatibility w/ TF2.1
1 parent fc24556 commit 699508c

File tree

7 files changed

+20
-40
lines changed

7 files changed

+20
-40
lines changed

examples/mnist/keras/mnist_pipeline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,9 @@ def rdd_generator():
6565

6666
multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks)
6767

68-
if ctx.job_name == 'chief':
69-
from tensorflow_estimator.python.estimator.export import export_lib
70-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
71-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
72-
# multi_worker_model.save(args.model_dir, save_format='tf')
68+
from tensorflow_estimator.python.estimator.export import export_lib
69+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
70+
multi_worker_model.save(export_dir, save_format='tf')
7371

7472
# terminating feed tells spark to skip processing further partitions
7573
tf_feed.terminate()

examples/mnist/keras/mnist_spark.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,9 @@ def rdd_generator():
6565

6666
multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks)
6767

68-
if ctx.job_name == 'chief':
69-
from tensorflow_estimator.python.estimator.export import export_lib
70-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
71-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
72-
# multi_worker_model.save(args.model_dir, save_format='tf')
68+
from tensorflow_estimator.python.estimator.export import export_lib
69+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
70+
multi_worker_model.save(export_dir, save_format='tf')
7371

7472
# terminating feed tells spark to skip processing further partitions
7573
tf_feed.terminate()

examples/mnist/keras/mnist_tf.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,9 @@ def build_and_compile_cnn_model():
6060
multi_worker_model = build_and_compile_cnn_model()
6161
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=args.steps_per_epoch, callbacks=callbacks)
6262

63-
if ctx.job_name == 'chief':
64-
from tensorflow_estimator.python.estimator.export import export_lib
65-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
66-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
67-
# multi_worker_model.save(args.model_dir, save_format='tf')
63+
from tensorflow_estimator.python.estimator.export import export_lib
64+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
65+
multi_worker_model.save(export_dir, save_format='tf')
6866

6967

7068
if __name__ == '__main__':

examples/mnist/keras/mnist_tf_ds.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ def build_and_compile_cnn_model():
8686
multi_worker_model = build_and_compile_cnn_model()
8787
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
8888

89-
if ctx.job_name == 'chief':
90-
from tensorflow_estimator.python.estimator.export import export_lib
91-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
92-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
93-
# multi_worker_model.save(args.model_dir, save_format='tf')
89+
from tensorflow_estimator.python.estimator.export import export_lib
90+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
91+
multi_worker_model.save(export_dir, save_format='tf')
9492

9593

9694
if __name__ == '__main__':

examples/segmentation/segmentation_spark.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,7 @@ def unet_model(output_channels):
159159
validation_steps=VALIDATION_STEPS,
160160
validation_data=test_dataset)
161161

162-
if ctx.job_name == 'chief':
163-
# Workaround for: https://github.com/tensorflow/tensorflow/issues/30251
164-
print("===== saving h5py model")
165-
model.save(args.model_dir + ".h5")
166-
print("===== re-loading model w/o DistributionStrategy")
167-
new_model = tf.keras.models.load_model(args.model_dir + ".h5")
168-
print("===== exporting saved_model")
169-
tf.keras.experimental.export_saved_model(new_model, args.export_dir)
170-
print("===== done exporting")
171-
else:
172-
print("===== sleeping")
173-
time.sleep(90)
162+
model.save(args.export_dir, save_format='tf')
174163

175164

176165
if __name__ == '__main__':

tensorflowonspark/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def _transform(self, dataset):
451451
# global on each python worker process on the executors
452452
pred_fn = None # saved_model prediction function/signature.
453453
pred_args = None # args provided to the _run_model() method. Any change will invalidate the pred_fn.
454+
saved_model = None
454455

455456

456457
def _run_model(iterator, args, tf_args):
@@ -471,7 +472,7 @@ def _run_model(iterator, args, tf_args):
471472
input_tensor_names = [tensor for col, tensor in sorted(args.input_mapping.items())]
472473
output_tensor_names = [tensor for tensor, col in sorted(args.output_mapping.items())]
473474

474-
global pred_fn, pred_args
475+
global pred_fn, pred_args, saved_model
475476

476477
# cache saved_model pred_fn to avoid reloading the model for each partition
477478
if not pred_fn or args != pred_args:

test/test_pipeline.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,11 @@ def rdd_generator():
115115
return
116116

117117
ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([2]), tf.TensorShape([1])))
118-
ds = ds.batch(args.batch_size)
119-
120-
# disable auto-sharding dataset
118+
# disable auto-sharding since we're feeding from an RDD generator
121119
options = tf.data.Options()
122-
options.experimental_distribute.auto_shard = False
120+
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
123121
ds = ds.with_options(options)
122+
ds = ds.batch(args.batch_size)
124123

125124
# only train 90% of each epoch to account for uneven RDD partition sizes
126125
steps_per_epoch = 1000 * 0.9 // (args.batch_size * ctx.num_workers)
@@ -134,9 +133,8 @@ def rdd_generator():
134133
# This fails with: "NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy"
135134
# model.fit_generator(ds, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
136135

137-
if ctx.job_name == 'chief' and args.export_dir:
138-
print("exporting model to: {}".format(args.export_dir))
139-
tf.keras.experimental.export_saved_model(model, args.export_dir)
136+
print("exporting model to: {}".format(args.export_dir))
137+
model.save(args.export_dir, save_format='tf')
140138

141139
tf_feed.terminate()
142140

0 commit comments

Comments
 (0)