Skip to content

Commit b5ca3af

Browse files
authored
Merge pull request #483 from yahoo/leewyang_compat
Add compat layer for TF2.1rc0 compatibility
2 parents fc24556 + fedc361 commit b5ca3af

File tree

15 files changed

+355
-69
lines changed

15 files changed

+355
-69
lines changed

examples/mnist/estimator/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# MNIST using Estimator
22

3-
Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator
3+
Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator
44

5-
This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark.
5+
This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark.
66

77
Note: this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.
88

examples/mnist/estimator/mnist_pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def input_fn(mode, input_context=None):
4545
ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1])))
4646
return ds.batch(BATCH_SIZE)
4747
else:
48-
raise Exception("I'm evaluating: mode={}, input_context={}".format(mode, input_context))
49-
48+
# read evaluation data from tensorflow_datasets directly
5049
def scale(image, label):
5150
image = tf.cast(image, tf.float32) / 255.0
5251
return image, label

examples/mnist/estimator/mnist_spark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ def main_fun(args, ctx):
99
import tensorflow_datasets as tfds
1010
from tensorflowonspark import TFNode
1111

12+
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
13+
1214
tfds.disable_progress_bar()
1315

1416
class StopFeedHook(tf.estimator.SessionRunHook):
@@ -91,7 +93,6 @@ def model_fn(features, labels, mode):
9193
train_op=optimizer.minimize(
9294
loss, tf.compat.v1.train.get_or_create_global_step()))
9395

94-
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
9596
config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100)
9697

9798
classifier = tf.estimator.Estimator(

examples/mnist/keras/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# MNIST using Keras
22

3-
Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras
3+
Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
44

5-
This is the [Multi-worker Training with Keras](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark.
5+
This is the [Multi-worker Training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark.
66

77
Notes:
88
- This example assumes that Spark, TensorFlow, TensorFlow Datasets, and TensorFlowOnSpark are already installed.
@@ -130,7 +130,7 @@ For batch inferencing use cases, you can use Spark to run multiple single-node T
130130
${TFoS_HOME}/examples/mnist/keras/mnist_inference.py \
131131
--cluster_size ${SPARK_WORKER_INSTANCES} \
132132
--images_labels ${TFoS_HOME}/data/mnist/tfr/test \
133-
--export_dir ${TFoS_HOME}/mnist_export \
133+
--export_dir ${SAVED_MODEL} \
134134
--output ${TFoS_HOME}/predictions
135135

136136
#### Train and Inference via Spark ML Pipeline API

examples/mnist/keras/mnist_pipeline.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def main_fun(args, ctx):
77
import numpy as np
88
import tensorflow as tf
9-
from tensorflowonspark import TFNode
9+
from tensorflowonspark import compat, TFNode
1010

1111
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1212

@@ -65,11 +65,9 @@ def rdd_generator():
6565

6666
multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks)
6767

68-
if ctx.job_name == 'chief':
69-
from tensorflow_estimator.python.estimator.export import export_lib
70-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
71-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
72-
# multi_worker_model.save(args.model_dir, save_format='tf')
68+
from tensorflow_estimator.python.estimator.export import export_lib
69+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
70+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
7371

7472
# terminating feed tells spark to skip processing further partitions
7573
tf_feed.terminate()

examples/mnist/keras/mnist_spark.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def main_fun(args, ctx):
77
import numpy as np
88
import tensorflow as tf
9-
from tensorflowonspark import TFNode
9+
from tensorflowonspark import compat, TFNode
1010

1111
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1212

@@ -65,11 +65,9 @@ def rdd_generator():
6565

6666
multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks)
6767

68-
if ctx.job_name == 'chief':
69-
from tensorflow_estimator.python.estimator.export import export_lib
70-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
71-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
72-
# multi_worker_model.save(args.model_dir, save_format='tf')
68+
from tensorflow_estimator.python.estimator.export import export_lib
69+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
70+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
7371

7472
# terminating feed tells spark to skip processing further partitions
7573
tf_feed.terminate()

examples/mnist/keras/mnist_tf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
def main_fun(args, ctx):
77
import tensorflow_datasets as tfds
88
import tensorflow as tf
9+
from tensorflowonspark import compat
10+
911
tfds.disable_progress_bar()
1012

1113
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
@@ -60,11 +62,9 @@ def build_and_compile_cnn_model():
6062
multi_worker_model = build_and_compile_cnn_model()
6163
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=args.steps_per_epoch, callbacks=callbacks)
6264

63-
if ctx.job_name == 'chief':
64-
from tensorflow_estimator.python.estimator.export import export_lib
65-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
66-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
67-
# multi_worker_model.save(args.model_dir, save_format='tf')
65+
from tensorflow_estimator.python.estimator.export import export_lib
66+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
67+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
6868

6969

7070
if __name__ == '__main__':

examples/mnist/keras/mnist_tf_ds.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
def main_fun(args, ctx):
77
"""Example demonstrating loading TFRecords directly from disk (e.g. HDFS) without tensorflow_datasets."""
88
import tensorflow as tf
9+
from tensorflowonspark import compat
910

1011
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1112

@@ -86,11 +87,9 @@ def build_and_compile_cnn_model():
8687
multi_worker_model = build_and_compile_cnn_model()
8788
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
8889

89-
if ctx.job_name == 'chief':
90-
from tensorflow_estimator.python.estimator.export import export_lib
91-
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
92-
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
93-
# multi_worker_model.save(args.model_dir, save_format='tf')
90+
from tensorflow_estimator.python.estimator.export import export_lib
91+
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
92+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
9493

9594

9695
if __name__ == '__main__':

examples/segmentation/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Image Segmentation
22

3-
Original Source: https://www.tensorflow.org/beta/tutorials/images/segmentation
3+
Original Source: https://www.tensorflow.org/tutorials/images/segmentation
44

5-
This code is based on the [Image Segmentation](https://www.tensorflow.org/beta/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage.
5+
This code is based on the [Image Segmentation](https://www.tensorflow.org/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage.
66

77
Notes:
88
- this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.

examples/segmentation/segmentation_spark.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,15 @@ def unet_model(output_channels):
159159
validation_steps=VALIDATION_STEPS,
160160
validation_data=test_dataset)
161161

162-
if ctx.job_name == 'chief':
162+
if tf.__version__ == '2.0.0':
163163
# Workaround for: https://github.com/tensorflow/tensorflow/issues/30251
164-
print("===== saving h5py model")
165-
model.save(args.model_dir + ".h5")
166-
print("===== re-loading model w/o DistributionStrategy")
167-
new_model = tf.keras.models.load_model(args.model_dir + ".h5")
168-
print("===== exporting saved_model")
169-
tf.keras.experimental.export_saved_model(new_model, args.export_dir)
170-
print("===== done exporting")
164+
# Save model locally as h5py and reload it w/o distribution strategy
165+
if ctx.job_name == 'chief':
166+
model.save(args.model_dir + ".h5")
167+
new_model = tf.keras.models.load_model(args.model_dir + ".h5")
168+
tf.keras.experimental.export_saved_model(new_model, args.export_dir)
171169
else:
172-
print("===== sleeping")
173-
time.sleep(90)
170+
model.save(args.export_dir, save_format='tf')
174171

175172

176173
if __name__ == '__main__':

0 commit comments

Comments
 (0)