Skip to content

Commit 8d3bd07

Browse files
feat: Update JAX backend array_type and array_subtype to jax.Array (#2079)
* Update JAX backend array_type and array_subtype to jax.Array, which is a unified array type introduced in jax v0.4.1 that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. - c.f. https://github.com/google/jax/releases/tag/jax-v0.4.1 * Update the lower bound of jax and jaxlib to v0.4.1 to ensure support for jax.Array. * Update lower bound of the supported scipy versions to v1.5.0 as required by jax v0.4.1. * Update the JAX backend docstring examples to use Array as the array type.
1 parent d9a5da2 commit 8d3bd07

File tree

4 files changed

+38
-38
lines changed

4 files changed

+38
-38
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ package_dir =
3434
include_package_data = True
3535
python_requires = >=3.8
3636
install_requires =
37-
scipy>=1.3.2 # requires numpy, which is required by pyhf and tensorflow
37+
scipy>=1.5.0 # requires numpy, which is required by pyhf and tensorflow
3838
click>=8.0.0 # for console scripts
3939
tqdm>=4.56.0 # for readxml
4040
jsonschema>=4.15.0 # for utils

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
'tensorflow-probability>=0.11.0', # c.f. PR #1657
88
],
99
'torch': ['torch>=1.10.0'], # c.f. PR #1657
10-
'jax': ['jax>=0.2.10', 'jaxlib>=0.1.61,!=0.1.68'], # c.f. PR #1962, Issue #1501
10+
'jax': ['jax>=0.4.1', 'jaxlib>=0.4.1'], # c.f. PR #2079
1111
'xmlio': ['uproot>=4.1.1'], # c.f. PR #1567
1212
'minuit': ['iminuit>=2.7.0'], # c.f. PR #1895
1313
}

src/pyhf/tensor/jax_backend.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
config.update('jax_enable_x64', True)
44

5+
from jax import Array
56
import jax.numpy as jnp
67
from jax.scipy.special import gammaln, xlogy
78
from jax.scipy import special
@@ -54,10 +55,10 @@ class jax_backend:
5455
__slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']
5556

5657
#: The array type for jax
57-
array_type = jnp.DeviceArray
58+
array_type = Array
5859

5960
#: The array content type for jax
60-
array_subtype = jnp.DeviceArray
61+
array_subtype = Array
6162

6263
def __init__(self, **kwargs):
6364
self.name = 'jax'
@@ -84,7 +85,7 @@ def clip(self, tensor_in, min_value, max_value):
8485
>>> pyhf.set_backend("jax")
8586
>>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
8687
>>> pyhf.tensorlib.clip(a, -1, 1)
87-
DeviceArray([-1., -1., 0., 1., 1.], dtype=float64)
88+
Array([-1., -1., 0., 1., 1.], dtype=float64)
8889
8990
Args:
9091
tensor_in (:obj:`tensor`): The input tensor object
@@ -106,8 +107,7 @@ def erf(self, tensor_in):
106107
>>> pyhf.set_backend("jax")
107108
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
108109
>>> pyhf.tensorlib.erf(a)
109-
DeviceArray([-0.99532227, -0.84270079, 0. , 0.84270079,
110-
0.99532227], dtype=float64)
110+
Array([-0.99532227, -0.84270079, 0. , 0.84270079, 0.99532227], dtype=float64)
111111
112112
Args:
113113
tensor_in (:obj:`tensor`): The input tensor object
@@ -127,7 +127,7 @@ def erfinv(self, tensor_in):
127127
>>> pyhf.set_backend("jax")
128128
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
129129
>>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a))
130-
DeviceArray([-2., -1., 0., 1., 2.], dtype=float64)
130+
Array([-2., -1., 0., 1., 2.], dtype=float64)
131131
132132
Args:
133133
tensor_in (:obj:`tensor`): The input tensor object
@@ -147,8 +147,8 @@ def tile(self, tensor_in, repeats):
147147
>>> pyhf.set_backend("jax")
148148
>>> a = pyhf.tensorlib.astensor([[1.0], [2.0]])
149149
>>> pyhf.tensorlib.tile(a, (1, 2))
150-
DeviceArray([[1., 1.],
151-
[2., 2.]], dtype=float64)
150+
Array([[1., 1.],
151+
[2., 2.]], dtype=float64)
152152
153153
Args:
154154
tensor_in (:obj:`tensor`): The tensor to be repeated
@@ -171,7 +171,7 @@ def conditional(self, predicate, true_callable, false_callable):
171171
>>> a = tensorlib.astensor([4])
172172
>>> b = tensorlib.astensor([5])
173173
>>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b)
174-
DeviceArray([9.], dtype=float64)
174+
Array([9.], dtype=float64)
175175
176176
Args:
177177
predicate (:obj:`scalar`): The logical condition that determines which callable to evaluate
@@ -213,16 +213,16 @@ def astensor(self, tensor_in, dtype="float"):
213213
>>> pyhf.set_backend("jax")
214214
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
215215
>>> tensor
216-
DeviceArray([[1., 2., 3.],
217-
[4., 5., 6.]], dtype=float64)
216+
Array([[1., 2., 3.],
217+
[4., 5., 6.]], dtype=float64)
218218
>>> type(tensor) # doctest:+ELLIPSIS
219-
<class '...DeviceArray'>
219+
<class '...Array'>
220220
221221
Args:
222222
tensor_in (Number or Tensor): Tensor object
223223
224224
Returns:
225-
`jaxlib.xla_extension.DeviceArray`: A multi-dimensional, fixed-size homogeneous array.
225+
`jaxlib.xla_extension.Array`: A multi-dimensional, fixed-size homogeneous array.
226226
"""
227227
# TODO: Remove doctest:+ELLIPSIS when JAX API stabilized
228228
try:
@@ -294,9 +294,9 @@ def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
294294
>>> pyhf.set_backend("jax")
295295
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
296296
>>> pyhf.tensorlib.percentile(a, 50)
297-
DeviceArray(3.5, dtype=float64)
297+
Array(3.5, dtype=float64)
298298
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
299-
DeviceArray([7., 2.], dtype=float64)
299+
Array([7., 2.], dtype=float64)
300300
301301
Args:
302302
tensor_in (`tensor`): The tensor containing the data
@@ -355,7 +355,7 @@ def simple_broadcast(self, *args):
355355
... pyhf.tensorlib.astensor([1]),
356356
... pyhf.tensorlib.astensor([2, 3, 4]),
357357
... pyhf.tensorlib.astensor([5, 6, 7]))
358-
[DeviceArray([1., 1., 1.], dtype=float64), DeviceArray([2., 3., 4.], dtype=float64), DeviceArray([5., 6., 7.], dtype=float64)]
358+
[Array([1., 1., 1.], dtype=float64), Array([2., 3., 4.], dtype=float64), Array([5., 6., 7.], dtype=float64)]
359359
360360
Args:
361361
args (Array of Tensors): Sequence of arrays
@@ -381,13 +381,13 @@ def ravel(self, tensor):
381381
>>> pyhf.set_backend("jax")
382382
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
383383
>>> pyhf.tensorlib.ravel(tensor)
384-
DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64)
384+
Array([1., 2., 3., 4., 5., 6.], dtype=float64)
385385
386386
Args:
387387
tensor (Tensor): Tensor object
388388
389389
Returns:
390-
`jaxlib.xla_extension.DeviceArray`: A flattened array.
390+
`jaxlib.xla_extension.Array`: A flattened array.
391391
"""
392392
return jnp.ravel(tensor)
393393

@@ -441,11 +441,11 @@ def poisson(self, n, lam):
441441
>>> import pyhf
442442
>>> pyhf.set_backend("jax")
443443
>>> pyhf.tensorlib.poisson(5., 6.)
444-
DeviceArray(0.16062314, dtype=float64, weak_type=True)
444+
Array(0.16062314, dtype=float64, weak_type=True)
445445
>>> values = pyhf.tensorlib.astensor([5., 9.])
446446
>>> rates = pyhf.tensorlib.astensor([6., 8.])
447447
>>> pyhf.tensorlib.poisson(values, rates)
448-
DeviceArray([0.16062314, 0.12407692], dtype=float64)
448+
Array([0.16062314, 0.12407692], dtype=float64)
449449
450450
Args:
451451
n (:obj:`tensor` or :obj:`float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f.
@@ -484,12 +484,12 @@ def normal(self, x, mu, sigma):
484484
>>> import pyhf
485485
>>> pyhf.set_backend("jax")
486486
>>> pyhf.tensorlib.normal(0.5, 0., 1.)
487-
DeviceArray(0.35206533, dtype=float64, weak_type=True)
487+
Array(0.35206533, dtype=float64, weak_type=True)
488488
>>> values = pyhf.tensorlib.astensor([0.5, 2.0])
489489
>>> means = pyhf.tensorlib.astensor([0., 2.3])
490490
>>> sigmas = pyhf.tensorlib.astensor([1., 0.8])
491491
>>> pyhf.tensorlib.normal(values, means, sigmas)
492-
DeviceArray([0.35206533, 0.46481887], dtype=float64)
492+
Array([0.35206533, 0.46481887], dtype=float64)
493493
494494
Args:
495495
x (:obj:`tensor` or :obj:`float`): The value at which to evaluate the Normal distribution p.d.f.
@@ -510,10 +510,10 @@ def normal_cdf(self, x, mu=0, sigma=1):
510510
>>> import pyhf
511511
>>> pyhf.set_backend("jax")
512512
>>> pyhf.tensorlib.normal_cdf(0.8)
513-
DeviceArray(0.7881446, dtype=float64)
513+
Array(0.7881446, dtype=float64)
514514
>>> values = pyhf.tensorlib.astensor([0.8, 2.0])
515515
>>> pyhf.tensorlib.normal_cdf(values)
516-
DeviceArray([0.7881446 , 0.97724987], dtype=float64)
516+
Array([0.7881446 , 0.97724987], dtype=float64)
517517
518518
Args:
519519
x (:obj:`tensor` or :obj:`float`): The observed value of the random variable to evaluate the CDF for
@@ -536,7 +536,7 @@ def poisson_dist(self, rate):
536536
>>> values = pyhf.tensorlib.astensor([4, 9])
537537
>>> poissons = pyhf.tensorlib.poisson_dist(rates)
538538
>>> poissons.log_prob(values)
539-
DeviceArray([-1.74030218, -2.0868536 ], dtype=float64)
539+
Array([-1.74030218, -2.0868536 ], dtype=float64)
540540
541541
Args:
542542
rate (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution (the expected number of events)
@@ -558,7 +558,7 @@ def normal_dist(self, mu, sigma):
558558
>>> values = pyhf.tensorlib.astensor([4, 9])
559559
>>> normals = pyhf.tensorlib.normal_dist(means, stds)
560560
>>> normals.log_prob(values)
561-
DeviceArray([-1.41893853, -2.22579135], dtype=float64)
561+
Array([-1.41893853, -2.22579135], dtype=float64)
562562
563563
Args:
564564
mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
@@ -579,8 +579,8 @@ def to_numpy(self, tensor_in):
579579
>>> pyhf.set_backend("jax")
580580
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
581581
>>> tensor
582-
DeviceArray([[1., 2., 3.],
583-
[4., 5., 6.]], dtype=float64)
582+
Array([[1., 2., 3.],
583+
[4., 5., 6.]], dtype=float64)
584584
>>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor)
585585
>>> numpy_ndarray
586586
array([[1., 2., 3.],
@@ -606,12 +606,12 @@ def transpose(self, tensor_in):
606606
>>> pyhf.set_backend("jax")
607607
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
608608
>>> tensor
609-
DeviceArray([[1., 2., 3.],
610-
[4., 5., 6.]], dtype=float64)
609+
Array([[1., 2., 3.],
610+
[4., 5., 6.]], dtype=float64)
611611
>>> pyhf.tensorlib.transpose(tensor)
612-
DeviceArray([[1., 4.],
613-
[2., 5.],
614-
[3., 6.]], dtype=float64)
612+
Array([[1., 4.],
613+
[2., 5.],
614+
[3., 6.]], dtype=float64)
615615
616616
Args:
617617
tensor_in (:obj:`tensor`): The input tensor object.

tests/constraints.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# core
2-
scipy==1.3.2 # c.f. PR #2044
2+
scipy==1.5.0 # c.f. PR #2079
33
click==8.0.0 # c.f. PR #1958, #1909
44
tqdm==4.56.0
55
jsonschema==4.15.0 # c.f. PR #1979
@@ -19,5 +19,5 @@ torch==1.10.0
1919
# Use Google Cloud Storage buckets for long term wheel support
2020
# c.f. https://github.com/google/jax/discussions/7608#discussioncomment-1269342
2121
--find-links https://storage.googleapis.com/jax-releases/jax_releases.html
22-
jax==0.2.10
23-
jaxlib==0.1.61 # c.f. PR #1962
22+
jax==0.4.1 # c.f. PR #2079
23+
jaxlib==0.4.1 # c.f. PR #2079

0 commit comments

Comments
 (0)