Skip to content

Commit d021d4c

Browse files
authored
fix: jax backend tolist for tracers in logging (#2580)
* JAX can not convert a tracer during .tolist (because it has no data to put in a list). This fix checks if JAX is currently in tracing time and instead returns the tracer (or rather its abstract value for a little nicer representation). This preserves the same error message as for the NumPy backend in case there's a shape mismatch in data. * Add Peter Fackeldey to contributors list.
1 parent 10488f0 commit d021d4c

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

docs/contributors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ Contributors include:
3636
- Lorenz Gaertner
3737
- Melissa Weber Mendonça
3838
- Matthias Bussonnier
39+
- Peter Fackeldey

src/pyhf/tensor/jax_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
config.update('jax_enable_x64', True)
44

5+
from jax.core import Tracer
56
from jax import Array
67
import jax.numpy as jnp
78
from jax.scipy.special import gammaln, xlogy
@@ -14,6 +15,13 @@
1415
log = logging.getLogger(__name__)
1516

1617

18+
def _currently_jitting():
19+
"""
20+
JAX turns arrays into Tracers during jit-compilation, so check for that.
21+
"""
22+
return isinstance(jnp.array(1), Tracer)
23+
24+
1725
class _BasicPoisson:
1826
def __init__(self, rate):
1927
self.rate = rate
@@ -184,6 +192,9 @@ def conditional(self, predicate, true_callable, false_callable):
184192
return true_callable() if predicate else false_callable()
185193

186194
def tolist(self, tensor_in):
195+
if _currently_jitting():
196+
# .aval is the abstract value and has a little nicer representation
197+
return tensor_in.aval
187198
try:
188199
return jnp.asarray(tensor_in).tolist()
189200
except (TypeError, ValueError):

tests/test_backends.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,21 @@ def test_backend_array_type(backend):
8484
def test_tensor_array_types():
8585
# can't really assert the content of them so easily
8686
assert pyhf.tensor.array_types
87+
88+
89+
@pytest.mark.only_jax
90+
def test_jax_data_shape_mismatch_during_jitting(backend):
91+
"""
92+
Validate that during JAX tracing time pyhf doesn't try
93+
to convert the data to a list, which is not possible with tracers,
94+
for a shape mismatch.
95+
Instead, return the tracer itself for a proper error message.
96+
Issue: https://github.com/scikit-hep/pyhf/issues/1422
97+
PR: https://github.com/scikit-hep/pyhf/pull/2580
98+
"""
99+
model = pyhf.simplemodels.uncorrelated_background([10], [15], [5])
100+
with pytest.raises(
101+
pyhf.exceptions.InvalidPdfData,
102+
match="eval failed as data has len 1 but 2 was expected",
103+
):
104+
pyhf.infer.mle.fit([12.5], model)

0 commit comments

Comments
 (0)