Skip to content

Conversation

jikunshang
Copy link
Collaborator

@jikunshang jikunshang commented Aug 11, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR enable torch compile on XPU platform. user can enable with -O3 option
limitations:

  • due to xpu still use ipex kernels for now, custom ops are not register. all custom ops (except attention) are not enabled, will use torch native impl. we will improve this with vllm-xpu-kernels. meanwhile, almost all custom passes are not supported.
  • xpu not support graph mode yet. so bypass graph capture, which means we still need --enforce-eager on xpu platform.

Test Plan

add a test in CI.

Test Result

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Aug 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds torch.compile support for the XPU platform. The changes correctly bypass CUDA graph capture and functionalization passes which are not yet supported on XPU. However, there is a critical issue where the XPU platform is configured to use CUDAPiecewiseBackend, which contains CUDA-specific API calls for graph capture. This will lead to runtime errors on XPU. I've provided a comment with a suggested fix.

@jikunshang jikunshang force-pushed the kunshang/t_compile_support branch from 27c1a28 to 16611f9 Compare August 15, 2025 00:50
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice & small! A few comments

Comment on lines 36 to 39
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused: what does it mean that XPU does not support autofunctionalization? And if so we should disable it at the pass level, not inside the loop

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is autofunctionalization pass is pretty special, not like some rule/op base pass, this will always be called in current code path. I feel it will be strange to change that.
I move this part to eariler. please take a review again.

Copy link

mergify bot commented Aug 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jikunshang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 15, 2025
@jikunshang jikunshang force-pushed the kunshang/t_compile_support branch from 16611f9 to e3ebfdb Compare August 16, 2025 02:15
@mergify mergify bot removed the needs-rebase label Aug 16, 2025
@jikunshang jikunshang force-pushed the kunshang/t_compile_support branch from a6bb017 to 061c150 Compare August 19, 2025 01:58
Comment on lines 201 to 181
def get_global_graph_pool(self) -> Any:
"""
Currently xpu does NOT support Graph model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just saying we don't support cudagraphs on xpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, torch-xpu will add this feature in the future.

Comment on lines 33 to 34
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -O3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is --enforce-eager -O3 ?

Can we do -O3 with use_cudagraph=False ? (or whatever the new way to disable cudagraphs is?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--enforce-eager -O3 this will use piecewise cuda graph compiler backend, but will not capture cuda graph on cuda device.(I am not certain about current cuda behavior, but it does work on cuda 2 or 3 month ago)
I think in vllm --enforce-eager equals to not use cuda graph, we don't expose use_cudagraph as an vllm arg.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a separate cudagraph arg. Now we should just use -O.cudagraph_mode=NONE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in vLLM --enforce-eager is supposed to mean "disable compile and cudagraphs", but it is not there yet. I don't want to add things like --enforce-eager -O3 that need to be updated later, we should use -O.cudagraph_mode=NONE instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense! please take a look again. thanks!

@jikunshang jikunshang force-pushed the kunshang/t_compile_support branch from 113bd0c to ddd14c5 Compare August 20, 2025 00:57
@chaojun-zhang
Copy link
Contributor

how about if just change CompilationLevel from piecewise to dynamo in xpu.py , then we don't need to add additional xpu check in these passes? as all passes in VllmBackend not supported in xpu?

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change config so that if cudagraphs are accidentaly enabled, we warn and disable them, unless this already happens?

"""
Currently xpu does NOT support Graph model.
"""
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unsafe, it'll break in an ugly way if someone enables cudagraphs. Have we tested this? At least we should raise an error here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we raise en error here, it will break the code path:( I prefer to log some warning here.
I agree it can be unsafe in the future. For now, graph_pool will not be used if cudagraph_mode is None.

Copy link
Collaborator

@ProExpertProg ProExpertProg Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we raise en error here, it will break the code path

Could you show me the error? Why does get_global_graph_pool get called at all?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see now, the graph pool handle is passed around the backend, for no good reason. static_graph_wrapper_cls doesn'r need it at all. Let me fix that and unblock you

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AI minion working in #23385

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing! will rebase&fix here once merged.

@jikunshang
Copy link
Collaborator Author

We should change config so that if cudagraphs are accidentaly enabled, we warn and disable them, unless this already happens?

it will always set cudagraph_mode to None on xpu. see here

@jikunshang jikunshang force-pushed the kunshang/t_compile_support branch from 4b1c7ef to d36b5d4 Compare August 21, 2025 01:32
@xuechendi
Copy link
Contributor

@ProExpertProg , may you help to approve the PR if the latest fixing has resolved your comment, thanks so much

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Aug 25, 2025

Let's wait for #23385 and remove the get_global_graph_pool, then we can approve and merge. I'll unblock CI so we can be ready to merge.

@ProExpertProg
Copy link
Collaborator

@xuechendi @jikunshang #23385 has just merged, please remove the get_global_graph_pool fro XPU platform, after that we should be good to merge this PR as well!

@@ -190,7 +190,7 @@ def __init__(
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
) and not current_platform.is_cpu() and not current_platform.is_xpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this a new property on the platform interface, called opaque_attention_op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, added.

Copy link

mergify bot commented Aug 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jikunshang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 26, 2025
@jikunshang jikunshang force-pushed the kunshang/t_compile_support branch from d36b5d4 to ca400b8 Compare August 26, 2025 03:13
@mergify mergify bot removed the needs-rebase label Aug 26, 2025
@jikunshang jikunshang requested a review from bigPYJ1151 as a code owner August 26, 2025 03:22
@mergify mergify bot added the rocm Related to AMD ROCm label Aug 26, 2025
@@ -182,3 +175,13 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
"Intel Arc A770 have bfloat16 accuracy known issue. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")

def get_global_graph_pool(self) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please still remove this method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

@jikunshang jikunshang force-pushed the kunshang/t_compile_support branch from b69a0f9 to 8ca13d9 Compare August 27, 2025 01:48
@ProExpertProg ProExpertProg enabled auto-merge (squash) August 27, 2025 03:16
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 27, 2025
@ProExpertProg ProExpertProg merged commit fce10db into vllm-project:main Aug 27, 2025
48 checks passed
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
dumb0002 pushed a commit to dumb0002/vllm that referenced this pull request Aug 28, 2025
2015aroras pushed a commit to 2015aroras/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants