Skip to content

Commit 385510e

Browse files
palancPatrick Lancaster
authored andcommitted
FEATURE: Allows examine_env to load hard coded policies (#87)
* Allow examine_envs to load hard coded policies, E.g. python robohive/utils/examine_env.py -e FrankaReachRandom-v0 -p robohive.utils.examine_env.rand_policy * Remove unnecessary import * examine_env loading scripted polcies: update DESC in examine_env and add unit test Co-authored-by: Patrick Lancaster <[email protected]>
1 parent ec43d54 commit 385510e

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

robohive/tests/test_examine_env.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,19 @@ def test_offscreen_rendering(self):
2222
result = runner.invoke(examine_env, ["--env_name", "door-v1", \
2323
"--num_episodes", 1, \
2424
"--render", "offscreen",\
25-
"--camera_name", "top_acam"])
25+
"--camera_name", "top_cam"])
2626
print(result.output.strip())
2727
self.assertEqual(result.exception, None)
2828

29+
def test_scripted_policy_loading(self):
30+
# Call your function and test its output/assertions
31+
print("Testing scripted policy loading")
32+
runner = click.testing.CliRunner()
33+
result = runner.invoke(examine_env, ["--env_name", "door-v1", \
34+
"--num_episodes", 1, \
35+
"--policy_path", "robohive.utils.examine_env.rand_policy"])
36+
print(result.output.strip())
37+
self.assertEqual(result.exception, None)
2938

3039
if __name__ == '__main__':
3140
unittest.main()

robohive/utils/examine_env.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
DESC = '''
1717
Helper script to examine an environment and associated policy for behaviors; \n
1818
- either onscreen, or offscreen, or just rollout without rendering.\n
19-
- save resulting paths as pickle or as 2D plots
19+
- save resulting paths as pickle or as 2D plots \n
20+
- rollout either learned policies or scripted policies (e.g. see rand_policy class below) \n
2021
USAGE:\n
21-
$ python examine_env.py --env_name door-v0 \n
22-
$ python examine_env.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n
22+
$ python examine_env.py --env_name door-v1 \n
23+
$ python examine_env.py --env_name door-v1 --policy_path robohive.utils.examine_env.rand_policy \n
24+
$ python examine_env.py --env_name door-v1 --policy_path my_policy.pickle --mode evaluation --episodes 10 \n
2325
'''
2426

2527
# Random policy
@@ -30,7 +32,14 @@ def __init__(self, env, seed):
3032

3133
def get_action(self, obs):
3234
# return self.env.np_random.uniform(high=self.env.action_space.high, low=self.env.action_space.low)
33-
return self.env.action_space.sample(), {'mode': 'random samples'}
35+
return self.env.action_space.sample(), {'mode': 'random samples', 'evaluation':self.env.action_space.sample()}
36+
37+
def load_class_from_str(module_name, class_name):
38+
try:
39+
m = __import__(module_name, globals(), locals(), class_name)
40+
return getattr(m, class_name)
41+
except (ImportError, AttributeError):
42+
return None
3443

3544
# MAIN =========================================================
3645
@click.command(help=DESC)
@@ -57,13 +66,19 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o
5766

5867
# resolve policy and outputs
5968
if policy_path is not None:
60-
pi = pickle.load(open(policy_path, 'rb'))
61-
if output_dir == './': # overide the default
62-
output_dir, pol_name = os.path.split(policy_path)
63-
output_name = os.path.splitext(pol_name)[0]
64-
if output_name is None:
65-
pol_name = os.path.split(policy_path)[1]
66-
output_name = os.path.splitext(pol_name)[0]
69+
policy_tokens = policy_path.split('.')
70+
pi = load_class_from_str('.'.join(policy_tokens[:-1]), policy_tokens[-1])
71+
72+
if pi is not None:
73+
pi = pi(env, seed)
74+
else:
75+
pi = pickle.load(open(policy_path, 'rb'))
76+
if output_dir == './': # overide the default
77+
output_dir, pol_name = os.path.split(policy_path)
78+
output_name = os.path.splitext(pol_name)[0]
79+
if output_name is None:
80+
pol_name = os.path.split(policy_path)[1]
81+
output_name = os.path.splitext(pol_name)[0]
6782
else:
6883
pi = rand_policy(env, seed)
6984
mode = 'exploration'

0 commit comments

Comments
 (0)