Skip to content

Commit bf2c0d5

Browse files
authored
Merge pull request #187 from stan-dev/feature/hmm-example
Added HMM example
2 parents 8a69641 + 1a550ac commit bf2c0d5

File tree

2 files changed

+347
-0
lines changed

2 files changed

+347
-0
lines changed

knitr/hmm-example/hmm-example.Rmd

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
---
2+
title: "HMM Example"
3+
author: "Ben Bales"
4+
date: 10-2-2020
5+
output:
6+
html_document: default
7+
pdf_document: default
8+
---
9+
10+
```{r setup, include=FALSE}
11+
library(tidyverse)
12+
library(ggplot2)
13+
library(cmdstanr)
14+
library(posterior)
15+
```
16+
17+
## Introduction
18+
19+
CmdStan 2.24 introduced a new interface for fitting Hidden Markov models (HMMs)
20+
in Stan. This document is intended to provide an example use of this interface.
21+
22+
HMMs model a process where a system probabilistically switches between $K$
23+
states over a sequence of $N$ points in time. It is assumed that the exact
24+
state of the system is unknown and must be inferred at each state.
25+
26+
HMMs are characterized in terms of the transition matrix $\Gamma_{ij}$ (each
27+
element being the probability of transitioning from state $i$ to state $j$
28+
between measurements), the types of measurements made on the system (the
29+
system may emit continuous or discrete measurements), and the initial state
30+
of the system. Currently the HMM interface in Stan only supports a constant
31+
transition matrix. Future versions will support a transition matrix for each state.
32+
33+
Any realization of an HMM's hidden state is a sequence of $N$ integers in the
34+
range $[1, K]$, however, because of the structure of the HMM, it is not
35+
necessary to sample the hidden states to do inference on the transition
36+
probabilities, the parameters of the measurement model, or the estimates
37+
of the initial state. Posterior draws from the hidden states can be computed
38+
separately.
39+
40+
A more complete mathematical definition of the HMM model and function interface
41+
is given in the [Hidden Markov Models](https://mc-stan.org/docs/2_24/functions-reference/hidden-markov-models.html)
42+
section of the Function Reference Guide.
43+
44+
There are three functions
45+
46+
- `hmm_marginal` - The likelihood of an HMM with the hidden discrete states
47+
integrated out
48+
- `hmm_latent_rng` - A function to generate posterior draws of the hidden state that are
49+
implicitly integrated out of the model when using `hmm_marginal` (this is
50+
different than sampling more states with a posterior draw of a transition matrix
51+
and initial state)
52+
- `hmm_hidden_state_prob` - A function to compute the posterior distributions of the
53+
integrated out hidden states
54+
55+
This guide will demonstrate how to simulate HMM realizations in R, fit the data
56+
with `hmm_marginal`, produce estimates of the distributions of the hidden states
57+
using `hmm_hidden_state_prob`, and generate draws of the hidden state from the
58+
posterior with `hmm_latent_rng`.
59+
60+
### Generating HMM realizations
61+
62+
Simulating an HMM requires a set of states, the transition probabilities
63+
between those states, and an estimate of the initial states.
64+
65+
For illustrative purposes, assume a three state system with states 1, 2, 3.
66+
67+
The transitions happen as follows:
68+
1. In state 1 there is a 50% chance of moving to state 2 and a 50% chance of staying in state 1
69+
2. In state 2 there is a 25% chance of moving to state 1, a 25% change of moving to state 3, and a 50% chance of staying in state 2
70+
3. In state 3 there is a 50% chance of moving to state 2 and a 50% chance of staying at state 3.
71+
72+
Assume that the system starts in state 1.
73+
74+
```{r}
75+
N = 100 # 100 measurements
76+
K = 3 # 3 states
77+
states = rep(1, N)
78+
states[1] = 1 # Start in state 1
79+
for(n in 2:length(states)) {
80+
if(states[n - 1] == 1)
81+
states[n] = sample(c(1, 2), size = 1, prob = c(0.5, 0.5))
82+
else if(states[n - 1] == 2)
83+
states[n] = sample(c(1, 2, 3), size = 1, prob = c(0.25, 0.5, 0.25))
84+
else if(states[n - 1] == 3)
85+
states[n] = sample(c(2, 3), size = 1, prob = c(0.5, 0.5))
86+
}
87+
```
88+
89+
The trajectory can easily be visualized:
90+
```{r}
91+
qplot(1:N, states)
92+
```
93+
94+
An HMM is useful when the hidden state is not measure directly (if the
95+
state was measured directly, it wouldn't be hidden).
96+
97+
In this example the observations are assumed to be
98+
normally distributed with a state specific mean and some measurement error.
99+
100+
```{r}
101+
mus = c(1.0, 5.0, 9.0)
102+
sigma = 2.0
103+
y = rnorm(N, mus[states], sd = sigma)
104+
```
105+
106+
Plotting the simulated measurements gives:
107+
108+
```{r}
109+
qplot(1:N, y)
110+
```
111+
112+
### Fitting the HMM
113+
114+
To make it clear how to use the HMM fit functions, the model here will fit the
115+
transition matrix, the initial state, and the parameters of the measurement
116+
model. It is not necessary to estimate all of these things in practice if some
117+
of them are known.
118+
119+
The data is the previously generated sequence of $N$ measurements:
120+
```{stan, output.var = "", eval = FALSE}
121+
data {
122+
int N; // Number of observations
123+
real y[N];
124+
}
125+
```
126+
127+
For the transition matrix, assume that it is known that states 1 and 3 are not
128+
directly connected. For $K$ states, estimating a full transition matrix means
129+
estimatng a matrix of $O(K^2)$ probabilities. Depending on the data available,
130+
this may not be possible and so it is important to take advantage of available
131+
modeling assumptions. The state means are estimated as an ordered vector
132+
to avoid mode-swap non-identifiabilities.
133+
134+
```{stan, output.var = "", eval = FALSE}
135+
parameters {
136+
// Rows of the transition matrix
137+
simplex[2] t1;
138+
simplex[3] t2;
139+
simplex[2] t3;
140+
141+
// Initial state
142+
simplex[3] rho;
143+
144+
// Parameters of measurement model
145+
vector[3] mu;
146+
real<lower = 0.0> sigma;
147+
}
148+
```
149+
150+
The `hmm_marginal` function takes the transition matrix and initial state
151+
directly. In this case the transition matrix needs to be constructed from `t1`,
152+
`t2`, and `t3` but that is relatively easy to build.
153+
154+
The measurement model, in contrast, is not passed directly to the HMM function.
155+
156+
Instead, a $K \times N$ matrix `log_omega` of log likelihoods is passed in. The
157+
$(k, n)$ entry of this matrix is the log likelihood of the $nth$ measurement
158+
given the system at time $n$ is actually in state $k$. For the generative
159+
model above, these are log normals evaluated at the three different means.
160+
161+
```{stan, output.var = "", eval = FALSE}
162+
transformed parameters {
163+
matrix[3, 3] gamma = rep_matrix(0, 3, 3);
164+
matrix[3, N] log_omega;
165+
166+
// Build the transition matrix
167+
gamma[1, 1:2] = t1;
168+
gamma[2, ] = t2;
169+
gamma[3, 2:3] = t3;
170+
171+
// Compute the log likelihoods in each possible state
172+
for(n in 1:N) {
173+
// The observation model could change with n, or vary in a number of
174+
// different ways (which is why log_omega is passed in as an argument)
175+
log_omega[1, n] = normal_lpdf(y[n] | mu[1], sigma);
176+
log_omega[2, n] = normal_lpdf(y[n] | mu[2], sigma);
177+
log_omega[3, n] = normal_lpdf(y[n] | mu[3], sigma);
178+
}
179+
}
180+
```
181+
182+
With all that in place, the only thing left to do is add priors and increment
183+
the log density:
184+
```{stan, output.var = "", eval = FALSE}
185+
model {
186+
mu ~ normal(0, 1);
187+
sigma ~ normal(0, 1);
188+
189+
target += hmm_marginal(log_omega, Gamma, rho);
190+
}
191+
```
192+
193+
The complete model is available on Github: [hmm-example.stan](https://github.com/stan-dev/example-models/tree/master/knitr/hmm-example/hmm-example.stan).
194+
195+
```{r echo = TRUE, results = FALSE, message = FALSE}
196+
model = cmdstan_model("hmm-example.stan")
197+
fit = model$sample(data = list(N = N, y = y), parallel_chains = 4)
198+
```
199+
200+
The estimated group means match the known ones:
201+
```{r}
202+
fit$summary("mu")
203+
```
204+
The estimated initial conditions are not much more informative than
205+
the prior, but it is there:
206+
```{r}
207+
fit$summary("rho")
208+
```
209+
210+
The transition probabilities from state 1 can be extracted:
211+
```{r}
212+
fit$summary("t1")
213+
```
214+
215+
Similarly for state 2:
216+
```{r}
217+
fit$summary("t2")
218+
```
219+
220+
And state 3:
221+
```{r}
222+
fit$summary("t3")
223+
```
224+
225+
### State Probabilities
226+
227+
Even though the hidden states are integrated out, the distribution
228+
of hidden states at each time point can be computed with the function
229+
`hmm_hidden_state_prob`:
230+
231+
```{stan, output.var = "", eval = FALSE}
232+
generated quantities {
233+
matrix[3, N] hidden_probs = hmm_hidden_state_prob(log_omega, Gamma, rho);
234+
}
235+
```
236+
237+
These can be plotted:
238+
239+
```{r}
240+
hidden_probs_df = fit$draws() %>%
241+
as_draws_df %>%
242+
select(starts_with("hidden_probs")) %>%
243+
pivot_longer(everything(),
244+
names_to = c("state", "n"),
245+
names_transform = list(k = as.integer, n = as.integer),
246+
names_pattern = "hidden_probs\\[([0-9]*),([0-9]*)\\]",
247+
values_to = "hidden_probs")
248+
249+
hidden_probs_df %>%
250+
group_by(state, n) %>%
251+
summarize(qh = quantile(hidden_probs, 0.8),
252+
m = median(hidden_probs),
253+
ql = quantile(hidden_probs, 0.2)) %>%
254+
ungroup() %>%
255+
ggplot() +
256+
geom_errorbar(aes(n, ymin = ql, ymax = qh, width = 0.0), alpha = 0.5) +
257+
geom_point(aes(n, m)) +
258+
facet_grid(state ~ ., labeller = "label_both") +
259+
ggtitle("Ribbon is 60% posterior interval, point is median") +
260+
ylab("Probability of being in state") +
261+
xlab("Time (n)")
262+
```
263+
264+
If it is more convenient to work with draws of the hidden states at each time
265+
point (instead of the probabilities provided by `hmm_hidden_state_prob`), these
266+
can be generated with `hmm_latent_rng`:
267+
268+
```{stan, output.var = "", eval = FALSE}
269+
generated quantities {
270+
int[N] y_sim = hmm_latent_rng(log_omega, Gamma, rho)
271+
}
272+
```
273+
274+
Note that the probabilities from `hmm_hidden_state_prob` are the marginal
275+
probabilities of the hidden states, meaning they cannot be directly used to
276+
jointly sample hidden states. The posterior draws generated by `hmm_latent_rng`
277+
account for the correlation between hidden states.
278+
279+
Note further these are draws of the hidden state that was integrated out. This is
280+
different than sampling new HMM realizations using posterior draws of the initial
281+
condition and the transition matrix.
282+
283+
The draws of the hidden state can be plotted as well:
284+
285+
```{r}
286+
y_sim = fit$draws() %>%
287+
as_draws_df() %>%
288+
select(starts_with("y_sim")) %>%
289+
as.matrix
290+
291+
qplot(1:N, y_sim[1,])
292+
```

knitr/hmm-example/hmm-example.stan

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
data {
2+
int N; // Number of observations
3+
real y[N];
4+
}
5+
6+
parameters {
7+
// Parameters of measurement model
8+
ordered[3] mu;
9+
real<lower = 0.0> sigma;
10+
11+
// Initial state
12+
simplex[3] rho;
13+
14+
// Rows of the transition matrix
15+
simplex[2] t1;
16+
simplex[3] t2;
17+
simplex[2] t3;
18+
}
19+
20+
transformed parameters {
21+
matrix[3, 3] Gamma = rep_matrix(0, 3, 3);
22+
matrix[3, N] log_omega;
23+
24+
// Build the transition matrix
25+
Gamma[1, 1:2] = t1';
26+
Gamma[2, ] = t2';
27+
Gamma[3, 2:3] = t3';
28+
29+
// Compute the log likelihoods in each possible state
30+
for(n in 1:N) {
31+
// The observation model could change with n, or vary in a number of
32+
// different ways (which is why log_omega is passed in as an argument)
33+
log_omega[1, n] = normal_lpdf(y[n] | mu[1], sigma);
34+
log_omega[2, n] = normal_lpdf(y[n] | mu[2], sigma);
35+
log_omega[3, n] = normal_lpdf(y[n] | mu[3], sigma);
36+
}
37+
}
38+
39+
model {
40+
mu ~ normal(0, 10);
41+
sigma ~ normal(0, 1);
42+
43+
rho ~ dirichlet([10, 1, 1]);
44+
45+
t1 ~ dirichlet([1, 1]);
46+
t2 ~ dirichlet([1, 1, 1]);
47+
t3 ~ dirichlet([1, 1]);
48+
49+
target += hmm_marginal(log_omega, Gamma, rho);
50+
}
51+
52+
generated quantities {
53+
matrix[3, N] hidden_probs = hmm_hidden_state_prob(log_omega, Gamma, rho);
54+
int y_sim[N] = hmm_latent_rng(log_omega, Gamma, rho);
55+
}

0 commit comments

Comments
 (0)