You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Don't require Root in JDC when sampling with a trivial sample shape.
Pros: In JAX world we often wrap large chunks of computation inside a vmap, which can encompass the log_prob/sample calls of a JDC. In this setting a JDC will never see a non-trivial sample shape. Root was introduced to handle non-trivial sample shapes, so in this setting, this annotation is superfluous. Thus, this change removes the requirement for Root when only trivial sample shapes are considered, improving the UX of JDCs in JAX.
Cons: The check is deferred to non-trivial sample shapes, meaning that a TF user might construct a malformed distribution and not receive an error until they try to sample with a non-trivial sample shape. Prior to this change, they would get an error as soon as `_get_single_sample_distributions` was called, which happens with most property accesses.
Backwards compatibility: Existing well-formed JDCs will continue to work, and will continue being checked for correctness modulo the 'con' above.
PiperOrigin-RevId: 385498069
0 commit comments