Skip to content

Commit 6948a61

Browse files
committed
[WIP] Add identification pipeline
1 parent d603ef4 commit 6948a61

File tree

2 files changed

+697
-0
lines changed

2 files changed

+697
-0
lines changed

drone_models/utils/derivatives.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""This module contains functions to compute derivatives using State Variable Filters."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import TYPE_CHECKING
7+
8+
import numpy as np
9+
from scipy.integrate import solve_ivp
10+
from scipy.interpolate import interp1d
11+
from scipy.signal import bilinear, butter, filtfilt, lfilter, lfiltic
12+
from scipy.spatial.transform import Rotation as R
13+
14+
from drone_models.utils.rotation import rpy_rates2ang_vel
15+
16+
if TYPE_CHECKING:
17+
from drone_models._typing import Array # To be changed to array_api_typing later
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
def preprocessing(data: dict[str, Array], constants: dict[str, float]) -> dict[str, Array]:
23+
"""TODO."""
24+
data["dt"] = np.diff(data["time"])
25+
data["time"] -= data["time"][0]
26+
### Outlier detection + interpolation
27+
b, a = butter(N=4, Wn=1, fs=1 / np.mean(data["dt"]))
28+
residuals = data["pos"] - filtfilt(b, a, data["pos"], axis=0)
29+
outliers = np.abs(residuals) > 0.3
30+
outliers = np.sum(outliers, axis=-1)
31+
is_outlier = np.asarray(outliers).astype(bool)
32+
n_outliers = np.sum(outliers)
33+
# TODO also check quat for outliers!
34+
35+
if n_outliers > 0:
36+
logger.warning(f"{n_outliers} outliers detected. Interpolating")
37+
time_good = data["time"][~is_outlier]
38+
pos_good = data["pos"][~is_outlier]
39+
quat_good = data["quat"][~is_outlier]
40+
interp_pos = interp1d(time_good, pos_good, axis=0, fill_value="extrapolate")
41+
interp_quat = interp1d(time_good, quat_good, axis=0, fill_value="extrapolate")
42+
data["pos"][is_outlier] = interp_pos(data["time"][is_outlier])
43+
data["quat"][is_outlier] = interp_quat(data["time"][is_outlier])
44+
45+
### Normalizing orientation (assuming zero at start) and calculating rpy
46+
time_span = 0.1
47+
time_index = int(time_span / np.mean(data["dt"]))
48+
quat_avg = np.mean(data["quat"][:time_index], axis=0)
49+
quat_avg /= np.linalg.norm(quat_avg)
50+
rot_corr = R.from_quat(quat_avg).inv()
51+
rot = rot_corr * R.from_quat(data["quat"])
52+
data["quat"] = rot.as_quat()
53+
data["rpy"] = rot.as_euler("xyz")
54+
data["z_axis"] = rot.inv().as_matrix()[..., -1, :]
55+
56+
### Force clipping and vectorization
57+
data["cmd_f"] = data["cmd_pwm"] / constants["PWM_MAX"] * constants["THRUST_MAX"] * 4
58+
data["cmd_f"] = np.clip(data["cmd_f"], 0, constants["THRUST_MAX"] * 4)
59+
data["cmd_pwm"] = np.clip(data["cmd_pwm"], constants["PWM_MIN"], constants["PWM_MAX"])
60+
rot = R.from_quat(data["quat"])
61+
zeros = np.zeros_like(data["cmd_f"])
62+
f_cmd_vec = np.stack((zeros, zeros, data["cmd_f"]), axis=-1)
63+
data["cmd_f_vec"] = rot.apply(f_cmd_vec)
64+
65+
### Rotational error
66+
R_act = rot.as_matrix()
67+
R_des = R.from_euler("xyz", data["cmd_rpy"], degrees=True).as_matrix()
68+
eRM = np.matmul(np.swapaxes(R_des, -1, -2), R_act) - np.matmul(
69+
np.swapaxes(R_act, -1, -2), R_des
70+
)
71+
data["eR"] = np.stack(
72+
(eRM[..., 2, 1], eRM[..., 0, 2], eRM[..., 1, 0]), axis=-1
73+
) # vee operator (SO3 to R3)
74+
data["eR_vec"] = (rot.inv() * R.from_euler("xyz", data["cmd_rpy"], degrees=True)).as_rotvec()
75+
76+
return data
77+
78+
79+
def derivatives_svf(data: dict[str, Array], constants: dict[str, float]) -> dict[str, Array]:
80+
"""Calculate derivatives with State Variable Filter."""
81+
# Important: Don't mix with unfiltered signals (also for input!)
82+
if data is None:
83+
return None
84+
85+
svf_linear = state_variable_filter(data["pos"].T, data["time"], f_c=6, N_deriv=3)
86+
data["SVF_pos"] = svf_linear[:, 0].T
87+
data["SVF_vel"] = svf_linear[:, 1].T
88+
data["SVF_acc"] = svf_linear[:, 2].T
89+
data["SVF_jerk"] = svf_linear[:, 3].T
90+
91+
svf_rotational = state_variable_filter(data["rpy"].T, data["time"], f_c=8, N_deriv=3)
92+
data["SVF_rpy"] = svf_rotational[:, 0].T
93+
data["SVF_drpy"] = svf_rotational[:, 1].T
94+
data["SVF_ddrpy"] = svf_rotational[:, 2].T
95+
data["SVF_dddrpy"] = svf_rotational[:, 3].T
96+
rot = R.from_euler("xyz", data["SVF_rpy"])
97+
data["SVF_quat"] = rot.as_quat()
98+
data["SVF_z_axis"] = rot.inv().as_matrix()[..., -1, :]
99+
data["SVF_ang_vel"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_drpy"])
100+
data["SVF_ang_acc"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_ddrpy"])
101+
data["SVF_ang_jerk"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_dddrpy"])
102+
103+
svf_input_pwm = state_variable_filter(data["cmd_pwm"], data["time"], f_c=6, N_deriv=3)
104+
data["SVF_cmd_pwm"] = svf_input_pwm[0]
105+
data["SVF_cmd_f"] = data["SVF_cmd_pwm"] / constants["PWM_MAX"] * constants["THRUST_MAX"] * 4
106+
107+
svf_input_rpy = state_variable_filter(data["cmd_rpy"].T, data["time"], f_c=8, N_deriv=3)
108+
data["SVF_cmd_rpy"] = svf_input_rpy[:, 0].T
109+
110+
R_act = rot.as_matrix()
111+
rot_cmd = R.from_euler("xyz", data["SVF_cmd_rpy"], degrees=True)
112+
R_des = rot_cmd.as_matrix()
113+
eRM = np.matmul(np.swapaxes(R_des, -1, -2), R_act) - np.matmul(
114+
np.swapaxes(R_act, -1, -2), R_des
115+
)
116+
data["SVF_eR"] = np.stack(
117+
(eRM[..., 2, 1], eRM[..., 0, 2], eRM[..., 1, 0]), axis=-1
118+
) # vee operator (SO3 to R3)
119+
data["SVF_eR_vec"] = (rot.inv() * rot_cmd).as_rotvec()
120+
121+
zeros = np.zeros_like(data["cmd_f"])
122+
f_cmd_vec = np.stack((zeros, zeros, data["cmd_f"]), axis=-1)
123+
data["SVF_cmd_f_vec"] = rot.apply(f_cmd_vec)
124+
125+
return data
126+
127+
128+
def state_variable_filter(y: Array, t: Array, f_c: float = 1, N_deriv: int = 2) -> Array:
129+
"""A state variable filter that low pass filters the signal and computes the derivatives.
130+
131+
Args:
132+
y: The signal to be filtered. Can be 1D (signal_length) or 2D (batch_size, signal_length).
133+
t: The time values for the signal. Optimally fixed sampling frequency.
134+
f_c: Corner frequency of the filter in Hz. Defaults to 1.
135+
N_deriv: Number of derivatives to be computed. Defaults to 2.
136+
137+
Returns:
138+
Array: The filtered signal and its derivatives. Shape (batch_size, N_deriv+1, signal_length).
139+
"""
140+
if y.ndim == 1:
141+
y = y[None, :] # Add batch dimension if single signal
142+
batch_size, signal_length = y.shape
143+
144+
# The filter needs to have a minimum of two extra states
145+
# One for the filtered input signal and one for the actual filter
146+
N_ord = N_deriv + 2
147+
omega_c = 2 * np.pi * f_c
148+
f_s = 1 / np.mean(np.diff(t))
149+
150+
b, a = butter(N=N_ord, Wn=omega_c, analog=True)
151+
b_dig, a_dig = bilinear(b, a, fs=f_s)
152+
a_flipped = np.flip(a)
153+
154+
def _f(t: Array, x: Array, u: Array) -> Array:
155+
x_dot = []
156+
x_dot_last = 0
157+
# The first states are a simple integrator chain
158+
for i in np.arange(1, N_ord):
159+
x_dot.append(x[i])
160+
# Last state uses the filter coefficients
161+
for i in np.arange(0, N_ord):
162+
x_dot_last -= a_flipped[i] * x[i]
163+
x_dot_last += b[0] * u(t)
164+
x_dot.append(x_dot_last)
165+
166+
return x_dot
167+
168+
results = np.zeros((batch_size, N_deriv + 1, signal_length))
169+
170+
for i in range(batch_size):
171+
# Define input
172+
# Prefilter input backwards to remove time shift
173+
# Add padding to remove filter oscillations in data
174+
pad = 100
175+
y_backwards = np.flip(y[i], axis=-1)
176+
y_backwards_padded = np.concatenate([np.ones(pad) * y_backwards[0], y_backwards])
177+
zi = lfiltic(
178+
b_dig, a_dig, y_backwards_padded, x=y_backwards_padded
179+
) # initial filter conditions
180+
y_backwards, _ = lfilter(b_dig, a_dig, y_backwards_padded, axis=-1, zi=zi)
181+
u = interp1d(
182+
t, np.flip(y_backwards[pad:], axis=-1), kind="linear", fill_value="extrapolate"
183+
)
184+
185+
# Solve system with initial conditions
186+
x0 = np.zeros(N_ord)
187+
x0[0] = y[i, 0]
188+
sol = solve_ivp(_f, [t[0], t[-1]], x0, t_eval=t, args=(u,))
189+
190+
results[i] = sol.y[:-1] # Last state is not of interest
191+
192+
return results.squeeze() # Remove batch dim if not needed

0 commit comments

Comments
 (0)