Skip to content

Conversation

ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Oct 2, 2025

This PR updates the following pins:

Key Changes:

  • @python was replaced by @rules_python at BUILD file (ref: jax-ml/jax#31709)
  • TF_ATTRIBUTE_NORETURN was removed in favor of abseil (ref: openxla/xla#31699)
  • Replaced include of xla/pjrt/tfrt_cpu_pjrt_client.h file by xla/pjrt/cpu/cpu_client.h in pjrt_registry.cpp (openxla/xla#30936)
  • Moved the old xla/tsl/platform/default/logging.* to torch_xla/csrc/runtime/tsl_platform_logging.*
    • They were removed in openxla/xla#29477
    • Copied them here, temporarily. They should be removed once we update our error throwing macros.
    • Commented out a few macro definitions, avoiding macro re-definitions

Update (Oct 3):

  • Add an OpenXLA patch for fixing static_assert(false) for GCC < 13 (ref)
  • Removed the flax pin, since it does not overwrite jax anymore
  • Removed TPU* prefix of jax.experimental.pallas.tpu components (ref: jax-ml/jax#29115)

#
# Newer `flax` versions might pull newer `jax` versions, which might be incompatible
# with the current version of PyTorch/XLA.
pip install flax==0.11.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like flax is still required

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am going to revert the last two changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants