1616DESC = '''
1717Helper script to examine an environment and associated policy for behaviors; \n
1818- either onscreen, or offscreen, or just rollout without rendering.\n
19- - save resulting paths as pickle or as 2D plots
19+ - save resulting paths as pickle or as 2D plots \n
20+ - rollout either learned policies or scripted policies (e.g. see rand_policy class below) \n
2021USAGE:\n
21- $ python examine_env.py --env_name door-v0 \n
22- $ python examine_env.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n
22+ $ python examine_env.py --env_name door-v1 \n
23+ $ python examine_env.py --env_name door-v1 --policy_path robohive.utils.examine_env.rand_policy \n
24+ $ python examine_env.py --env_name door-v1 --policy_path my_policy.pickle --mode evaluation --episodes 10 \n
2325'''
2426
2527# Random policy
@@ -30,7 +32,14 @@ def __init__(self, env, seed):
3032
3133 def get_action (self , obs ):
3234 # return self.env.np_random.uniform(high=self.env.action_space.high, low=self.env.action_space.low)
33- return self .env .action_space .sample (), {'mode' : 'random samples' }
35+ return self .env .action_space .sample (), {'mode' : 'random samples' , 'evaluation' :self .env .action_space .sample ()}
36+
37+ def load_class_from_str (module_name , class_name ):
38+ try :
39+ m = __import__ (module_name , globals (), locals (), class_name )
40+ return getattr (m , class_name )
41+ except (ImportError , AttributeError ):
42+ return None
3443
3544# MAIN =========================================================
3645@click .command (help = DESC )
@@ -57,13 +66,19 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o
5766
5867 # resolve policy and outputs
5968 if policy_path is not None :
60- pi = pickle .load (open (policy_path , 'rb' ))
61- if output_dir == './' : # overide the default
62- output_dir , pol_name = os .path .split (policy_path )
63- output_name = os .path .splitext (pol_name )[0 ]
64- if output_name is None :
65- pol_name = os .path .split (policy_path )[1 ]
66- output_name = os .path .splitext (pol_name )[0 ]
69+ policy_tokens = policy_path .split ('.' )
70+ pi = load_class_from_str ('.' .join (policy_tokens [:- 1 ]), policy_tokens [- 1 ])
71+
72+ if pi is not None :
73+ pi = pi (env , seed )
74+ else :
75+ pi = pickle .load (open (policy_path , 'rb' ))
76+ if output_dir == './' : # overide the default
77+ output_dir , pol_name = os .path .split (policy_path )
78+ output_name = os .path .splitext (pol_name )[0 ]
79+ if output_name is None :
80+ pol_name = os .path .split (policy_path )[1 ]
81+ output_name = os .path .splitext (pol_name )[0 ]
6782 else :
6883 pi = rand_policy (env , seed )
6984 mode = 'exploration'
0 commit comments