diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index fe4c1c8ff04..2a7bf5f148e 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -15,6 +15,8 @@ 'default_env', 'extract_jax', 'enable_globally', + 'enable_performance_mode', + 'enable_accuracy_mode', ] from jax._src import xla_bridge @@ -40,7 +42,30 @@ def default_env(): def extract_jax(mod: torch.nn.Module, env=None): - """Returns a pytree of jax.ndarray and a jax callable.""" + """Extracts `torch.nn.Module` into a (pytree, function) pair. + + **Arguments:** + + * `mod` (`torch.nn.Module`): The PyTorch model to extract the state from. + * `env` (optional): The `torchax` environment to use. If not provided, the default environment is used. + + **Returns:** + + A tuple containing: + + * A `pytree` of `jax.ndarray` representing the model's state (parameters and buffers). + * A JAX-callable function that executes the model's forward pass. + + **Usage:** + + ```python + import torch + import torchax + + model = torch.nn.Linear(10, 20) + states, jax_func = torchax.extract_jax(model) + ``` + """ if env is None: env = default_env() states = dict(mod.named_buffers()) @@ -60,11 +85,36 @@ def jax_func(states, args, kwargs=None): def enable_globally(): + """Enables `torchax` globally. + + This which intercepts PyTorch operations and routes them to + the JAX backend. This is the primary entry point for using `torchax`. + + **Usage:** + + ```python + import torchax + + torchax.enable_globally() + ``` + """ env = default_env().enable_torch_modes() return env def disable_globally(): + """Disables the `torchax` backend. + + After calling this, PyTorch operations will revert to their default behavior. + + **Usage:** + + ```python + import torchax + + torchax.disable_globally() + ``` + """ global env default_env().disable_torch_modes() @@ -110,6 +160,40 @@ class CompileOptions: def compile(fn, options: Optional[CompileOptions] = None): + """Compiles a function or `torch.nn.Module` for optimized execution with JAX. + + **Arguments:** + + * `fn`: The function or `torch.nn.Module` to compile. + * `options` (`CompileOptions`, optional): A `CompileOptions` object to configure the compilation process. + + **`CompileOptions`:** + + * `methods_to_compile` (`List[str]`, default=`['forward']`): A list of methods to compile when `fn` is a `torch.nn.Module`. + * `jax_jit_kwargs` (`Dict[str, Any]`, default=`{}`): A dictionary of keyword arguments to pass to `jax.jit`. + * `mode` (`str`, default=`'jax'`): The compilation mode. Currently, only `'jax'` is supported. + + **Returns:** + + A compiled version of the input function or module. + + **Usage:** + + ```python + import torch + import torchax + + model = torch.nn.Linear(10, 20) + compiled_model = torchax.compile(model) + + # With options + options = torchax.CompileOptions( + methods_to_compile=['forward', 'encode'], + jax_jit_kwargs={'static_argnums': (0,)} + ) + compiled_model = torchax.compile(model, options) + ``` + """ options = options or CompileOptions() if options.mode == 'jax': from torchax import interop diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index a87efe9dfe7..d746e9d03ba 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -56,6 +56,29 @@ def set_one(module, prefix): class JittableModule(torch.nn.Module): + """A wrapper class that makes a `torch.nn.Module` compatible with `jax.jit`. It separates the model's parameters and buffers, allowing them to be passed as arguments to a functional version of the model. + + **Arguments:** + + * `m` (`torch.nn.Module`): The PyTorch model to wrap. + * `extra_jit_args` (`dict`, optional): A dictionary of extra arguments to pass to `jax.jit`. + * `dedup_parameters` (`bool`, optional): If `True`, deduplicates parameters that are shared within the model. + + **Usage:** + + ```python + import torch + import torchax + from torchax.interop import JittableModule + + model = torch.nn.Linear(10, 20) + jittable_model = JittableModule(model) + + # The first call will compile the model + inputs = torch.randn(5, 10, device='jax') + outputs = jittable_model(inputs) + ``` + """ def __init__(self, m: torch.nn.Module, @@ -230,12 +253,17 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue, def j2t_autograd(fn, call_jax=call_jax): - """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. + """Given a JAX function, returns a PyTorch `autograd` function that is implemented with `jax.vjp`. This allows you to define custom gradients for your PyTorch operations using JAX. + + **Arguments:** + + * `fn`: The JAX function for which to create a PyTorch `autograd` function. + * `call_jax` (optional): The function to use for calling JAX functions from PyTorch. - It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate - activations). The wrapped function is then run via `call_jax` and integrated into - the PyTorch autograd framework by saving the residuals into the context object. - """ + **Returns:** + + A PyTorch function with custom gradients defined by the JAX function. + """ @wraps(fn) def inner(*args, **kwargs): @@ -333,11 +361,50 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None): def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False): + """A decorator that applies `jax.jit` to a PyTorch function. + + **Arguments:** + + * `torch_function`: The PyTorch function to be JIT-compiled. + * `kwargs_for_jax_jit` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.jit`. + * `fix_for_buffer_donation` (`bool`, optional): A flag to enable a workaround for buffer donation issues. + + **Returns:** + + A JIT-compiled version of the PyTorch function. + + **Usage:** + + ```python + import torch + import torchax + from torchax.interop import jax_jit + + @jax_jit + def my_function(x, y): + return torch.sin(x) + torch.cos(y) + + x = torch.randn(5, 10, device='jax') + y = torch.randn(5, 10, device='jax') + result = my_function(x, y) + ``` + """ return wrap_jax_jit( torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit) def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): + """Applies `jax.experimental.shard_map` to a PyTorch function, allowing for data parallelism across multiple devices. + + **Arguments:** + + * `torch_function`: The PyTorch function to be sharded. + * `kwargs_for_jax_shard_map` (`dict`, optional): A dictionary of keyword arguments to pass to `shard_map`. + + **Returns:** + + A sharded version of the PyTorch function. + """ return wrap_jax_jit( torch_function, jax_jit_func=shard_map, @@ -345,6 +412,17 @@ def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): + """Applies `jax.value_and_grad` to a PyTorch function. + + **Arguments:** + + * `torch_function`: The PyTorch function. + * `kwargs_for_value_and_grad` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.value_and_grad`. + + **Returns:** + + A function that computes both the value and the gradient of the input `torch_function`. + """ return wrap_jax_jit( torch_function, jax_jit_func=jax.value_and_grad, @@ -352,5 +430,16 @@ def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): def gradient_checkpoint(torch_function, kwargs=None): + """Applies `jax.checkpoint` to a PyTorch function. This is useful for reducing memory usage during training by recomputing intermediate activations during the backward pass instead of storing them. + + **Arguments:** + + * `torch_function`: The PyTorch function to checkpoint. + * `kwargs` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.checkpoint`. + + **Returns:** + + A checkpointed version of the PyTorch function. + """ return wrap_jax_jit( torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)