Skip to content

Commit ec327d0

Browse files
rchen152Tensorflow Cloud maintainers
authored andcommitted
Add missing typing.Optional type annotations to function parameters.
PiperOrigin-RevId: 376238040
1 parent 5ac694e commit ec327d0

File tree

5 files changed

+30
-28
lines changed

5 files changed

+30
-28
lines changed

src/python/tensorflow_cloud/tuner/cloud_fit_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@
4545
def cloud_fit(
4646
model: tf.keras.Model,
4747
remote_dir: Text,
48-
region: Text = None,
49-
project_id: Text = None,
50-
image_uri: Text = None,
48+
region: Optional[Text] = None,
49+
project_id: Optional[Text] = None,
50+
image_uri: Optional[Text] = None,
5151
distribution_strategy: Text = DEFAULT_DISTRIBUTION_STRATEGY,
52-
job_spec: Dict[str, Any] = None,
53-
job_id: Text = None,
52+
job_spec: Optional[Dict[str, Any]] = None,
53+
job_id: Optional[Text] = None,
5454
**fit_kwargs
5555
) -> Text:
5656
"""Executes in-memory Model and Dataset remotely on AI Platform.
@@ -209,7 +209,7 @@ def _serialize_assets(remote_dir: Text,
209209
def _default_job_spec(
210210
region: Text,
211211
image_uri: Text,
212-
entry_point_args: Sequence[Text] = None,
212+
entry_point_args: Optional[Sequence[Text]] = None,
213213
) -> Dict[str, Any]:
214214
"""Creates a basic job_spec for cloud AI Training.
215215

src/python/tensorflow_cloud/tuner/tuner.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def __init__(
9393
self,
9494
project_id: Text,
9595
region: Text,
96-
objective: Union[Text, oracle_module.Objective] = None,
97-
hyperparameters: hp_module.HyperParameters = None,
96+
objective: Optional[Union[Text, oracle_module.Objective]] = None,
97+
hyperparameters: Optional[hp_module.HyperParameters] = None,
9898
study_config: Optional[Dict[Text, Any]] = None,
99-
max_trials: int = None,
99+
max_trials: Optional[int] = None,
100100
study_id: Optional[Text] = None,
101101
):
102102
"""KerasTuner Oracle interface implemented with Vizier backend.
@@ -458,10 +458,10 @@ def __init__(
458458
tf.keras.Model]],
459459
project_id: Text,
460460
region: Text,
461-
objective: Union[Text, oracle_module.Objective] = None,
462-
hyperparameters: hp_module.HyperParameters = None,
461+
objective: Optional[Union[Text, oracle_module.Objective]] = None,
462+
hyperparameters: Optional[hp_module.HyperParameters] = None,
463463
study_config: Optional[Dict[Text, Any]] = None,
464-
max_trials: int = None,
464+
max_trials: Optional[int] = None,
465465
study_id: Optional[Text] = None,
466466
**kwargs):
467467
"""Constructor.
@@ -516,10 +516,10 @@ def __init__(
516516
project_id: Text,
517517
region: Text,
518518
directory: Text,
519-
objective: Union[Text, oracle_module.Objective] = None,
520-
hyperparameters: hp_module.HyperParameters = None,
519+
objective: Optional[Union[Text, oracle_module.Objective]] = None,
520+
hyperparameters: Optional[hp_module.HyperParameters] = None,
521521
study_config: Optional[Dict[Text, Any]] = None,
522-
max_trials: int = None,
522+
max_trials: Optional[int] = None,
523523
study_id: Optional[Text] = None,
524524
container_uri: Optional[Text] = None,
525525
replica_config: Optional[machine_config.MachineConfig] = None,

src/python/tensorflow_cloud/tuner/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def convert_hyperparams_to_hparams(
380380
def format_objective(
381381
objective: Union[Text, oracle_module.Objective,
382382
List[Union[Text, oracle_module.Objective]]],
383-
direction: Text = None) -> List[oracle_module.Objective]:
383+
direction: Optional[Text] = None) -> List[oracle_module.Objective]:
384384
"""Formats objective to a list of oracle_module.Objective.
385385
386386
Args:

src/python/tensorflow_cloud/tuner/vizier_client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self,
3939
service_client: discovery.Resource,
4040
project_id: Text,
4141
region: Text,
42-
study_id: Text = None):
42+
study_id: Optional[Text] = None):
4343
"""Create an VizierClient object.
4444
4545
Use this constructor when you know the study_id, and when the Study
@@ -206,10 +206,11 @@ def should_trial_stop(self, trial_id: Text) -> bool:
206206
return True
207207
return False
208208

209-
def complete_trial(self,
210-
trial_id: Text,
211-
trial_infeasible: bool,
212-
infeasibility_reason: Text = None) -> Dict[Text, Any]:
209+
def complete_trial(
210+
self,
211+
trial_id: Text,
212+
trial_infeasible: bool,
213+
infeasibility_reason: Optional[Text] = None) -> Dict[Text, Any]:
213214
"""Marks the trial as COMPLETED and sets the final measurement.
214215
215216
Args:
@@ -289,7 +290,7 @@ def list_studies(self) -> List[Dict[Text, Any]]:
289290
raise
290291
return resp.get("studies", [])
291292

292-
def delete_study(self, study_name: Text = None) -> None:
293+
def delete_study(self, study_name: Optional[Text] = None) -> None:
293294
"""Deletes the study.
294295
295296
Args:

src/python/tensorflow_cloud/tuner/vizier_client_interface.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""An abstract class for the client used in both OSS Vizier and Cloud AI Platform Optimizer Service."""
1616
import abc
17-
from typing import List, Mapping, Text, Union, Dict, Any
17+
from typing import Any, Dict, List, Mapping, Optional, Text, Union
1818

1919

2020
class VizierClientInterface(abc.ABC):
@@ -77,10 +77,11 @@ def should_trial_stop(self, trial_id: Text) -> bool:
7777
"""
7878

7979
@abc.abstractmethod
80-
def complete_trial(self,
81-
trial_id: Text,
82-
trial_infeasible: bool,
83-
infeasibility_reason: Text = None) -> Dict[Text, Any]:
80+
def complete_trial(
81+
self,
82+
trial_id: Text,
83+
trial_infeasible: bool,
84+
infeasibility_reason: Optional[Text] = None) -> Dict[Text, Any]:
8485
"""Marks the trial as COMPLETED and sets the final measurement.
8586
8687
Args:
@@ -110,7 +111,7 @@ def list_studies(self) -> List[Dict[Text, Any]]:
110111
"""
111112

112113
@abc.abstractmethod
113-
def delete_study(self, study_name: Text = None) -> None:
114+
def delete_study(self, study_name: Optional[Text] = None) -> None:
114115
"""Deletes the study.
115116
116117
Args:

0 commit comments

Comments
 (0)