Skip to content

Conversation

sshonTT
Copy link

@sshonTT sshonTT commented Oct 3, 2025

Rebase to upstream

zmelumian972 and others added 30 commits July 18, 2025 12:42
…#9501)

* Refactored jax device handling
* Removed option to use CPU jax array for CPU torch tensors. - changing jax devices after the fact will use different APIs
sshonTT and others added 10 commits October 3, 2025 15:06
Remove the explicit opmath-driven cast chain (bf16→f32→bf16, etc.) from
`mul`. The op now executes in the dtype chosen by standard dtype
promotion, without inserting unconditional upcast/downcast steps. But
leave its functionality for future usage.
…n PyTorch/XLA (#1)

Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things:

- Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]).
- Converts the new GSPMD module with the V2 annotations into a Shardy module.
…chip training (#2)

* Add V2 sharding support and improve partition spec handling for multi-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

* Fix formatting according to Torch-XLA style guide

---------

Co-authored-by: Het Shah <[email protected]>
… PJRT backend

This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code.
Key changes:
* Python API
    * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values).
    * Added torch_xla.set_custom_compile_options() utility for setting compile options globally.
    * Added internal binding _XLAC._set_custom_compile_options().
* C++ Runtime
    * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient.
    * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation.
    * Options are stringified before being passed to XLA for compatibility.
Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
…ation (#7)

This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs.

See pytorch#9541 for the upstream PR discussion and additional context.

* Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon.

* New implementation (WIP)

* Fix new implementation

* Fix visualize_tensor_sharding function for V2 shardings
@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

@AleksKnezevic I tried this out, but it seems to cause random segfaults. I’ll need to dig in further to figure out the root cause.

@AleksKnezevic
Copy link

Thanks @sshonTT, then please repoen and merge #9 while we investigate.

fix for api match
@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

Hi @hshahTT, @jazpurTT, @ddilbazTT,

I’ve rebased our branch with the upstream changes and verified it on my side, but I’d appreciate it if you could double-check that everything works correctly.

I’ve also built a wheel for testing, which you can find here:
wh-lb-57:/localdev/sshon/ws/pytorch/pytorch-xla/dist/torch_xla-2.9.0+git86bac8b-cp311-cp311-linux_x86_64.whl

Please let me know if it installs and runs fine on your end. Thanks!

@AleksKnezevic
Copy link

@sshonTT are you still seeing segfaults?

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

@AleksKnezevic No I don't see it now.

@AleksKnezevic
Copy link

That's great, what was causing them previously, any ideas?

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

I couldn’t root-cause it completely, but it turned out to be a system-level issue related to Torch Inductor’s mutex handling. It was failing because a mutex was already acquired by another process or context, likely left uncleared from a previous pytest run.

After releasing and reassigning the same IRD machine, the issue disappeared. I also verified it on another IRD machine to confirm that it works correctly now.

@AleksKnezevic
Copy link

awesome, thanks @sshonTT! Do we have a way of running CI with this wheel?

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

I’ve triggered this workflow run
to build a wheel. Once it’s ready, I’ll update the torch-xla version in tt-xla and test how it behaves. Other than that, I don’t currently have a concrete way to verify this change yet.

@sshonTT
Copy link
Author

sshonTT commented Oct 6, 2025

I think we have a build issue since here. Will find a way to get over this.

@sshonTT sshonTT force-pushed the sshon/rebase-to-upstream branch 3 times, most recently from 616047b to 626b736 Compare October 7, 2025 16:15
Torch build option change to avoid build warning and error.
@sshonTT sshonTT force-pushed the sshon/rebase-to-upstream branch from 626b736 to 27f7792 Compare October 7, 2025 16:27
@sshonTT
Copy link
Author

sshonTT commented Oct 8, 2025

Build success after turning off warning as an error, but there is an error when publish it. I think it is some related to s3 bucket credential,
image

@jazpurTT I believe you have experience on S3 bucket, so do you know what is going on and have any suggestion to fix it?

@sshonTT sshonTT force-pushed the sshon/rebase-to-upstream branch from 9165c52 to b1ebc54 Compare October 9, 2025 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.