Skip to content

Commit 2ccd5dc

Browse files
committed
Add gemini edited docstring
1 parent 9995e97 commit 2ccd5dc

File tree

2 files changed

+172
-6
lines changed

2 files changed

+172
-6
lines changed

torchax/torchax/__init__.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,30 @@ def default_env():
4040

4141

4242
def extract_jax(mod: torch.nn.Module, env=None):
43-
"""Returns a pytree of jax.ndarray and a jax callable."""
43+
"""Extracts the state of a `torch.nn.Module` into a JAX-compatible format.
44+
45+
**Arguments:**
46+
47+
* `mod` (`torch.nn.Module`): The PyTorch model to extract the state from.
48+
* `env` (optional): The `torchax` environment to use. If not provided, the default environment is used.
49+
50+
**Returns:**
51+
52+
A tuple containing:
53+
54+
* A `pytree` of `jax.ndarray` representing the model's state (parameters and buffers).
55+
* A JAX-callable function that executes the model's forward pass.
56+
57+
**Usage:**
58+
59+
```python
60+
import torch
61+
import torchax
62+
63+
model = torch.nn.Linear(10, 20)
64+
states, jax_func = torchax.extract_jax(model)
65+
```
66+
"""
4467
if env is None:
4568
env = default_env()
4669
states = dict(mod.named_buffers())
@@ -60,11 +83,31 @@ def jax_func(states, args, kwargs=None):
6083

6184

6285
def enable_globally():
86+
"""Enables `torchax` globally, which intercepts PyTorch operations and routes them to the JAX backend. This is the primary entry point for using `torchax`.
87+
88+
**Usage:**
89+
90+
```python
91+
import torchax
92+
93+
torchax.enable_globally()
94+
```
95+
"""
6396
env = default_env().enable_torch_modes()
6497
return env
6598

6699

67100
def disable_globally():
101+
"""Disables the `torchax` backend. After calling this, PyTorch operations will revert to their default behavior.
102+
103+
**Usage:**
104+
105+
```python
106+
import torchax
107+
108+
torchax.disable_globally()
109+
```
110+
"""
68111
global env
69112
default_env().disable_torch_modes()
70113

@@ -110,6 +153,40 @@ class CompileOptions:
110153

111154

112155
def compile(fn, options: Optional[CompileOptions] = None):
156+
"""Compiles a function or `torch.nn.Module` for optimized execution with JAX.
157+
158+
**Arguments:**
159+
160+
* `fn`: The function or `torch.nn.Module` to compile.
161+
* `options` (`CompileOptions`, optional): A `CompileOptions` object to configure the compilation process.
162+
163+
**`CompileOptions`:**
164+
165+
* `methods_to_compile` (`List[str]`, default=`['forward']`): A list of methods to compile when `fn` is a `torch.nn.Module`.
166+
* `jax_jit_kwargs` (`Dict[str, Any]`, default=`{}`): A dictionary of keyword arguments to pass to `jax.jit`.
167+
* `mode` (`str`, default=`'jax'`): The compilation mode. Currently, only `'jax'` is supported.
168+
169+
**Returns:**
170+
171+
A compiled version of the input function or module.
172+
173+
**Usage:**
174+
175+
```python
176+
import torch
177+
import torchax
178+
179+
model = torch.nn.Linear(10, 20)
180+
compiled_model = torchax.compile(model)
181+
182+
# With options
183+
options = torchax.CompileOptions(
184+
methods_to_compile=['forward', 'encode'],
185+
jax_jit_kwargs={'static_argnums': (0,)}
186+
)
187+
compiled_model = torchax.compile(model, options)
188+
```
189+
"""
113190
options = options or CompileOptions()
114191
if options.mode == 'jax':
115192
from torchax import interop

torchax/torchax/interop.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,29 @@ def set_one(module, prefix):
5656

5757

5858
class JittableModule(torch.nn.Module):
59+
"""A wrapper class that makes a `torch.nn.Module` compatible with `jax.jit`. It separates the model's parameters and buffers, allowing them to be passed as arguments to a functional version of the model.
60+
61+
**Arguments:**
62+
63+
* `m` (`torch.nn.Module`): The PyTorch model to wrap.
64+
* `extra_jit_args` (`dict`, optional): A dictionary of extra arguments to pass to `jax.jit`.
65+
* `dedup_parameters` (`bool`, optional): If `True`, deduplicates parameters that are shared within the model.
66+
67+
**Usage:**
68+
69+
```python
70+
import torch
71+
import torchax
72+
from torchax.interop import JittableModule
73+
74+
model = torch.nn.Linear(10, 20)
75+
jittable_model = JittableModule(model)
76+
77+
# The first call will compile the model
78+
inputs = torch.randn(5, 10, device='jax')
79+
outputs = jittable_model(inputs)
80+
```
81+
"""
5982

6083
def __init__(self,
6184
m: torch.nn.Module,
@@ -230,12 +253,17 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue,
230253

231254

232255
def j2t_autograd(fn, call_jax=call_jax):
233-
"""Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
256+
"""Given a JAX function, returns a PyTorch `autograd` function that is implemented with `jax.vjp`. This allows you to define custom gradients for your PyTorch operations using JAX.
257+
258+
**Arguments:**
259+
260+
* `fn`: The JAX function for which to create a PyTorch `autograd` function.
261+
* `call_jax` (optional): The function to use for calling JAX functions from PyTorch.
234262
235-
It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
236-
activations). The wrapped function is then run via `call_jax` and integrated into
237-
the PyTorch autograd framework by saving the residuals into the context object.
238-
"""
263+
**Returns:**
264+
265+
A PyTorch function with custom gradients defined by the JAX function.
266+
"""
239267

240268
@wraps(fn)
241269
def inner(*args, **kwargs):
@@ -333,24 +361,85 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
333361
def jax_jit(torch_function,
334362
kwargs_for_jax_jit=None,
335363
fix_for_buffer_donation=False):
364+
"""A decorator that applies `jax.jit` to a PyTorch function.
365+
366+
**Arguments:**
367+
368+
* `torch_function`: The PyTorch function to be JIT-compiled.
369+
* `kwargs_for_jax_jit` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.jit`.
370+
* `fix_for_buffer_donation` (`bool`, optional): A flag to enable a workaround for buffer donation issues.
371+
372+
**Returns:**
373+
374+
A JIT-compiled version of the PyTorch function.
375+
376+
**Usage:**
377+
378+
```python
379+
import torch
380+
import torchax
381+
from torchax.interop import jax_jit
382+
383+
@jax_jit
384+
def my_function(x, y):
385+
return torch.sin(x) + torch.cos(y)
386+
387+
x = torch.randn(5, 10, device='jax')
388+
y = torch.randn(5, 10, device='jax')
389+
result = my_function(x, y)
390+
```
391+
"""
336392
return wrap_jax_jit(
337393
torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit)
338394

339395

340396
def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
397+
"""Applies `jax.experimental.shard_map` to a PyTorch function, allowing for data parallelism across multiple devices.
398+
399+
**Arguments:**
400+
401+
* `torch_function`: The PyTorch function to be sharded.
402+
* `kwargs_for_jax_shard_map` (`dict`, optional): A dictionary of keyword arguments to pass to `shard_map`.
403+
404+
**Returns:**
405+
406+
A sharded version of the PyTorch function.
407+
"""
341408
return wrap_jax_jit(
342409
torch_function,
343410
jax_jit_func=shard_map,
344411
kwargs_for_jax=kwargs_for_jax_shard_map)
345412

346413

347414
def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
415+
"""Applies `jax.value_and_grad` to a PyTorch function.
416+
417+
**Arguments:**
418+
419+
* `torch_function`: The PyTorch function.
420+
* `kwargs_for_value_and_grad` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.value_and_grad`.
421+
422+
**Returns:**
423+
424+
A function that computes both the value and the gradient of the input `torch_function`.
425+
"""
348426
return wrap_jax_jit(
349427
torch_function,
350428
jax_jit_func=jax.value_and_grad,
351429
kwargs_for_jax=kwargs_for_value_and_grad)
352430

353431

354432
def gradient_checkpoint(torch_function, kwargs=None):
433+
"""Applies `jax.checkpoint` to a PyTorch function. This is useful for reducing memory usage during training by recomputing intermediate activations during the backward pass instead of storing them.
434+
435+
**Arguments:**
436+
437+
* `torch_function`: The PyTorch function to checkpoint.
438+
* `kwargs` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.checkpoint`.
439+
440+
**Returns:**
441+
442+
A checkpointed version of the PyTorch function.
443+
"""
355444
return wrap_jax_jit(
356445
torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)

0 commit comments

Comments
 (0)