forked from Cognitive-AI-Systems/learn-to-follow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.py
More file actions
67 lines (48 loc) · 2.54 KB
/
example.py
File metadata and controls
67 lines (48 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
from env.create_env import create_env_base
from env.custom_maps import MAPS_REGISTRY
from utils.eval_utils import run_episode
from follower.training_config import EnvironmentMazes
from follower.inference import FollowerInferenceConfig, FollowerInference
from follower.preprocessing import follower_preprocessor
from follower_cpp.inference import FollowerConfigCPP, FollowerInferenceCPP
from follower_cpp.preprocessing import follower_cpp_preprocessor
def create_custom_env(cfg):
env_cfg = EnvironmentMazes(with_animation=cfg.animation)
env_cfg.grid_config.num_agents = cfg.num_agents
env_cfg.grid_config.map_name = cfg.map_name
env_cfg.grid_config.seed = cfg.seed
env_cfg.grid_config.max_episode_steps = cfg.max_episode_steps
return create_env_base(env_cfg)
def run_follower(env):
follower_cfg = FollowerInferenceConfig()
algo = FollowerInference(follower_cfg)
env = follower_preprocessor(env, follower_cfg)
return run_episode(env, algo)
def run_follower_cpp(env):
follower_cfg = FollowerConfigCPP(path_to_weights='model/follower-lite', num_threads=6)
algo = FollowerInferenceCPP(follower_cfg)
env = follower_cpp_preprocessor(env, follower_cfg)
return run_episode(env, algo)
def main():
parser = argparse.ArgumentParser(description='Follower Inference Script')
parser.add_argument('--animation', action='store_false', help='Enable animation (default: %(default)s)')
parser.add_argument('--num_agents', type=int, default=128, help='Number of agents (default: %(default)d)')
parser.add_argument('--seed', type=int, default=0, help='Random seed (default: %(default)d)')
parser.add_argument('--map_name', type=str, default='wfi_warehouse', help='Map name (default: %(default)s)')
parser.add_argument('--max_episode_steps', type=int, default=256,
help='Maximum episode steps (default: %(default)d)')
parser.add_argument('--show_map_names', action='store_true', help='Shows names of all available maps')
parser.add_argument('--algorithm', type=str, choices=['Follower', 'FollowerLite'], default='Follower',
help='Algorithm to use: "Follower" or "FollowerLite" (default: "Follower")')
args = parser.parse_args()
if args.show_map_names:
for map_ in MAPS_REGISTRY:
print(map_)
return
if args.algorithm == 'FollowerLite':
print(run_follower_cpp(create_custom_env(args)))
else: # Default to 'Follower'
print(run_follower(create_custom_env(args)))
if __name__ == '__main__':
main()