-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Description
Hi all,
I'm trying to compute a posterior predictive distribution over samples from a posterior distribution (Colab here). TFP 0.25 with JAX backend.
My (mre and therefore contrived) model specification is
@tfd.JointDistributionCoroutineAutoBatched
def model_autobatched():
theta = yield tfd.Normal(loc=0., scale=1., name="theta")
yield tfd.Normal(loc=theta, scale=0.1, name="y")
i.e. a Normally-distributed observation model with Normally-distributed mean. To compute the posterior predictive distribution, I wish to sample the y
component conditional on a vector of theta
samples.
theta_samples = np.arange(5.)
model_autobatched.sample(theta=theta_samples, seed=jax.random.key(0))
giving
StructTuple(
theta=Array([0., 1., 2., 3., 4.], dtype=float32),
y=Array([0.06215769, 1.0621576 , 2.0621576 , 3.0621576 , 4.0621576 ], dtype=float32)
)
Oh dear, we notice that y - theta = constant
. This seems to suggest that a single PRNG key is being used for each draw of y
given the sample from theta
.
Moreover, this approach fails entirely if sample_distributions
is called.
model_autobatched.sample_distributions(theta=theta_samples, seed=jax.random.key(0))
ValueError: Attempt to convert a value (<object object at 0x7a53561590d0>) with an unsupported type (<class 'object'>) to a Tensor.
As a workaround, we could use the older JointDistributionCoroutine
with Root
annotation which works as desired (see Colab)
[edit] actually, JDCoroutine/Root only works because the whole theta
vector is passed to y
's constructor, not vectorisation over the whole model.
Do we have a bug or a feature, I wonder?
Regards,
Chris