Skip to content

Commit 481d5c9

Browse files
authored
refactor: add jax deprication warning (#3609)
1 parent 2ca5ccd commit 481d5c9

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/awkward/_backends/jax.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import warnings
6+
57
import awkward_cpp
68

79
import awkward as ak
@@ -28,6 +30,12 @@ def nplike(self) -> Jax:
2830
return self._jax
2931

3032
def __init__(self):
33+
warnings.warn(
34+
"The JAX backend is deprecated and will be removed in a future release of Awkward Array. "
35+
"Please plan to migrate your code accordingly.",
36+
DeprecationWarning,
37+
stacklevel=2,
38+
)
3139
self._jax = Jax.instance()
3240

3341
def __getitem__(self, index: KernelKeyType) -> JaxKernel:

src/awkward/_nplikes/jax.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import warnings
6+
57
import awkward as ak
68
from awkward._nplikes.array_like import ArrayLike
79
from awkward._nplikes.array_module import ArrayModuleNumpyLike
@@ -19,6 +21,12 @@ class Jax(ArrayModuleNumpyLike):
1921
supports_virtual_arrays: Final = True
2022

2123
def __init__(self):
24+
warnings.warn(
25+
"The JAX backend is deprecated and will be removed in a future release of Awkward Array. "
26+
"Please plan to migrate your code accordingly.",
27+
DeprecationWarning,
28+
stacklevel=2,
29+
)
2230
jax = ak.jax.import_jax()
2331
self._module = jax.numpy
2432

0 commit comments

Comments
 (0)