Skip to content

Commit 6dc4ae7

Browse files
reedwmtensorflower-gardener
authored andcommitted
Return default strategy from get_distribution_strategy when given "off".
Before, it returned None. But almost every use of get_distribution_strategy() assumes an actual strategy is returned and crashes when None is returned. Returning the default strategy fixes these issues and is equivalent to using no strategy, as the default strategy is always in effect when no other strategy is used. PiperOrigin-RevId: 380951055
1 parent 0327186 commit 6dc4ae7

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

official/common/distribute_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored",
102102
distribution_strategy: a string specifying which distribution strategy to
103103
use. Accepted values are "off", "one_device", "mirrored",
104104
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
105-
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
106-
use TPUStrategy using `tpu_address`.
105+
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
106+
"off" means to use the default strategy which is obtained from
107+
tf.distribute.get_strategy (for details on the default strategy, see
108+
https://www.tensorflow.org/guide/distributed_training#default_strategy).
107109
num_gpus: Number of GPUs to run this model.
108110
all_reduce_alg: Optional. Specifies which algorithm to use when performing
109111
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
141143
if num_gpus > 1:
142144
raise ValueError("When {} GPUs are specified, distribution_strategy "
143145
"flag cannot be set to `off`.".format(num_gpus))
144-
return None
146+
# Return the default distribution strategy.
147+
return tf.distribute.get_strategy()
145148

146149
if distribution_strategy == "tpu":
147150
# When tpu_address is an empty string, we communicate with local TPUs.

official/common/distribute_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_mirrored_strategy(self):
4343

4444
def test_no_strategy(self):
4545
ds = distribute_utils.get_distribution_strategy('off')
46-
self.assertIsNone(ds)
46+
self.assertIs(ds, tf.distribute.get_strategy())
4747

4848
def test_invalid_strategy(self):
4949
with self.assertRaisesRegexp(

0 commit comments

Comments
 (0)