Skip to content

Commit d74a404

Browse files
authored
Add functional control API (#50)
1 parent 48bdc36 commit d74a404

File tree

8 files changed

+252
-78
lines changed

8 files changed

+252
-78
lines changed

crazyflow/sim/data.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class ControlData(typing.Protocol):
102102

103103
@dataclass
104104
class SimControls:
105+
mode: Control = field(pytree_node=False)
106+
"""Control mode of the simulation."""
105107
state: ControlData | None
106108
"""State control data."""
107109
attitude: ControlData | None
@@ -136,7 +138,11 @@ def create(
136138
n_worlds, n_drones, force_torque_freq, drone_model, device
137139
)
138140
return SimControls(
139-
state=state, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
141+
mode=control,
142+
state=state,
143+
attitude=attitude,
144+
force_torque=force_torque,
145+
rotor_vel=rotor_vel,
140146
)
141147
case Control.attitude:
142148
attitude = attitude = MellingerAttitudeData.create(
@@ -146,14 +152,22 @@ def create(
146152
n_worlds, n_drones, force_torque_freq, drone_model, device
147153
)
148154
return SimControls(
149-
state=None, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
155+
mode=control,
156+
state=None,
157+
attitude=attitude,
158+
force_torque=force_torque,
159+
rotor_vel=rotor_vel,
150160
)
151161
case Control.force_torque:
152162
force_torque = MellingerForceTorqueData.create(
153163
n_worlds, n_drones, force_torque_freq, drone_model, device
154164
)
155165
return SimControls(
156-
state=None, attitude=None, force_torque=force_torque, rotor_vel=rotor_vel
166+
mode=control,
167+
state=None,
168+
attitude=None,
169+
force_torque=force_torque,
170+
rotor_vel=rotor_vel,
157171
)
158172
case _:
159173
raise ValueError(f"Control mode {control} not implemented")

crazyflow/sim/functional.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from crazyflow.control import Control
6+
from crazyflow.control.control import controllable as _controllable
7+
from crazyflow.utils import to_device
8+
9+
if TYPE_CHECKING:
10+
from jax import Array
11+
12+
from crazyflow.sim.data import SimData
13+
14+
15+
def state_control(data: SimData, controls: Array) -> SimData:
16+
"""State control function."""
17+
assert data.controls.mode == Control.state, f"control type {data.controls.mode} not enabled"
18+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 13), "controls shape mismatch"
19+
controls = to_device(controls, data.core.steps.device)
20+
data = data.replace(
21+
controls=data.controls.replace(state=data.controls.state.replace(staged_cmd=controls))
22+
)
23+
return data
24+
25+
26+
def attitude_control(data: SimData, controls: Array) -> SimData:
27+
"""Attitude control function.
28+
29+
We need to stage the attitude controls because the sys_id physics mode operates directly on
30+
the attitude controls. If we were to directly update the controls, this would effectively
31+
bypass the control frequency and run the attitude controller at the physics update rate. By
32+
staging the controls, we ensure that the physics module sees the old controls until the
33+
controller updates at its correct frequency.
34+
"""
35+
assert data.controls.mode == Control.attitude, f"control type {data.controls.mode} not enabled"
36+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
37+
controls = to_device(controls, data.core.steps.device)
38+
data = data.replace(
39+
controls=data.controls.replace(attitude=data.controls.attitude.replace(staged_cmd=controls))
40+
)
41+
return data
42+
43+
44+
def force_torque_control(data: SimData, controls: Array) -> SimData:
45+
"""Force-torque control function."""
46+
assert data.controls.mode == Control.force_torque, (
47+
f"control type {data.controls.mode} not enabled"
48+
)
49+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
50+
controls = to_device(controls, data.core.steps.device)
51+
data = data.replace(
52+
controls=data.controls.replace(
53+
force_torque=data.controls.force_torque.replace(staged_cmd=controls)
54+
)
55+
)
56+
return data
57+
58+
59+
def controllable(data: SimData) -> Array:
60+
"""Check which worlds can currently update their controllers."""
61+
controls = data.controls
62+
match data.controls.mode:
63+
case Control.state:
64+
control_steps, control_freq = controls.state.steps, controls.state.freq
65+
case Control.attitude:
66+
control_steps, control_freq = controls.attitude.steps, controls.attitude.freq
67+
case Control.force_torque:
68+
control_steps = controls.force_torque.steps
69+
control_freq = controls.force_torque.freq
70+
case _:
71+
raise NotImplementedError(f"Control mode {data.controls.mode} not implemented")
72+
return _controllable(data.core.steps, data.core.freq, control_steps, control_freq)

crazyflow/sim/sim.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1919
from jax import Array, Device
2020

21+
import crazyflow.sim.functional as F
2122
from crazyflow.control.control import Control, controllable
2223
from crazyflow.exception import ConfigError, NotInitializedError
2324
from crazyflow.sim.data import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv
@@ -29,7 +30,7 @@
2930
so_rpy_rotor_drag_physics,
3031
so_rpy_rotor_physics,
3132
)
32-
from crazyflow.utils import grid_2d, leaf_replace, pytree_replace, to_device
33+
from crazyflow.utils import grid_2d, leaf_replace, pytree_replace
3334

3435
if TYPE_CHECKING:
3536
from mujoco.mjx import Data, Model
@@ -134,45 +135,15 @@ def step(self, n_steps: int = 1):
134135

135136
def state_control(self, controls: Array):
136137
"""Set the desired state for all drones in all worlds."""
137-
assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch"
138-
assert self.control == Control.state, "State control is not enabled by the sim config"
139-
controls = to_device(controls, self.device)
140-
self.data = self.data.replace(
141-
controls=self.data.controls.replace(
142-
state=self.data.controls.state.replace(staged_cmd=controls)
143-
)
144-
)
138+
self.data = F.state_control(self.data, controls)
145139

146140
def attitude_control(self, controls: Array):
147-
"""Set the desired attitude for all drones in all worlds.
141+
"""Set the desired attitude for all drones in all worlds."""
142+
self.data = F.attitude_control(self.data, controls)
148143

149-
We need to stage the attitude controls because the sys_id physics mode operates directly on
150-
the attitude controls. If we were to directly update the controls, this would effectively
151-
bypass the control frequency and run the attitude controller at the physics update rate. By
152-
staging the controls, we ensure that the physics module sees the old controls until the
153-
controller updates at its correct frequency.
154-
"""
155-
assert controls.shape == (self.n_worlds, self.n_drones, 4), "controls shape mismatch"
156-
assert self.control == Control.attitude, "Attitude control is not enabled by the sim config"
157-
controls = to_device(controls, self.device)
158-
self.data = self.data.replace(
159-
controls=self.data.controls.replace(
160-
attitude=self.data.controls.attitude.replace(staged_cmd=controls)
161-
)
162-
)
163-
164-
def force_torque_control(self, cmd: Array):
144+
def force_torque_control(self, controls: Array):
165145
"""Set the desired force and torque for all drones in all worlds."""
166-
assert cmd.shape == (self.n_worlds, self.n_drones, 4), "Command shape mismatch"
167-
assert self.control == Control.force_torque, (
168-
"Force-torque control is not enabled by the sim config"
169-
)
170-
controls = to_device(cmd, self.device)
171-
self.data = self.data.replace(
172-
controls=self.data.controls.replace(
173-
force_torque=self.data.controls.force_torque.replace(staged_cmd=controls)
174-
)
175-
)
146+
self.data = F.force_torque_control(self.data, controls)
176147

177148
@requires_mujoco_sync
178149
def render(
@@ -408,18 +379,7 @@ def controllable(self) -> Array:
408379
as soon as the controller frequency allows for an update. Successive control updates that
409380
happen before the staged buffers are applied overwrite the desired values.
410381
"""
411-
controls = self.data.controls
412-
match self.control:
413-
case Control.state:
414-
control_steps, control_freq = controls.state.steps, controls.state.freq
415-
case Control.attitude:
416-
control_steps, control_freq = controls.attitude.steps, controls.attitude.freq
417-
case Control.force_torque:
418-
control_steps = controls.force_torque.steps
419-
control_freq = controls.force_torque.freq
420-
case _:
421-
raise NotImplementedError(f"Control mode {self.control} not implemented")
422-
return controllable(self.data.core.steps, self.data.core.freq, control_steps, control_freq)
382+
return F.controllable(self.data)
423383

424384
@requires_mujoco_sync
425385
def contacts(self, body: str | None = None) -> Array:

pixi.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/unit/test_functional.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
import jax.numpy as jnp
4+
import numpy as np
5+
import pytest
6+
7+
import crazyflow.sim.functional as F
8+
from crazyflow.control import Control
9+
from crazyflow.sim import Sim
10+
11+
12+
@pytest.mark.unit
13+
def test_functional_resets():
14+
"""Test that the functional API works as expected for resets."""
15+
sim = Sim()
16+
data, default_data = sim.build_data(), sim.build_default_data()
17+
reset_fn = sim.build_reset_fn()
18+
data = data.replace(states=data.states.replace(pos=jnp.ones_like(data.states.pos)))
19+
sim.data = data
20+
# Test types
21+
assert callable(reset_fn), "reset_fn must be a pure function"
22+
assert not hasattr(reset_fn, "__self__"), "reset_fn must not be a bound method"
23+
# Test the reset runs as expected
24+
data_reset = reset_fn(data, default_data, None)
25+
assert jnp.all(data_reset.states.pos == 0), "reset_fn did not reset positions to zero"
26+
data_reset = reset_fn(data, default_data, jnp.array([True] * sim.n_worlds))
27+
assert jnp.all(data_reset.states.pos == 0), "reset_fn did not reset positions to zero with mask"
28+
29+
30+
@pytest.mark.unit
31+
def test_functional_steps():
32+
"""Test that the functional API works as expected for steps."""
33+
sim = Sim()
34+
data = sim.build_data()
35+
data = data.replace(states=data.states.replace(pos=jnp.ones_like(data.states.pos)))
36+
step_fn = sim.build_step_fn()
37+
# Test types
38+
assert callable(step_fn), "step_fn must be a pure function"
39+
assert not hasattr(step_fn, "__self__"), "step_fn must not be a bound method"
40+
# Test the step function runs as expected
41+
data_step = step_fn(data, 5)
42+
assert jnp.all(data_step.states.pos[..., 2] < 1), "step_fn did not step correctly"
43+
44+
45+
@pytest.mark.unit
46+
@pytest.mark.parametrize("attitude_freq", [33, 50, 100, 200])
47+
def test_functional_attitude_control(attitude_freq: int):
48+
"""Test that functional attitude control respects frequency and applies commands correctly.
49+
50+
Ported from test_attitude_control in test_sim.py.
51+
"""
52+
sim = Sim(n_worlds=2, n_drones=3, control="attitude", freq=100, attitude_freq=attitude_freq)
53+
54+
data = sim.build_data()
55+
default_data = sim.build_default_data()
56+
reset_fn = sim.build_reset_fn()
57+
step_fn = sim.build_step_fn()
58+
59+
can_control_1 = np.arange(6) * attitude_freq % sim.freq < attitude_freq
60+
can_control_2 = np.array([0, 0, 1, 2, 3, 4]) * attitude_freq % sim.freq < attitude_freq
61+
for i in range(6):
62+
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 4)
63+
# Check controllable status
64+
controllable = F.controllable(data)
65+
assert jnp.all(controllable[0] == can_control_1[i]), f"Controllable 1 mismatch at t={i}"
66+
assert jnp.all(controllable[1] == can_control_2[i]), f"Controllable 2 mismatch at t={i}"
67+
# Apply control
68+
data = F.attitude_control(data, cmd)
69+
data = step_fn(data, 1)
70+
sim_cmd = data.controls.attitude.cmd[0]
71+
if can_control_1[i]:
72+
assert jnp.all(sim_cmd == cmd[0]), f"Controls do not match at t={i}"
73+
else:
74+
assert not jnp.all(sim_cmd == cmd[0]), f"Controls shouldn't match at t={i}"
75+
sim_cmd = data.controls.attitude.cmd[1]
76+
if can_control_2[i]:
77+
assert jnp.all(sim_cmd == cmd[1]), f"Controls do not match at t={i}"
78+
else:
79+
assert not jnp.all(sim_cmd == cmd[1]), f"Controls shouldn't match at t={i}"
80+
if i == 0: # Make world 2 asynchronous
81+
data = reset_fn(data, default_data, np.array([False, True]))
82+
83+
84+
@pytest.mark.unit
85+
def test_functional_attitude_control_device(device: str):
86+
"""Test that functional attitude control maintains JAX arrays on correct device."""
87+
sim = Sim(n_worlds=2, n_drones=3, control=Control.attitude, device=device)
88+
data = sim.build_data()
89+
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 4)
90+
data = F.attitude_control(data, cmd)
91+
controls = data.controls.attitude
92+
assert isinstance(controls.staged_cmd, jnp.ndarray), "Buffers must remain JAX arrays"
93+
assert jnp.all(controls.staged_cmd == cmd), "Buffers must match command"
94+
95+
96+
@pytest.mark.unit
97+
@pytest.mark.parametrize("state_freq", [33, 50, 100, 200])
98+
def test_functional_state_control(state_freq: int):
99+
"""Test that functional state control respects frequency and applies commands correctly."""
100+
sim = Sim(n_worlds=2, n_drones=3, control=Control.state, freq=100, state_freq=state_freq)
101+
102+
data = sim.build_data()
103+
default_data = sim.build_default_data()
104+
reset_fn = sim.build_reset_fn()
105+
step_fn = sim.build_step_fn()
106+
107+
can_control_1 = np.arange(6) * state_freq % sim.freq < state_freq
108+
can_control_2 = np.array([0, 0, 1, 2, 3, 4]) * state_freq % sim.freq < state_freq
109+
110+
for i in range(6):
111+
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 13)
112+
# Check controllable status
113+
controllable = F.controllable(data)
114+
assert jnp.all(controllable[0] == can_control_1[i]), f"Controllable 1 mismatch at t={i}"
115+
assert jnp.all(controllable[1] == can_control_2[i]), f"Controllable 2 mismatch at t={i}"
116+
# Apply control
117+
data = F.state_control(data, cmd)
118+
last_attitude = data.controls.attitude.staged_cmd
119+
data = step_fn(data, 1)
120+
attitude = data.controls.attitude.staged_cmd
121+
122+
last_att, att = last_attitude[0], attitude[0]
123+
if can_control_1[i]:
124+
assert not jnp.all(att == last_att), f"Controls haven't been applied at t={i}"
125+
else:
126+
assert jnp.all(att == last_att), f"Controls should be unchanged at t={i}"
127+
128+
last_att, att = last_attitude[1], attitude[1]
129+
if can_control_2[i]:
130+
assert not jnp.all(att == last_att), f"Controls haven't been applied at t={i}"
131+
else:
132+
assert jnp.all(att == last_att), f"Controls should be unchanged at t={i}"
133+
if i == 0: # Make world 2 asynchronous
134+
data = reset_fn(data, default_data, np.array([False, True]))
135+
136+
137+
@pytest.mark.unit
138+
def test_functional_state_control_device(device: str):
139+
"""Test that functional state control maintains JAX arrays on correct device."""
140+
sim = Sim(n_worlds=2, n_drones=3, control=Control.state, device=device)
141+
data = sim.build_data()
142+
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 13)
143+
data = F.state_control(data, cmd)
144+
controls = data.controls.state
145+
assert isinstance(controls.cmd, jnp.ndarray), "Buffers must remain JAX arrays"
146+
assert isinstance(controls.staged_cmd, jnp.ndarray), "Buffers must remain JAX arrays"
147+
assert jnp.all(controls.staged_cmd == cmd), "Buffers must match command"

0 commit comments

Comments
 (0)