Skip to content

Commit a2604d0

Browse files
Rename policy_saver.PolicySaver to (e.g.) policy_saver.MLGOPolicySaver (google#527)
Addresses issue google#309 Just a typo within the commit message It is indeed MLGOPolicySaver
1 parent a075d97 commit a2604d0

File tree

14 files changed

+18
-17
lines changed

14 files changed

+18
-17
lines changed

compiler_opt/es/blackbox_learner_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def setUp(self):
111111
init_params = policy_utils.get_vectorized_parameters_from_policy(policy)
112112

113113
# save the policy
114-
saver = policy_saver.PolicySaver({policy_name: policy})
114+
saver = policy_saver.MLGOPolicySaver({policy_name: policy})
115115
policy_save_path = os.path.join(output_dir.full_path, 'temp_output',
116116
'policy')
117117
saver.save(policy_save_path)

compiler_opt/es/inlining/inlining_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class InliningWorker(worker.Worker):
4242
def _setup_base_policy(self):
4343
self._tf_base_temp_dir = tempfile.mkdtemp()
4444
policy = policy_utils.create_actor_policy()
45-
saver = policy_saver.PolicySaver({"policy": policy})
45+
saver = policy_saver.MLGOPolicySaver({"policy": policy})
4646
saver.save(self._tf_base_temp_dir)
4747
self._tf_base_policy_path = os.path.join(self._tf_base_temp_dir, "policy")
4848

compiler_opt/es/policy_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables',
121121
policy_name: The value to name the policy.
122122
"""
123123
set_vectorized_parameters_for_policy(policy, parameters)
124-
saver = policy_saver.PolicySaver({policy_name: policy})
124+
saver = policy_saver.MLGOPolicySaver({policy_name: policy})
125125
saver.save(save_folder)
126126

127127

compiler_opt/es/policy_utils_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class VectorTest(absltest.TestCase):
111111
POLICY_NAME = 'test_policy_name'
112112

113113
def _save_inlining_policy(
114-
self) -> tuple[str, actor_policy.ActorPolicy, policy_saver.PolicySaver]:
114+
self
115+
) -> tuple[str, actor_policy.ActorPolicy, policy_saver.MLGOPolicySaver]:
115116
problem_config = registry.get_configuration(
116117
implementation=inlining.InliningConfig)
117118
time_step_spec, action_spec = problem_config.get_signature_spec()
@@ -138,7 +139,7 @@ def _save_inlining_policy(
138139
actor_network=actor_network)
139140

140141
# save the policy
141-
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
142+
saver = policy_saver.MLGOPolicySaver({VectorTest.POLICY_NAME: policy})
142143
testing_path = self.create_tempdir()
143144
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
144145
saver.save(policy_save_path)

compiler_opt/es/regalloc_trace/regalloc_trace_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class RegallocTraceWorker(worker.Worker):
5656
def _setup_base_policy(self):
5757
self._tf_base_temp_dir = tempfile.mkdtemp()
5858
policy = policy_utils.create_actor_policy()
59-
saver = policy_saver.PolicySaver({"policy": policy})
59+
saver = policy_saver.MLGOPolicySaver({"policy": policy})
6060
saver.save(self._tf_base_temp_dir)
6161
self._tf_base_policy_path = os.path.join(self._tf_base_temp_dir, "policy")
6262

compiler_opt/rl/distributed/ppo_collect_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _get_policy_bytes(agent):
4848
"""Recover the collect_policy bytes from a TF agent"""
4949
policy_key = 'collect'
5050
with tempfile.TemporaryDirectory() as tmpdirname:
51-
saver = policy_saver.PolicySaver(policy_dict={
51+
saver = policy_saver.MLGOPolicySaver(policy_dict={
5252
policy_key: agent.collect_policy,
5353
})
5454
saver.save(tmpdirname)

compiler_opt/rl/distributed/ppo_eval_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def sequence_example_iterator_fn(seq_ex: list[str]):
121121
actions = []
122122
while True:
123123
with tempfile.TemporaryDirectory() as tmpdirname:
124-
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
124+
saver = policy_saver.MLGOPolicySaver(policy_dict=policy_dict)
125125
saver.save(tmpdirname)
126126
policy_bytes = policy_saver.Policy.from_filesystem(
127127
os.path.join(tmpdirname, 'policy'))

compiler_opt/rl/imitation_learning/weighted_bc_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def train():
6969
time_step_spec=expected_signature,
7070
action_spec=action_spec)
7171
policy_dict = {'tf_agents_policy': wrapped_keras_model}
72-
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
72+
saver = policy_saver.MLGOPolicySaver(policy_dict=policy_dict)
7373
saver.save(_SAVE_MODEL_DIR.value)
7474

7575

compiler_opt/rl/policy_saver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ def from_filesystem(location: str):
142142
return Policy(output_spec=output_spec, policy=policy)
143143

144144

145-
class PolicySaver:
145+
class MLGOPolicySaver:
146146
"""Object that saves policy and model config file required by inference.
147147
148148
```python
149-
policy_saver = PolicySaver(policy_dict, config)
149+
policy_saver = MLGOPolicySaver(policy_dict, config)
150150
policy_saver.save(root_dir)
151151
```
152152
"""

compiler_opt/rl/policy_saver_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from compiler_opt.testing import model_test_utils
2929

3030

31-
class PolicySaverTest(tf.test.TestCase):
31+
class MLGOPolicySaverTest(tf.test.TestCase):
3232

3333
def setUp(self):
3434
super().setUp()
@@ -54,7 +54,7 @@ def test_save_policy(self):
5454
'saved_policy': test_agent.policy,
5555
'saved_collect_policy': test_agent.collect_policy
5656
}
57-
test_policy_saver = policy_saver.PolicySaver(policy_dict=policy_dict)
57+
test_policy_saver = policy_saver.MLGOPolicySaver(policy_dict=policy_dict)
5858

5959
root_dir = self.get_temp_dir()
6060
test_policy_saver.save(root_dir)

0 commit comments

Comments
 (0)