Skip to content

Commit f1952ee

Browse files
committed
Fix global mutate of casadi symbols
1 parent e2f68a2 commit f1952ee

File tree

7 files changed

+36
-16
lines changed

7 files changed

+36
-16
lines changed

docs/user-guide/symbolic.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,7 @@ State vector layout with `model_rotor_vel=True`:
7575

7676
Both functions return raw CasADi expressions. Wrap them in a `cs.Function` to pass to any CasADi-based solver:
7777

78-
!!! warning
79-
This is currently broken. Investigate and fix the issue.
80-
81-
```{ .python notest }
78+
```python
8279
import casadi as cs
8380
from drone_models.so_rpy_rotor_drag import symbolic_dynamics_euler
8481
from drone_models.core import parametrize

drone_models/first_principles/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def symbolic_dynamics(
231231
+ rotor_dyn_coef[3] * (U**2 - symbols.rotor_vel**2),
232232
)
233233
else:
234+
_saved_rotor_vel = symbols.rotor_vel
234235
symbols.rotor_vel = U
235236
# Creating force and torque vector
236237
forces_motor = (
@@ -281,4 +282,6 @@ def symbolic_dynamics(
281282
X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot)
282283
Y = cs.vertcat(symbols.pos, symbols.quat)
283284

285+
if not model_rotor_vel:
286+
symbols.rotor_vel = _saved_rotor_vel
284287
return X_dot, X, U, Y

drone_models/so_rpy/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,12 @@ def symbolic_dynamics(
164164
* ``Y``: Output ``[pos(3), quat(4)]``.
165165
"""
166166
# We need to set the rpy and drpy symbols before building the euler model
167-
symbols.rpy = rotation.cs_quat2euler(symbols.quat)
168-
symbols.drpy = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel)
167+
_saved_rpy = symbols.rpy
168+
_saved_drpy = symbols.drpy
169+
_rpy_quat = rotation.cs_quat2euler(symbols.quat)
170+
_drpy_quat = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel)
171+
symbols.rpy = _rpy_quat
172+
symbols.drpy = _drpy_quat
169173
X_dot_euler, X_euler, U_euler, Y_euler = symbolic_dynamics_euler(
170174
model_rotor_vel=model_rotor_vel,
171175
mass=mass,
@@ -178,6 +182,8 @@ def symbolic_dynamics(
178182
rpy_rates_coef=rpy_rates_coef,
179183
cmd_rpy_coef=cmd_rpy_coef,
180184
)
185+
symbols.rpy = _saved_rpy
186+
symbols.drpy = _saved_drpy
181187

182188
# States and Inputs
183189
X = cs.vertcat(symbols.pos, symbols.quat, symbols.vel, symbols.ang_vel)
@@ -203,7 +209,7 @@ def symbolic_dynamics(
203209
)
204210
quat_dot = 0.5 * (xi @ symbols.quat)
205211
ang_vel_dot = rotation.cs_rpy_rates_deriv2ang_vel_deriv(
206-
symbols.quat, symbols.drpy, X_dot_euler[9:12]
212+
symbols.quat, _drpy_quat, X_dot_euler[9:12]
207213
)
208214
if model_dist_t:
209215
# adding torque disturbances to the state

drone_models/so_rpy_rotor/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,12 @@ def symbolic_dynamics(
184184
* ``Y``: Output ``[pos(3), quat(4)]``.
185185
"""
186186
## We need to set the rpy and drpy symbols before building the euler model
187-
symbols.rpy = rotation.cs_quat2euler(symbols.quat)
188-
symbols.drpy = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel)
187+
_saved_rpy = symbols.rpy
188+
_saved_drpy = symbols.drpy
189+
_rpy_quat = rotation.cs_quat2euler(symbols.quat)
190+
_drpy_quat = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel)
191+
symbols.rpy = _rpy_quat
192+
symbols.drpy = _drpy_quat
189193
X_dot_euler, X_euler, U_euler, Y_euler = symbolic_dynamics_euler(
190194
model_rotor_vel=model_rotor_vel,
191195
mass=mass,
@@ -199,6 +203,8 @@ def symbolic_dynamics(
199203
rpy_rates_coef=rpy_rates_coef,
200204
cmd_rpy_coef=cmd_rpy_coef,
201205
)
206+
symbols.rpy = _saved_rpy
207+
symbols.drpy = _saved_drpy
202208

203209
# States and Inputs
204210
X = cs.vertcat(symbols.pos, symbols.quat, symbols.vel, symbols.ang_vel)
@@ -223,7 +229,7 @@ def symbolic_dynamics(
223229
)
224230
quat_dot = 0.5 * (xi @ symbols.quat)
225231
ang_vel_dot = rotation.cs_rpy_rates_deriv2ang_vel_deriv(
226-
symbols.quat, symbols.drpy, X_dot_euler[9:12]
232+
symbols.quat, _drpy_quat, X_dot_euler[9:12]
227233
)
228234
if model_dist_t:
229235
# adding torque disturbances to the state

drone_models/so_rpy_rotor_drag/model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,15 @@ def symbolic_dynamics(
199199
* ``U``: Input vector ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``.
200200
* ``Y``: Output ``[pos(3), quat(4)]``.
201201
"""
202-
# We need to set the rpy and drpy symbols before building the euler model
203-
symbols.rpy = rotation.cs_quat2euler(symbols.quat)
204-
symbols.drpy = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel)
202+
# Temporarily override rpy/drpy so symbolic_dynamics_euler uses quaternion-derived
203+
# expressions for this call. Restore them afterwards so subsequent calls to
204+
# symbolic_dynamics_euler still get the original leaf symbolic variables.
205+
_saved_rpy = symbols.rpy
206+
_saved_drpy = symbols.drpy
207+
_rpy_quat = rotation.cs_quat2euler(symbols.quat)
208+
_drpy_quat = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel)
209+
symbols.rpy = _rpy_quat
210+
symbols.drpy = _drpy_quat
205211
X_dot_euler, X_euler, U_euler, Y_euler = symbolic_dynamics_euler(
206212
model_rotor_vel=model_rotor_vel,
207213
mass=mass,
@@ -216,6 +222,8 @@ def symbolic_dynamics(
216222
cmd_rpy_coef=cmd_rpy_coef,
217223
drag_matrix=drag_matrix,
218224
)
225+
symbols.rpy = _saved_rpy
226+
symbols.drpy = _saved_drpy
219227

220228
# States and Inputs
221229
X = cs.vertcat(symbols.pos, symbols.quat, symbols.vel, symbols.ang_vel)
@@ -240,7 +248,7 @@ def symbolic_dynamics(
240248
)
241249
quat_dot = 0.5 * (xi @ symbols.quat)
242250
ang_vel_dot = rotation.cs_rpy_rates_deriv2ang_vel_deriv(
243-
symbols.quat, symbols.drpy, X_dot_euler[9:12]
251+
symbols.quat, _drpy_quat, X_dot_euler[9:12]
244252
)
245253
if model_dist_t:
246254
# adding torque disturbances to the state

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ pytest-markdown-docs = "*"
138138

139139
[tool.pixi.feature.tests.tasks]
140140
tests = { cmd = "pytest -v tests", description = "Run tests" }
141-
test-docs = { cmd = "pytest --markdown-docs --markdown-docs-syntax=superfences drone_models/ docs/ --ignore=docs/gen_ref_pages.py", env = { SCIPY_ARRAY_API = "1" }, description = "Run doctests for docstrings and documentation" }
141+
test-docs = { cmd = "pytest -v --markdown-docs --markdown-docs-syntax=superfences drone_models/ docs/ --ignore=docs/gen_ref_pages.py", env = { SCIPY_ARRAY_API = "1" }, description = "Run doctests for docstrings and documentation" }
142142

143143
[tool.pixi.feature.docs.dependencies]
144144
mkdocs = "*"

0 commit comments

Comments
 (0)