Skip to content

Commit 2889f69

Browse files
authored
Revert 2 accidental commits that I made. (#9536)
1 parent b6a5b82 commit 2889f69

17 files changed

+313
-770
lines changed

torchax/torchax/CONTRIBUTING.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Contributing to TorchXLA2
2+
3+
We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
4+
5+
If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
6+
7+
8+
# Developer setup
9+
10+
## Mac setup:
11+
@qihqi
12+
13+
I am able to develop directly on mac (m1) laptop for most of parts. Using steps
14+
in README.md works. The condensed version for easy copy & paste:
15+
16+
```bash
17+
conda create --name <your_name> python=3.10
18+
conda activate <your_name>
19+
pip install --upgrade "jax[cpu]" torch
20+
pip install -r test_requirements.txt
21+
pip install -e .
22+
pytest test
23+
```
24+
25+
### VSCode
26+
27+
I use vscode on my Mac. I loosely followed instruction in
28+
https://code.visualstudio.com/docs/python/python-tutorial
29+
to setup a proper python environment.
30+
31+
The plugins I installed (a subset of the ones listed above) are:
32+
* VSCode's official Python plugin
33+
* Ruff formatter
34+
* Python Debugger
35+
36+
I also changed Python interpreter to point at the one in my conda env.
37+
That is all the changes I have.
38+

torchax/torchax/__init__.py

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,7 @@ def default_env():
4040

4141

4242
def extract_jax(mod: torch.nn.Module, env=None):
43-
"""Extracts the state of a `torch.nn.Module` into a JAX-compatible format.
44-
45-
**Arguments:**
46-
47-
* `mod` (`torch.nn.Module`): The PyTorch model to extract the state from.
48-
* `env` (optional): The `torchax` environment to use. If not provided, the default environment is used.
49-
50-
**Returns:**
51-
52-
A tuple containing:
53-
54-
* A `pytree` of `jax.ndarray` representing the model's state (parameters and buffers).
55-
* A JAX-callable function that executes the model's forward pass.
56-
57-
**Usage:**
58-
59-
```python
60-
import torch
61-
import torchax
62-
63-
model = torch.nn.Linear(10, 20)
64-
states, jax_func = torchax.extract_jax(model)
65-
```
66-
"""
43+
"""Returns a pytree of jax.ndarray and a jax callable."""
6744
if env is None:
6845
env = default_env()
6946
states = dict(mod.named_buffers())
@@ -83,31 +60,11 @@ def jax_func(states, args, kwargs=None):
8360

8461

8562
def enable_globally():
86-
"""Enables `torchax` globally, which intercepts PyTorch operations and routes them to the JAX backend. This is the primary entry point for using `torchax`.
87-
88-
**Usage:**
89-
90-
```python
91-
import torchax
92-
93-
torchax.enable_globally()
94-
```
95-
"""
9663
env = default_env().enable_torch_modes()
9764
return env
9865

9966

10067
def disable_globally():
101-
"""Disables the `torchax` backend. After calling this, PyTorch operations will revert to their default behavior.
102-
103-
**Usage:**
104-
105-
```python
106-
import torchax
107-
108-
torchax.disable_globally()
109-
```
110-
"""
11168
global env
11269
default_env().disable_torch_modes()
11370

@@ -153,40 +110,6 @@ class CompileOptions:
153110

154111

155112
def compile(fn, options: Optional[CompileOptions] = None):
156-
"""Compiles a function or `torch.nn.Module` for optimized execution with JAX.
157-
158-
**Arguments:**
159-
160-
* `fn`: The function or `torch.nn.Module` to compile.
161-
* `options` (`CompileOptions`, optional): A `CompileOptions` object to configure the compilation process.
162-
163-
**`CompileOptions`:**
164-
165-
* `methods_to_compile` (`List[str]`, default=`['forward']`): A list of methods to compile when `fn` is a `torch.nn.Module`.
166-
* `jax_jit_kwargs` (`Dict[str, Any]`, default=`{}`): A dictionary of keyword arguments to pass to `jax.jit`.
167-
* `mode` (`str`, default=`'jax'`): The compilation mode. Currently, only `'jax'` is supported.
168-
169-
**Returns:**
170-
171-
A compiled version of the input function or module.
172-
173-
**Usage:**
174-
175-
```python
176-
import torch
177-
import torchax
178-
179-
model = torch.nn.Linear(10, 20)
180-
compiled_model = torchax.compile(model)
181-
182-
# With options
183-
options = torchax.CompileOptions(
184-
methods_to_compile=['forward', 'encode'],
185-
jax_jit_kwargs={'static_argnums': (0,)}
186-
)
187-
compiled_model = torchax.compile(model, options)
188-
```
189-
"""
190113
options = options or CompileOptions()
191114
if options.mode == 'jax':
192115
from torchax import interop

torchax/torchax/amp.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,6 @@ def is_float(a):
5757

5858
@contextlib.contextmanager
5959
def autocast(device, dtype=torch.bfloat16, env=None):
60-
"""A context manager for automatic mixed precision (AMP).
61-
62-
This context manager enables automatic mixed precision, which can improve
63-
performance by using lower-precision data types for certain operations.
64-
65-
**Arguments:**
66-
67-
* `device`: The device to use for autocasting (e.g., "cuda", "cpu").
68-
* `dtype` (`torch.dtype`, optional): The lower-precision data type to use.
69-
Defaults to `torch.bfloat16`.
70-
* `env` (optional): The `torchax` environment. If not provided, the default
71-
environment is used.
72-
"""
7360
del device
7461
if env is None:
7562
import torchax

torchax/torchax/config.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,6 @@
33

44
@dataclasses.dataclass
55
class Configuration:
6-
"""A dataclass for configuring the behavior of `torchax`.
7-
8-
**Attributes:**
9-
10-
* `debug_print_each_op` (`bool`): If `True`, prints each operation as it is
11-
dispatched.
12-
* `debug_accuracy_for_each_op` (`bool`): If `True`, checks the accuracy of
13-
each operation by comparing its output with the equivalent PyTorch
14-
operation on the CPU.
15-
* `debug_mixed_tensor` (`bool`): If `True`, enables debugging for mixed
16-
tensor operations.
17-
* `debug_print_each_op_operands` (`bool`): If `True`, prints the operands of
18-
each operation.
19-
* `use_int32_for_index` (`bool`): If `True`, uses `int32` for indexing
20-
operations.
21-
* `allow_mixed_math_with_scalar_tensor` (`bool`): If `True`, allows mixed
22-
math operations between `torchax.Tensor` and scalar `torch.Tensor`s.
23-
* `force_materialize_views` (`bool`): If `True`, eagerly materializes `View`
24-
objects into `torchax.Tensor`s.
25-
* `use_dlpack_for_data_conversion` (`bool`): If `True`, uses DLPack for
26-
converting between `jax.Array` and `torch.Tensor`.
27-
* `use_tpu_flash_attention` (`bool`): If `True`, uses TPU-optimized flash
28-
attention.
29-
* `shmap_flash_attention` (`bool`): If `True`, uses `shard_map` for flash
30-
attention.
31-
* `treat_cuda_as_jax_device` (`bool`): If `True`, treats CUDA devices as JAX
32-
devices.
33-
* `internal_respect_torch_return_dtypes` (`bool`): If `True`, respects the
34-
return data types of PyTorch operations.
35-
"""
366
debug_print_each_op: bool = False
377
debug_accuracy_for_each_op: bool = False
388
debug_mixed_tensor: bool = False

torchax/torchax/decompositions.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""This file contains PyTorch operator decompositions that are not available in
2-
the stable version of PyTorch.
1+
"""This file contains some decompositons that are not available in torch stable.
32
4-
The decompositions are primarily sourced from the `main` branch of the PyTorch
5-
repository and are included here to provide support for newer operators. This
6-
module can also contain decompositions of a PyTorch op in terms of other
7-
PyTorch ops.
3+
Most likely from Content of
4+
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
5+
at main branch HEAD that we find useful here.
6+
7+
Can also contain decompositions of a torch op in terms of other torch ops.
88
"""
99

1010
import functools
@@ -104,21 +104,18 @@ def _reflection_or_replication_pad(
104104

105105

106106
def bernoulli(self, *, generator=None):
107-
"""Decomposition for the `bernoulli` operator."""
108107
return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
109108

110109

111110
_try_register(aten.bernoulli.default, bernoulli)
112111

113112

114113
def rand_like(self, **kwargs):
115-
"""Decomposition for the `rand_like` operator."""
116114
dtype = kwargs.get("dtype", self.dtype)
117115
return torch.rand(self.shape, dtype=dtype)
118116

119117

120118
def channel_shuffle(self, groups):
121-
"""Decomposition for the `channel_shuffle` operator."""
122119
batchsize, channels, height, width = self.shape
123120
channels_per_group = channels // groups
124121
self = self.reshape(batchsize, groups, channels_per_group, height, width)
@@ -134,7 +131,6 @@ def channel_shuffle(self, groups):
134131

135132

136133
def bernoulli_float(self, p=0.5):
137-
"""Decomposition for the `bernoulli_` operator with a float probability."""
138134
return self.bernoulli_(p)
139135

140136

@@ -154,10 +150,9 @@ def _grid_sampler_3d(
154150
padding_mode: int = 0,
155151
align_corners: bool = False,
156152
) -> Tensor:
157-
"""Decomposition for the `grid_sampler_3d` operator.
153+
"""References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
158154
159-
This implementation is based on the 2D version in the PyTorch repository:
160-
https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
155+
The above implement the 2d case.
161156
"""
162157
_expand_grid = False
163158
torch._check(
@@ -778,4 +773,4 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor,
778773
MUTABLE_DECOMPOSITION = [
779774
torch.ops.aten.bernoulli_.Tensor,
780775
torch.ops.aten.bernoulli_.float,
781-
]
776+
]

torchax/torchax/device_module.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,32 @@
22

33

44
def _is_in_bad_fork():
5-
"""Returns `False` as forking is not applicable in the same way as CUDA."""
65
return False
76

87

98
def manual_seed_all(seed):
10-
"""A placeholder for API compatibility; does not affect JAX's PRNG."""
119
pass
1210

1311

1412
def device_count():
15-
"""Returns `1` as JAX manages devices as a single logical device."""
1613
return 1
1714

1815

1916
def get_rng_state():
20-
"""Returns an empty list for API compatibility."""
2117
return []
2218

2319

2420
def set_rng_state(new_state, device):
25-
"""A placeholder for API compatibility; does not affect JAX's PRNG."""
2621
pass
2722

2823

2924
def is_available():
30-
"""Returns `True` if JAX is available."""
3125
return True
3226

3327

3428
def current_device():
35-
"""Returns `0` as JAX manages devices as a single logical device."""
3629
return 0
3730

3831

3932
def get_amp_supported_dtype():
40-
"""Returns the data types supported by AMP (Automatic Mixed Precision)."""
4133
return [torch.float16, torch.bfloat16]

0 commit comments

Comments
 (0)