Skip to content

Commit 11590c1

Browse files
ysiraichibhavya01
andauthored
Update PyTorch and XLA pin. (#9668)
This PR updates the following pins: - PyTorch: pytorch/pytorch@928ac57 to pytorch/pytorch@21fec65 (v2.9.0-rc5) - OpenXLA: openxla/xla@92f7b59 to openxla/xla@9a9aa0e - `libtpu`: 0.0.21 to 0.0.24 - JAX (and `jaxlib`): 0.7.1 to 0.8.0 **Key Changes:** - `@python` was replaced by `@rules_python` at `BUILD` file (ref: [jax-ml/jax#31709](jax-ml/jax#31709)) - `TF_ATTRIBUTE_NORETURN` was removed in favor of abseil (ref: [openxla/xla#31699](openxla/xla#31699)) - Replaced include of `xla/pjrt/tfrt_cpu_pjrt_client.h` file by `xla/pjrt/cpu/cpu_client.h` in `pjrt_registry.cpp` ([openxla/xla#30936](openxla/xla#30936)) - Moved the old `xla/tsl/platform/default/logging.*` to `torch_xla/csrc/runtime/tsl_platform_logging.*` - They were removed in [openxla/xla#29477](openxla/xla#29477) - Copied them here, temporarily. They should be removed once we update our error throwing macros. - Commented out a few macro definitions, avoiding macro re-definitions **Update (Oct 3):** - Add an OpenXLA patch for fixing `static_assert(false)` for GCC < 13 ([ref](https://gcc.gnu.org/git/?p=gcc.git;a=commit;h=9944ca17c0766623bce260684edc614def7ea761)) - Removed the `flax` pin, since it does not overwrite `jax` anymore - Removed `TPU*` prefix of `jax.experimental.pallas.tpu` components (ref: [jax-ml/jax#29115](jax-ml/jax#29115)) --------- Co-authored-by: Bhavya Bahl <[email protected]>
1 parent 2a9138a commit 11590c1

18 files changed

+1390
-82
lines changed

.github/workflows/_tpu_ci.yml

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,14 @@ jobs:
5151
pip install fsspec
5252
pip install rich
5353
54+
# Test dependencies
55+
pip install --upgrade protobuf
56+
pip install flax
57+
5458
# PyTorch/XLA Optional Dependencies
5559
# =================================
5660
#
57-
# Install `JAX` and `libtpu` dependencies for pallas and TPU tests.
61+
# Install `jax` and `libtpu` dependencies for pallas and TPU tests.
5862
#
5963
# Note that we might need to install pre-release versions of both, in
6064
# external artifact repositories.
@@ -70,18 +74,6 @@ jobs:
7074
pip install "$WHL[pallas]" --pre --index-url $INDEX --find-links $LINKS
7175
pip install "$WHL[tpu]" --pre --index-url $INDEX --find-links $LINKS
7276
73-
pip install --upgrade protobuf
74-
75-
# Flax Pin
76-
# ========
77-
#
78-
# Be careful when bumping the `flax` version, since it can cause tests that
79-
# depend on `jax` to start breaking.
80-
#
81-
# Newer `flax` versions might pull newer `jax` versions, which might be incompatible
82-
# with the current version of PyTorch/XLA.
83-
pip install flax==0.11.2
84-
8577
- name: Run Tests (${{ matrix.test_script }})
8678
if: inputs.has_code_changes == 'true'
8779
env:

.torch_commit

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
# 2025-09-17
2-
928ac57c2ab03f9f79376f9995553eea2e6f4ca8
1+
# 2025-09-29
2+
21fec65781bebe867faf209f89bb687ffd236ca4

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
load("@python//:defs.bzl", "compile_pip_requirements")
21
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
2+
load("@rules_python//python:pip.bzl", "compile_pip_requirements")
33

44
compile_pip_requirements(
55
name = "requirements",

WORKSPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ new_local_repository(
5252

5353
# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
5454
# the openxla git commit hash and note the date of the commit.
55-
xla_hash = '92f7b5952dd585c5be17c9a5caad27407005b513' # Committed on 2025-08-15.
55+
xla_hash = '9a9aa0e11e4fcda8d6a9c3267dca6776ddbdb0ca' # Committed on 2025-10-01.
5656

5757
http_archive(
5858
name = "xla",
@@ -63,6 +63,7 @@ http_archive(
6363
patch_tool = "patch",
6464
patches = [
6565
"//openxla_patches:no_fortify.diff",
66+
"//openxla_patches:if_constexpr_static_assert.diff",
6667
],
6768
strip_prefix = "xla-" + xla_hash,
6869
urls = [
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
diff --git a/xla/python/ifrt/attribute_map.h b/xla/python/ifrt/attribute_map.h
2+
index a8c9f11c8d..e5bb70bcf8 100644
3+
--- a/xla/python/ifrt/attribute_map.h
4+
+++ b/xla/python/ifrt/attribute_map.h
5+
@@ -106,7 +106,9 @@ class AttributeMap {
6+
} else if constexpr (std::is_same_v<T, float>) {
7+
return Get<T, FloatValue>(key);
8+
} else {
9+
- static_assert(false, "Unsupported type for AttributeMap::Get");
10+
+ // Same as: static_assert(false).
11+
+ // Make it compileable by GCC version < 13.
12+
+ static_assert(!sizeof(T), "Unsupported type for AttributeMap::Get");
13+
}
14+
}
15+
16+
diff --git a/xla/stream_executor/plugin_registry.cc b/xla/stream_executor/plugin_registry.cc
17+
index f16a4f6707..8bfd51b238 100644
18+
--- a/xla/stream_executor/plugin_registry.cc
19+
+++ b/xla/stream_executor/plugin_registry.cc
20+
@@ -41,7 +41,9 @@ PluginKind GetPluginKind() {
21+
} else if constexpr (std::is_same_v<FactoryT, PluginRegistry::FftFactory>) {
22+
return PluginKind::kFft;
23+
} else {
24+
- static_assert(false, "Unsupported factory type");
25+
+ // Same as: static_assert(false).
26+
+ // Make it compileable by GCC version < 13.
27+
+ static_assert(!sizeof(FactoryT), "Unsupported factory type");
28+
}
29+
}
30+
template <typename FactoryT>
31+
@@ -53,7 +55,9 @@ absl::string_view GetPluginName() {
32+
} else if constexpr (std::is_same_v<FactoryT, PluginRegistry::FftFactory>) {
33+
return "FFT";
34+
} else {
35+
- static_assert(false, "Unsupported factory type");
36+
+ // Same as: static_assert(false).
37+
+ // Make it compileable by GCC version < 13.
38+
+ static_assert(!sizeof(FactoryT), "Unsupported factory type");
39+
}
40+
}

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@
112112

113113
USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX.
114114

115-
_libtpu_version = '0.0.21'
116-
_libtpu_date = '20250813'
115+
_libtpu_version = '0.0.24'
116+
_libtpu_date = '20250929'
117117

118-
_jax_version = '0.7.1'
119-
_jaxlib_version = '0.7.1'
120-
_jax_date = '20250813' # Date for jax and jaxlib.
118+
_jax_version = '0.8.0'
119+
_jaxlib_version = '0.8.0'
120+
_jax_date = '20251001' # Date for jax and jaxlib.
121121

122122
if USE_NIGHTLY:
123123
_libtpu_version += f".dev{_libtpu_date}+nightly"

test/spmd/test_fsdp_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_fsdp_v2_basic(self):
5555
# Make sure optimization barrier is applied.
5656
hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
5757
self.assertIn(
58-
'opt-barrier.38 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.37',
58+
'opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.2',
5959
hlo)
6060

6161
# Make sure the model can execute without error.

test/spmd/test_xla_sharding.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def test_inplace_add_with_sharding(self):
613613
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))
614614
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
615615
self.assertIn(
616-
'%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6), custom_call_target="Sharding", sharding=',
616+
'%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1), custom_call_target="Sharding", sharding=',
617617
hlo)
618618

619619
# avoid calling xr.addressable_device_count here otherwise it will init the test
@@ -713,7 +713,8 @@ def test_xla_sharded_hlo_dump(self):
713713
partition_spec)
714714
xst2 = xst1 + 5
715715
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor])
716-
self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo)
716+
print(hlo)
717+
self.assertIn('%p1.1 = f32[1,8]{1,0} parameter(1), sharding', hlo)
717718
if torch_xla._XLAC._xla_get_auto_sharding():
718719
# scalar 5 should be implicitly replicated, so the pre-optimization HLO
719720
# shouldn't mark it with sharding.
@@ -828,13 +829,13 @@ def test_mark_sharding_ir(self):
828829
(0, 1))
829830
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
830831
self.assertIn(
831-
'%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=',
832+
'%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1), custom_call_target="Sharding", sharding=',
832833
hlo)
833834

834835
actual += 0
835836
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
836837
self.assertIn(
837-
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)',
838+
'%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1, f32[1,128]{1,0} %broadcast.3)',
838839
hlo)
839840

840841
self.assertTrue(torch.allclose(expected, actual.cpu()))
@@ -1141,7 +1142,7 @@ def test_backward_optimization_barrier(self):
11411142

11421143
hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
11431144
self.assertIn(
1144-
'%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)',
1145+
'%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2)',
11451146
hlo)
11461147

11471148
def test_mark_shard_scalar(self):
@@ -1198,7 +1199,7 @@ def test_spmd_full_to_shard_shape(self):
11981199

11991200
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12001201
self.assertEqual(xx.shape, (8, 8 // self.n_devices))
1201-
self.assertIn(f'%custom-call.2 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
1202+
self.assertIn(f'%custom-call.1 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
12021203
self.assertIn(
12031204
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
12041205
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
@@ -1215,7 +1216,7 @@ def test_spmd_full_to_shard_shape(self):
12151216

12161217
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12171218
self.assertEqual(xx.shape, (8, 4))
1218-
self.assertIn(f'%custom-call.2 = f32[8,4]{{1,0}}', hlo)
1219+
self.assertIn(f'%custom-call.1 = f32[8,4]{{1,0}}', hlo)
12191220
self.assertIn(
12201221
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
12211222
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
@@ -1246,7 +1247,7 @@ def test_spmd_shard_to_full_shape(self):
12461247

12471248
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12481249
self.assertEqual(xx.shape, x.shape)
1249-
self.assertIn('%custom-call.9 = f32[8,8]{1,0}', hlo)
1250+
self.assertIn('%custom-call.5 = f32[8,8]{1,0}', hlo)
12501251
self.assertIn(
12511252
'custom_call_target="SPMDShardToFullShape", sharding={replicated}', hlo)
12521253
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}")
@@ -1297,7 +1298,7 @@ def test_spmd_reduce_scatter(self):
12971298

12981299
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
12991300
self.assertIn(
1300-
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3",
1301+
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1",
13011302
hlo)
13021303

13031304
expected_x = torch.ones(8 // self.n_devices, 8) * self.n_devices
@@ -1318,7 +1319,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
13181319

13191320
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13201321
self.assertIn(
1321-
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3",
1322+
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1",
13221323
hlo)
13231324

13241325
expected_x = torch.ones(8, 8 // self.n_devices) * self.n_devices
@@ -1338,7 +1339,7 @@ def test_spmd_all_reduce(self):
13381339

13391340
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13401341
self.assertIn(
1341-
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
1342+
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
13421343
hlo)
13431344

13441345
expected_x = torch.ones(8, 8) * self.n_devices
@@ -1359,7 +1360,7 @@ def test_spmd_all_reduce_scale(self):
13591360

13601361
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13611362
self.assertIn(
1362-
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
1363+
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
13631364
hlo)
13641365

13651366
expected_x = torch.ones(8, 8) * int(self.n_devices * scale)
@@ -1713,7 +1714,7 @@ def test_annotate_custom_sharding(self):
17131714
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}',
17141715
hlo)
17151716
self.assertIn(
1716-
f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
1717+
f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
17171718
hlo)
17181719
xm.mark_step()
17191720
# Ensure that the resulting sharding spec is preserved

torch_xla/csrc/lowering_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
124124
};
125125

126126
// Reports an XLA builder error for the given node.
127-
TF_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
128-
absl::string_view error_msg);
127+
ABSL_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
128+
absl::string_view error_msg);
129129

130130
xla::XlaBuilder builder_;
131131
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>

torch_xla/csrc/runtime/BUILD

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,18 +381,34 @@ cc_test(
381381
],
382382
)
383383

384+
cc_library(
385+
name = "tsl_platform_logging",
386+
srcs = ["tsl_platform_logging.cpp"],
387+
hdrs = ["tsl_platform_logging.h"],
388+
deps = [
389+
"@xla//xla/tsl/platform:env_time",
390+
"@xla//xla/tsl/platform:logging",
391+
"@xla//xla/tsl/platform:macros",
392+
"@xla//xla/tsl/platform:types",
393+
"@com_google_absl//absl/base:core_headers",
394+
"@com_google_absl//absl/base:log_severity",
395+
"@com_google_absl//absl/container:flat_hash_map",
396+
"@com_google_absl//absl/strings:str_format",
397+
"@com_google_absl//absl/strings:string_view",
398+
],
399+
)
400+
384401
cc_library(
385402
name = "tf_logging",
386403
srcs = ["tf_logging.cpp"],
387404
hdrs = ["tf_logging.h"],
388405
deps = [
406+
":tsl_platform_logging",
389407
"//torch_xla/csrc:status",
390408
"@torch//:headers",
391409
"@torch//:runtime_headers",
392-
"@tsl//tsl/platform:stacktrace",
393-
"@tsl//tsl/platform:statusor",
394-
"@xla//xla/service:platform_util",
395410
"@com_google_absl//absl/base:log_severity",
411+
"@com_google_absl//absl/log:absl_log",
396412
],
397413
)
398414

0 commit comments

Comments
 (0)