Skip to content

Commit 327a3da

Browse files
authored
Merge pull request #183 from stan-dev/convert-odes
Case study for converting old ODE code to new ODE code (design-doc #19)
2 parents 74dbca4 + 724dc77 commit 327a3da

File tree

2 files changed

+1027
-0
lines changed

2 files changed

+1027
-0
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
---
2+
title: "Upgrading to the new ODE interface"
3+
author: "Ben Bales & Sebastian Weber"
4+
date: "19 July 2020"
5+
output: html_document
6+
---
7+
8+
# Introduction
9+
10+
Cmdstan 2.24 introduces a new ODE interface intended to make it easier to
11+
specify the ODE RHS by avoiding packing and unpacking schemes required with
12+
the old interface.
13+
14+
Stan solves for $y(t, \theta)$ at a sequence of times $t_1, t_2, \cdots, t_N$
15+
in the ODE initial value problem defined by:
16+
17+
$$
18+
y(t, \theta)' = f(t, y, \theta)\\
19+
y(t = t_0, \theta) = y_0
20+
$$
21+
22+
For notation, $y(t, \theta)$ is the state, $f(t, y, \theta)$ is the ODE system
23+
function, and $y_0$ and $t_0$ are the initial conditions.
24+
25+
The solution, $y(t, \theta)$, is written explicitly in terms of $\theta$
26+
as a reminder that the solution of an ODE initial value problem can be a
27+
function of model parameters or data. This is the usual use case: an ODE
28+
is parameterized by a set of parameters that are to be estimated.
29+
30+
Specifying an ODE initial value problem in Stan involves writing a Stan
31+
function for the ODE system function, $f(t, y, \theta)$. The big difference
32+
in the old and new ODE interfaces is that previously any arguments meant for
33+
the ODE system function had to be manually packed and unpacked from special
34+
argument arrays that were passed from the ODE solve function call to the ODE
35+
system function -- in the new interface none of this packing and unpacking
36+
is necessary.
37+
38+
The new interface also uses `vector` variables for the state rather than
39+
`real[]` variables.
40+
41+
As an example, in the old solver interface the system function for the ODE
42+
$y' = -\alpha y$ would be written:
43+
44+
```{stan, output.var = "", eval = FALSE}
45+
functions {
46+
real[] rhs(real t, real[] y, real[] theta, real[] x_r, int[] x_i) {
47+
real alpha = theta[1];
48+
49+
real yp[1] = { -alpha * y[1] };
50+
51+
return yp;
52+
}
53+
}
54+
```
55+
56+
In the new interface the system function can be written:
57+
58+
```{stan, output.var = "", eval = FALSE}
59+
functions {
60+
vector rhs(real t, vector y, real alpha) {
61+
vector[1] yp = -alpha * y;
62+
63+
return yp;
64+
}
65+
}
66+
```
67+
68+
The new interface avoids any unused arguments (such as `x_r`, and `x_i` in
69+
this example), and the parameter `alpha` can be passed directly instead of
70+
being packed into `theta`.
71+
72+
For a simple function, this does not look like much, but for more
73+
complicated models with numerous arguments of different types, the packing
74+
and unpacking is tedious and error prone. This leads to models that are
75+
difficult to debug and difficult to iterate on.
76+
77+
# New Interface
78+
79+
The new interface introduces six new functions:
80+
81+
`ode_bdf`, `ode_adams`, `ode_rk45` and `ode_bdf_tol`,`ode_adams_tol`,
82+
`ode_rk45_tol`
83+
84+
The solvers in the first columns have default tolerance settings. The solvers
85+
in the second column accept arguments for relative tolerance, absolute
86+
tolerance, and the maximum number of steps to take between output times.
87+
88+
This is different from the old interface where tolerances are presented
89+
through using the same function name with a few more arguments.
90+
91+
To make it easier to write ODEs, the solve functions take extra arguments
92+
that are passed along unmodified to the user-supplied system function.
93+
Because there can be any number of these arguments and they can be of
94+
different types, they are denoted below as `...`. The types of the
95+
arguments represented by `...` in the ODE solve function call must match
96+
the types of the arguments represented by `...` in the user-supplied system
97+
function.
98+
99+
The new `ode_bdf` solver interface is (the interfaces for `ode_adams` and
100+
`ode_rk45` are the same):
101+
102+
```{stan, output.var = "", eval = FALSE}
103+
vector[] ode_bdf(F f, vector y0, real t0, real[] times, ...)
104+
```
105+
106+
The arguments are:
107+
108+
1. ```f``` - ODE system function
109+
110+
2. ```y0``` - Initial state of the ODE
111+
112+
3. ```t0``` - Initial time of the ODE
113+
114+
4. ```times``` - Sorted array of times to which the ode will be solved (each
115+
element must be greater than t0, but times do not need to be strictly
116+
increasing)
117+
118+
5. ```...``` - Sequence of arguments passed unmodified to the ODE system
119+
function. There can be any number of ```...``` arguments, and the ```...```
120+
arguments can be any type, but they must match the types of the corresponding
121+
```...``` arguments of ```f```.
122+
123+
The ODE system function should take the form:
124+
125+
```{stan, output.var = "", eval = FALSE}
126+
vector f(real t, vector y, ...)
127+
```
128+
129+
The arguments are:
130+
131+
1. ```t``` - Time at which to evaluate the ODE system function
132+
133+
2. ```y``` - State at which to evaluate the ODE system function
134+
135+
3. ```...``` - Sequence of arguments passed unmodified from the ODE solver
136+
function call. The ```...``` must match the types of the corresponding
137+
```...``` arguments of the ODE solver function call.
138+
139+
A call to `ode_bdf` returns the solution of the ODE specified by the system
140+
function (`f`) and the initial conditions (`y0` and `t0`) at the time points given
141+
by the `times` argument. The solution is given by an array of vectors.
142+
143+
The `ode_bdf_tol` interface is (the interfaces for `ode_rk45_tol`
144+
and `ode_adams_tol` are the same):
145+
146+
```{stan, output.var = "", eval = FALSE}
147+
vector[] ode_bdf_tol(F f, vector y0, real t0, real[] times,
148+
real rel_tol, real abs_tol, int max_num_steps, ...)
149+
```
150+
151+
The arguments are:
152+
1. ```f``` - ODE system function
153+
154+
2. ```y0``` - Initial state of the ODE
155+
156+
3. ```t0``` - Initial time of the ODE
157+
158+
4. ```times``` - Sorted array of times to which the ode will be solved (each
159+
element must be greater than t0, but times do not need to be strictly
160+
increasing)
161+
162+
5. ```rel_tol``` - Relative tolerance for solve (data)
163+
164+
6. ```abs_tol``` - Absolute tolerance for solve (data)
165+
166+
7. ```max_num_steps``` - Maximum number of timesteps to take in integrating
167+
the ODE solution between output time points (data)
168+
169+
5. ```...``` - Sequence of arguments passed unmodified to the ODE system
170+
function. There can be any number of ```...``` arguments, and the ```...```
171+
arguments can be any type, but they must match the types of the corresponding
172+
```...``` arguments of ```f```.
173+
174+
The `ode_rk45`/`ode_bdf`/`ode_adams` interfaces are just wrappers around the
175+
`ode_rk45_tol`/`ode_bdf_tol`/`ode_adams_tol` interfaces with defaults for
176+
`rel_tol`, `abs_tol`, and `max_num_steps`. For the RK45 solver the defaults
177+
are $10^{-6}$ for `rel_tol` and `abs_tol` and $10^6$ for `max_num_steps`.
178+
For the BDF/Adams solvers the defaults are $10^{-10}$ for `rel_tol` and
179+
`abs_tol` and $10^8$ for `max_num_steps`.
180+
181+
For more detailed information about either interface, look at the function
182+
reference guide:
183+
[New interface](https://mc-stan.org/docs/2_24/functions-reference/functions-ode-solver.html),
184+
[Old interface](https://mc-stan.org/docs/2_24/functions-reference/functions-old-ode-solver.html)
185+
186+
# Example Models
187+
188+
The two models here come from the Stan
189+
[Statistical Computation Benchmarks](https://github.com/stan-dev/stat_comp_benchmarks).
190+
191+
## SIR Model
192+
193+
### ODE System Function
194+
195+
In the old SIR system function, `beta`, `kappa`, `gamma`, `xi`, and `delta`,
196+
are packed into the `real[] theta` argument. `kappa` isn't actually a model
197+
parameter so it is not clear why it is packed in with the other parameters,
198+
but it is. Promoting `kappa` to a parameter causes there to be more states
199+
in the extended ODE sensitivity system (used to get gradients of the ODE
200+
with respect to inputs). Adding states to the sensitivity system makes the
201+
ODE harder to solve and should always be avoided. The ODE system function looks
202+
like:
203+
204+
```{stan, output.var = "", eval = FALSE}
205+
functions {
206+
// theta[1] = beta, water contact rate
207+
// theta[2] = kappa, C_{50}
208+
// theta[3] = gamma, recovery rate
209+
// theta[4] = xi, bacteria production rate
210+
// theta[5] = delta, bacteria removal rate
211+
real[] simple_SIR(real t,
212+
real[] y,
213+
real[] theta,
214+
real[] x_r,
215+
int[] x_i) {
216+
real dydt[4];
217+
218+
dydt[1] = - theta[1] * y[4] / (y[4] + theta[2]) * y[1];
219+
dydt[2] = theta[1] * y[4] / (y[4] + theta[2]) * y[1] - theta[3] * y[2];
220+
dydt[3] = theta[3] * y[2];
221+
dydt[4] = theta[4] * y[2] - theta[5] * y[4];
222+
223+
return dydt;
224+
}
225+
}
226+
```
227+
228+
For comparison, with the new interface the ODE system function can be
229+
rewritten to explicitly name all the parameters. No separation of data
230+
and parameters is necessary either -- the solver will not add more
231+
equations for arguments that are defined in the `data` and
232+
`transformed data` blocks. The state variables in the new model are also
233+
represented by `vector` variables instead of `real[]` variables.
234+
The new ODE system function is:
235+
236+
```{stan, output.var = "", eval = FALSE}
237+
functions {
238+
vector simple_SIR(real t,
239+
vector y,
240+
real beta, // water contact rate
241+
real kappa, // C_{50}
242+
real gamma, // recovery rate
243+
real xi, // bacteria production rate
244+
real delta) { // bacteria removal rate
245+
vector[4] dydt;
246+
247+
dydt[1] = -beta * y[4] / (y[4] + kappa) * y[1];
248+
dydt[2] = beta * y[4] / (y[4] + kappa) * y[1] - gamma * y[2];
249+
dydt[3] = gamma * y[2];
250+
dydt[4] = xi * y[2] - delta * y[4];
251+
252+
return dydt;
253+
}
254+
}
255+
```
256+
257+
### Calling the ODE Solver
258+
259+
In the old ODE interface, the parameters are all packed into a `real[]` array
260+
before calling the ODE solver:
261+
262+
```{stan, output.var = "", eval = FALSE}
263+
transformed parameters {
264+
real<lower=0> y[N_t, 4];
265+
{
266+
real theta[5] = {beta, kappa, gamma, xi, delta};
267+
y = integrate_ode_rk45(simple_SIR, y0, t0, t, theta, x_r, x_i);
268+
}
269+
}
270+
```
271+
272+
In the new ODE interface each of the arguments is appended on to the ODE
273+
solver function call. The RK45 ODE solver with default tolerances is called
274+
`ode_rk45`. Because the states are handled as `vector` variables, the solver
275+
output is an array of vectors (`vector[]`).
276+
277+
```{stan, output.var = "", eval = FALSE}
278+
transformed parameters {
279+
vector<lower=0>[4] y[N_t] = ode_rk45(simple_SIR, y0, t0, t, beta, kappa, gamma, xi, delta);
280+
}
281+
```
282+
283+
### Full Model
284+
285+
The full model with the new interface can be found
286+
[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/feature/variadic-odes/benchmarks/sir/sir.stan);
287+
the full model with the old interface can be found
288+
[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/sir/sir.stan),
289+
and the data can be found
290+
[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/sir/sir.data.R).
291+
292+
## PKPD Model
293+
294+
### ODE System Function
295+
296+
In the old system function, parameters are manually unpacked from `theta` and
297+
`x_r`:
298+
299+
```{stan, output.var = "", eval = FALSE}
300+
functions {
301+
real[] one_comp_mm_elim_abs(real t,
302+
real[] y,
303+
real[] theta,
304+
real[] x_r,
305+
int[] x_i) {
306+
real dydt[1];
307+
real k_a = theta[1]; // Dosing rate in 1/day
308+
real K_m = theta[2]; // Michaelis-Menten constant in mg/L
309+
real V_m = theta[3]; // Maximum elimination rate in 1/day
310+
real D = x_r[1];
311+
real V = x_r[2];
312+
real dose = 0;
313+
real elim = (V_m / V) * y[1] / (K_m + y[1]);
314+
315+
if (t > 0)
316+
dose = exp(- k_a * t) * D * k_a / V;
317+
318+
dydt[1] = dose - elim;
319+
320+
return dydt;
321+
}
322+
}
323+
```
324+
325+
In the new interface, they are passed directly and so the unpacking is avoided:
326+
327+
```{stan, output.var = "", eval = FALSE}
328+
functions {
329+
vector one_comp_mm_elim_abs(real t,
330+
vector y,
331+
real k_a, // Dosing rate in 1/day
332+
real K_m, // Michaelis-Menten constant in mg/L
333+
real V_m, // Maximum elimination rate in 1/day
334+
real D,
335+
real V) {
336+
vector[1] dydt;
337+
338+
real dose = 0;
339+
real elim = (V_m / V) * y[1] / (K_m + y[1]);
340+
341+
if (t > 0)
342+
dose = exp(- k_a * t) * D * k_a / V;
343+
344+
dydt[1] = dose - elim;
345+
346+
return dydt;
347+
}
348+
}
349+
```
350+
351+
### Calling the ODE Solver
352+
353+
In the old interface the `theta` and `x_r` arguments are packed manually, and
354+
the `x_i` argument is required even though it isn't used:
355+
356+
```{stan, output.var = "", eval = FALSE}
357+
transformed data {
358+
real x_r[2] = {D, V};
359+
int x_i[0];
360+
}
361+
...
362+
transformed parameters {
363+
real C[N_t, 1];
364+
{
365+
real theta[3] = {k_a, K_m, V_m};
366+
C = integrate_ode_bdf(one_comp_mm_elim_abs, C0, t0, times, theta, x_r, x_i);
367+
}
368+
}
369+
```
370+
371+
In the new interface the arguments are simply passed at the end of the
372+
`ode_bdf` call (and the `transformed data` block removed):
373+
374+
```{stan, output.var = "", eval = FALSE}
375+
transformed parameters {
376+
vector[1] C[N_t] = ode_bdf(one_comp_mm_elim_abs, C0, t0, times, k_a, K_m, V_m, D, V);
377+
}
378+
```
379+
### Full Model
380+
381+
The full model with the new interface can be found
382+
[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/feature/variadic-odes/benchmarks/one_comp_mm_elim_abs/one_comp_mm_elim_abs.stan);
383+
the full model with the old interface can be found
384+
[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/pkpd/one_comp_mm_elim_abs.stan),
385+
and the data can be found
386+
[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/pkpd/one_comp_mm_elim_abs.data.R).

0 commit comments

Comments
 (0)