@@ -41,8 +41,8 @@ for more details on how to use the MLIR code generated from it.
41
41
from torch.export import export
42
42
import torchvision
43
43
import torch
44
- import torch_xla2 as tx
45
- import torch_xla2 .export
44
+ import torchax as tx
45
+ import torchax .export
46
46
import jax
47
47
import jax.numpy as jnp
48
48
@@ -111,10 +111,10 @@ import unittest
111
111
import torch
112
112
import torch.nn.functional as F
113
113
from torch.library import Library, impl, impl_abstract
114
- import torch_xla2
115
- import torch_xla2 .export
116
- from torch_xla2 .ops import jaten
117
- from torch_xla2 .ops import jlibrary
114
+ import torchax
115
+ import torchax .export
116
+ from torchax .ops import jaten
117
+ from torchax .ops import jlibrary
118
118
119
119
120
120
# Create a `mylib` library which has a basic SDPA op.
@@ -163,7 +163,7 @@ class LibraryTest(unittest.TestCase):
163
163
164
164
def setUp (self ):
165
165
torch.manual_seed(0 )
166
- torch_xla2 .default_env().config.use_torch_native_for_cpu_tensor = False
166
+ torchax .default_env().config.use_torch_native_for_cpu_tensor = False
167
167
168
168
def test_basic_sdpa_library (self ):
169
169
@@ -179,7 +179,7 @@ class LibraryTest(unittest.TestCase):
179
179
args = (arg, arg, arg, )
180
180
181
181
exported = torch.export.export(model, args = args)
182
- stablehlo = torch_xla2 .export.exported_program_to_stablehlo(exported)
182
+ stablehlo = torchax .export.exported_program_to_stablehlo(exported)
183
183
module_str = str (stablehlo.mlir_module())
184
184
185
185
# # TODO Update this machinery from producing function calls to producing
0 commit comments