-
Notifications
You must be signed in to change notification settings - Fork 743
feat(platforms): add MetaX MACA support and vllm-upstream extra #2596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
168ce10
eae4d72
1f34f09
2a7983e
cac28ba
989ba54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| -r common.txt | ||
| # MetaX MACA: install matching `vllm` + `vllm-metax` wheels or from source per | ||
| # https://github.com/MetaX-MACA/vLLM-metax — not published as a single PyPI extra here. | ||
| onnxruntime>=1.23.2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,12 +121,37 @@ def musa_omni_platform_plugin() -> str | None: | |
| return "vllm_omni.platforms.musa.platform.MUSAOmniPlatform" if is_musa else None | ||
|
|
||
|
|
||
| def maca_omni_platform_plugin() -> str | None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On a MetaX host |
||
| """Check if MACA (MetaX) OmniPlatform should be activated. | ||
|
|
||
| Requires vLLM-metax and a CUDA-compatible torch runtime. If multiple Omni | ||
| platform plugins match, set ``VLLM_OMNI_TARGET_DEVICE`` (or adjust the | ||
| environment) so only one activates. | ||
| """ | ||
| logger.debug("Checking if MACA OmniPlatform is available.") | ||
| try: | ||
| import torch | ||
| except Exception as e: | ||
| logger.debug("MACA OmniPlatform is not available because: %s", str(e)) | ||
| return None | ||
| try: | ||
| import vllm_metax # noqa: F401 | ||
| except ImportError: | ||
| return None | ||
| if not torch.cuda.is_available(): | ||
| logger.debug("MACA OmniPlatform is not available: CUDA runtime not available.") | ||
| return None | ||
| logger.debug("Confirmed MACA OmniPlatform is available.") | ||
| return "vllm_omni.platforms.maca.platform.MacaOmniPlatform" | ||
|
|
||
|
|
||
| builtin_omni_platform_plugins = { | ||
| "cuda": cuda_omni_platform_plugin, | ||
| "rocm": rocm_omni_platform_plugin, | ||
| "npu": npu_omni_platform_plugin, | ||
| "xpu": xpu_omni_platform_plugin, | ||
| "musa": musa_omni_platform_plugin, | ||
| "maca": maca_omni_platform_plugin, | ||
| } | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from vllm_omni.platforms.maca.platform import MacaOmniPlatform | ||
|
|
||
| __all__ = ["MacaOmniPlatform"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import torch | ||
| from vllm.logger import init_logger | ||
| from vllm.utils.torch_utils import cuda_device_count_stateless | ||
| from vllm_metax.platform import MacaPlatform as MacaPlatformBase | ||
|
|
||
| from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum | ||
| from vllm_omni.diffusion.envs import PACKAGES_CHECKER | ||
| from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class MacaOmniPlatform(OmniPlatform, MacaPlatformBase): | ||
| """MetaX MACA implementation of OmniPlatform. | ||
|
|
||
| Inherits MACA-specific behavior from vLLM-metax and adds Omni worker / diffusion hooks. | ||
| """ | ||
|
|
||
| _omni_enum = OmniPlatformEnum.MACA | ||
|
|
||
| @classmethod | ||
| def get_omni_ar_worker_cls(cls) -> str: | ||
| return "vllm_omni.platforms.maca.worker.maca_ar_worker.MacaARWorker" | ||
|
|
||
| @classmethod | ||
| def get_omni_generation_worker_cls(cls) -> str: | ||
| return "vllm_omni.platforms.maca.worker.maca_generation_worker.MacaGenerationWorker" | ||
|
|
||
| @classmethod | ||
| def get_default_stage_config_path(cls) -> str: | ||
| return "vllm_omni/model_executor/stage_configs" | ||
|
|
||
| @classmethod | ||
| def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: | ||
| if op_name == "hunyuan_fused_moe": | ||
| return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault" | ||
| return super().get_diffusion_model_impl_qualname(op_name) | ||
|
|
||
| @classmethod | ||
| def get_diffusion_attn_backend_cls( | ||
| cls, | ||
| selected_backend: str | None, | ||
| head_size: int, | ||
| ) -> str: | ||
| """Diffusion attention backend selection (CUDA-compatible path, like CudaOmniPlatform).""" | ||
|
|
||
| compute_capability = cls.get_device_capability() | ||
| compute_supported = False | ||
| if compute_capability is not None: | ||
| major, minor = compute_capability | ||
| capability = major * 10 + minor | ||
| compute_supported = 80 <= capability < 100 | ||
|
|
||
| packages_info = PACKAGES_CHECKER.get_packages_info() | ||
| packages_available = packages_info.get("has_flash_attn", False) | ||
| flash_attn_supported = compute_supported and packages_available | ||
|
|
||
| if selected_backend is not None: | ||
| backend_upper = selected_backend.upper() | ||
| if backend_upper == "FLASH_ATTN" and not flash_attn_supported: | ||
| if not compute_supported: | ||
| logger.warning( | ||
| "Flash Attention expects compute capability >= 8.0 and < 10.0. " | ||
| "Falling back to TORCH_SDPA backend." | ||
| ) | ||
| elif not packages_available: | ||
| logger.warning("Flash Attention packages not available. Falling back to TORCH_SDPA backend.") | ||
| logger.info("Defaulting to diffusion attention backend SDPA") | ||
| return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() | ||
| backend = DiffusionAttentionBackendEnum[backend_upper] | ||
| logger.info("Using diffusion attention backend '%s'", backend_upper) | ||
| return backend.get_path() | ||
|
|
||
| if flash_attn_supported: | ||
| logger.info("Defaulting to diffusion attention backend FLASH_ATTN") | ||
| return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path() | ||
|
|
||
| logger.info("Defaulting to diffusion attention backend SDPA") | ||
| return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() | ||
|
|
||
| @classmethod | ||
| def supports_torch_inductor(cls) -> bool: | ||
| # vLLM-metax currently disables generic Triton kernel paths; keep inductor off by default. | ||
| return False | ||
|
|
||
| @classmethod | ||
| def get_torch_device(cls, local_rank: int | None = None) -> torch.device: | ||
| if local_rank is None: | ||
| return torch.device("cuda") | ||
| return torch.device("cuda", local_rank) | ||
|
|
||
| @classmethod | ||
| def get_device_count(cls) -> int: | ||
| return cuda_device_count_stateless() | ||
|
|
||
| @classmethod | ||
| def get_device_version(cls) -> str | None: | ||
| return torch.version.cuda | ||
|
|
||
| @classmethod | ||
| def synchronize(cls) -> None: | ||
| torch.cuda.synchronize() | ||
|
|
||
| @classmethod | ||
| def get_free_memory(cls, device: torch.device | None = None) -> int: | ||
| free, _ = torch.cuda.mem_get_info(device) | ||
| return free |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from vllm_omni.platforms.maca.worker.maca_ar_worker import MacaARWorker | ||
| from vllm_omni.platforms.maca.worker.maca_generation_worker import MacaGenerationWorker | ||
|
|
||
| __all__ = ["MacaARWorker", "MacaGenerationWorker"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| """MACA (MetaX) AR worker for vLLM-Omni. | ||
|
|
||
| MetaX uses the CUDA-compatible runtime (``device_type="cuda"`` in vLLM-metax); this | ||
| worker reuses the standard Omni GPU AR worker implementation. | ||
| """ | ||
|
|
||
| from vllm_omni.worker.gpu_ar_worker import GPUARWorker | ||
|
|
||
|
|
||
| class MacaARWorker(GPUARWorker): | ||
| """Autoregressive omni stages on MetaX MACA via vLLM-metax.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| """MACA (MetaX) generation worker for vLLM-Omni. | ||
|
|
||
| Non-AR stages use the same CUDA-compatible device path as ``GPUGenerationWorker``. | ||
| """ | ||
|
|
||
| from vllm_omni.worker.gpu_generation_worker import GPUGenerationWorker | ||
|
|
||
|
|
||
| class MacaGenerationWorker(GPUGenerationWorker): | ||
| """Non-autoregressive generation stages on MetaX MACA via vLLM-metax.""" |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -148,6 +148,8 @@ def get_quant_method( | |||||
| if not self.is_checkpoint_int8_serialized: | ||||||
| if current_omni_platform.is_cuda(): | ||||||
| online_method = Int8OnlineLinearMethod(self) | ||||||
| elif current_omni_platform.is_maca(): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Can you just |
||||||
| online_method = Int8OnlineLinearMethod(self) | ||||||
| elif current_omni_platform.is_npu(): | ||||||
| online_method = NPUInt8OnlineLinearMethod(self) | ||||||
| else: | ||||||
|
|
@@ -156,6 +158,8 @@ def get_quant_method( | |||||
| else: | ||||||
| if current_omni_platform.is_cuda(): | ||||||
| offline_method = Int8LinearMethod(self) | ||||||
| elif current_omni_platform.is_maca(): | ||||||
| offline_method = Int8LinearMethod(self) | ||||||
| elif current_omni_platform.is_npu(): | ||||||
| offline_method = NPUInt8LinearMethod(self) | ||||||
| else: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
vllm-upstreamextra looks like it was dropped in a later commit — please remove this line from the docstring.