-
Notifications
You must be signed in to change notification settings - Fork 0
Rebase to upstream #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…_call in the first parameter (pytorch#9451) Co-authored-by: zmelumian <[email protected]>
Co-authored-by: qihqi <[email protected]>
…ropagate status. (pytorch#9429)
…#9501) * Refactored jax device handling * Removed option to use CPU jax array for CPU torch tensors. - changing jax devices after the fact will use different APIs
…`XlaDataToTensors`. (pytorch#9431)
… to propagate status. (pytorch#9445)
Remove the explicit opmath-driven cast chain (bf16→f32→bf16, etc.) from `mul`. The op now executes in the dtype chosen by standard dtype promotion, without inserting unconditional upcast/downcast steps. But leave its functionality for future usage.
…n PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
…chip training (#2) * Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <[email protected]>
… PJRT backend This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation: This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
…ation (#7) This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs. See pytorch#9541 for the upstream PR discussion and additional context. * Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon. * New implementation (WIP) * Fix new implementation * Fix visualize_tensor_sharding function for V2 shardings
@AleksKnezevic I tried this out, but it seems to cause random segfaults. I’ll need to dig in further to figure out the root cause. |
fix for api match
Hi @hshahTT, @jazpurTT, @ddilbazTT, I’ve rebased our branch with the upstream changes and verified it on my side, but I’d appreciate it if you could double-check that everything works correctly. I’ve also built a wheel for testing, which you can find here: Please let me know if it installs and runs fine on your end. Thanks! |
@sshonTT are you still seeing segfaults? |
@AleksKnezevic No I don't see it now. |
That's great, what was causing them previously, any ideas? |
I couldn’t root-cause it completely, but it turned out to be a system-level issue related to Torch Inductor’s mutex handling. It was failing because a mutex was already acquired by another process or context, likely left uncleared from a previous pytest run. After releasing and reassigning the same IRD machine, the issue disappeared. I also verified it on another IRD machine to confirm that it works correctly now. |
awesome, thanks @sshonTT! Do we have a way of running CI with this wheel? |
I’ve triggered this workflow run |
I think we have a build issue since here. Will find a way to get over this. |
616047b
to
626b736
Compare
Torch build option change to avoid build warning and error.
626b736
to
27f7792
Compare
Build success after turning off warning as an error, but there is an error when publish it. I think it is some related to s3 bucket credential, @jazpurTT I believe you have experience on S3 bucket, so do you know what is going on and have any suggestion to fix it? |
checkout branch
9165c52
to
b1ebc54
Compare
Rebase to upstream