Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion torchax/torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
'default_env',
'extract_jax',
'enable_globally',
'enable_performance_mode',
'enable_accuracy_mode',
]

from jax._src import xla_bridge
Expand All @@ -40,7 +42,30 @@ def default_env():


def extract_jax(mod: torch.nn.Module, env=None):
"""Returns a pytree of jax.ndarray and a jax callable."""
"""Extracts `torch.nn.Module` into a (pytree, function) pair.

**Arguments:**

* `mod` (`torch.nn.Module`): The PyTorch model to extract the state from.
* `env` (optional): The `torchax` environment to use. If not provided, the default environment is used.

**Returns:**

A tuple containing:

* A `pytree` of `jax.ndarray` representing the model's state (parameters and buffers).
* A JAX-callable function that executes the model's forward pass.

**Usage:**

```python
import torch
import torchax

model = torch.nn.Linear(10, 20)
states, jax_func = torchax.extract_jax(model)
```
"""
if env is None:
env = default_env()
states = dict(mod.named_buffers())
Expand All @@ -60,11 +85,36 @@ def jax_func(states, args, kwargs=None):


def enable_globally():
"""Enables `torchax` globally.

This which intercepts PyTorch operations and routes them to
the JAX backend. This is the primary entry point for using `torchax`.

**Usage:**

```python
import torchax

torchax.enable_globally()
```
"""
env = default_env().enable_torch_modes()
return env


def disable_globally():
"""Disables the `torchax` backend.

After calling this, PyTorch operations will revert to their default behavior.

**Usage:**

```python
import torchax

torchax.disable_globally()
```
"""
global env
default_env().disable_torch_modes()

Expand Down Expand Up @@ -110,6 +160,40 @@ class CompileOptions:


def compile(fn, options: Optional[CompileOptions] = None):
"""Compiles a function or `torch.nn.Module` for optimized execution with JAX.

**Arguments:**

* `fn`: The function or `torch.nn.Module` to compile.
* `options` (`CompileOptions`, optional): A `CompileOptions` object to configure the compilation process.

**`CompileOptions`:**

* `methods_to_compile` (`List[str]`, default=`['forward']`): A list of methods to compile when `fn` is a `torch.nn.Module`.
* `jax_jit_kwargs` (`Dict[str, Any]`, default=`{}`): A dictionary of keyword arguments to pass to `jax.jit`.
* `mode` (`str`, default=`'jax'`): The compilation mode. Currently, only `'jax'` is supported.

**Returns:**

A compiled version of the input function or module.

**Usage:**

```python
import torch
import torchax

model = torch.nn.Linear(10, 20)
compiled_model = torchax.compile(model)

# With options
options = torchax.CompileOptions(
methods_to_compile=['forward', 'encode'],
jax_jit_kwargs={'static_argnums': (0,)}
)
compiled_model = torchax.compile(model, options)
```
"""
options = options or CompileOptions()
if options.mode == 'jax':
from torchax import interop
Expand Down
99 changes: 94 additions & 5 deletions torchax/torchax/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,29 @@ def set_one(module, prefix):


class JittableModule(torch.nn.Module):
"""A wrapper class that makes a `torch.nn.Module` compatible with `jax.jit`. It separates the model's parameters and buffers, allowing them to be passed as arguments to a functional version of the model.

**Arguments:**

* `m` (`torch.nn.Module`): The PyTorch model to wrap.
* `extra_jit_args` (`dict`, optional): A dictionary of extra arguments to pass to `jax.jit`.
* `dedup_parameters` (`bool`, optional): If `True`, deduplicates parameters that are shared within the model.

**Usage:**

```python
import torch
import torchax
from torchax.interop import JittableModule

model = torch.nn.Linear(10, 20)
jittable_model = JittableModule(model)

# The first call will compile the model
inputs = torch.randn(5, 10, device='jax')
outputs = jittable_model(inputs)
```
"""

def __init__(self,
m: torch.nn.Module,
Expand Down Expand Up @@ -230,12 +253,17 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue,


def j2t_autograd(fn, call_jax=call_jax):
"""Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
"""Given a JAX function, returns a PyTorch `autograd` function that is implemented with `jax.vjp`. This allows you to define custom gradients for your PyTorch operations using JAX.

**Arguments:**

* `fn`: The JAX function for which to create a PyTorch `autograd` function.
* `call_jax` (optional): The function to use for calling JAX functions from PyTorch.

It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
activations). The wrapped function is then run via `call_jax` and integrated into
the PyTorch autograd framework by saving the residuals into the context object.
"""
**Returns:**

A PyTorch function with custom gradients defined by the JAX function.
"""

@wraps(fn)
def inner(*args, **kwargs):
Expand Down Expand Up @@ -333,24 +361,85 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
def jax_jit(torch_function,
kwargs_for_jax_jit=None,
fix_for_buffer_donation=False):
"""A decorator that applies `jax.jit` to a PyTorch function.

**Arguments:**

* `torch_function`: The PyTorch function to be JIT-compiled.
* `kwargs_for_jax_jit` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.jit`.
* `fix_for_buffer_donation` (`bool`, optional): A flag to enable a workaround for buffer donation issues.

**Returns:**

A JIT-compiled version of the PyTorch function.

**Usage:**

```python
import torch
import torchax
from torchax.interop import jax_jit

@jax_jit
def my_function(x, y):
return torch.sin(x) + torch.cos(y)

x = torch.randn(5, 10, device='jax')
y = torch.randn(5, 10, device='jax')
result = my_function(x, y)
```
"""
return wrap_jax_jit(
torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit)


def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
"""Applies `jax.experimental.shard_map` to a PyTorch function, allowing for data parallelism across multiple devices.

**Arguments:**

* `torch_function`: The PyTorch function to be sharded.
* `kwargs_for_jax_shard_map` (`dict`, optional): A dictionary of keyword arguments to pass to `shard_map`.

**Returns:**

A sharded version of the PyTorch function.
"""
return wrap_jax_jit(
torch_function,
jax_jit_func=shard_map,
kwargs_for_jax=kwargs_for_jax_shard_map)


def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
"""Applies `jax.value_and_grad` to a PyTorch function.

**Arguments:**

* `torch_function`: The PyTorch function.
* `kwargs_for_value_and_grad` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.value_and_grad`.

**Returns:**

A function that computes both the value and the gradient of the input `torch_function`.
"""
return wrap_jax_jit(
torch_function,
jax_jit_func=jax.value_and_grad,
kwargs_for_jax=kwargs_for_value_and_grad)


def gradient_checkpoint(torch_function, kwargs=None):
"""Applies `jax.checkpoint` to a PyTorch function. This is useful for reducing memory usage during training by recomputing intermediate activations during the backward pass instead of storing them.

**Arguments:**

* `torch_function`: The PyTorch function to checkpoint.
* `kwargs` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.checkpoint`.

**Returns:**

A checkpointed version of the PyTorch function.
"""
return wrap_jax_jit(
torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)