@@ -444,7 +444,7 @@ def __init__(
444
444
max_trials : int = None ,
445
445
study_id : Optional [Text ] = None ,
446
446
container_uri : Optional [Text ] = None ,
447
- replica_config = "auto" ,
447
+ replica_config : Optional [ machine_config . MachineConfig ] = None ,
448
448
replica_count : Optional [int ] = 1 ,
449
449
** kwargs ):
450
450
"""Constructor.
@@ -469,10 +469,10 @@ def __init__(
469
469
container_uri: Base image to use for AI Platform Training. This
470
470
image must follow cloud_fit image with a cloud_fit.remote() as
471
471
entry point. Refer to cloud_fit documentation for more details
472
- at tensorflow_cloud/experimental/cloud_fit/README .md
472
+ at tensorflow_cloud/tuner/cloud_fit_readme .md.
473
473
replica_config: Optional `MachineConfig` that represents the
474
474
configuration for the general workers in a distribution cluster.
475
- Defaults to 'auto'. 'auto' maps to a standard CPU config such as
475
+ Defaults is None and mapped to a standard CPU config such as
476
476
`tensorflow_cloud.core.COMMON_MACHINE_CONFIGS.CPU`.
477
477
replica_count: Optional integer that represents the total number of
478
478
workers in a distribution cluster including a chief worker. Has
@@ -489,7 +489,9 @@ def __init__(
489
489
# here.
490
490
self ._replica_count = replica_count
491
491
self ._replica_config = replica_config
492
- if replica_config == "auto" :
492
+ if replica_config :
493
+ self ._replica_config = replica_config
494
+ else :
493
495
self ._replica_config = machine_config .COMMON_MACHINE_CONFIGS ["CPU" ]
494
496
495
497
# Setting AI Platform Training runtime configurations. User can create
0 commit comments