|
| 1 | +--- |
| 2 | +title: "Upgrading to the new ODE interface" |
| 3 | +author: "Ben Bales & Sebastian Weber" |
| 4 | +date: "19 July 2020" |
| 5 | +output: html_document |
| 6 | +--- |
| 7 | + |
| 8 | +# Introduction |
| 9 | + |
| 10 | +Cmdstan 2.24 introduces a new ODE interface intended to make it easier to |
| 11 | +specify the ODE RHS by avoiding packing and unpacking schemes required with |
| 12 | +the old interface. |
| 13 | + |
| 14 | +Stan solves for $y(t, \theta)$ at a sequence of times $t_1, t_2, \cdots, t_N$ |
| 15 | +in the ODE initial value problem defined by: |
| 16 | + |
| 17 | +$$ |
| 18 | +y(t, \theta)' = f(t, y, \theta)\\ |
| 19 | +y(t = t_0, \theta) = y_0 |
| 20 | +$$ |
| 21 | + |
| 22 | +For notation, $y(t, \theta)$ is the state, $f(t, y, \theta)$ is the ODE system |
| 23 | +function, and $y_0$ and $t_0$ are the initial conditions. |
| 24 | + |
| 25 | +The solution, $y(t, \theta)$, is written explicitly in terms of $\theta$ |
| 26 | +as a reminder that the solution of an ODE initial value problem can be a |
| 27 | +function of model parameters or data. This is the usual use case: an ODE |
| 28 | +is parameterized by a set of parameters that are to be estimated. |
| 29 | + |
| 30 | +Specifying an ODE initial value problem in Stan involves writing a Stan |
| 31 | +function for the ODE system function, $f(t, y, \theta)$. The big difference |
| 32 | +in the old and new ODE interfaces is that previously any arguments meant for |
| 33 | +the ODE system function had to be manually packed and unpacked from special |
| 34 | +argument arrays that were passed from the ODE solve function call to the ODE |
| 35 | +system function -- in the new interface none of this packing and unpacking |
| 36 | +is necessary. |
| 37 | + |
| 38 | +The new interface also uses `vector` variables for the state rather than |
| 39 | +`real[]` variables. |
| 40 | + |
| 41 | +As an example, in the old solver interface the system function for the ODE |
| 42 | +$y' = -\alpha y$ would be written: |
| 43 | + |
| 44 | +```{stan, output.var = "", eval = FALSE} |
| 45 | +functions { |
| 46 | + real[] rhs(real t, real[] y, real[] theta, real[] x_r, int[] x_i) { |
| 47 | + real alpha = theta[1]; |
| 48 | +
|
| 49 | + real yp[1] = { -alpha * y[1] }; |
| 50 | +
|
| 51 | + return yp; |
| 52 | + } |
| 53 | +} |
| 54 | +``` |
| 55 | + |
| 56 | +In the new interface the system function can be written: |
| 57 | + |
| 58 | +```{stan, output.var = "", eval = FALSE} |
| 59 | +functions { |
| 60 | + vector rhs(real t, vector y, real alpha) { |
| 61 | + vector[1] yp = -alpha * y; |
| 62 | +
|
| 63 | + return yp; |
| 64 | + } |
| 65 | +} |
| 66 | +``` |
| 67 | + |
| 68 | +The new interface avoids any unused arguments (such as `x_r`, and `x_i` in |
| 69 | +this example), and the parameter `alpha` can be passed directly instead of |
| 70 | +being packed into `theta`. |
| 71 | + |
| 72 | +For a simple function, this does not look like much, but for more |
| 73 | +complicated models with numerous arguments of different types, the packing |
| 74 | +and unpacking is tedious and error prone. This leads to models that are |
| 75 | +difficult to debug and difficult to iterate on. |
| 76 | + |
| 77 | +# New Interface |
| 78 | + |
| 79 | +The new interface introduces six new functions: |
| 80 | + |
| 81 | +`ode_bdf`, `ode_adams`, `ode_rk45` and `ode_bdf_tol`,`ode_adams_tol`, |
| 82 | +`ode_rk45_tol` |
| 83 | + |
| 84 | +The solvers in the first columns have default tolerance settings. The solvers |
| 85 | +in the second column accept arguments for relative tolerance, absolute |
| 86 | +tolerance, and the maximum number of steps to take between output times. |
| 87 | + |
| 88 | +This is different from the old interface where tolerances are presented |
| 89 | +through using the same function name with a few more arguments. |
| 90 | + |
| 91 | +To make it easier to write ODEs, the solve functions take extra arguments |
| 92 | +that are passed along unmodified to the user-supplied system function. |
| 93 | +Because there can be any number of these arguments and they can be of |
| 94 | +different types, they are denoted below as `...`. The types of the |
| 95 | +arguments represented by `...` in the ODE solve function call must match |
| 96 | +the types of the arguments represented by `...` in the user-supplied system |
| 97 | +function. |
| 98 | + |
| 99 | +The new `ode_bdf` solver interface is (the interfaces for `ode_adams` and |
| 100 | +`ode_rk45` are the same): |
| 101 | + |
| 102 | +```{stan, output.var = "", eval = FALSE} |
| 103 | +vector[] ode_bdf(F f, vector y0, real t0, real[] times, ...) |
| 104 | +``` |
| 105 | + |
| 106 | +The arguments are: |
| 107 | + |
| 108 | +1. ```f``` - ODE system function |
| 109 | + |
| 110 | +2. ```y0``` - Initial state of the ODE |
| 111 | + |
| 112 | +3. ```t0``` - Initial time of the ODE |
| 113 | + |
| 114 | +4. ```times``` - Sorted array of times to which the ode will be solved (each |
| 115 | + element must be greater than t0, but times do not need to be strictly |
| 116 | + increasing) |
| 117 | + |
| 118 | +5. ```...``` - Sequence of arguments passed unmodified to the ODE system |
| 119 | +function. There can be any number of ```...``` arguments, and the ```...``` |
| 120 | +arguments can be any type, but they must match the types of the corresponding |
| 121 | +```...``` arguments of ```f```. |
| 122 | + |
| 123 | +The ODE system function should take the form: |
| 124 | + |
| 125 | +```{stan, output.var = "", eval = FALSE} |
| 126 | +vector f(real t, vector y, ...) |
| 127 | +``` |
| 128 | + |
| 129 | +The arguments are: |
| 130 | + |
| 131 | +1. ```t``` - Time at which to evaluate the ODE system function |
| 132 | + |
| 133 | +2. ```y``` - State at which to evaluate the ODE system function |
| 134 | + |
| 135 | +3. ```...``` - Sequence of arguments passed unmodified from the ODE solver |
| 136 | +function call. The ```...``` must match the types of the corresponding |
| 137 | +```...``` arguments of the ODE solver function call. |
| 138 | + |
| 139 | +A call to `ode_bdf` returns the solution of the ODE specified by the system |
| 140 | +function (`f`) and the initial conditions (`y0` and `t0`) at the time points given |
| 141 | +by the `times` argument. The solution is given by an array of vectors. |
| 142 | + |
| 143 | +The `ode_bdf_tol` interface is (the interfaces for `ode_rk45_tol` |
| 144 | +and `ode_adams_tol` are the same): |
| 145 | + |
| 146 | +```{stan, output.var = "", eval = FALSE} |
| 147 | +vector[] ode_bdf_tol(F f, vector y0, real t0, real[] times, |
| 148 | + real rel_tol, real abs_tol, int max_num_steps, ...) |
| 149 | +``` |
| 150 | + |
| 151 | +The arguments are: |
| 152 | +1. ```f``` - ODE system function |
| 153 | + |
| 154 | +2. ```y0``` - Initial state of the ODE |
| 155 | + |
| 156 | +3. ```t0``` - Initial time of the ODE |
| 157 | + |
| 158 | +4. ```times``` - Sorted array of times to which the ode will be solved (each |
| 159 | + element must be greater than t0, but times do not need to be strictly |
| 160 | + increasing) |
| 161 | + |
| 162 | +5. ```rel_tol``` - Relative tolerance for solve (data) |
| 163 | + |
| 164 | +6. ```abs_tol``` - Absolute tolerance for solve (data) |
| 165 | + |
| 166 | +7. ```max_num_steps``` - Maximum number of timesteps to take in integrating |
| 167 | + the ODE solution between output time points (data) |
| 168 | + |
| 169 | +5. ```...``` - Sequence of arguments passed unmodified to the ODE system |
| 170 | +function. There can be any number of ```...``` arguments, and the ```...``` |
| 171 | +arguments can be any type, but they must match the types of the corresponding |
| 172 | +```...``` arguments of ```f```. |
| 173 | + |
| 174 | +The `ode_rk45`/`ode_bdf`/`ode_adams` interfaces are just wrappers around the |
| 175 | +`ode_rk45_tol`/`ode_bdf_tol`/`ode_adams_tol` interfaces with defaults for |
| 176 | +`rel_tol`, `abs_tol`, and `max_num_steps`. For the RK45 solver the defaults |
| 177 | +are $10^{-6}$ for `rel_tol` and `abs_tol` and $10^6$ for `max_num_steps`. |
| 178 | +For the BDF/Adams solvers the defaults are $10^{-10}$ for `rel_tol` and |
| 179 | +`abs_tol` and $10^8$ for `max_num_steps`. |
| 180 | + |
| 181 | +For more detailed information about either interface, look at the function |
| 182 | +reference guide: |
| 183 | +[New interface](https://mc-stan.org/docs/2_24/functions-reference/functions-ode-solver.html), |
| 184 | +[Old interface](https://mc-stan.org/docs/2_24/functions-reference/functions-old-ode-solver.html) |
| 185 | + |
| 186 | +# Example Models |
| 187 | + |
| 188 | +The two models here come from the Stan |
| 189 | +[Statistical Computation Benchmarks](https://github.com/stan-dev/stat_comp_benchmarks). |
| 190 | + |
| 191 | +## SIR Model |
| 192 | + |
| 193 | +### ODE System Function |
| 194 | + |
| 195 | +In the old SIR system function, `beta`, `kappa`, `gamma`, `xi`, and `delta`, |
| 196 | +are packed into the `real[] theta` argument. `kappa` isn't actually a model |
| 197 | +parameter so it is not clear why it is packed in with the other parameters, |
| 198 | +but it is. Promoting `kappa` to a parameter causes there to be more states |
| 199 | +in the extended ODE sensitivity system (used to get gradients of the ODE |
| 200 | +with respect to inputs). Adding states to the sensitivity system makes the |
| 201 | +ODE harder to solve and should always be avoided. The ODE system function looks |
| 202 | +like: |
| 203 | + |
| 204 | +```{stan, output.var = "", eval = FALSE} |
| 205 | +functions { |
| 206 | + // theta[1] = beta, water contact rate |
| 207 | + // theta[2] = kappa, C_{50} |
| 208 | + // theta[3] = gamma, recovery rate |
| 209 | + // theta[4] = xi, bacteria production rate |
| 210 | + // theta[5] = delta, bacteria removal rate |
| 211 | + real[] simple_SIR(real t, |
| 212 | + real[] y, |
| 213 | + real[] theta, |
| 214 | + real[] x_r, |
| 215 | + int[] x_i) { |
| 216 | + real dydt[4]; |
| 217 | +
|
| 218 | + dydt[1] = - theta[1] * y[4] / (y[4] + theta[2]) * y[1]; |
| 219 | + dydt[2] = theta[1] * y[4] / (y[4] + theta[2]) * y[1] - theta[3] * y[2]; |
| 220 | + dydt[3] = theta[3] * y[2]; |
| 221 | + dydt[4] = theta[4] * y[2] - theta[5] * y[4]; |
| 222 | +
|
| 223 | + return dydt; |
| 224 | + } |
| 225 | +} |
| 226 | +``` |
| 227 | + |
| 228 | +For comparison, with the new interface the ODE system function can be |
| 229 | +rewritten to explicitly name all the parameters. No separation of data |
| 230 | +and parameters is necessary either -- the solver will not add more |
| 231 | +equations for arguments that are defined in the `data` and |
| 232 | +`transformed data` blocks. The state variables in the new model are also |
| 233 | +represented by `vector` variables instead of `real[]` variables. |
| 234 | +The new ODE system function is: |
| 235 | + |
| 236 | +```{stan, output.var = "", eval = FALSE} |
| 237 | +functions { |
| 238 | + vector simple_SIR(real t, |
| 239 | + vector y, |
| 240 | + real beta, // water contact rate |
| 241 | + real kappa, // C_{50} |
| 242 | + real gamma, // recovery rate |
| 243 | + real xi, // bacteria production rate |
| 244 | + real delta) { // bacteria removal rate |
| 245 | + vector[4] dydt; |
| 246 | +
|
| 247 | + dydt[1] = -beta * y[4] / (y[4] + kappa) * y[1]; |
| 248 | + dydt[2] = beta * y[4] / (y[4] + kappa) * y[1] - gamma * y[2]; |
| 249 | + dydt[3] = gamma * y[2]; |
| 250 | + dydt[4] = xi * y[2] - delta * y[4]; |
| 251 | +
|
| 252 | + return dydt; |
| 253 | + } |
| 254 | +} |
| 255 | +``` |
| 256 | + |
| 257 | +### Calling the ODE Solver |
| 258 | + |
| 259 | +In the old ODE interface, the parameters are all packed into a `real[]` array |
| 260 | +before calling the ODE solver: |
| 261 | + |
| 262 | +```{stan, output.var = "", eval = FALSE} |
| 263 | +transformed parameters { |
| 264 | + real<lower=0> y[N_t, 4]; |
| 265 | + { |
| 266 | + real theta[5] = {beta, kappa, gamma, xi, delta}; |
| 267 | + y = integrate_ode_rk45(simple_SIR, y0, t0, t, theta, x_r, x_i); |
| 268 | + } |
| 269 | +} |
| 270 | +``` |
| 271 | + |
| 272 | +In the new ODE interface each of the arguments is appended on to the ODE |
| 273 | +solver function call. The RK45 ODE solver with default tolerances is called |
| 274 | +`ode_rk45`. Because the states are handled as `vector` variables, the solver |
| 275 | +output is an array of vectors (`vector[]`). |
| 276 | + |
| 277 | +```{stan, output.var = "", eval = FALSE} |
| 278 | +transformed parameters { |
| 279 | + vector<lower=0>[4] y[N_t] = ode_rk45(simple_SIR, y0, t0, t, beta, kappa, gamma, xi, delta); |
| 280 | +} |
| 281 | +``` |
| 282 | + |
| 283 | +### Full Model |
| 284 | + |
| 285 | +The full model with the new interface can be found |
| 286 | +[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/feature/variadic-odes/benchmarks/sir/sir.stan); |
| 287 | +the full model with the old interface can be found |
| 288 | +[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/sir/sir.stan), |
| 289 | +and the data can be found |
| 290 | +[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/sir/sir.data.R). |
| 291 | + |
| 292 | +## PKPD Model |
| 293 | + |
| 294 | +### ODE System Function |
| 295 | + |
| 296 | +In the old system function, parameters are manually unpacked from `theta` and |
| 297 | +`x_r`: |
| 298 | + |
| 299 | +```{stan, output.var = "", eval = FALSE} |
| 300 | +functions { |
| 301 | + real[] one_comp_mm_elim_abs(real t, |
| 302 | + real[] y, |
| 303 | + real[] theta, |
| 304 | + real[] x_r, |
| 305 | + int[] x_i) { |
| 306 | + real dydt[1]; |
| 307 | + real k_a = theta[1]; // Dosing rate in 1/day |
| 308 | + real K_m = theta[2]; // Michaelis-Menten constant in mg/L |
| 309 | + real V_m = theta[3]; // Maximum elimination rate in 1/day |
| 310 | + real D = x_r[1]; |
| 311 | + real V = x_r[2]; |
| 312 | + real dose = 0; |
| 313 | + real elim = (V_m / V) * y[1] / (K_m + y[1]); |
| 314 | +
|
| 315 | + if (t > 0) |
| 316 | + dose = exp(- k_a * t) * D * k_a / V; |
| 317 | +
|
| 318 | + dydt[1] = dose - elim; |
| 319 | +
|
| 320 | + return dydt; |
| 321 | + } |
| 322 | +} |
| 323 | +``` |
| 324 | + |
| 325 | +In the new interface, they are passed directly and so the unpacking is avoided: |
| 326 | + |
| 327 | +```{stan, output.var = "", eval = FALSE} |
| 328 | +functions { |
| 329 | + vector one_comp_mm_elim_abs(real t, |
| 330 | + vector y, |
| 331 | + real k_a, // Dosing rate in 1/day |
| 332 | + real K_m, // Michaelis-Menten constant in mg/L |
| 333 | + real V_m, // Maximum elimination rate in 1/day |
| 334 | + real D, |
| 335 | + real V) { |
| 336 | + vector[1] dydt; |
| 337 | +
|
| 338 | + real dose = 0; |
| 339 | + real elim = (V_m / V) * y[1] / (K_m + y[1]); |
| 340 | +
|
| 341 | + if (t > 0) |
| 342 | + dose = exp(- k_a * t) * D * k_a / V; |
| 343 | +
|
| 344 | + dydt[1] = dose - elim; |
| 345 | +
|
| 346 | + return dydt; |
| 347 | + } |
| 348 | +} |
| 349 | +``` |
| 350 | + |
| 351 | +### Calling the ODE Solver |
| 352 | + |
| 353 | +In the old interface the `theta` and `x_r` arguments are packed manually, and |
| 354 | +the `x_i` argument is required even though it isn't used: |
| 355 | + |
| 356 | +```{stan, output.var = "", eval = FALSE} |
| 357 | +transformed data { |
| 358 | + real x_r[2] = {D, V}; |
| 359 | + int x_i[0]; |
| 360 | +} |
| 361 | +... |
| 362 | +transformed parameters { |
| 363 | + real C[N_t, 1]; |
| 364 | + { |
| 365 | + real theta[3] = {k_a, K_m, V_m}; |
| 366 | + C = integrate_ode_bdf(one_comp_mm_elim_abs, C0, t0, times, theta, x_r, x_i); |
| 367 | + } |
| 368 | +} |
| 369 | +``` |
| 370 | + |
| 371 | +In the new interface the arguments are simply passed at the end of the |
| 372 | +`ode_bdf` call (and the `transformed data` block removed): |
| 373 | + |
| 374 | +```{stan, output.var = "", eval = FALSE} |
| 375 | +transformed parameters { |
| 376 | + vector[1] C[N_t] = ode_bdf(one_comp_mm_elim_abs, C0, t0, times, k_a, K_m, V_m, D, V); |
| 377 | +} |
| 378 | +``` |
| 379 | +### Full Model |
| 380 | + |
| 381 | +The full model with the new interface can be found |
| 382 | +[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/feature/variadic-odes/benchmarks/one_comp_mm_elim_abs/one_comp_mm_elim_abs.stan); |
| 383 | +the full model with the old interface can be found |
| 384 | +[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/pkpd/one_comp_mm_elim_abs.stan), |
| 385 | +and the data can be found |
| 386 | +[here](https://github.com/stan-dev/stat_comp_benchmarks/blob/master/benchmarks/pkpd/one_comp_mm_elim_abs.data.R). |
0 commit comments