Skip to content

Commit f107897

Browse files
committed
[WIP] Add Mellinger tests
1 parent 08ffb18 commit f107897

File tree

4 files changed

+154
-26
lines changed

4 files changed

+154
-26
lines changed

drone_models/controller/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,15 @@
55
controllers as pure functions to ensure that users can jit-compile them. All controllers use
66
broadcasting to support batching of arbitrary leading dimensions.
77
"""
8+
9+
from typing import Callable
10+
11+
from drone_models.controller.mellinger import (
12+
attitude2force_torque as mellinger_attitude2force_torque,
13+
)
14+
from drone_models.controller.mellinger import pos2attitude as mellinger_pos2attitude
15+
16+
available_controller: dict[str, Callable] = {
17+
"mellinger_pos2attitude": mellinger_pos2attitude,
18+
"mellinger_attitude2force_torque": mellinger_attitude2force_torque,
19+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,60 @@
11
"""Constants for the controllers."""
2+
3+
from types import ModuleType
4+
5+
import numpy as np
6+
from array_api_typing import Array
7+
8+
# Constants for the controllers
9+
# Same as in the firmware. Do not touch
10+
# Not part of the Constants class, since the values are controller specific and not drone specific
11+
12+
#### Mellinger controller (see controller_mellinger.c)
13+
# Note: The firmware assumes mass=0.027. With battery thats closer to 0.034 though!!!
14+
mass = 0.034 # TODO This is the wrong mass (cf with battery weighs more!)
15+
massThrust = 132000 * 0.034 / 0.027
16+
17+
# XY Position PID
18+
kp_xy = 0.4 # P
19+
kd_xy = 0.2 # D
20+
ki_xy = 0.05 # I
21+
i_range_xy = 2.0
22+
23+
# Z Position
24+
kp_z = 1.25 # P
25+
kd_z = 0.4 # D
26+
ki_z = 0.05 # I
27+
i_range_z = 0.4
28+
29+
# Attitude
30+
kR_xy = 70000.0 # P
31+
kw_xy = 20000.0 # D
32+
ki_m_xy = 0.0 # I
33+
i_range_m_xy = 1.0
34+
35+
# Yaw
36+
kR_z = 60000.0 # P
37+
kw_z = 12000.0 # D
38+
ki_m_z = 500.0 # I
39+
i_range_m_z = 1500.0
40+
41+
# roll and pitch angular velocity
42+
kd_omega_rp = 200.0 # D
43+
44+
45+
def cntrl_const_mel(xp: ModuleType = np) -> dict[str, Array | float]:
46+
"""Returns the controller constants for the Mellinger controller."""
47+
return {
48+
"mass": mass,
49+
"massThrust": massThrust,
50+
"kp": xp.asarray([kp_xy, kp_xy, kp_z]),
51+
"kd": xp.asarray([kd_xy, kd_xy, kd_z]),
52+
"ki": xp.asarray([ki_xy, ki_xy, ki_z]),
53+
"i_range": xp.asarray([i_range_xy, i_range_xy, i_range_z]),
54+
"kR": xp.asarray([kR_xy, kR_xy, kR_z]),
55+
"kw": xp.asarray([kw_xy, kw_xy, kw_z]),
56+
"ki_m": xp.asarray([ki_m_xy, ki_m_xy, ki_m_z]),
57+
"kd_omega": xp.asarray([kd_omega_rp, kd_omega_rp, 0.0]),
58+
"i_range_m": xp.asarray([i_range_m_xy, i_range_m_xy, i_range_m_z]),
59+
"torque_pwm_range": 32_000.0,
60+
}

drone_models/controller/mellinger.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from typing import TYPE_CHECKING
66

77
from array_api_compat import array_namespace
8+
from scipy.spatial.transform import Rotation as R
89

9-
import drone_models.utils.cf2 as cf2
10-
import drone_models.utils.rotation as R
10+
import drone_models.utils.rotation as rotation
1111
from drone_models.transform import force2pwm, motor_force2rotor_speed, pwm2force
12-
from drone_models.utils.constants_controllers import cntrl_const_mel
1312

1413
if TYPE_CHECKING:
1514
from array_api_typing import Array
@@ -24,6 +23,7 @@ def pos2attitude(
2423
ang_vel: Array,
2524
cmd: Array,
2625
constants: Constants,
26+
parameters: dict[str, Array | float],
2727
dt: float = 1 / 500,
2828
i_error: Array | None = None,
2929
) -> tuple[Array, Array]:
@@ -44,6 +44,7 @@ def pos2attitude(
4444
cmd: Full state command in SI units and rad with shape (..., 13). The entries are
4545
[x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].
4646
constants: Drone specific constants.
47+
parameters: Controller specific parameters. TODO make class?
4748
dt: Time since last call.
4849
i_error: Integral error of the position controller with shape (..., 3).
4950
@@ -63,17 +64,15 @@ def pos2attitude(
6364
# l.151 ff Integral Error
6465
if i_error is None:
6566
i_error = xp.zeros_like(pos)
66-
i_error = xp.clip(
67-
i_error + r_error * dt, -cntrl_const_mel["i_range"], cntrl_const_mel["i_range"]
68-
)
67+
i_error = xp.clip(i_error + r_error * dt, -parameters["i_range"], parameters["i_range"])
6968
# l. 161 Desired thrust [F_des]
7069
# => only one case here, since setpoint is always in absolute mode
7170
# Note: since we've defined the gravity in z direction, a "-" needs to be added
7271
target_thrust = (
73-
cntrl_const_mel["mass"] * (setpoint_acc - constants.GRAVITY_VEC)
74-
+ cntrl_const_mel["kp"] * r_error
75-
+ cntrl_const_mel["kd"] * v_error
76-
+ cntrl_const_mel["ki"] * i_error
72+
parameters["mass"] * (setpoint_acc - constants.GRAVITY_VEC)
73+
+ parameters["kp"] * r_error
74+
+ parameters["kd"] * v_error
75+
+ parameters["ki"] * i_error
7776
)
7877
# l. 178 Rate-controlled YAW is moving YAW angle setpoint
7978
# => only one case here, since the setpoint is always in absolute mode
@@ -88,7 +87,7 @@ def pos2attitude(
8887
# Taking the dot product of the last axis:
8988
current_thrust = xp.vecdot(target_thrust, z_axis, axis=-1)
9089
# l. 207 Calculate axis [zB_des]
91-
z_axis_desired = target_thrust / xp.linalg.norm(target_thrust)
90+
z_axis_desired = target_thrust / xp.linalg.vector_norm(target_thrust)
9291
# l. 210 [xC_des]
9392
# x_axis_desired = z_axis_desired x [sin(yaw), cos(yaw), 0]^T
9493
x_c_des_x = xp.cos(desiredYaw)
@@ -97,16 +96,16 @@ def pos2attitude(
9796
x_c_des = xp.stack((x_c_des_x, x_c_des_y, x_c_des_z), axis=-1)
9897
# [yB_des]
9998
y_axis_desired = xp.linalg.cross(z_axis_desired, x_c_des)
100-
y_axis_desired = y_axis_desired / xp.linalg.norm(y_axis_desired)
99+
y_axis_desired = y_axis_desired / xp.linalg.vector_norm(y_axis_desired)
101100
# [xB_des]
102101
x_axis_desired = xp.linalg.cross(y_axis_desired, z_axis_desired)
103102
# converting desired axis to rotation matrix and then to RPY
104103
matrix = xp.stack((x_axis_desired, y_axis_desired, z_axis_desired), axis=-1)
105104
command_RPY = R.from_matrix(matrix).as_euler("xyz", degrees=False)
106105
# l. 283
107-
thrust = cntrl_const_mel["massThrust"] * current_thrust
106+
thrust = parameters["massThrust"] * current_thrust
108107
# Transform thrust into N to keep uniform interface
109-
thrust = cf2.pwm2force(thrust, constants, perMotor=False)
108+
thrust = pwm2force(thrust, constants.THRUST_MAX, constants.PWM_MAX) * 4
110109
command_rpyt = xp.concat((command_RPY, thrust[..., None]), axis=-1)
111110
return command_rpyt, i_error
112111

@@ -116,8 +115,9 @@ def attitude2force_torque(
116115
quat: Array,
117116
vel: Array,
118117
ang_vel: Array,
119-
command_rpyt: Array,
118+
cmd: Array,
120119
constants: Constants,
120+
parameters: dict[str, Array | float],
121121
dt: float = 1 / 500,
122122
i_error_m: Array | None = None,
123123
ang_vel_des: Array | None = None,
@@ -131,8 +131,9 @@ def attitude2force_torque(
131131
quat: Drone orientation as xyzw quaternion with shape (..., 4).
132132
vel: Drone velocity with shape (..., 3).
133133
ang_vel: Drone angular drone velocity in rad/s with shape (..., 3).
134-
command_rpyt: Commanded attitude (roll, pitch, yaw) and total thrust [rad, rad, rad, N]
134+
cmd: Commanded attitude (roll, pitch, yaw) and total thrust [rad, rad, rad, N]
135135
constants (Constants): Constants of the specific drone
136+
parameters: Controller specific parameters. TODO make class?
136137
dt: Time since last call.
137138
i_error_m: Integral error.
138139
ang_vel_des: Desired angular velocity in rad/s.
@@ -142,10 +143,10 @@ def attitude2force_torque(
142143
Returns:
143144
4 Motor forces [N], i_error_m
144145
"""
145-
xp = array_namespace(pos)
146-
force_des = command_rpyt[..., -1] # Total thrust in N
147-
rpy_des = command_rpyt[..., :-1]
148-
axis_flip = xp.array([1, -1, 1]) # to change the direction of the y axis
146+
xp = array_namespace(quat)
147+
force_des = cmd[..., -1] # Total thrust in N
148+
rpy_des = cmd[..., -4:-1]
149+
axis_flip = xp.asarray([1.0, -1.0, 1.0]) # to change the direction of the y axis
149150
# l. 220 ff [eR]. We're using the "inefficient" code path from the firmware
150151
rot = R.from_quat(quat)
151152
rot_des = R.from_euler("xyz", rpy_des, degrees=False)
@@ -174,16 +175,18 @@ def attitude2force_torque(
174175
if i_error_m is None:
175176
i_error_m = xp.zeros_like(pos)
176177
i_error_m = i_error_m - eR * dt
177-
i_error_m = xp.clip(i_error_m, -cntrl_const_mel["i_range_m"], cntrl_const_mel["i_range_m"])
178+
i_error_m = xp.clip(i_error_m, -parameters["i_range_m"], parameters["i_range_m"])
178179
# l. 278 ff Moment:
179180
torque_pwm = (
180-
-cntrl_const_mel["kR"] * eR
181-
+ cntrl_const_mel["kw"] * ew
182-
+ cntrl_const_mel["ki_m"] * i_error_m
183-
+ cntrl_const_mel["kd_omega"] * err_d
181+
-parameters["kR"] * eR
182+
+ parameters["kw"] * ew
183+
+ parameters["ki_m"] * i_error_m
184+
+ parameters["kd_omega"] * err_d
184185
)
185186
# l. 297 ff
186-
torque_pwm = xp.clip(torque_pwm, -32_000, 32_000)
187+
torque_pwm = xp.clip(
188+
torque_pwm, -parameters["torque_pwm_range"], parameters["torque_pwm_range"]
189+
)
187190
torque_pwm = xp.where((force_des > 0)[..., None], torque_pwm, 0.0)
188191
# Info: The following part is NOT part of the Mellinger controller, but of the firmware.
189192
# The mixing of force and torque is done on the PWM level. We therefore mix those here according

tests/unit/test_controller.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Tests of the numeric models."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Callable
6+
7+
import array_api_compat.numpy as np
8+
import array_api_strict as xp
9+
import jax
10+
import pytest
11+
12+
from drone_models.controller import available_controller
13+
from drone_models.controller.constants import cntrl_const_mel
14+
from drone_models.utils.constants import Constants, available_drone_types
15+
16+
if TYPE_CHECKING:
17+
from array_api_typing import Array
18+
19+
# For all tests to pass, we need the same precsion in jax as in np
20+
jax.config.update("jax_enable_x64", True)
21+
22+
23+
def create_rnd_states(
24+
shape: tuple[int, ...] = (),
25+
) -> tuple[Array, Array, Array, Array, Array, Array, Array]:
26+
"""Creates N random states."""
27+
pos = xp.asarray(np.random.uniform(-5, 5, shape + (3,)))
28+
quat = xp.asarray(np.random.uniform(-1, 1, shape + (4,))) # Libraries normalize automatically
29+
vel = xp.asarray(np.random.uniform(-5, 5, shape + (3,)))
30+
ang_vel = xp.asarray(np.random.uniform(-2, 2, shape + (3,)))
31+
rotor_vel = xp.asarray(np.random.uniform(0, 0.2, shape + (4,)))
32+
forces_dist = xp.asarray(np.random.uniform(-0.2, 0.2, shape + (3,)))
33+
torques_dist = xp.asarray(np.random.uniform(-0.05, 0.05, shape + (3,)))
34+
return pos, quat, vel, ang_vel, rotor_vel, forces_dist, torques_dist
35+
36+
37+
def create_rnd_commands(shape: tuple[int, ...] = (), dim: int = 4) -> Array:
38+
"""Creates N random inputs with size dim."""
39+
return xp.asarray(np.random.uniform(0, 0.2, shape + (dim,)))
40+
41+
42+
@pytest.mark.unit
43+
@pytest.mark.parametrize("controller_name, controller", available_controller.items())
44+
@pytest.mark.parametrize("drone_type", available_drone_types)
45+
def test_controller(controller_name: str, controller: Callable, drone_type: str):
46+
"""TODO."""
47+
constants = Constants.from_config(drone_type, xp)
48+
batch_shape = (10,)
49+
pos, quat, vel, ang_vel, _, _, _ = create_rnd_states(batch_shape)
50+
cmd = create_rnd_commands(batch_shape, dim=13) # TODO make dependent on controller
51+
52+
parameters = cntrl_const_mel(xp)
53+
54+
controller(pos, quat, vel, ang_vel, cmd, constants, parameters)

0 commit comments

Comments
 (0)