Skip to content

Commit e2f68a2

Browse files
committed
Add doctests
1 parent d0e7921 commit e2f68a2

File tree

16 files changed

+434
-89
lines changed

16 files changed

+434
-89
lines changed

.github/workflows/testing.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,19 @@ jobs:
2525
run: |
2626
pixi --version
2727
pixi run -e tests tests -v
28+
29+
test-docs:
30+
runs-on: ubuntu-latest
31+
steps:
32+
- uses: actions/checkout@v4
33+
- name: Setup Pixi
34+
uses: prefix-dev/setup-pixi@v0.9.3
35+
with:
36+
pixi-version: v0.61.0
37+
cache: true
38+
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
39+
environments: tests
40+
activate-environment: false
41+
locked: true
42+
- name: Run doctests
43+
run: pixi run -e tests test-docs

docs/examples/index.md

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ pos = np.zeros(3)
1717
quat = np.array([0., 0., 0., 1.]) # xyzw — identity (no rotation)
1818
vel = np.zeros(3)
1919
ang_vel = np.zeros(3)
20+
rotor_vel = np.ones(4) * 12_000.
2021
cmd = np.full(4, 15_000.) # motor RPMs
2122

2223
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
23-
pos, quat, vel, ang_vel, cmd
24+
pos, quat, vel, ang_vel, cmd, rotor_vel
2425
)
2526
```
2627

@@ -31,23 +32,40 @@ The outputs are the time derivatives of each state variable — the right-hand s
3132
In the previous example `rotor_vel_dot` is `None` because we didn't tell the model what speed the motors are currently at. The model just assumed they were already at the commanded RPM. That's a reasonable approximation for slow maneuvers, but real motors take time to spin up and down. Passing `rotor_vel` enables the rotor dynamics, which computes the acceleration of each motor toward its target.
3233

3334
```python
35+
import numpy as np
36+
from drone_models import parametrize
37+
from drone_models.first_principles import dynamics
38+
39+
model = parametrize(dynamics, drone_model="cf2x_L250")
40+
pos = np.zeros(3)
41+
quat = np.array([0., 0., 0., 1.])
42+
vel = np.zeros(3)
43+
ang_vel = np.zeros(3)
44+
cmd = np.full(4, 15_000.)
3445
# The motors are at 12 000 RPM but commanded to 15 000 — they're spinning up.
3546
rotor_vel = np.full(4, 12_000.)
3647

3748
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
3849
pos, quat, vel, ang_vel, cmd, rotor_vel=rotor_vel
3950
)
40-
print(rotor_vel_dot) # positive — rotors accelerating toward cmd
51+
rotor_vel_dot # positive — rotors accelerating toward cmd
4152
```
4253

4354
## Fitted models
4455

4556
The first-principles model requires individual motor RPMs as input, which means you need rotor-level commands. The fitted models — `so_rpy`, `so_rpy_rotor`, `so_rpy_rotor_drag` — take a higher-level command instead: roll, pitch, yaw setpoints in radians plus collective thrust in Newtons. This matches the command interface of typical flight controllers and makes them convenient for control design and system identification.
4657

4758
```python
48-
from drone_models.so_rpy_rotor_drag import dynamics as srrd
59+
import numpy as np
60+
from drone_models import parametrize
61+
from drone_models.so_rpy import dynamics
4962

50-
model = parametrize(srrd, drone_model="cf2x_L250")
63+
pos = np.zeros(3)
64+
quat = np.array([0., 0., 0., 1.])
65+
vel = np.zeros(3)
66+
ang_vel = np.zeros(3)
67+
68+
model = parametrize(dynamics, drone_model="cf2x_L250")
5169

5270
# Collective thrust near hover: mass * g ≈ 0.0319 * 9.81 ≈ 0.31 N
5371
cmd = np.array([0., 0., 0., 0.31]) # [roll_rad, pitch_rad, yaw_rad, thrust_N]
@@ -73,10 +91,11 @@ pos = torch.zeros(3)
7391
quat = torch.tensor([0., 0., 0., 1.])
7492
vel = torch.zeros(3)
7593
ang_vel = torch.zeros(3)
94+
rotor_vel = torch.ones(4) * 12_000.
7695
cmd = torch.full((4,), 15_000.)
7796

7897
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
79-
pos, quat, vel, ang_vel, cmd
98+
pos, quat, vel, ang_vel, cmd, rotor_vel
8099
)
81100
```
82101

@@ -85,6 +104,14 @@ pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
85104
The same model handles arbitrary leading batch dimensions — no special API, no loops. Add a leading dimension to all state and command arrays and the model evaluates all instances in a single call. This works identically across all backends.
86105

87106
```python
107+
import torch
108+
109+
from drone_models import parametrize
110+
from drone_models.first_principles import dynamics
111+
112+
# Parameters are stored as torch tensors — no per-call conversion needed.
113+
model = parametrize(dynamics, drone_model="cf2x_L250", xp=torch)
114+
88115
N = 1_000
89116

90117
pos = torch.zeros(N, 3)
@@ -97,7 +124,7 @@ rotor_vel = torch.full((N, 4), 15_000.)
97124
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
98125
pos, quat, vel, ang_vel, cmd, rotor_vel=rotor_vel
99126
)
100-
print(vel_dot.shape) # (1000, 3)
127+
vel_dot.shape # (1000, 3)
101128
```
102129

103130
## Overriding parameters at call time
@@ -117,12 +144,13 @@ pos = jnp.zeros(3)
117144
quat = jnp.array([0., 0., 0., 1.])
118145
vel = jnp.zeros(3)
119146
ang_vel = jnp.zeros(3)
147+
rotor_vel = jnp.ones(4) * 12_000.
120148
cmd = jnp.full((4,), 15_000.)
121149

122150
# The model was parametrized with mass=0.0319 kg.
123151
# Simulate the same drone carrying a 10 g payload for this one call:
124152
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
125-
pos, quat, vel, ang_vel, cmd, mass=jnp.float32(0.0419)
153+
pos, quat, vel, ang_vel, cmd, rotor_vel, mass=jnp.float32(0.0419)
126154
)
127155
```
128156

@@ -164,7 +192,7 @@ nominal_mass = model.keywords["mass"] # scalar
164192
nominal_J = model.keywords["J"] # (3, 3)
165193

166194
key, k1, k2 = jax.random.split(key, 3)
167-
mass_batch = nominal_mass * jax.random.uniform(k1, (N,), minval=0.9, maxval=1.1)
195+
mass_batch = nominal_mass * jax.random.uniform(k1, (N, 1), minval=0.9, maxval=1.1)
168196
J_batch = nominal_J * jax.random.uniform(k2, (N, 3, 3), minval=0.9, maxval=1.1)
169197
J_inv_batch = jnp.linalg.inv(J_batch)
170198

@@ -173,5 +201,5 @@ pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = step(
173201
pos, quat, vel, ang_vel, cmd, rotor_vel,
174202
mass_batch, J_batch, J_inv_batch,
175203
)
176-
print(vel_dot.shape) # (4096, 3)
204+
vel_dot.shape # (4096, 3)
177205
```

docs/get-started/installation.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,5 @@ If your drone is not listed, you can fit parameters from your own flight data us
8888

8989
```python
9090
from drone_models import available_models
91-
print(list(available_models))
92-
# ['first_principles', 'so_rpy', 'so_rpy_rotor', 'so_rpy_rotor_drag']
91+
list(available_models) # ['first_principles', 'so_rpy', 'so_rpy_rotor', 'so_rpy_rotor_drag']
9392
```

docs/get-started/quick-start.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ model = parametrize(dynamics, drone_model="cf2x_L250")
3333
pos = np.zeros(3) # [m]
3434
quat = np.array([0., 0., 0., 1.]) # xyzw — identity (no rotation)
3535
vel = np.zeros(3) # [m/s]
36+
rotor_vel = np.ones(4) * 12_000. # [RPM] — motors are spinning but not yet at the 15 000 RPM
3637
ang_vel = np.zeros(3) # [rad/s]
3738

3839
# Command: all four motors at 15 000 RPM (rough hover point for cf2x_L250).
3940
cmd = np.full(4, 15_000.) # [RPM]
4041

4142
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
42-
pos, quat, vel, ang_vel, cmd
43+
pos, quat, vel, ang_vel, cmd, rotor_vel
4344
)
4445
```
4546

@@ -64,13 +65,23 @@ These are the right-hand side of the continuous-time ODE $\dot{x} = f(x, u)$. To
6465
Real motors don't respond instantaneously to commands. Passing the current motor state as `rotor_vel` enables the rotor dynamics model, which computes how the motors accelerate or decelerate toward the commanded RPM.
6566

6667
```python
68+
import numpy as np
69+
from drone_models import parametrize
70+
from drone_models.first_principles import dynamics
71+
72+
model = parametrize(dynamics, drone_model="cf2x_L250")
73+
pos = np.zeros(3)
74+
quat = np.array([0., 0., 0., 1.])
75+
vel = np.zeros(3)
76+
ang_vel = np.zeros(3)
77+
cmd = np.full(4, 15_000.)
6778
# Current RPMs lag behind the 15 000 RPM command — motors are still spinning up.
6879
rotor_vel = np.full(4, 12_000.)
6980

7081
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
7182
pos, quat, vel, ang_vel, cmd, rotor_vel=rotor_vel
7283
)
73-
print(rotor_vel_dot) # positive — rotors accelerating toward cmd
84+
rotor_vel_dot # positive — rotors accelerating toward cmd
7485
```
7586

7687
## Next steps

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ pos = np.zeros(3)
1717
quat = np.array([0., 0., 0., 1.]) # xyzw, identity
1818
vel = np.zeros(3)
1919
ang_vel = np.zeros(3)
20+
rotor_vel = np.ones(4) * 12_000.
2021
cmd = np.full(4, 15_000.) # motor RPMs, near hover
2122

2223
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
23-
pos, quat, vel, ang_vel, cmd
24+
pos, quat, vel, ang_vel, cmd, rotor_vel
2425
)
2526
```
2627

docs/user-guide/batching.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ rotor_vel = jnp.full((N, 4), 15_000.)
2020
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
2121
pos, quat, vel, ang_vel, cmd, rotor_vel=rotor_vel
2222
)
23-
print(vel_dot.shape) # (1000, 3)
23+
vel_dot.shape # (1000, 3)
2424
```
2525

2626
A runnable version of this example is in [Examples: Batched evaluation](../examples/index.md#batched-evaluation).
@@ -30,15 +30,21 @@ A runnable version of this example is in [Examples: Batched evaluation](../examp
3030
Any number of leading dimensions works. A common pattern is a grid of environments, each containing multiple drones:
3131

3232
```python
33+
import jax.numpy as jnp
34+
from drone_models import parametrize
35+
from drone_models.first_principles import dynamics
36+
37+
model = parametrize(dynamics, drone_model="cf2x_L250", xp=jnp)
3338
# 50 environments, 20 drones each
3439
pos = jnp.zeros((50, 20, 3))
3540
quat = jnp.broadcast_to(jnp.array([0., 0., 0., 1.]), (50, 20, 4))
3641
vel = jnp.zeros((50, 20, 3))
3742
ang_vel = jnp.zeros((50, 20, 3))
43+
rotor_vel = jnp.full((50, 20, 4), 12_000.)
3844
cmd = jnp.full((50, 20, 4), 15_000.)
3945

40-
vel_dot, *_ = model(pos, quat, vel, ang_vel, cmd)
41-
print(vel_dot.shape) # (50, 20, 3)
46+
vel_dot, *_ = model(pos, quat, vel, ang_vel, cmd, rotor_vel)
47+
vel_dot.shape # (50, 20, 3)
4248
```
4349

4450
## Domain randomization
@@ -56,6 +62,10 @@ from drone_models.first_principles import dynamics
5662
N = 4_096
5763
key = jax.random.PRNGKey(0)
5864

65+
pos, vel, ang_vel = jnp.zeros((N, 3)), jnp.zeros((N, 3)), jnp.zeros((N, 3))
66+
quat = jnp.tile(jnp.array([0., 0., 0., 1.]), (N, 1))
67+
cmd = jnp.full((N, 4), 15_000.)
68+
rotor_vel = jnp.full((N, 4), 15_000.)
5969
model = parametrize(dynamics, drone_model="cf2x_L250", xp=jnp)
6070
nominal_mass = model.keywords["mass"]
6171
nominal_J = model.keywords["J"]
@@ -68,7 +78,7 @@ def step(pos, quat, vel, ang_vel, cmd, rotor_vel, mass, J, J_inv):
6878
)
6979

7080
key, k1, k2 = jax.random.split(key, 3)
71-
mass_batch = nominal_mass * jax.random.uniform(k1, (N,), minval=0.9, maxval=1.1)
81+
mass_batch = nominal_mass * jax.random.uniform(k1, (N, 1), minval=0.9, maxval=1.1)
7282
J_batch = nominal_J * jax.random.uniform(k2, (N, 3, 3), minval=0.9, maxval=1.1)
7383
J_inv_batch = jnp.linalg.inv(J_batch)
7484

@@ -78,7 +88,7 @@ vel_dot = step(pos, quat, vel, ang_vel, cmd, rotor_vel,
7888

7989
**Option 2 — mutate `model.keywords` directly.** Simpler when you don't need JIT or are happy to retrace. Replace a scalar parameter with a `(N,)` array and each element in the batch uses its own value.
8090

81-
```python
91+
```{ .python notest }
8292
model.keywords["mass"] = nominal_mass * mass_batch # shape (N,)
8393
vel_dot = model(pos, quat, vel, ang_vel, cmd)[2]
8494
```

docs/user-guide/models.md

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,17 @@ The full rigid-body physics model. The command is four individual motor RPMs. Th
2727
Working at the rotor-velocity level means you need a controller that converts higher-level commands — position setpoints, attitude + collective thrust — down to individual motor RPMs. [drone-controllers](https://utiasdsl.github.io/drone-controllers/) provides a matching set of controllers designed for exactly this interface.
2828

2929
```python
30+
import numpy as np
3031
from drone_models import parametrize
3132
from drone_models.first_principles import dynamics
3233

3334
model = parametrize(dynamics, drone_model="cf2x_L250")
3435

36+
pos, vel, ang_vel = np.zeros((3,)), np.zeros((3,)), np.zeros((3,))
37+
quat = np.array([0., 0., 0., 1.])
38+
cmd = np.full((4,), 15_000.)
39+
rotor_vel = np.full((4,), 12_000.)
40+
3541
pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = model(
3642
pos, quat, vel, ang_vel,
3743
cmd, # shape (4,) — motor RPMs
@@ -45,7 +51,7 @@ See the [`first_principles` API reference](../reference/drone_models/first_princ
4551

4652
A fitted second-order model where the command is `[roll_rad, pitch_rad, yaw_rad, thrust_N]` — the same interface used by most flight controller firmware. First-order thrust dynamics model motor spin-up delay, and a linear body-frame drag term accounts for aerodynamic resistance. All coefficients are identified from flight data rather than derived from physics, which makes the model easy to calibrate and well-suited to real-time control.
4753

48-
```python
54+
```{ .python notest }
4955
from drone_models.so_rpy_rotor_drag import dynamics
5056
5157
model = parametrize(dynamics, drone_model="cf2x_L250")
@@ -78,6 +84,14 @@ from drone_models.so_rpy import dynamics
7884
All four models accept optional `dist_f` (external force, world frame, N) and `dist_t` (external torque, body frame, N·m) arguments. These are useful for modelling wind, contact forces, or other perturbations without modifying the model itself.
7985

8086
```python
87+
import numpy as np
88+
from drone_models import parametrize
89+
from drone_models.so_rpy import dynamics
90+
91+
model = parametrize(dynamics, drone_model="cf2x_L250")
92+
pos, vel, ang_vel = np.zeros((3,)), np.zeros((3,)), np.zeros((3,))
93+
quat = np.array([0., 0., 0., 1.])
94+
cmd = np.array([0., 0., 0., 0.31])
8195
dist_f = np.array([0.05, 0., 0.]) # 50 mN headwind [N]
8296
dist_t = np.zeros(3)
8397

@@ -95,8 +109,8 @@ from drone_models import model_features
95109
from drone_models.first_principles import dynamics as fp
96110
from drone_models.so_rpy import dynamics as srpy
97111

98-
print(model_features(fp)) # {'rotor_dynamics': True}
99-
print(model_features(srpy)) # {'rotor_dynamics': False}
112+
model_features(fp) # {'rotor_dynamics': True}
113+
model_features(srpy) # {'rotor_dynamics': False}
100114
```
101115

102116
---

0 commit comments

Comments
 (0)