Skip to content

Commit 4ab4f6c

Browse files
authored
Merge pull request #131 from hmd101/likelihood-api
### Likelihood refactor: - subclasses of likelihood now only need to implement predict, because log_likelihood vmaps over predict - added a simulate method that creates synthetic data based on a task likelihood
2 parents d57e40c + ee9b8c4 commit 4ab4f6c

10 files changed

Lines changed: 253 additions & 479 deletions

docs/examples/wppm/full_wppm_fit_example.md

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ For OddityTask, we store trials as (ref, comparison) even though the task involv
6969

7070
Note on data used in this script: here, we simulate data (and hence have a ground truth to compare agains). To how to conveniently simulate data yourself, checkout the [script](https://github.com/flatironinstitute/psyphy/blob/main/docs/examples/wppm/full_wppm_fit_example.py)
7171

72-
<details>
73-
<summary>### 2 ways of representing data in `psyphy` (important)</summary>
72+
### 2 ways of representing data in `psyphy` (important)
7473
`psyphy` provides two lightweight containers for trial data (defined in [`src/psyphy/data/dataset.py`](https://github.com/flatironinstitute/psyphy/blob/main/src/psyphy/data/dataset.py)):
7574

7675
**`TrialData` (compute-first; used for fitting):**
@@ -96,7 +95,7 @@ Avoid repeatedly converting Python lists -> JAX arrays inside tight loops.
9695

9796

9897
Note that here, we simlulate data, for details check out [`full_wppm_fit_example.py`](full_wppm_fit_example.py) directly resulting in a `TrialData` object.
99-
</details>
98+
10099

101100
---
102101

@@ -109,8 +108,7 @@ The WPPM parameters are basis weights stored as a dict:
109108
where `W` is a tensor of Chebyshev-basis coefficients.
110109

111110

112-
<details>
113-
<summary>### Prior distribution over weights:</summary>
111+
### Prior distribution over weights:
114112

115113
`Prior.sample_params(key)` samples weights `W` from a **zero-mean Gaussian** with a *degree-dependent variance*.
116114

@@ -146,7 +144,7 @@ W_{ijde} \sim \mathcal{N}(0, \sigma^2_{ij}).
146144
\]
147145

148146
This is the state of the WPPM: **before any data**, WPPM draws smooth random fields because high-frequency coefficients are shrunk by the decay.
149-
</details>
147+
150148

151149
---
152150

@@ -240,43 +238,14 @@ To see how we generate the covariance field figures, checkout the plotting code
240238
---
241239

242240
## To recap: Minimal recipe (copy/paste mental model)
241+
We are Bayesian, so we need to define the Prior and Likelihood and choose an inference method (here MAP) that will hand us the posterior distribution over the parameters.
243242

244-
To use WPPM on your own data, these are the essential calls:
245-
246-
**1. Create** task + noise + prior:
247-
248-
- `task = OddityTask()`
249-
250-
- `noise = GaussianNoise(sigma=...)`
251-
252-
- `prior = Prior(input_dim=..., basis_degree=..., extra_embedding_dims=..., decay_rate=..., variance_scale=...)`
253-
254-
**2. Create** WPPM:
255-
256-
- `model = WPPM(input_dim=..., prior=prior, task=task, noise=noise, diag_term=...)`
257-
258-
**3. Initialize** parameters:
259-
260-
- `params0 = model.init_params(jax.random.PRNGKey(...))` (draws from `Prior.sample_params`)
261-
262-
**4. Load/build** a dataset:
263-
264-
- `data = TrialData(refs=..., comparisons=..., responses=...)`
265-
266-
**5. Fit**:
267-
268-
- `map = MAPOptimizer(...).fit(model, data, init_params=params0, ...)`
269-
270-
**6. Inspect** $\Sigma(x)$:
271-
272-
- `field = WPPMCovarianceField(model, map.params)`
273-
- `Sigmas = field(xs)`
274-
243+
For an even more minimal code setup that doesn't require a GPU but will run on your CPU in < 1 min, you may find [`quickstart`](https://flatironinstitute.github.io/psyphy/examples/wppm/quick_start/) helpful.
275244
---
276245

277246
## Notes and pitfalls
278247

279-
- **CPU vs GPU:** this example can be heavy because the oddity likelihood uses Monte Carlo. A GPU can help a lot.
248+
- **CPU vs GPU:** this example can be heavy because the oddity likelihood uses Monte Carlo. A GPU can help a lot, see [`quickstart`](https://flatironinstitute.github.io/psyphy/examples/wppm/quick_start/) for a CPU friendly version.
280249
- **Positive definiteness:** `diag_term` is important. If you ever see a non-PD covariance, increase `diag_term` slightly.
281250
- **MC variance:** optimization stability depends on `MC_SAMPLES`. Too small means noisy gradients.
282251

@@ -308,7 +277,7 @@ instead of using relative filesystem paths.
308277
- MAP fitting: [`src/psyphy/inference/map_optimizer.py`](https://github.com/flatironinstitute/psyphy/blob/main/src/psyphy/inference/map_optimizer.py) (see `MAPOptimizer`)
309278
- Data container: [`src/psyphy/data/dataset.py`](https://github.com/flatironinstitute/psyphy/blob/main/src/psyphy/data/dataset.py) (see `ResponseData`)
310279

311-
If you want to follow the call graph:
280+
If you want to "follow the call graph":
312281

313282
1. `WPPM.init_params(...)` (defined in [`src/psyphy/model/wppm.py`](https://github.com/flatironinstitute/psyphy/blob/main/src/psyphy/model/wppm.py)) → delegates to the prior’s `Prior.sample_params(...)` (defined in [`src/psyphy/model/prior.py`](https://github.com/flatironinstitute/psyphy/blob/main/src/psyphy/model/prior.py)).
314283
2. `OddityTask.predict_with_kwargs(...)` / `OddityTask.loglik(...)` (defined in [`src/psyphy/model/likelihood.py`](https://github.com/flatironinstitute/psyphy/blob/main/src/psyphy/model/likelihood.py)) → calls into the model to get $\Sigma(x)$ and then runs the task’s decision rule (Monte Carlo in the full model).

docs/examples/wppm/full_wppm_fit_example.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
209209
#
210210

211211

212-
# --8<-- [start:simulate_data]
212+
##### Simulate data
213213
num_trials_per_ref = NUM_TRIALS_Per_Ref # (trials per reference point)
214214
n_ref_grid = 5 # NUM_GRID_PTS
215215
ref_grid = jnp.linspace(-1, 1, n_ref_grid) # [-1,1] space
@@ -234,7 +234,7 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
234234
Sigmas_ref = truth_field(refs) # (N, 2, 2)
235235

236236
# Sample unit directions on the circle.
237-
k_dir, k_pred, k_y = jr.split(key, 3)
237+
k_dir, k_sim = jr.split(key)
238238
angles = jr.uniform(k_dir, shape=(num_trials_total,), minval=0.0, maxval=2.0 * jnp.pi)
239239
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1) # (N, 2)
240240

@@ -250,30 +250,9 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
250250
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs) # (N, 2)
251251
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)
252252

253-
# Compute p(correct) in batch. We vmap the single-trial predictor.
254-
trial_pred_keys = jr.split(k_pred, num_trials_total)
255-
256-
257-
# we use task as the generative model to create observations (user responses)
258-
def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
259-
# Task MC settings (num_samples/bandwidth) come from OddityTaskConfig.
260-
# Only the randomness is threaded dynamically.
261-
return task._simulate_trial_mc(
262-
params=truth_params,
263-
ref=ref,
264-
comparison=comp,
265-
model=truth_model,
266-
noise=truth_model.noise,
267-
num_samples=int(task.config.num_samples),
268-
bandwidth=float(task.config.bandwidth),
269-
key=kk,
270-
)
271-
272-
273-
p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)
274-
275-
# Sample observed y ~ Bernoulli(p_correct) in batch.
276-
ys = jr.bernoulli(k_y, p_correct, shape=(num_trials_total,)).astype(jnp.int32)
253+
# --8<-- [start:simulate_data]
254+
# Simulate observed responses using the likelihood implied by the task.
255+
ys, p_correct = task.simulate(truth_params, refs, comparisons, truth_model, key=k_sim)
277256

278257
# Build the canonical batched dataset for compute.
279258
#
4.02 KB
Loading
7.88 KB
Loading

docs/examples/wppm/quick_start.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
134134
# Step 2 — Simulate data at a *single* reference point
135135
# ---------------------------------------------------------------------------
136136

137-
# --8<-- [start:simulate_data]
137+
138138
# Single reference point at the centre of the stimulus space.
139139
ref_point = jnp.array([[0.0, 0.0]]) # shape (1, 2) — kept as a batch for generality
140140

@@ -149,7 +149,7 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
149149
Sigmas_ref = truth_field(refs) # (NUM_TRIALS, 2, 2)
150150

151151
# Sample unit directions and build covariance-scaled probe displacements.
152-
k_dir, k_pred, k_y = jr.split(key, 3)
152+
k_dir, k_sim = jr.split(key)
153153
angles = jr.uniform(k_dir, shape=(NUM_TRIALS,), minval=0.0, maxval=2.0 * jnp.pi)
154154
unit_dirs = jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1) # (N, 2)
155155

@@ -160,27 +160,9 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
160160
deltas = MAHAL_RADIUS * jnp.einsum("nij,nj->ni", L, unit_dirs) # (N, 2)
161161
comparisons = jnp.clip(refs + deltas, -1.0, 1.0)
162162

163-
# Compute p(correct) via MC simulation of the oddity task.
164-
trial_pred_keys = jr.split(k_pred, NUM_TRIALS)
165-
166-
167-
def _p_correct_one(ref: jnp.ndarray, comp: jnp.ndarray, kk: jnp.ndarray) -> jnp.ndarray:
168-
return task._simulate_trial_mc(
169-
params=truth_params,
170-
ref=ref,
171-
comparison=comp,
172-
model=truth_model,
173-
noise=truth_model.noise,
174-
num_samples=int(task.config.num_samples),
175-
bandwidth=float(task.config.bandwidth),
176-
key=kk,
177-
)
178-
179-
180-
p_correct = jax.vmap(_p_correct_one)(refs, comparisons, trial_pred_keys)
181-
182-
# Sample observed responses y ~ Bernoulli(p_correct).
183-
ys = jr.bernoulli(k_y, p_correct, shape=(NUM_TRIALS,)).astype(jnp.int32)
163+
# --8<-- [start:simulate_data]
164+
# Simulate observed responses using the likelihood implied by the task
165+
ys, p_correct = task.simulate(truth_params, refs, comparisons, truth_model, key=k_sim)
184166
# --8<-- [end:simulate_data]
185167

186168
# --8<-- [start:data]

0 commit comments

Comments
 (0)