22
33config .update ('jax_enable_x64' , True )
44
5+ from jax import Array
56import jax .numpy as jnp
67from jax .scipy .special import gammaln , xlogy
78from jax .scipy import special
@@ -54,10 +55,10 @@ class jax_backend:
5455 __slots__ = ['name' , 'precision' , 'dtypemap' , 'default_do_grad' ]
5556
5657 #: The array type for jax
57- array_type = jnp . DeviceArray
58+ array_type = Array
5859
5960 #: The array content type for jax
60- array_subtype = jnp . DeviceArray
61+ array_subtype = Array
6162
6263 def __init__ (self , ** kwargs ):
6364 self .name = 'jax'
@@ -84,7 +85,7 @@ def clip(self, tensor_in, min_value, max_value):
8485 >>> pyhf.set_backend("jax")
8586 >>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
8687 >>> pyhf.tensorlib.clip(a, -1, 1)
87- DeviceArray ([-1., -1., 0., 1., 1.], dtype=float64)
88+ Array ([-1., -1., 0., 1., 1.], dtype=float64)
8889
8990 Args:
9091 tensor_in (:obj:`tensor`): The input tensor object
@@ -106,8 +107,7 @@ def erf(self, tensor_in):
106107 >>> pyhf.set_backend("jax")
107108 >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
108109 >>> pyhf.tensorlib.erf(a)
109- DeviceArray([-0.99532227, -0.84270079, 0. , 0.84270079,
110- 0.99532227], dtype=float64)
110+ Array([-0.99532227, -0.84270079, 0. , 0.84270079, 0.99532227], dtype=float64)
111111
112112 Args:
113113 tensor_in (:obj:`tensor`): The input tensor object
@@ -127,7 +127,7 @@ def erfinv(self, tensor_in):
127127 >>> pyhf.set_backend("jax")
128128 >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
129129 >>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a))
130- DeviceArray ([-2., -1., 0., 1., 2.], dtype=float64)
130+ Array ([-2., -1., 0., 1., 2.], dtype=float64)
131131
132132 Args:
133133 tensor_in (:obj:`tensor`): The input tensor object
@@ -147,8 +147,8 @@ def tile(self, tensor_in, repeats):
147147 >>> pyhf.set_backend("jax")
148148 >>> a = pyhf.tensorlib.astensor([[1.0], [2.0]])
149149 >>> pyhf.tensorlib.tile(a, (1, 2))
150- DeviceArray ([[1., 1.],
151- [2., 2.]], dtype=float64)
150+ Array ([[1., 1.],
151+ [2., 2.]], dtype=float64)
152152
153153 Args:
154154 tensor_in (:obj:`tensor`): The tensor to be repeated
@@ -171,7 +171,7 @@ def conditional(self, predicate, true_callable, false_callable):
171171 >>> a = tensorlib.astensor([4])
172172 >>> b = tensorlib.astensor([5])
173173 >>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b)
174- DeviceArray ([9.], dtype=float64)
174+ Array ([9.], dtype=float64)
175175
176176 Args:
177177 predicate (:obj:`scalar`): The logical condition that determines which callable to evaluate
@@ -213,16 +213,16 @@ def astensor(self, tensor_in, dtype="float"):
213213 >>> pyhf.set_backend("jax")
214214 >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
215215 >>> tensor
216- DeviceArray ([[1., 2., 3.],
217- [4., 5., 6.]], dtype=float64)
216+ Array ([[1., 2., 3.],
217+ [4., 5., 6.]], dtype=float64)
218218 >>> type(tensor) # doctest:+ELLIPSIS
219- <class '...DeviceArray '>
219+ <class '...Array '>
220220
221221 Args:
222222 tensor_in (Number or Tensor): Tensor object
223223
224224 Returns:
225- `jaxlib.xla_extension.DeviceArray `: A multi-dimensional, fixed-size homogeneous array.
225+ `jaxlib.xla_extension.Array `: A multi-dimensional, fixed-size homogeneous array.
226226 """
227227 # TODO: Remove doctest:+ELLIPSIS when JAX API stabilized
228228 try :
@@ -294,9 +294,9 @@ def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
294294 >>> pyhf.set_backend("jax")
295295 >>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
296296 >>> pyhf.tensorlib.percentile(a, 50)
297- DeviceArray (3.5, dtype=float64)
297+ Array (3.5, dtype=float64)
298298 >>> pyhf.tensorlib.percentile(a, 50, axis=1)
299- DeviceArray ([7., 2.], dtype=float64)
299+ Array ([7., 2.], dtype=float64)
300300
301301 Args:
302302 tensor_in (`tensor`): The tensor containing the data
@@ -355,7 +355,7 @@ def simple_broadcast(self, *args):
355355 ... pyhf.tensorlib.astensor([1]),
356356 ... pyhf.tensorlib.astensor([2, 3, 4]),
357357 ... pyhf.tensorlib.astensor([5, 6, 7]))
358- [DeviceArray ([1., 1., 1.], dtype=float64), DeviceArray ([2., 3., 4.], dtype=float64), DeviceArray ([5., 6., 7.], dtype=float64)]
358+ [Array ([1., 1., 1.], dtype=float64), Array ([2., 3., 4.], dtype=float64), Array ([5., 6., 7.], dtype=float64)]
359359
360360 Args:
361361 args (Array of Tensors): Sequence of arrays
@@ -381,13 +381,13 @@ def ravel(self, tensor):
381381 >>> pyhf.set_backend("jax")
382382 >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
383383 >>> pyhf.tensorlib.ravel(tensor)
384- DeviceArray ([1., 2., 3., 4., 5., 6.], dtype=float64)
384+ Array ([1., 2., 3., 4., 5., 6.], dtype=float64)
385385
386386 Args:
387387 tensor (Tensor): Tensor object
388388
389389 Returns:
390- `jaxlib.xla_extension.DeviceArray `: A flattened array.
390+ `jaxlib.xla_extension.Array `: A flattened array.
391391 """
392392 return jnp .ravel (tensor )
393393
@@ -441,11 +441,11 @@ def poisson(self, n, lam):
441441 >>> import pyhf
442442 >>> pyhf.set_backend("jax")
443443 >>> pyhf.tensorlib.poisson(5., 6.)
444- DeviceArray (0.16062314, dtype=float64, weak_type=True)
444+ Array (0.16062314, dtype=float64, weak_type=True)
445445 >>> values = pyhf.tensorlib.astensor([5., 9.])
446446 >>> rates = pyhf.tensorlib.astensor([6., 8.])
447447 >>> pyhf.tensorlib.poisson(values, rates)
448- DeviceArray ([0.16062314, 0.12407692], dtype=float64)
448+ Array ([0.16062314, 0.12407692], dtype=float64)
449449
450450 Args:
451451 n (:obj:`tensor` or :obj:`float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f.
@@ -484,12 +484,12 @@ def normal(self, x, mu, sigma):
484484 >>> import pyhf
485485 >>> pyhf.set_backend("jax")
486486 >>> pyhf.tensorlib.normal(0.5, 0., 1.)
487- DeviceArray (0.35206533, dtype=float64, weak_type=True)
487+ Array (0.35206533, dtype=float64, weak_type=True)
488488 >>> values = pyhf.tensorlib.astensor([0.5, 2.0])
489489 >>> means = pyhf.tensorlib.astensor([0., 2.3])
490490 >>> sigmas = pyhf.tensorlib.astensor([1., 0.8])
491491 >>> pyhf.tensorlib.normal(values, means, sigmas)
492- DeviceArray ([0.35206533, 0.46481887], dtype=float64)
492+ Array ([0.35206533, 0.46481887], dtype=float64)
493493
494494 Args:
495495 x (:obj:`tensor` or :obj:`float`): The value at which to evaluate the Normal distribution p.d.f.
@@ -510,10 +510,10 @@ def normal_cdf(self, x, mu=0, sigma=1):
510510 >>> import pyhf
511511 >>> pyhf.set_backend("jax")
512512 >>> pyhf.tensorlib.normal_cdf(0.8)
513- DeviceArray (0.7881446, dtype=float64)
513+ Array (0.7881446, dtype=float64)
514514 >>> values = pyhf.tensorlib.astensor([0.8, 2.0])
515515 >>> pyhf.tensorlib.normal_cdf(values)
516- DeviceArray ([0.7881446 , 0.97724987], dtype=float64)
516+ Array ([0.7881446 , 0.97724987], dtype=float64)
517517
518518 Args:
519519 x (:obj:`tensor` or :obj:`float`): The observed value of the random variable to evaluate the CDF for
@@ -536,7 +536,7 @@ def poisson_dist(self, rate):
536536 >>> values = pyhf.tensorlib.astensor([4, 9])
537537 >>> poissons = pyhf.tensorlib.poisson_dist(rates)
538538 >>> poissons.log_prob(values)
539- DeviceArray ([-1.74030218, -2.0868536 ], dtype=float64)
539+ Array ([-1.74030218, -2.0868536 ], dtype=float64)
540540
541541 Args:
542542 rate (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution (the expected number of events)
@@ -558,7 +558,7 @@ def normal_dist(self, mu, sigma):
558558 >>> values = pyhf.tensorlib.astensor([4, 9])
559559 >>> normals = pyhf.tensorlib.normal_dist(means, stds)
560560 >>> normals.log_prob(values)
561- DeviceArray ([-1.41893853, -2.22579135], dtype=float64)
561+ Array ([-1.41893853, -2.22579135], dtype=float64)
562562
563563 Args:
564564 mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
@@ -579,8 +579,8 @@ def to_numpy(self, tensor_in):
579579 >>> pyhf.set_backend("jax")
580580 >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
581581 >>> tensor
582- DeviceArray ([[1., 2., 3.],
583- [4., 5., 6.]], dtype=float64)
582+ Array ([[1., 2., 3.],
583+ [4., 5., 6.]], dtype=float64)
584584 >>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor)
585585 >>> numpy_ndarray
586586 array([[1., 2., 3.],
@@ -606,12 +606,12 @@ def transpose(self, tensor_in):
606606 >>> pyhf.set_backend("jax")
607607 >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
608608 >>> tensor
609- DeviceArray ([[1., 2., 3.],
610- [4., 5., 6.]], dtype=float64)
609+ Array ([[1., 2., 3.],
610+ [4., 5., 6.]], dtype=float64)
611611 >>> pyhf.tensorlib.transpose(tensor)
612- DeviceArray ([[1., 4.],
613- [2., 5.],
614- [3., 6.]], dtype=float64)
612+ Array ([[1., 4.],
613+ [2., 5.],
614+ [3., 6.]], dtype=float64)
615615
616616 Args:
617617 tensor_in (:obj:`tensor`): The input tensor object.
0 commit comments