Skip to content

Commit de2af64

Browse files
committed
add tests
1 parent 564294f commit de2af64

File tree

3 files changed

+104
-1
lines changed

3 files changed

+104
-1
lines changed

tests/test_marl_world.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
world:
2+
height: 12 # the height of the world
3+
width: 12 # the height of the world
4+
step_time: 1.0
5+
sample_time: 1.0
6+
collision_mode: 'reactive'
7+
8+
robot:
9+
- number: 3
10+
kinematics: {name: 'diff'}
11+
distribution: {name: 'manual'}
12+
shape: {name: 'circle', radius: 0.2}
13+
vel_min: [ 0, -1.0 ]
14+
vel_max: [ 1.0, 1.0 ]
15+
state: [[3, 10, 0], [3, 6, 0], [3, 2, 0]]
16+
goal: [[9, 9, 0], [8, 8, 0], [7, 7, 0]]
17+
color: ['royalblue', 'red', 'green', 'orange', 'purple', 'yellow', 'cyan', 'magenta', 'lime', 'pink', 'brown']
18+
arrive_mode: position
19+
goal_threshold: 0.3
20+
21+
plot:
22+
show_trajectory: False

tests/test_model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from pathlib import Path
22

3+
import torch
4+
5+
from robot_nav.SIM_ENV.marl_sim import MARL_SIM
36
from robot_nav.models.RCPG.RCPG import RCPG
47
from robot_nav.models.TD3.TD3 import TD3
58
from robot_nav.models.CNNTD3.CNNTD3 import CNNTD3
69
from robot_nav.models.SAC.SAC import SAC
710
from robot_nav.models.DDPG.DDPG import DDPG
811
from robot_nav.utils import get_buffer
912
from robot_nav.SIM_ENV.sim import SIM
13+
from robot_nav.models.MARL.marlTD3 import TD3 as marlTD3
1014
import pytest
1115

1216
PROJECT_ROOT = Path(__file__).resolve().parents[1]
@@ -91,3 +95,56 @@ def test_max_bound_models(model, state_dim):
9195
iterations=2,
9296
batch_size=8,
9397
)
98+
99+
100+
def test_marl_models():
101+
sim = MARL_SIM("/tests/test_marl_world.yaml", disable_plotting=True)
102+
test_model = marlTD3(
103+
state_dim=11,
104+
action_dim=2,
105+
max_action=1,
106+
num_robots=sim.num_robots,
107+
device="cpu",
108+
save_every=0,
109+
load_model=False,
110+
) # instantiate a model
111+
112+
replay_buffer = get_buffer(
113+
model=test_model,
114+
sim=sim,
115+
load_saved_buffer=False,
116+
pretrain=False,
117+
pretraining_iterations=0,
118+
training_iterations=0,
119+
batch_size=0,
120+
buffer_size=100,
121+
)
122+
123+
for _ in range(10):
124+
connections = torch.tensor(
125+
[[0.0 for _ in range(sim.num_robots - 1)] for _ in range(3)]
126+
)
127+
(
128+
poses,
129+
distance,
130+
cos,
131+
sin,
132+
collision,
133+
goal,
134+
a,
135+
reward,
136+
positions,
137+
goal_positions,
138+
) = sim.step([[0, 0] for _ in range(sim.num_robots)], connections)
139+
state, terminal = test_model.prepare_state(
140+
poses, distance, cos, sin, collision, a, goal_positions
141+
)
142+
replay_buffer.add(
143+
state, [[0, 0] for _ in range(sim.num_robots)], reward, terminal, state
144+
)
145+
146+
test_model.train(
147+
replay_buffer=replay_buffer,
148+
iterations=2,
149+
batch_size=8,
150+
)

tests/test_sim.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22

33
import pytest
4+
import torch
45

6+
from robot_nav.SIM_ENV.marl_sim import MARL_SIM
57
from robot_nav.SIM_ENV.sim import SIM
68
import numpy as np
79

@@ -12,7 +14,7 @@
1214

1315
@skip_on_ci
1416
def test_sim():
15-
sim = SIM("/tests/test_world.yaml")
17+
sim = SIM("/tests/test_world.yaml", disable_plotting=True)
1618
robot_state = sim.env.get_robot_state()
1719
state = sim.step(1, 0)
1820
next_robot_state = sim.env.get_robot_state()
@@ -28,6 +30,28 @@ def test_sim():
2830
assert np.not_equal(robot_state[1], new_robot_state[1])
2931

3032

33+
def test_marl_sim():
34+
sim = MARL_SIM("/tests/test_marl_world.yaml", disable_plotting=True)
35+
robot_state = [sim.env.robot_list[i].state[:2] for i in range(3)]
36+
connections = torch.tensor(
37+
[[0.0 for _ in range(sim.num_robots - 1)] for _ in range(3)]
38+
)
39+
40+
_ = sim.step([[1, 0], [1, 0], [1, 0]], connections)
41+
next_robot_state = [sim.env.robot_list[i].state[:2] for i in range(3)]
42+
for j in range(3):
43+
assert np.isclose(robot_state[j][0], next_robot_state[j][0] - 1)
44+
assert np.isclose(robot_state[j][1], robot_state[j][1])
45+
46+
assert len(sim.env.obstacle_list) == 0
47+
48+
sim.reset()
49+
new_robot_state = [sim.env.robot_list[i].state[:2] for i in range(3)]
50+
for j in range(3):
51+
assert np.not_equal(robot_state[j][0], new_robot_state[j][0])
52+
assert np.not_equal(robot_state[j][1], new_robot_state[j][1])
53+
54+
3155
@skip_on_ci
3256
def test_sincos():
3357
sim = SIM("/tests/test_world.yaml")

0 commit comments

Comments
 (0)