Commit 367d4da
[JAX] Fixes for tensorflow_probability for an upcoming change to jnp.array().
An upcoming change to jax.numpy.array() means that, under a transformation like jax.jit(), it will always stage its arrays into the trace. This often breaks if the array is being used for a shape calculation. Make sure we use static shapes in more places to fix test failures.
PiperOrigin-RevId: 3978492991 parent 4746a60 commit 367d4da
File tree
2 files changed
+4
-2
lines changed- tensorflow_probability/python/internal/backend
- meta
- numpy/gen
2 files changed
+4
-2
lines changedLines changed: 2 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
196 | 196 | | |
197 | 197 | | |
198 | 198 | | |
| 199 | + | |
| 200 | + | |
199 | 201 | | |
200 | 202 | | |
201 | 203 | | |
| |||
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
486 | 486 | | |
487 | 487 | | |
488 | 488 | | |
489 | | - | |
| 489 | + | |
490 | 490 | | |
491 | 491 | | |
492 | 492 | | |
| |||
530 | 530 | | |
531 | 531 | | |
532 | 532 | | |
533 | | - | |
| 533 | + | |
534 | 534 | | |
535 | 535 | | |
536 | 536 | | |
| |||
0 commit comments