You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Enables `torchax` globally, which intercepts PyTorch operations and routes them to the JAX backend. This is the primary entry point for using `torchax`.
87
+
88
+
**Usage:**
89
+
90
+
```python
91
+
import torchax
92
+
93
+
torchax.enable_globally()
94
+
```
95
+
"""
63
96
env=default_env().enable_torch_modes()
64
97
returnenv
65
98
66
99
67
100
defdisable_globally():
101
+
"""Disables the `torchax` backend. After calling this, PyTorch operations will revert to their default behavior.
Copy file name to clipboardExpand all lines: torchax/torchax/interop.py
+94-5Lines changed: 94 additions & 5 deletions
Original file line number
Diff line number
Diff line change
@@ -56,6 +56,29 @@ def set_one(module, prefix):
56
56
57
57
58
58
classJittableModule(torch.nn.Module):
59
+
"""A wrapper class that makes a `torch.nn.Module` compatible with `jax.jit`. It separates the model's parameters and buffers, allowing them to be passed as arguments to a functional version of the model.
60
+
61
+
**Arguments:**
62
+
63
+
* `m` (`torch.nn.Module`): The PyTorch model to wrap.
64
+
* `extra_jit_args` (`dict`, optional): A dictionary of extra arguments to pass to `jax.jit`.
65
+
* `dedup_parameters` (`bool`, optional): If `True`, deduplicates parameters that are shared within the model.
"""Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
256
+
"""Given a JAX function, returns a PyTorch `autograd` function that is implemented with `jax.vjp`. This allows you to define custom gradients for your PyTorch operations using JAX.
257
+
258
+
**Arguments:**
259
+
260
+
* `fn`: The JAX function for which to create a PyTorch `autograd` function.
261
+
* `call_jax` (optional): The function to use for calling JAX functions from PyTorch.
234
262
235
-
It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
236
-
activations). The wrapped function is then run via `call_jax` and integrated into
237
-
the PyTorch autograd framework by saving the residuals into the context object.
238
-
"""
263
+
**Returns:**
264
+
265
+
A PyTorch function with custom gradients defined by the JAX function.
"""Applies `jax.checkpoint` to a PyTorch function. This is useful for reducing memory usage during training by recomputing intermediate activations during the backward pass instead of storing them.
434
+
435
+
**Arguments:**
436
+
437
+
* `torch_function`: The PyTorch function to checkpoint.
438
+
* `kwargs` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.checkpoint`.
0 commit comments