Skip to content

Commit 9499e6f

Browse files
authored
s/torch_xla2/torchax (#9353)
1 parent ffce7fb commit 9499e6f

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

docs/source/features/stablehlo.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ for more details on how to use the MLIR code generated from it.
4141
from torch.export import export
4242
import torchvision
4343
import torch
44-
import torch_xla2 as tx
45-
import torch_xla2.export
44+
import torchax as tx
45+
import torchax.export
4646
import jax
4747
import jax.numpy as jnp
4848

@@ -111,10 +111,10 @@ import unittest
111111
import torch
112112
import torch.nn.functional as F
113113
from torch.library import Library, impl, impl_abstract
114-
import torch_xla2
115-
import torch_xla2.export
116-
from torch_xla2.ops import jaten
117-
from torch_xla2.ops import jlibrary
114+
import torchax
115+
import torchax.export
116+
from torchax.ops import jaten
117+
from torchax.ops import jlibrary
118118

119119

120120
# Create a `mylib` library which has a basic SDPA op.
@@ -163,7 +163,7 @@ class LibraryTest(unittest.TestCase):
163163

164164
def setUp(self):
165165
torch.manual_seed(0)
166-
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False
166+
torchax.default_env().config.use_torch_native_for_cpu_tensor = False
167167

168168
def test_basic_sdpa_library(self):
169169

@@ -179,7 +179,7 @@ class LibraryTest(unittest.TestCase):
179179
args = (arg, arg, arg, )
180180

181181
exported = torch.export.export(model, args=args)
182-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
182+
stablehlo = torchax.export.exported_program_to_stablehlo(exported)
183183
module_str = str(stablehlo.mlir_module())
184184

185185
## TODO Update this machinery from producing function calls to producing

0 commit comments

Comments
 (0)