Skip to content

Commit d509c23

Browse files
yinghsienwuTensorflow Cloud maintainers
authored andcommitted
Fix machine_config for DistributingCloudTuner
PiperOrigin-RevId: 365931415
1 parent 4fda0d8 commit d509c23

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/python/tensorflow_cloud/tuner/cloud_fit_readme.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ you can skip the setup and authentication steps and start from step 8.
143143
Create a dockerfile as follows:
144144

145145
```shell
146-
# Using DLVM base image
146+
# Using DLVM base image. For GPU training use
147+
# gcr.io/deeplearning-platform-release/tf2-gpu instead.
147148
FROM gcr.io/deeplearning-platform-release/tf2-cpu
148149
WORKDIR /root
149150

src/python/tensorflow_cloud/tuner/tuner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def __init__(
444444
max_trials: int = None,
445445
study_id: Optional[Text] = None,
446446
container_uri: Optional[Text] = None,
447-
replica_config="auto",
447+
replica_config: Optional[machine_config.MachineConfig] = None,
448448
replica_count: Optional[int] = 1,
449449
**kwargs):
450450
"""Constructor.
@@ -469,10 +469,10 @@ def __init__(
469469
container_uri: Base image to use for AI Platform Training. This
470470
image must follow cloud_fit image with a cloud_fit.remote() as
471471
entry point. Refer to cloud_fit documentation for more details
472-
at tensorflow_cloud/experimental/cloud_fit/README.md
472+
at tensorflow_cloud/tuner/cloud_fit_readme.md.
473473
replica_config: Optional `MachineConfig` that represents the
474474
configuration for the general workers in a distribution cluster.
475-
Defaults to 'auto'. 'auto' maps to a standard CPU config such as
475+
Defaults is None and mapped to a standard CPU config such as
476476
`tensorflow_cloud.core.COMMON_MACHINE_CONFIGS.CPU`.
477477
replica_count: Optional integer that represents the total number of
478478
workers in a distribution cluster including a chief worker. Has
@@ -489,7 +489,9 @@ def __init__(
489489
# here.
490490
self._replica_count = replica_count
491491
self._replica_config = replica_config
492-
if replica_config == "auto":
492+
if replica_config:
493+
self._replica_config = replica_config
494+
else:
493495
self._replica_config = machine_config.COMMON_MACHINE_CONFIGS["CPU"]
494496

495497
# Setting AI Platform Training runtime configurations. User can create

0 commit comments

Comments
 (0)