|
| 1 | +--- |
| 2 | +title: "HMM Example" |
| 3 | +author: "Ben Bales" |
| 4 | +date: 10-2-2020 |
| 5 | +output: |
| 6 | + html_document: default |
| 7 | + pdf_document: default |
| 8 | +--- |
| 9 | + |
| 10 | +```{r setup, include=FALSE} |
| 11 | +library(tidyverse) |
| 12 | +library(ggplot2) |
| 13 | +library(cmdstanr) |
| 14 | +library(posterior) |
| 15 | +``` |
| 16 | + |
| 17 | +## Introduction |
| 18 | + |
| 19 | +CmdStan 2.24 introduced a new interface for fitting Hidden Markov models (HMMs) |
| 20 | +in Stan. This document is intended to provide an example use of this interface. |
| 21 | + |
| 22 | +HMMs model a process where a system probabilistically switches between $K$ |
| 23 | +states over a sequence of $N$ points in time. It is assumed that the exact |
| 24 | +state of the system is unknown and must be inferred at each state. |
| 25 | + |
| 26 | +HMMs are characterized in terms of the transition matrix $\Gamma_{ij}$ (each |
| 27 | +element being the probability of transitioning from state $i$ to state $j$ |
| 28 | +between measurements), the types of measurements made on the system (the |
| 29 | +system may emit continuous or discrete measurements), and the initial state |
| 30 | +of the system. Currently the HMM interface in Stan only supports a constant |
| 31 | +transition matrix. Future versions will support a transition matrix for each state. |
| 32 | + |
| 33 | +Any realization of an HMM's hidden state is a sequence of $N$ integers in the |
| 34 | +range $[1, K]$, however, because of the structure of the HMM, it is not |
| 35 | +necessary to sample the hidden states to do inference on the transition |
| 36 | +probabilities, the parameters of the measurement model, or the estimates |
| 37 | +of the initial state. Posterior draws from the hidden states can be computed |
| 38 | +separately. |
| 39 | + |
| 40 | +A more complete mathematical definition of the HMM model and function interface |
| 41 | +is given in the [Hidden Markov Models](https://mc-stan.org/docs/2_24/functions-reference/hidden-markov-models.html) |
| 42 | +section of the Function Reference Guide. |
| 43 | + |
| 44 | +There are three functions |
| 45 | + |
| 46 | +- `hmm_marginal` - The likelihood of an HMM with the hidden discrete states |
| 47 | +integrated out |
| 48 | +- `hmm_latent_rng` - A function to generate posterior draws of the hidden state that are |
| 49 | +implicitly integrated out of the model when using `hmm_marginal` (this is |
| 50 | +different than sampling more states with a posterior draw of a transition matrix |
| 51 | +and initial state) |
| 52 | +- `hmm_hidden_state_prob` - A function to compute the posterior distributions of the |
| 53 | +integrated out hidden states |
| 54 | + |
| 55 | +This guide will demonstrate how to simulate HMM realizations in R, fit the data |
| 56 | +with `hmm_marginal`, produce estimates of the distributions of the hidden states |
| 57 | +using `hmm_hidden_state_prob`, and generate draws of the hidden state from the |
| 58 | +posterior with `hmm_latent_rng`. |
| 59 | + |
| 60 | +### Generating HMM realizations |
| 61 | + |
| 62 | +Simulating an HMM requires a set of states, the transition probabilities |
| 63 | +between those states, and an estimate of the initial states. |
| 64 | + |
| 65 | +For illustrative purposes, assume a three state system with states 1, 2, 3. |
| 66 | + |
| 67 | +The transitions happen as follows: |
| 68 | +1. In state 1 there is a 50% chance of moving to state 2 and a 50% chance of staying in state 1 |
| 69 | +2. In state 2 there is a 25% chance of moving to state 1, a 25% change of moving to state 3, and a 50% chance of staying in state 2 |
| 70 | +3. In state 3 there is a 50% chance of moving to state 2 and a 50% chance of staying at state 3. |
| 71 | + |
| 72 | +Assume that the system starts in state 1. |
| 73 | + |
| 74 | +```{r} |
| 75 | +N = 100 # 100 measurements |
| 76 | +K = 3 # 3 states |
| 77 | +states = rep(1, N) |
| 78 | +states[1] = 1 # Start in state 1 |
| 79 | +for(n in 2:length(states)) { |
| 80 | + if(states[n - 1] == 1) |
| 81 | + states[n] = sample(c(1, 2), size = 1, prob = c(0.5, 0.5)) |
| 82 | + else if(states[n - 1] == 2) |
| 83 | + states[n] = sample(c(1, 2, 3), size = 1, prob = c(0.25, 0.5, 0.25)) |
| 84 | + else if(states[n - 1] == 3) |
| 85 | + states[n] = sample(c(2, 3), size = 1, prob = c(0.5, 0.5)) |
| 86 | +} |
| 87 | +``` |
| 88 | + |
| 89 | +The trajectory can easily be visualized: |
| 90 | +```{r} |
| 91 | +qplot(1:N, states) |
| 92 | +``` |
| 93 | + |
| 94 | +An HMM is useful when the hidden state is not measure directly (if the |
| 95 | +state was measured directly, it wouldn't be hidden). |
| 96 | + |
| 97 | +In this example the observations are assumed to be |
| 98 | +normally distributed with a state specific mean and some measurement error. |
| 99 | + |
| 100 | +```{r} |
| 101 | +mus = c(1.0, 5.0, 9.0) |
| 102 | +sigma = 2.0 |
| 103 | +y = rnorm(N, mus[states], sd = sigma) |
| 104 | +``` |
| 105 | + |
| 106 | +Plotting the simulated measurements gives: |
| 107 | + |
| 108 | +```{r} |
| 109 | +qplot(1:N, y) |
| 110 | +``` |
| 111 | + |
| 112 | +### Fitting the HMM |
| 113 | + |
| 114 | +To make it clear how to use the HMM fit functions, the model here will fit the |
| 115 | +transition matrix, the initial state, and the parameters of the measurement |
| 116 | +model. It is not necessary to estimate all of these things in practice if some |
| 117 | +of them are known. |
| 118 | + |
| 119 | +The data is the previously generated sequence of $N$ measurements: |
| 120 | +```{stan, output.var = "", eval = FALSE} |
| 121 | +data { |
| 122 | + int N; // Number of observations |
| 123 | + real y[N]; |
| 124 | +} |
| 125 | +``` |
| 126 | + |
| 127 | +For the transition matrix, assume that it is known that states 1 and 3 are not |
| 128 | +directly connected. For $K$ states, estimating a full transition matrix means |
| 129 | +estimatng a matrix of $O(K^2)$ probabilities. Depending on the data available, |
| 130 | +this may not be possible and so it is important to take advantage of available |
| 131 | +modeling assumptions. The state means are estimated as an ordered vector |
| 132 | +to avoid mode-swap non-identifiabilities. |
| 133 | + |
| 134 | +```{stan, output.var = "", eval = FALSE} |
| 135 | +parameters { |
| 136 | + // Rows of the transition matrix |
| 137 | + simplex[2] t1; |
| 138 | + simplex[3] t2; |
| 139 | + simplex[2] t3; |
| 140 | + |
| 141 | + // Initial state |
| 142 | + simplex[3] rho; |
| 143 | + |
| 144 | + // Parameters of measurement model |
| 145 | + vector[3] mu; |
| 146 | + real<lower = 0.0> sigma; |
| 147 | +} |
| 148 | +``` |
| 149 | + |
| 150 | +The `hmm_marginal` function takes the transition matrix and initial state |
| 151 | +directly. In this case the transition matrix needs to be constructed from `t1`, |
| 152 | +`t2`, and `t3` but that is relatively easy to build. |
| 153 | + |
| 154 | +The measurement model, in contrast, is not passed directly to the HMM function. |
| 155 | + |
| 156 | +Instead, a $K \times N$ matrix `log_omega` of log likelihoods is passed in. The |
| 157 | +$(k, n)$ entry of this matrix is the log likelihood of the $nth$ measurement |
| 158 | +given the system at time $n$ is actually in state $k$. For the generative |
| 159 | +model above, these are log normals evaluated at the three different means. |
| 160 | + |
| 161 | +```{stan, output.var = "", eval = FALSE} |
| 162 | +transformed parameters { |
| 163 | + matrix[3, 3] gamma = rep_matrix(0, 3, 3); |
| 164 | + matrix[3, N] log_omega; |
| 165 | + |
| 166 | + // Build the transition matrix |
| 167 | + gamma[1, 1:2] = t1; |
| 168 | + gamma[2, ] = t2; |
| 169 | + gamma[3, 2:3] = t3; |
| 170 | + |
| 171 | + // Compute the log likelihoods in each possible state |
| 172 | + for(n in 1:N) { |
| 173 | + // The observation model could change with n, or vary in a number of |
| 174 | + // different ways (which is why log_omega is passed in as an argument) |
| 175 | + log_omega[1, n] = normal_lpdf(y[n] | mu[1], sigma); |
| 176 | + log_omega[2, n] = normal_lpdf(y[n] | mu[2], sigma); |
| 177 | + log_omega[3, n] = normal_lpdf(y[n] | mu[3], sigma); |
| 178 | + } |
| 179 | +} |
| 180 | +``` |
| 181 | + |
| 182 | +With all that in place, the only thing left to do is add priors and increment |
| 183 | +the log density: |
| 184 | +```{stan, output.var = "", eval = FALSE} |
| 185 | +model { |
| 186 | + mu ~ normal(0, 1); |
| 187 | + sigma ~ normal(0, 1); |
| 188 | +
|
| 189 | + target += hmm_marginal(log_omega, Gamma, rho); |
| 190 | +} |
| 191 | +``` |
| 192 | + |
| 193 | +The complete model is available on Github: [hmm-example.stan](https://github.com/stan-dev/example-models/tree/master/knitr/hmm-example/hmm-example.stan). |
| 194 | + |
| 195 | +```{r echo = TRUE, results = FALSE, message = FALSE} |
| 196 | +model = cmdstan_model("hmm-example.stan") |
| 197 | +fit = model$sample(data = list(N = N, y = y), parallel_chains = 4) |
| 198 | +``` |
| 199 | + |
| 200 | +The estimated group means match the known ones: |
| 201 | +```{r} |
| 202 | +fit$summary("mu") |
| 203 | +``` |
| 204 | +The estimated initial conditions are not much more informative than |
| 205 | +the prior, but it is there: |
| 206 | +```{r} |
| 207 | +fit$summary("rho") |
| 208 | +``` |
| 209 | + |
| 210 | +The transition probabilities from state 1 can be extracted: |
| 211 | +```{r} |
| 212 | +fit$summary("t1") |
| 213 | +``` |
| 214 | + |
| 215 | +Similarly for state 2: |
| 216 | +```{r} |
| 217 | +fit$summary("t2") |
| 218 | +``` |
| 219 | + |
| 220 | +And state 3: |
| 221 | +```{r} |
| 222 | +fit$summary("t3") |
| 223 | +``` |
| 224 | + |
| 225 | +### State Probabilities |
| 226 | + |
| 227 | +Even though the hidden states are integrated out, the distribution |
| 228 | +of hidden states at each time point can be computed with the function |
| 229 | +`hmm_hidden_state_prob`: |
| 230 | + |
| 231 | +```{stan, output.var = "", eval = FALSE} |
| 232 | +generated quantities { |
| 233 | + matrix[3, N] hidden_probs = hmm_hidden_state_prob(log_omega, Gamma, rho); |
| 234 | +} |
| 235 | +``` |
| 236 | + |
| 237 | +These can be plotted: |
| 238 | + |
| 239 | +```{r} |
| 240 | +hidden_probs_df = fit$draws() %>% |
| 241 | + as_draws_df %>% |
| 242 | + select(starts_with("hidden_probs")) %>% |
| 243 | + pivot_longer(everything(), |
| 244 | + names_to = c("state", "n"), |
| 245 | + names_transform = list(k = as.integer, n = as.integer), |
| 246 | + names_pattern = "hidden_probs\\[([0-9]*),([0-9]*)\\]", |
| 247 | + values_to = "hidden_probs") |
| 248 | +
|
| 249 | +hidden_probs_df %>% |
| 250 | + group_by(state, n) %>% |
| 251 | + summarize(qh = quantile(hidden_probs, 0.8), |
| 252 | + m = median(hidden_probs), |
| 253 | + ql = quantile(hidden_probs, 0.2)) %>% |
| 254 | + ungroup() %>% |
| 255 | + ggplot() + |
| 256 | + geom_errorbar(aes(n, ymin = ql, ymax = qh, width = 0.0), alpha = 0.5) + |
| 257 | + geom_point(aes(n, m)) + |
| 258 | + facet_grid(state ~ ., labeller = "label_both") + |
| 259 | + ggtitle("Ribbon is 60% posterior interval, point is median") + |
| 260 | + ylab("Probability of being in state") + |
| 261 | + xlab("Time (n)") |
| 262 | +``` |
| 263 | + |
| 264 | +If it is more convenient to work with draws of the hidden states at each time |
| 265 | +point (instead of the probabilities provided by `hmm_hidden_state_prob`), these |
| 266 | +can be generated with `hmm_latent_rng`: |
| 267 | + |
| 268 | +```{stan, output.var = "", eval = FALSE} |
| 269 | +generated quantities { |
| 270 | + int[N] y_sim = hmm_latent_rng(log_omega, Gamma, rho) |
| 271 | +} |
| 272 | +``` |
| 273 | + |
| 274 | +Note that the probabilities from `hmm_hidden_state_prob` are the marginal |
| 275 | +probabilities of the hidden states, meaning they cannot be directly used to |
| 276 | +jointly sample hidden states. The posterior draws generated by `hmm_latent_rng` |
| 277 | +account for the correlation between hidden states. |
| 278 | + |
| 279 | +Note further these are draws of the hidden state that was integrated out. This is |
| 280 | +different than sampling new HMM realizations using posterior draws of the initial |
| 281 | +condition and the transition matrix. |
| 282 | + |
| 283 | +The draws of the hidden state can be plotted as well: |
| 284 | + |
| 285 | +```{r} |
| 286 | +y_sim = fit$draws() %>% |
| 287 | + as_draws_df() %>% |
| 288 | + select(starts_with("y_sim")) %>% |
| 289 | + as.matrix |
| 290 | +
|
| 291 | +qplot(1:N, y_sim[1,]) |
| 292 | +``` |
0 commit comments