|
| 1 | +"""This module contains functions to compute derivatives using State Variable Filters.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import logging |
| 6 | +from typing import TYPE_CHECKING |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from scipy.integrate import solve_ivp |
| 10 | +from scipy.interpolate import interp1d |
| 11 | +from scipy.signal import bilinear, butter, filtfilt, lfilter, lfiltic |
| 12 | +from scipy.spatial.transform import Rotation as R |
| 13 | + |
| 14 | +from drone_models.utils.rotation import rpy_rates2ang_vel |
| 15 | + |
| 16 | +if TYPE_CHECKING: |
| 17 | + from drone_models._typing import Array # To be changed to array_api_typing later |
| 18 | + |
| 19 | +logger = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +def preprocessing(data: dict[str, Array], constants: dict[str, float]) -> dict[str, Array]: |
| 23 | + """TODO.""" |
| 24 | + data["dt"] = np.diff(data["time"]) |
| 25 | + data["time"] -= data["time"][0] |
| 26 | + ### Outlier detection + interpolation |
| 27 | + b, a = butter(N=4, Wn=1, fs=1 / np.mean(data["dt"])) |
| 28 | + residuals = data["pos"] - filtfilt(b, a, data["pos"], axis=0) |
| 29 | + outliers = np.abs(residuals) > 0.3 |
| 30 | + outliers = np.sum(outliers, axis=-1) |
| 31 | + is_outlier = np.asarray(outliers).astype(bool) |
| 32 | + n_outliers = np.sum(outliers) |
| 33 | + # TODO also check quat for outliers! |
| 34 | + |
| 35 | + if n_outliers > 0: |
| 36 | + logger.warning(f"{n_outliers} outliers detected. Interpolating") |
| 37 | + time_good = data["time"][~is_outlier] |
| 38 | + pos_good = data["pos"][~is_outlier] |
| 39 | + quat_good = data["quat"][~is_outlier] |
| 40 | + interp_pos = interp1d(time_good, pos_good, axis=0, fill_value="extrapolate") |
| 41 | + interp_quat = interp1d(time_good, quat_good, axis=0, fill_value="extrapolate") |
| 42 | + data["pos"][is_outlier] = interp_pos(data["time"][is_outlier]) |
| 43 | + data["quat"][is_outlier] = interp_quat(data["time"][is_outlier]) |
| 44 | + |
| 45 | + ### Normalizing orientation (assuming zero at start) and calculating rpy |
| 46 | + time_span = 0.1 |
| 47 | + time_index = int(time_span / np.mean(data["dt"])) |
| 48 | + quat_avg = np.mean(data["quat"][:time_index], axis=0) |
| 49 | + quat_avg /= np.linalg.norm(quat_avg) |
| 50 | + rot_corr = R.from_quat(quat_avg).inv() |
| 51 | + rot = rot_corr * R.from_quat(data["quat"]) |
| 52 | + data["quat"] = rot.as_quat() |
| 53 | + data["rpy"] = rot.as_euler("xyz") |
| 54 | + data["z_axis"] = rot.inv().as_matrix()[..., -1, :] |
| 55 | + |
| 56 | + ### Force clipping and vectorization |
| 57 | + data["cmd_f"] = data["cmd_pwm"] / constants["PWM_MAX"] * constants["THRUST_MAX"] * 4 |
| 58 | + data["cmd_f"] = np.clip(data["cmd_f"], 0, constants["THRUST_MAX"] * 4) |
| 59 | + data["cmd_pwm"] = np.clip(data["cmd_pwm"], constants["PWM_MIN"], constants["PWM_MAX"]) |
| 60 | + rot = R.from_quat(data["quat"]) |
| 61 | + zeros = np.zeros_like(data["cmd_f"]) |
| 62 | + f_cmd_vec = np.stack((zeros, zeros, data["cmd_f"]), axis=-1) |
| 63 | + data["cmd_f_vec"] = rot.apply(f_cmd_vec) |
| 64 | + |
| 65 | + ### Rotational error |
| 66 | + R_act = rot.as_matrix() |
| 67 | + R_des = R.from_euler("xyz", data["cmd_rpy"], degrees=True).as_matrix() |
| 68 | + eRM = np.matmul(np.swapaxes(R_des, -1, -2), R_act) - np.matmul( |
| 69 | + np.swapaxes(R_act, -1, -2), R_des |
| 70 | + ) |
| 71 | + data["eR"] = np.stack( |
| 72 | + (eRM[..., 2, 1], eRM[..., 0, 2], eRM[..., 1, 0]), axis=-1 |
| 73 | + ) # vee operator (SO3 to R3) |
| 74 | + data["eR_vec"] = (rot.inv() * R.from_euler("xyz", data["cmd_rpy"], degrees=True)).as_rotvec() |
| 75 | + |
| 76 | + return data |
| 77 | + |
| 78 | + |
| 79 | +def derivatives_svf(data: dict[str, Array], constants: dict[str, float]) -> dict[str, Array]: |
| 80 | + """Calculate derivatives with State Variable Filter.""" |
| 81 | + # Important: Don't mix with unfiltered signals (also for input!) |
| 82 | + if data is None: |
| 83 | + return None |
| 84 | + |
| 85 | + svf_linear = state_variable_filter(data["pos"].T, data["time"], f_c=6, N_deriv=3) |
| 86 | + data["SVF_pos"] = svf_linear[:, 0].T |
| 87 | + data["SVF_vel"] = svf_linear[:, 1].T |
| 88 | + data["SVF_acc"] = svf_linear[:, 2].T |
| 89 | + data["SVF_jerk"] = svf_linear[:, 3].T |
| 90 | + |
| 91 | + svf_rotational = state_variable_filter(data["rpy"].T, data["time"], f_c=8, N_deriv=3) |
| 92 | + data["SVF_rpy"] = svf_rotational[:, 0].T |
| 93 | + data["SVF_drpy"] = svf_rotational[:, 1].T |
| 94 | + data["SVF_ddrpy"] = svf_rotational[:, 2].T |
| 95 | + data["SVF_dddrpy"] = svf_rotational[:, 3].T |
| 96 | + rot = R.from_euler("xyz", data["SVF_rpy"]) |
| 97 | + data["SVF_quat"] = rot.as_quat() |
| 98 | + data["SVF_z_axis"] = rot.inv().as_matrix()[..., -1, :] |
| 99 | + data["SVF_ang_vel"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_drpy"]) |
| 100 | + data["SVF_ang_acc"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_ddrpy"]) |
| 101 | + data["SVF_ang_jerk"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_dddrpy"]) |
| 102 | + |
| 103 | + svf_input_pwm = state_variable_filter(data["cmd_pwm"], data["time"], f_c=6, N_deriv=3) |
| 104 | + data["SVF_cmd_pwm"] = svf_input_pwm[0] |
| 105 | + data["SVF_cmd_f"] = data["SVF_cmd_pwm"] / constants["PWM_MAX"] * constants["THRUST_MAX"] * 4 |
| 106 | + |
| 107 | + svf_input_rpy = state_variable_filter(data["cmd_rpy"].T, data["time"], f_c=8, N_deriv=3) |
| 108 | + data["SVF_cmd_rpy"] = svf_input_rpy[:, 0].T |
| 109 | + |
| 110 | + R_act = rot.as_matrix() |
| 111 | + rot_cmd = R.from_euler("xyz", data["SVF_cmd_rpy"], degrees=True) |
| 112 | + R_des = rot_cmd.as_matrix() |
| 113 | + eRM = np.matmul(np.swapaxes(R_des, -1, -2), R_act) - np.matmul( |
| 114 | + np.swapaxes(R_act, -1, -2), R_des |
| 115 | + ) |
| 116 | + data["SVF_eR"] = np.stack( |
| 117 | + (eRM[..., 2, 1], eRM[..., 0, 2], eRM[..., 1, 0]), axis=-1 |
| 118 | + ) # vee operator (SO3 to R3) |
| 119 | + data["SVF_eR_vec"] = (rot.inv() * rot_cmd).as_rotvec() |
| 120 | + |
| 121 | + zeros = np.zeros_like(data["cmd_f"]) |
| 122 | + f_cmd_vec = np.stack((zeros, zeros, data["cmd_f"]), axis=-1) |
| 123 | + data["SVF_cmd_f_vec"] = rot.apply(f_cmd_vec) |
| 124 | + |
| 125 | + return data |
| 126 | + |
| 127 | + |
| 128 | +def state_variable_filter(y: Array, t: Array, f_c: float = 1, N_deriv: int = 2) -> Array: |
| 129 | + """A state variable filter that low pass filters the signal and computes the derivatives. |
| 130 | +
|
| 131 | + Args: |
| 132 | + y: The signal to be filtered. Can be 1D (signal_length) or 2D (batch_size, signal_length). |
| 133 | + t: The time values for the signal. Optimally fixed sampling frequency. |
| 134 | + f_c: Corner frequency of the filter in Hz. Defaults to 1. |
| 135 | + N_deriv: Number of derivatives to be computed. Defaults to 2. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + Array: The filtered signal and its derivatives. Shape (batch_size, N_deriv+1, signal_length). |
| 139 | + """ |
| 140 | + if y.ndim == 1: |
| 141 | + y = y[None, :] # Add batch dimension if single signal |
| 142 | + batch_size, signal_length = y.shape |
| 143 | + |
| 144 | + # The filter needs to have a minimum of two extra states |
| 145 | + # One for the filtered input signal and one for the actual filter |
| 146 | + N_ord = N_deriv + 2 |
| 147 | + omega_c = 2 * np.pi * f_c |
| 148 | + f_s = 1 / np.mean(np.diff(t)) |
| 149 | + |
| 150 | + b, a = butter(N=N_ord, Wn=omega_c, analog=True) |
| 151 | + b_dig, a_dig = bilinear(b, a, fs=f_s) |
| 152 | + a_flipped = np.flip(a) |
| 153 | + |
| 154 | + def _f(t: Array, x: Array, u: Array) -> Array: |
| 155 | + x_dot = [] |
| 156 | + x_dot_last = 0 |
| 157 | + # The first states are a simple integrator chain |
| 158 | + for i in np.arange(1, N_ord): |
| 159 | + x_dot.append(x[i]) |
| 160 | + # Last state uses the filter coefficients |
| 161 | + for i in np.arange(0, N_ord): |
| 162 | + x_dot_last -= a_flipped[i] * x[i] |
| 163 | + x_dot_last += b[0] * u(t) |
| 164 | + x_dot.append(x_dot_last) |
| 165 | + |
| 166 | + return x_dot |
| 167 | + |
| 168 | + results = np.zeros((batch_size, N_deriv + 1, signal_length)) |
| 169 | + |
| 170 | + for i in range(batch_size): |
| 171 | + # Define input |
| 172 | + # Prefilter input backwards to remove time shift |
| 173 | + # Add padding to remove filter oscillations in data |
| 174 | + pad = 100 |
| 175 | + y_backwards = np.flip(y[i], axis=-1) |
| 176 | + y_backwards_padded = np.concatenate([np.ones(pad) * y_backwards[0], y_backwards]) |
| 177 | + zi = lfiltic( |
| 178 | + b_dig, a_dig, y_backwards_padded, x=y_backwards_padded |
| 179 | + ) # initial filter conditions |
| 180 | + y_backwards, _ = lfilter(b_dig, a_dig, y_backwards_padded, axis=-1, zi=zi) |
| 181 | + u = interp1d( |
| 182 | + t, np.flip(y_backwards[pad:], axis=-1), kind="linear", fill_value="extrapolate" |
| 183 | + ) |
| 184 | + |
| 185 | + # Solve system with initial conditions |
| 186 | + x0 = np.zeros(N_ord) |
| 187 | + x0[0] = y[i, 0] |
| 188 | + sol = solve_ivp(_f, [t[0], t[-1]], x0, t_eval=t, args=(u,)) |
| 189 | + |
| 190 | + results[i] = sol.y[:-1] # Last state is not of interest |
| 191 | + |
| 192 | + return results.squeeze() # Remove batch dim if not needed |
0 commit comments