|
| 1 | +--- |
| 2 | +title: "HMM Example" |
| 3 | +output: html_document |
| 4 | +--- |
| 5 | + |
| 6 | +```{r setup, include=FALSE} |
| 7 | +library(tidyverse) |
| 8 | +library(ggplot2) |
| 9 | +library(cmdstanr) |
| 10 | +library(posterior) |
| 11 | +``` |
| 12 | + |
| 13 | +## Introduction |
| 14 | + |
| 15 | +CmdStan 2.24 introduced a new interface for fitting Hidden Markov models (HMMs) |
| 16 | +in Stan. This document is intended to provide an example use of this interface. |
| 17 | + |
| 18 | +HMMs model a process where a system probabilistically switches between $K$ |
| 19 | +states over a sequence of $N$ points in time. It is assumed that the exact |
| 20 | +state of the system is unknown and must be measured at each state. |
| 21 | + |
| 22 | +HMMs are characterized in terms of the transition matrix $\Gamma_{ij}$ (each |
| 23 | +element being the probability of transitioning from state $i$ to state $j$ |
| 24 | +between measurements), the types of measurements made on the system (the |
| 25 | +system may emit continuous or discrete measurements), and the initial state |
| 26 | +of the system. |
| 27 | + |
| 28 | +Any realization of an HMM is a sequence of $N$ integers in the range $[1, K]$, |
| 29 | +however, because of the structure of the HMM, it is not necessary to sample |
| 30 | +the latent states to do inference on the transition probabilities, the |
| 31 | +parameters of the measurement model, or the estimates of the initial state. |
| 32 | +Estimates of the distribution of states at each measurement time can be |
| 33 | +computed separately. |
| 34 | + |
| 35 | +A more complete mathematical definition of the HMM model and function interface |
| 36 | +is given in the [Hidden Markov Models](https://mc-stan.org/docs/2_24/functions-reference/hidden-markov-models.html) |
| 37 | +section of the Function Reference Guide. |
| 38 | + |
| 39 | +There are three functions |
| 40 | + |
| 41 | +- `hmm_marginal` - The likelihood of an HMM with the latent discrete states |
| 42 | +integrated out |
| 43 | +- `hmm_latent_rng` - A function to generate random realizations of an HMM |
| 44 | +- `hmm_hidden_state_prob` - A function to compute the distribution of latent |
| 45 | +states at each measurement. |
| 46 | + |
| 47 | +This guide will demonstrate how to simulate HMM realizations in R, fit the data |
| 48 | +with `hmm_marginal` and produce estimates of the latent states using |
| 49 | +`hmm_hidden_state_prob`. |
| 50 | + |
| 51 | +### Generating HMM realizations |
| 52 | + |
| 53 | +Simulating an HMM requires a set of states, the transition probabilities |
| 54 | +between those states, and an estimate of the initial states. |
| 55 | + |
| 56 | +For illustrative purposes, assume a three state system with states 1, 2, 3. |
| 57 | + |
| 58 | +The transitions happen as follows: |
| 59 | +1. In state 1 there is a 50% chance of moving to state 2 and a 50% chance of staying in state 1 |
| 60 | +2. In state 2 there is a 25% chance of moving to state 1, a 25% change of moving to state 3, and a 50% chance of staying in state 2 |
| 61 | +3. In state 3 there is a 50% chance of moving to state 2 and a 50% chance of staying at state 3. |
| 62 | + |
| 63 | +Assume that the system starts in state 1. |
| 64 | + |
| 65 | +```{r} |
| 66 | +N = 100 # 100 measurements |
| 67 | +K = 3 # 3 states |
| 68 | +states = rep(1, N) |
| 69 | +states[1] = 1 # Start in state 1 |
| 70 | +for(n in 2:length(states)) { |
| 71 | + if(states[n - 1] == 1) |
| 72 | + states[n] = sample(c(1, 2), size = 1, prob = c(0.5, 0.5)) |
| 73 | + else if(states[n - 1] == 2) |
| 74 | + states[n] = sample(c(1, 2, 3), size = 1, prob = c(0.25, 0.5, 0.25)) |
| 75 | + else if(states[n - 1] == 3) |
| 76 | + states[n] = sample(c(2, 3), size = 1, prob = c(0.5, 0.5)) |
| 77 | +} |
| 78 | +``` |
| 79 | + |
| 80 | +The trajectory can easily be visualized: |
| 81 | +```{r} |
| 82 | +qplot(1:N, states) |
| 83 | +``` |
| 84 | + |
| 85 | +An HMM is useful when the latent state is only measured indirectly. |
| 86 | + |
| 87 | +In this example the observations are assumed to be |
| 88 | +normally distributed with a state specific mean and some measurement error. |
| 89 | + |
| 90 | +```{r} |
| 91 | +mus = c(1.0, 5.0, 9.0) |
| 92 | +sigma = 2.0 |
| 93 | +y = rnorm(N, mus[states], sd = sigma) |
| 94 | +``` |
| 95 | + |
| 96 | +Plotting the simulated measurements gives: |
| 97 | + |
| 98 | +```{r} |
| 99 | +qplot(1:N, y) |
| 100 | +``` |
| 101 | + |
| 102 | +### Fitting the HMM |
| 103 | + |
| 104 | +To make it clear how to use the HMM fit functions, the model here will fit the |
| 105 | +transition matrix, the initial state, and the parameters of the measurement |
| 106 | +model. It is not necessary to estimate all of these things in practice if some |
| 107 | +of them are known. |
| 108 | + |
| 109 | +The data is the previously generated sequence of $N$ measurements: |
| 110 | +```{stan, output.var = "", eval = FALSE} |
| 111 | +data { |
| 112 | + int N; // Number of observations |
| 113 | + real y[N]; |
| 114 | +} |
| 115 | +``` |
| 116 | + |
| 117 | +For the transition matrix, assume that it is known that states 1 and 3 are not |
| 118 | +directly connected. For $K$ states, estimating a full transition matrix means |
| 119 | +estimatng a matrix of $O(K^2)$ probabilities. Depending on the data available, |
| 120 | +this may not be possible and so it is important to take advantage of available |
| 121 | +modeling assumptions. The state means are estimated as an ordered vector |
| 122 | +to avoid mode-swap non-identifiabilities. |
| 123 | + |
| 124 | +```{stan, output.var = "", eval = FALSE} |
| 125 | +parameters { |
| 126 | + // Rows of the transition matrix |
| 127 | + simplex[2] t1; |
| 128 | + simplex[3] t2; |
| 129 | + simplex[2] t3; |
| 130 | + |
| 131 | + // Initial state |
| 132 | + simplex[3] rho; |
| 133 | + |
| 134 | + // Parameters of measurement model |
| 135 | + vector[3] mu; |
| 136 | + real<lower = 0.0> sigma; |
| 137 | +} |
| 138 | +``` |
| 139 | + |
| 140 | +The `hmm_marginal` function takes the transition matrix and initial state |
| 141 | +directly. In this case the transition matrix needs constructed from `t1`, |
| 142 | +`t2`, and `t3` but that is relatively easy to build. |
| 143 | + |
| 144 | +The measurement model, in contrast, is not passed directly to the HMM function. |
| 145 | + |
| 146 | +Instead, a $KxN$ matrix `log_omega` of log likelihoods is passed in. The |
| 147 | +$(k, n)$ entry of this matrix is the log likelihood of the $nth$ measurement |
| 148 | +given the system at time $n$ is actually in state $k$. For the generative |
| 149 | +model above, these are log normals evaluated at the three different means. |
| 150 | + |
| 151 | +```{stan, output.var = "", eval = FALSE} |
| 152 | +transformed parameters { |
| 153 | + matrix[3, 3] gamma = rep_matrix(0, 3, 3); |
| 154 | + matrix[3, N] log_omega; |
| 155 | + |
| 156 | + // Build the transition matrix |
| 157 | + gamma[1, 1:2] = t1; |
| 158 | + gamma[2, ] = t2; |
| 159 | + gamma[3, 2:3] = t3; |
| 160 | + |
| 161 | + // Compute the log likelihoods in each possible state |
| 162 | + for(n in 1:N) { |
| 163 | + // The observation model could change with n, or vary in a number of |
| 164 | + // different ways (which is why log_omega is passed in as an argument) |
| 165 | + log_omega[1, n] = normal_lpdf(y[n] | mu[1], sigma); |
| 166 | + log_omega[2, n] = normal_lpdf(y[n] | mu[2], sigma); |
| 167 | + log_omega[3, n] = normal_lpdf(y[n] | mu[3], sigma); |
| 168 | + } |
| 169 | +} |
| 170 | +``` |
| 171 | + |
| 172 | +With all that in place, the only thing left to do is add priors and increment |
| 173 | +the log density: |
| 174 | +```{stan, output.var = "", eval = FALSE} |
| 175 | +model { |
| 176 | + mu ~ normal(0, 1); |
| 177 | + sigma ~ normal(0, 1); |
| 178 | +
|
| 179 | + target += hmm_marginal(log_omega, Gamma, rho); |
| 180 | +} |
| 181 | +``` |
| 182 | + |
| 183 | +```{r echo = TRUE, results = FALSE, message = FALSE} |
| 184 | +model = cmdstan_model("hmm-example.stan") |
| 185 | +fit = model$sample(data = list(N = N, y = y), parallel_chains = 4) |
| 186 | +``` |
| 187 | + |
| 188 | +The estimated group means match the known ones: |
| 189 | +```{r} |
| 190 | +fit$summary("mu") |
| 191 | +``` |
| 192 | +The estimated of the initial condition is not much more informative than |
| 193 | +the prior, but it is there: |
| 194 | +```{r} |
| 195 | +fit$summary("rho") |
| 196 | +``` |
| 197 | + |
| 198 | +The transition probabilities from state 1 can be extracted: |
| 199 | +```{r} |
| 200 | +fit$summary("t1") |
| 201 | +``` |
| 202 | + |
| 203 | +Similarly for state 2: |
| 204 | +```{r} |
| 205 | +fit$summary("t2") |
| 206 | +``` |
| 207 | + |
| 208 | +And state 3: |
| 209 | +```{r} |
| 210 | +fit$summary("t3") |
| 211 | +``` |
| 212 | + |
| 213 | +### State Probabilities |
| 214 | + |
| 215 | +Even though the latent states are not sampled directly, the distribution |
| 216 | +of latent states at each time point can be computed with the function |
| 217 | +`hmm_hidden_state_prob`: |
| 218 | + |
| 219 | +```{stan, output.var = "", eval = FALSE} |
| 220 | +generated quantities { |
| 221 | + matrix[3, N] latent_probs = hmm_hidden_state_prob(log_omega, Gamma, rho); |
| 222 | +} |
| 223 | +``` |
| 224 | + |
| 225 | +These can be plotted: |
| 226 | + |
| 227 | +```{r} |
| 228 | +latent_probs_df = fit$draws() %>% |
| 229 | + as_draws_df %>% |
| 230 | + select(starts_with("latent_probs")) %>% |
| 231 | + pivot_longer(everything(), |
| 232 | + names_to = c("state", "n"), |
| 233 | + names_transform = list(k = as.integer, n = as.integer), |
| 234 | + names_pattern = "latent_probs\\[([0-9]*),([0-9]*)\\]", |
| 235 | + values_to = "latent_probs") |
| 236 | +
|
| 237 | +latent_probs_df %>% |
| 238 | + group_by(state, n) %>% |
| 239 | + summarize(qh = quantile(latent_probs, 0.8), |
| 240 | + m = median(latent_probs), |
| 241 | + ql = quantile(latent_probs, 0.2)) %>% |
| 242 | + ungroup() %>% |
| 243 | + ggplot() + |
| 244 | + geom_errorbar(aes(n, ymin = ql, ymax = qh, width = 0.0), alpha = 0.5) + |
| 245 | + geom_point(aes(n, m)) + |
| 246 | + facet_grid(state ~ ., labeller = "label_both") + |
| 247 | + ggtitle("Ribbon is 60% posterior interval, point is median") + |
| 248 | + ylab("Probability of being in state") + |
| 249 | + xlab("Time (n)") |
| 250 | +``` |
| 251 | + |
| 252 | +New simulations from the fitted HMM can be generated with `hmm_latent_rng`: |
| 253 | + |
| 254 | +```{stan, output.var = "", eval = FALSE} |
| 255 | +generated quantities { |
| 256 | + int[N] y_sim = hmm_latent_rng(log_omega, Gamma, rho) |
| 257 | +} |
| 258 | +``` |
| 259 | + |
| 260 | +These can be plotted as well: |
| 261 | + |
| 262 | +```{r} |
| 263 | +y_sim = fit$draws() %>% |
| 264 | + as_draws_df() %>% |
| 265 | + select(starts_with("y_sim")) %>% |
| 266 | + as.matrix |
| 267 | +
|
| 268 | +qplot(1:N, y_sim[1,]) |
| 269 | +``` |
0 commit comments