Skip to content

Commit b6a5b82

Browse files
committed
add more files
1 parent 2ccd5dc commit b6a5b82

File tree

15 files changed

+598
-307
lines changed

15 files changed

+598
-307
lines changed

torchax/torchax/CONTRIBUTING.md

Lines changed: 0 additions & 38 deletions
This file was deleted.

torchax/torchax/amp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ def is_float(a):
5757

5858
@contextlib.contextmanager
5959
def autocast(device, dtype=torch.bfloat16, env=None):
60+
"""A context manager for automatic mixed precision (AMP).
61+
62+
This context manager enables automatic mixed precision, which can improve
63+
performance by using lower-precision data types for certain operations.
64+
65+
**Arguments:**
66+
67+
* `device`: The device to use for autocasting (e.g., "cuda", "cpu").
68+
* `dtype` (`torch.dtype`, optional): The lower-precision data type to use.
69+
Defaults to `torch.bfloat16`.
70+
* `env` (optional): The `torchax` environment. If not provided, the default
71+
environment is used.
72+
"""
6073
del device
6174
if env is None:
6275
import torchax

torchax/torchax/config.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,36 @@
33

44
@dataclasses.dataclass
55
class Configuration:
6+
"""A dataclass for configuring the behavior of `torchax`.
7+
8+
**Attributes:**
9+
10+
* `debug_print_each_op` (`bool`): If `True`, prints each operation as it is
11+
dispatched.
12+
* `debug_accuracy_for_each_op` (`bool`): If `True`, checks the accuracy of
13+
each operation by comparing its output with the equivalent PyTorch
14+
operation on the CPU.
15+
* `debug_mixed_tensor` (`bool`): If `True`, enables debugging for mixed
16+
tensor operations.
17+
* `debug_print_each_op_operands` (`bool`): If `True`, prints the operands of
18+
each operation.
19+
* `use_int32_for_index` (`bool`): If `True`, uses `int32` for indexing
20+
operations.
21+
* `allow_mixed_math_with_scalar_tensor` (`bool`): If `True`, allows mixed
22+
math operations between `torchax.Tensor` and scalar `torch.Tensor`s.
23+
* `force_materialize_views` (`bool`): If `True`, eagerly materializes `View`
24+
objects into `torchax.Tensor`s.
25+
* `use_dlpack_for_data_conversion` (`bool`): If `True`, uses DLPack for
26+
converting between `jax.Array` and `torch.Tensor`.
27+
* `use_tpu_flash_attention` (`bool`): If `True`, uses TPU-optimized flash
28+
attention.
29+
* `shmap_flash_attention` (`bool`): If `True`, uses `shard_map` for flash
30+
attention.
31+
* `treat_cuda_as_jax_device` (`bool`): If `True`, treats CUDA devices as JAX
32+
devices.
33+
* `internal_respect_torch_return_dtypes` (`bool`): If `True`, respects the
34+
return data types of PyTorch operations.
35+
"""
636
debug_print_each_op: bool = False
737
debug_accuracy_for_each_op: bool = False
838
debug_mixed_tensor: bool = False

torchax/torchax/decompositions.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""This file contains some decompositons that are not available in torch stable.
1+
"""This file contains PyTorch operator decompositions that are not available in
2+
the stable version of PyTorch.
23
3-
Most likely from Content of
4-
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
5-
at main branch HEAD that we find useful here.
6-
7-
Can also contain decompositions of a torch op in terms of other torch ops.
4+
The decompositions are primarily sourced from the `main` branch of the PyTorch
5+
repository and are included here to provide support for newer operators. This
6+
module can also contain decompositions of a PyTorch op in terms of other
7+
PyTorch ops.
88
"""
99

1010
import functools
@@ -104,18 +104,21 @@ def _reflection_or_replication_pad(
104104

105105

106106
def bernoulli(self, *, generator=None):
107+
"""Decomposition for the `bernoulli` operator."""
107108
return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
108109

109110

110111
_try_register(aten.bernoulli.default, bernoulli)
111112

112113

113114
def rand_like(self, **kwargs):
115+
"""Decomposition for the `rand_like` operator."""
114116
dtype = kwargs.get("dtype", self.dtype)
115117
return torch.rand(self.shape, dtype=dtype)
116118

117119

118120
def channel_shuffle(self, groups):
121+
"""Decomposition for the `channel_shuffle` operator."""
119122
batchsize, channels, height, width = self.shape
120123
channels_per_group = channels // groups
121124
self = self.reshape(batchsize, groups, channels_per_group, height, width)
@@ -131,6 +134,7 @@ def channel_shuffle(self, groups):
131134

132135

133136
def bernoulli_float(self, p=0.5):
137+
"""Decomposition for the `bernoulli_` operator with a float probability."""
134138
return self.bernoulli_(p)
135139

136140

@@ -150,9 +154,10 @@ def _grid_sampler_3d(
150154
padding_mode: int = 0,
151155
align_corners: bool = False,
152156
) -> Tensor:
153-
"""References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
157+
"""Decomposition for the `grid_sampler_3d` operator.
154158
155-
The above implement the 2d case.
159+
This implementation is based on the 2D version in the PyTorch repository:
160+
https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
156161
"""
157162
_expand_grid = False
158163
torch._check(
@@ -773,4 +778,4 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor,
773778
MUTABLE_DECOMPOSITION = [
774779
torch.ops.aten.bernoulli_.Tensor,
775780
torch.ops.aten.bernoulli_.float,
776-
]
781+
]

torchax/torchax/device_module.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,40 @@
22

33

44
def _is_in_bad_fork():
5+
"""Returns `False` as forking is not applicable in the same way as CUDA."""
56
return False
67

78

89
def manual_seed_all(seed):
10+
"""A placeholder for API compatibility; does not affect JAX's PRNG."""
911
pass
1012

1113

1214
def device_count():
15+
"""Returns `1` as JAX manages devices as a single logical device."""
1316
return 1
1417

1518

1619
def get_rng_state():
20+
"""Returns an empty list for API compatibility."""
1721
return []
1822

1923

2024
def set_rng_state(new_state, device):
25+
"""A placeholder for API compatibility; does not affect JAX's PRNG."""
2126
pass
2227

2328

2429
def is_available():
30+
"""Returns `True` if JAX is available."""
2531
return True
2632

2733

2834
def current_device():
35+
"""Returns `0` as JAX manages devices as a single logical device."""
2936
return 0
3037

3138

3239
def get_amp_supported_dtype():
40+
"""Returns the data types supported by AMP (Automatic Mixed Precision)."""
3341
return [torch.float16, torch.bfloat16]

torchax/torchax/export.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616

1717

1818
class JaxInterpreter(torch.fx.Interpreter):
19-
"""Experimental."""
19+
"""An `fx.Interpreter` that executes a PyTorch FX graph using JAX.
20+
21+
This interpreter traverses an FX graph and replaces PyTorch operations with
22+
their corresponding JAX implementations from the `torchax` operator registry.
23+
It is a key component in the process of exporting PyTorch models to JAX and
24+
StableHLO.
25+
"""
2026

2127
def __init__(self, graph_module):
2228
super().__init__(graph_module)
@@ -74,11 +80,24 @@ def _extract_states_from_exported_program(exported_model):
7480

7581

7682
def exported_program_to_jax(exported_program, export_raw: bool = False):
77-
"""returns a pytree of jax arrays(state), and
83+
"""Converts a `torch.export.ExportedProgram` to a JAX-compatible function and state.
84+
85+
This function takes a PyTorch `ExportedProgram`, runs the necessary
86+
decompositions, and returns a JAX-compatible function and the model's state
87+
(parameters and buffers) as JAX arrays.
88+
89+
**Arguments:**
7890
79-
a callable(func) that is jax function.
91+
* `exported_program` (`torch.export.ExportedProgram`): The PyTorch
92+
`ExportedProgram` to convert.
93+
* `export_raw` (`bool`, optional): If `True`, returns the raw states and
94+
function without converting them to JAX arrays. Defaults to `False`.
8095
81-
func(state, input) would be how you call it.
96+
**Returns:**
97+
98+
A tuple containing:
99+
* A pytree of JAX arrays representing the model's state.
100+
* A JAX-callable function that takes the state and inputs as arguments.
82101
"""
83102
if torch.__version__ >= '2.2':
84103
# torch version 2.1 didn't expose this yet
@@ -115,8 +134,19 @@ def func(states, inputs):
115134

116135

117136
def extract_avals(exported):
118-
"""Return JAX Abstract Value shapes for all input parameters of the exported
119-
program. This supports dynamic batch dimensions, including with constraints.
137+
"""Returns JAX abstract values (`ShapeDtypeStruct`) for all input parameters of the exported program.
138+
139+
This function supports dynamic batch dimensions, including those with
140+
constraints.
141+
142+
**Arguments:**
143+
144+
* `exported` (`torch.export.ExportedProgram`): The exported PyTorch program.
145+
146+
**Returns:**
147+
148+
A list of `jax.ShapeDtypeStruct` objects representing the abstract values of
149+
the input parameters.
120150
"""
121151

122152
def _to_aval(arg_meta, symbolic_shapes):
@@ -232,12 +262,24 @@ def _build_symbolic_shape(sym, constraint, free_symbols):
232262

233263

234264
def exported_program_to_stablehlo(exported_program):
235-
"""Replacement for torch_xla.stablehlo.exported_program_to_stablehlo
265+
"""Converts a `torch.export.ExportedProgram` to StableHLO.
266+
267+
This function serves as a replacement for
268+
`torch_xla.stablehlo.exported_program_to_stablehlo`. It supports dynamic
269+
dimension sizes and generates explicit checks for Dynamo guards in the IR
270+
using `shape_assertion` custom calls.
271+
272+
**Arguments:**
273+
274+
* `exported_program` (`torch.export.ExportedProgram`): The exported PyTorch
275+
program.
236276
237-
Convert a program exported via torch.export to StableHLO.
277+
**Returns:**
238278
239-
This supports dynamic dimension sizes and generates explicit checks for
240-
dynamo guards in the IR using shape_assertion custom_call ops.
279+
A tuple containing:
280+
* The model's state (weights) as a pytree of JAX arrays.
281+
* A `jax.export.Exported` object containing the StableHLO representation of
282+
the model.
241283
"""
242284
weights, func = exported_program_to_jax(exported_program)
243285
jax_avals = extract_avals(exported_program)

torchax/torchax/flax.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,32 @@
66

77

88
class FlaxNNModule(torch.nn.Module):
9+
"""A `torch.nn.Module` that wraps a Flax module for interoperability.
10+
11+
This class allows you to use a Flax module within a PyTorch model. It
12+
initializes the Flax module, extracts its parameters, and wraps them in a
13+
`torch.nn.ParameterDict` so they can be managed by PyTorch. The `forward`
14+
pass then calls the Flax module's `apply` method with the appropriate
15+
parameters.
16+
17+
**Attributes:**
18+
19+
* `_params` (`torch.nn.Module`): A nested `torch.nn.Module` that holds the
20+
parameters of the Flax module.
21+
* `_flax_module`: The original Flax module.
22+
"""
923

1024
def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
25+
"""Initializes the `FlaxNNModule`.
26+
27+
**Args:**
28+
29+
* `env`: The `torchax` environment.
30+
* `flax_module`: The Flax module to wrap.
31+
* `sample_args`: A tuple of sample arguments to initialize the Flax module.
32+
* `sample_kwargs` (optional): A dictionary of sample keyword arguments to
33+
initialize the Flax module.
34+
"""
1135
super().__init__()
1236
prng = env.prng_key
1337
sample_kwargs = sample_kwargs or {}
@@ -34,6 +58,7 @@ def _decode_nested_dict(self, child_module):
3458
return result
3559

3660
def forward(self, *args, **kwargs):
61+
"""Performs the forward pass by calling the wrapped Flax module."""
3762
nested_dict_params = self._decode_nested_dict(self._params)
3863
return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
3964
*args, **kwargs)

0 commit comments

Comments
 (0)