Skip to content

Commit 5808349

Browse files
fix: Update JAX DeviceArray type in docstrings for doctest (#1206)
* Update JAX DeviceArray type in docstrings to jax.interpreters.xla._DeviceArray - Needed to pass doctest
1 parent 62466c3 commit 5808349

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/pyhf/tensor/jax_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,13 @@ def astensor(self, tensor_in, dtype='float'):
210210
DeviceArray([[1., 2., 3.],
211211
[4., 5., 6.]], dtype=float64)
212212
>>> type(tensor)
213-
<class 'jax.interpreters.xla.DeviceArray'>
213+
<class 'jax.interpreters.xla._DeviceArray'>
214214
215215
Args:
216216
tensor_in (Number or Tensor): Tensor object
217217
218218
Returns:
219-
`jax.interpreters.xla.DeviceArray`: A multi-dimensional, fixed-size homogenous array.
219+
`jax.interpreters.xla._DeviceArray`: A multi-dimensional, fixed-size homogenous array.
220220
"""
221221
try:
222222
dtype = self.dtypemap[dtype]
@@ -320,7 +320,7 @@ def ravel(self, tensor):
320320
tensor (Tensor): Tensor object
321321
322322
Returns:
323-
`jax.interpreters.xla.DeviceArray`: A flattened array.
323+
`jax.interpreters.xla._DeviceArray`: A flattened array.
324324
"""
325325
return jnp.ravel(tensor)
326326

0 commit comments

Comments
 (0)