Skip to content

Commit 8001224

Browse files
committed
Added HMM example
1 parent 8a69641 commit 8001224

File tree

2 files changed

+324
-0
lines changed

2 files changed

+324
-0
lines changed

knitr/hmm-example/hmm-example.Rmd

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
---
2+
title: "HMM Example"
3+
output: html_document
4+
---
5+
6+
```{r setup, include=FALSE}
7+
library(tidyverse)
8+
library(ggplot2)
9+
library(cmdstanr)
10+
library(posterior)
11+
```
12+
13+
## Introduction
14+
15+
CmdStan 2.24 introduced a new interface for fitting Hidden Markov models (HMMs)
16+
in Stan. This document is intended to provide an example use of this interface.
17+
18+
HMMs model a process where a system probabilistically switches between $K$
19+
states over a sequence of $N$ points in time. It is assumed that the exact
20+
state of the system is unknown and must be measured at each state.
21+
22+
HMMs are characterized in terms of the transition matrix $\Gamma_{ij}$ (each
23+
element being the probability of transitioning from state $i$ to state $j$
24+
between measurements), the types of measurements made on the system (the
25+
system may emit continuous or discrete measurements), and the initial state
26+
of the system.
27+
28+
Any realization of an HMM is a sequence of $N$ integers in the range $[1, K]$,
29+
however, because of the structure of the HMM, it is not necessary to sample
30+
the latent states to do inference on the transition probabilities, the
31+
parameters of the measurement model, or the estimates of the initial state.
32+
Estimates of the distribution of states at each measurement time can be
33+
computed separately.
34+
35+
A more complete mathematical definition of the HMM model and function interface
36+
is given in the [Hidden Markov Models](https://mc-stan.org/docs/2_24/functions-reference/hidden-markov-models.html)
37+
section of the Function Reference Guide.
38+
39+
There are three functions
40+
41+
- `hmm_marginal` - The likelihood of an HMM with the latent discrete states
42+
integrated out
43+
- `hmm_latent_rng` - A function to generate random realizations of an HMM
44+
- `hmm_hidden_state_prob` - A function to compute the distribution of latent
45+
states at each measurement.
46+
47+
This guide will demonstrate how to simulate HMM realizations in R, fit the data
48+
with `hmm_marginal` and produce estimates of the latent states using
49+
`hmm_hidden_state_prob`.
50+
51+
### Generating HMM realizations
52+
53+
Simulating an HMM requires a set of states, the transition probabilities
54+
between those states, and an estimate of the initial states.
55+
56+
For illustrative purposes, assume a three state system with states 1, 2, 3.
57+
58+
The transitions happen as follows:
59+
1. In state 1 there is a 50% chance of moving to state 2 and a 50% chance of staying in state 1
60+
2. In state 2 there is a 25% chance of moving to state 1, a 25% change of moving to state 3, and a 50% chance of staying in state 2
61+
3. In state 3 there is a 50% chance of moving to state 2 and a 50% chance of staying at state 3.
62+
63+
Assume that the system starts in state 1.
64+
65+
```{r}
66+
N = 100 # 100 measurements
67+
K = 3 # 3 states
68+
states = rep(1, N)
69+
states[1] = 1 # Start in state 1
70+
for(n in 2:length(states)) {
71+
if(states[n - 1] == 1)
72+
states[n] = sample(c(1, 2), size = 1, prob = c(0.5, 0.5))
73+
else if(states[n - 1] == 2)
74+
states[n] = sample(c(1, 2, 3), size = 1, prob = c(0.25, 0.5, 0.25))
75+
else if(states[n - 1] == 3)
76+
states[n] = sample(c(2, 3), size = 1, prob = c(0.5, 0.5))
77+
}
78+
```
79+
80+
The trajectory can easily be visualized:
81+
```{r}
82+
qplot(1:N, states)
83+
```
84+
85+
An HMM is useful when the latent state is only measured indirectly.
86+
87+
In this example the observations are assumed to be
88+
normally distributed with a state specific mean and some measurement error.
89+
90+
```{r}
91+
mus = c(1.0, 5.0, 9.0)
92+
sigma = 2.0
93+
y = rnorm(N, mus[states], sd = sigma)
94+
```
95+
96+
Plotting the simulated measurements gives:
97+
98+
```{r}
99+
qplot(1:N, y)
100+
```
101+
102+
### Fitting the HMM
103+
104+
To make it clear how to use the HMM fit functions, the model here will fit the
105+
transition matrix, the initial state, and the parameters of the measurement
106+
model. It is not necessary to estimate all of these things in practice if some
107+
of them are known.
108+
109+
The data is the previously generated sequence of $N$ measurements:
110+
```{stan, output.var = "", eval = FALSE}
111+
data {
112+
int N; // Number of observations
113+
real y[N];
114+
}
115+
```
116+
117+
For the transition matrix, assume that it is known that states 1 and 3 are not
118+
directly connected. For $K$ states, estimating a full transition matrix means
119+
estimatng a matrix of $O(K^2)$ probabilities. Depending on the data available,
120+
this may not be possible and so it is important to take advantage of available
121+
modeling assumptions. The state means are estimated as an ordered vector
122+
to avoid mode-swap non-identifiabilities.
123+
124+
```{stan, output.var = "", eval = FALSE}
125+
parameters {
126+
// Rows of the transition matrix
127+
simplex[2] t1;
128+
simplex[3] t2;
129+
simplex[2] t3;
130+
131+
// Initial state
132+
simplex[3] rho;
133+
134+
// Parameters of measurement model
135+
vector[3] mu;
136+
real<lower = 0.0> sigma;
137+
}
138+
```
139+
140+
The `hmm_marginal` function takes the transition matrix and initial state
141+
directly. In this case the transition matrix needs constructed from `t1`,
142+
`t2`, and `t3` but that is relatively easy to build.
143+
144+
The measurement model, in contrast, is not passed directly to the HMM function.
145+
146+
Instead, a $KxN$ matrix `log_omega` of log likelihoods is passed in. The
147+
$(k, n)$ entry of this matrix is the log likelihood of the $nth$ measurement
148+
given the system at time $n$ is actually in state $k$. For the generative
149+
model above, these are log normals evaluated at the three different means.
150+
151+
```{stan, output.var = "", eval = FALSE}
152+
transformed parameters {
153+
matrix[3, 3] gamma = rep_matrix(0, 3, 3);
154+
matrix[3, N] log_omega;
155+
156+
// Build the transition matrix
157+
gamma[1, 1:2] = t1;
158+
gamma[2, ] = t2;
159+
gamma[3, 2:3] = t3;
160+
161+
// Compute the log likelihoods in each possible state
162+
for(n in 1:N) {
163+
// The observation model could change with n, or vary in a number of
164+
// different ways (which is why log_omega is passed in as an argument)
165+
log_omega[1, n] = normal_lpdf(y[n] | mu[1], sigma);
166+
log_omega[2, n] = normal_lpdf(y[n] | mu[2], sigma);
167+
log_omega[3, n] = normal_lpdf(y[n] | mu[3], sigma);
168+
}
169+
}
170+
```
171+
172+
With all that in place, the only thing left to do is add priors and increment
173+
the log density:
174+
```{stan, output.var = "", eval = FALSE}
175+
model {
176+
mu ~ normal(0, 1);
177+
sigma ~ normal(0, 1);
178+
179+
target += hmm_marginal(log_omega, Gamma, rho);
180+
}
181+
```
182+
183+
```{r echo = TRUE, results = FALSE, message = FALSE}
184+
model = cmdstan_model("hmm-example.stan")
185+
fit = model$sample(data = list(N = N, y = y), parallel_chains = 4)
186+
```
187+
188+
The estimated group means match the known ones:
189+
```{r}
190+
fit$summary("mu")
191+
```
192+
The estimated of the initial condition is not much more informative than
193+
the prior, but it is there:
194+
```{r}
195+
fit$summary("rho")
196+
```
197+
198+
The transition probabilities from state 1 can be extracted:
199+
```{r}
200+
fit$summary("t1")
201+
```
202+
203+
Similarly for state 2:
204+
```{r}
205+
fit$summary("t2")
206+
```
207+
208+
And state 3:
209+
```{r}
210+
fit$summary("t3")
211+
```
212+
213+
### State Probabilities
214+
215+
Even though the latent states are not sampled directly, the distribution
216+
of latent states at each time point can be computed with the function
217+
`hmm_hidden_state_prob`:
218+
219+
```{stan, output.var = "", eval = FALSE}
220+
generated quantities {
221+
matrix[3, N] latent_probs = hmm_hidden_state_prob(log_omega, Gamma, rho);
222+
}
223+
```
224+
225+
These can be plotted:
226+
227+
```{r}
228+
latent_probs_df = fit$draws() %>%
229+
as_draws_df %>%
230+
select(starts_with("latent_probs")) %>%
231+
pivot_longer(everything(),
232+
names_to = c("state", "n"),
233+
names_transform = list(k = as.integer, n = as.integer),
234+
names_pattern = "latent_probs\\[([0-9]*),([0-9]*)\\]",
235+
values_to = "latent_probs")
236+
237+
latent_probs_df %>%
238+
group_by(state, n) %>%
239+
summarize(qh = quantile(latent_probs, 0.8),
240+
m = median(latent_probs),
241+
ql = quantile(latent_probs, 0.2)) %>%
242+
ungroup() %>%
243+
ggplot() +
244+
geom_errorbar(aes(n, ymin = ql, ymax = qh, width = 0.0), alpha = 0.5) +
245+
geom_point(aes(n, m)) +
246+
facet_grid(state ~ ., labeller = "label_both") +
247+
ggtitle("Ribbon is 60% posterior interval, point is median") +
248+
ylab("Probability of being in state") +
249+
xlab("Time (n)")
250+
```
251+
252+
New simulations from the fitted HMM can be generated with `hmm_latent_rng`:
253+
254+
```{stan, output.var = "", eval = FALSE}
255+
generated quantities {
256+
int[N] y_sim = hmm_latent_rng(log_omega, Gamma, rho)
257+
}
258+
```
259+
260+
These can be plotted as well:
261+
262+
```{r}
263+
y_sim = fit$draws() %>%
264+
as_draws_df() %>%
265+
select(starts_with("y_sim")) %>%
266+
as.matrix
267+
268+
qplot(1:N, y_sim[1,])
269+
```

knitr/hmm-example/hmm-example.stan

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
data {
2+
int N; // Number of observations
3+
real y[N];
4+
}
5+
6+
parameters {
7+
// Parameters of measurement model
8+
ordered[3] mu;
9+
real<lower = 0.0> sigma;
10+
11+
// Initial state
12+
simplex[3] rho;
13+
14+
// Rows of the transition matrix
15+
simplex[2] t1;
16+
simplex[3] t2;
17+
simplex[2] t3;
18+
}
19+
20+
transformed parameters {
21+
matrix[3, 3] Gamma = rep_matrix(0, 3, 3);
22+
matrix[3, N] log_omega;
23+
24+
// Build the transition matrix
25+
Gamma[1, 1:2] = t1';
26+
Gamma[2, ] = t2';
27+
Gamma[3, 2:3] = t3';
28+
29+
// Compute the log likelihoods in each possible state
30+
for(n in 1:N) {
31+
// The observation model could change with n, or vary in a number of
32+
// different ways (which is why log_omega is passed in as an argument)
33+
log_omega[1, n] = normal_lpdf(y[n] | mu[1], sigma);
34+
log_omega[2, n] = normal_lpdf(y[n] | mu[2], sigma);
35+
log_omega[3, n] = normal_lpdf(y[n] | mu[3], sigma);
36+
}
37+
}
38+
39+
model {
40+
mu ~ normal(0, 10);
41+
sigma ~ normal(0, 1);
42+
43+
rho ~ dirichlet([10, 1, 1]);
44+
45+
t1 ~ dirichlet([1, 1]);
46+
t2 ~ dirichlet([1, 1, 1]);
47+
t3 ~ dirichlet([1, 1]);
48+
49+
target += hmm_marginal(log_omega, Gamma, rho);
50+
}
51+
52+
generated quantities {
53+
matrix[3, N] latent_probs = hmm_hidden_state_prob(log_omega, Gamma, rho);
54+
int y_sim[N] = hmm_latent_rng(log_omega, Gamma, rho);
55+
}

0 commit comments

Comments
 (0)