Skip to content

Commit 9f3cd53

Browse files
author
Lee Yang
committed
introduce compat library to support both TF2.0 and TF2.1
1 parent 699508c commit 9f3cd53

File tree

10 files changed

+32
-17
lines changed

10 files changed

+32
-17
lines changed

examples/mnist/estimator/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# MNIST using Estimator
22

3-
Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator
3+
Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator
44

5-
This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark.
5+
This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark.
66

77
Note: this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.
88

examples/mnist/estimator/mnist_spark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ def main_fun(args, ctx):
99
import tensorflow_datasets as tfds
1010
from tensorflowonspark import TFNode
1111

12+
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
13+
1214
tfds.disable_progress_bar()
1315

1416
class StopFeedHook(tf.estimator.SessionRunHook):
@@ -91,7 +93,7 @@ def model_fn(features, labels, mode):
9193
train_op=optimizer.minimize(
9294
loss, tf.compat.v1.train.get_or_create_global_step()))
9395

94-
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
96+
# strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
9597
config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100)
9698

9799
classifier = tf.estimator.Estimator(

examples/mnist/keras/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# MNIST using Keras
22

3-
Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras
3+
Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
44

5-
This is the [Multi-worker Training with Keras](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark.
5+
This is the [Multi-worker Training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark.
66

77
Notes:
88
- This example assumes that Spark, TensorFlow, TensorFlow Datasets, and TensorFlowOnSpark are already installed.

examples/mnist/keras/mnist_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def main_fun(args, ctx):
77
import numpy as np
88
import tensorflow as tf
9-
from tensorflowonspark import TFNode
9+
from tensorflowonspark import compat, TFNode
1010

1111
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1212

@@ -67,7 +67,7 @@ def rdd_generator():
6767

6868
from tensorflow_estimator.python.estimator.export import export_lib
6969
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
70-
multi_worker_model.save(export_dir, save_format='tf')
70+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
7171

7272
# terminating feed tells spark to skip processing further partitions
7373
tf_feed.terminate()

examples/mnist/keras/mnist_spark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def main_fun(args, ctx):
77
import numpy as np
88
import tensorflow as tf
9-
from tensorflowonspark import TFNode
9+
from tensorflowonspark import compat, TFNode
1010

1111
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1212

@@ -67,7 +67,7 @@ def rdd_generator():
6767

6868
from tensorflow_estimator.python.estimator.export import export_lib
6969
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
70-
multi_worker_model.save(export_dir, save_format='tf')
70+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
7171

7272
# terminating feed tells spark to skip processing further partitions
7373
tf_feed.terminate()

examples/mnist/keras/mnist_tf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
def main_fun(args, ctx):
77
import tensorflow_datasets as tfds
88
import tensorflow as tf
9+
from tensorflowonspark import compat
10+
911
tfds.disable_progress_bar()
1012

1113
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
@@ -62,7 +64,7 @@ def build_and_compile_cnn_model():
6264

6365
from tensorflow_estimator.python.estimator.export import export_lib
6466
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
65-
multi_worker_model.save(export_dir, save_format='tf')
67+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
6668

6769

6870
if __name__ == '__main__':

examples/mnist/keras/mnist_tf_ds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
def main_fun(args, ctx):
77
"""Example demonstrating loading TFRecords directly from disk (e.g. HDFS) without tensorflow_datasets."""
88
import tensorflow as tf
9+
from tensorflowonspark import compat
910

1011
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
1112

@@ -88,7 +89,7 @@ def build_and_compile_cnn_model():
8889

8990
from tensorflow_estimator.python.estimator.export import export_lib
9091
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
91-
multi_worker_model.save(export_dir, save_format='tf')
92+
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')
9293

9394

9495
if __name__ == '__main__':

examples/segmentation/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Image Segmentation
22

3-
Original Source: https://www.tensorflow.org/beta/tutorials/images/segmentation
3+
Original Source: https://www.tensorflow.org/tutorials/images/segmentation
44

5-
This code is based on the [Image Segmentation](https://www.tensorflow.org/beta/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage.
5+
This code is based on the [Image Segmentation](https://www.tensorflow.org/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage.
66

77
Notes:
88
- this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.

examples/segmentation/segmentation_spark.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,15 @@ def unet_model(output_channels):
159159
validation_steps=VALIDATION_STEPS,
160160
validation_data=test_dataset)
161161

162-
model.save(args.export_dir, save_format='tf')
162+
if tf.__version__ == '2.0.0':
163+
# Workaround for: https://github.com/tensorflow/tensorflow/issues/30251
164+
# Save model locally as h5py and reload it w/o distribution strategy
165+
if ctx.job_name == 'chief':
166+
model.save(args.model_dir + ".h5")
167+
new_model = tf.keras.models.load_model(args.model_dir + ".h5")
168+
tf.keras.experimental.export_saved_model(new_model, args.export_dir)
169+
else:
170+
model.save(args.export_dir, save_format='tf')
163171

164172

165173
if __name__ == '__main__':

test/test_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import test
55
import unittest
66

7+
from tensorflowonspark import compat
78
from tensorflowonspark.pipeline import HasBatchSize, HasSteps, Namespace, TFEstimator, TFParams
89
from tensorflow.keras import Sequential
910
from tensorflow.keras.layers import Dense
@@ -117,7 +118,7 @@ def rdd_generator():
117118
ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([2]), tf.TensorShape([1])))
118119
# disable auto-sharding since we're feeding from an RDD generator
119120
options = tf.data.Options()
120-
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
121+
compat.disable_auto_shard(options)
121122
ds = ds.with_options(options)
122123
ds = ds.batch(args.batch_size)
123124

@@ -133,8 +134,9 @@ def rdd_generator():
133134
# This fails with: "NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy"
134135
# model.fit_generator(ds, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
135136

136-
print("exporting model to: {}".format(args.export_dir))
137-
model.save(args.export_dir, save_format='tf')
137+
if args.export_dir:
138+
print("exporting model to: {}".format(args.export_dir))
139+
compat.export_saved_model(model, args.export_dir, ctx.job_name == 'chief')
138140

139141
tf_feed.terminate()
140142

0 commit comments

Comments
 (0)