diff --git a/pyproject.toml b/pyproject.toml index e49aa6e325..dbcef129ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,6 +178,7 @@ markers = [ "xpu: Tests that run on XPU (auto-added)", "npu: Tests that run on NPU/Ascend (auto-added)", "musa: Tests that run on MUSA/Moore Threads (auto-added)", + "maca: Tests that run on MACA/MetaX (auto-added)", # specified computation resources marks (auto-added) "H100: Tests that require H100 GPU", "L4: Tests that require L4 GPU", @@ -190,11 +191,13 @@ markers = [ "distributed_xpu: Tests that require multi cards on XPU platform", "distributed_npu: Tests that require multi cards on NPU platform", "distributed_musa: Tests that require multi cards on MUSA platform", + "distributed_maca: Tests that require multi cards on MACA/MetaX platform", "skipif_cuda: Skip if the num of CUDA cards is less than the required", "skipif_rocm: Skip if the num of ROCm cards is less than the required", "skipif_xpu: Skip if the num of XPU cards is less than the required", "skipif_npu: Skip if the num of NPU cards is less than the required", "skipif_musa: Skip if the num of MUSA cards is less than the required", + "skipif_maca: Skip if the num of MACA cards is less than the required", # more detailed markers "slow: Slow tests (may skip in quick CI)", "benchmark: Benchmark tests", diff --git a/requirements/maca.txt b/requirements/maca.txt new file mode 100644 index 0000000000..8c4ff1c500 --- /dev/null +++ b/requirements/maca.txt @@ -0,0 +1,4 @@ +-r common.txt +# MetaX MACA: install matching `vllm` + `vllm-metax` wheels or from source per +# https://github.com/MetaX-MACA/vLLM-metax — not published as a single PyPI extra here. +onnxruntime>=1.23.2 diff --git a/setup.py b/setup.py index 057212d67f..8a2a92d17b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,9 @@ This setup.py implements platform-aware dependency routing so users can run `pip install vllm-omni` and automatically receive the correct platform-specific -dependencies (CUDA/ROCm/CPU/XPU/NPU/MUSA) without requiring extras like `[cuda]`. +dependencies (CUDA/ROCm/CPU/XPU/NPU/MUSA/MACA) without requiring extras like `[cuda]`. + +Optional PyPI vLLM pin (CUDA): ``pip install -e ".[vllm-upstream]"`` (see pyproject.toml). """ import os @@ -46,16 +48,16 @@ def detect_target_device() -> str: Priority order: 1. VLLM_OMNI_TARGET_DEVICE environment variable (highest priority) - 2. Torch backend detection (cuda, rocm, npu, xpu, musa) + 2. Torch backend detection (maca, cuda, rocm, npu, xpu, musa) 3. CPU fallback (default) Returns: - str: Device name ('cuda', 'rocm', 'npu', 'xpu', 'musa', or 'cpu') + str: Device name ('cuda', 'rocm', 'npu', 'xpu', 'musa', 'maca', or 'cpu') """ # Priority 1: Explicit override via environment variable target_device = os.environ.get("VLLM_OMNI_TARGET_DEVICE") if target_device: - valid_devices = ["cuda", "rocm", "npu", "xpu", "musa", "cpu"] + valid_devices = ["cuda", "rocm", "npu", "xpu", "musa", "maca", "cpu"] if target_device.lower() in valid_devices: print(f"Using target device from VLLM_OMNI_TARGET_DEVICE: {target_device.lower()}") return target_device.lower() @@ -68,7 +70,17 @@ def detect_target_device() -> str: try: import torch - # Check for CUDA + # MACA (MetaX): mcPyTorch may set torch.version.cuda; detect before generic CUDA. + try: + import vllm_metax # noqa: F401 + except ImportError: + pass + else: + if torch.cuda.is_available(): + print("Detected MACA (MetaX) backend from vllm-metax") + return "maca" + + # Check for CUDA (NVIDIA) if torch.version.cuda is not None: print("Detected CUDA backend from torch") return "cuda" @@ -163,6 +175,8 @@ def get_vllm_omni_version() -> str: version += f"{sep}xpu" elif device == "musa": version += f"{sep}musa" + elif device == "maca": + version += f"{sep}maca" elif device == "cpu": version += f"{sep}cpu" else: diff --git a/vllm_omni/diffusion/attention/backends/abstract.py b/vllm_omni/diffusion/attention/backends/abstract.py index 472fde422d..20c3a17eb8 100644 --- a/vllm_omni/diffusion/attention/backends/abstract.py +++ b/vllm_omni/diffusion/attention/backends/abstract.py @@ -95,6 +95,8 @@ def forward( return self.forward_hip(query, key, value, attn_metadata) elif current_omni_platform.is_cuda(): return self.forward_cuda(query, key, value, attn_metadata) + elif current_omni_platform.is_maca(): + return self.forward_maca(query, key, value, attn_metadata) elif current_omni_platform.is_npu(): return self.forward_npu(query, key, value, attn_metadata) elif current_omni_platform.is_xpu(): @@ -150,3 +152,13 @@ def forward_musa( ) -> torch.Tensor: # By default, MUSA ops are compatible with CUDA ops. return self.forward_cuda(query, key, value, attn_metadata) + + def forward_maca( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T | None = None, + ) -> torch.Tensor: + # MetaX MACA uses the CUDA-compatible runtime (vLLM-metax). + return self.forward_cuda(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/layers/custom_op.py b/vllm_omni/diffusion/layers/custom_op.py index 27e3bce1f2..6fa1acd6fb 100644 --- a/vllm_omni/diffusion/layers/custom_op.py +++ b/vllm_omni/diffusion/layers/custom_op.py @@ -21,6 +21,8 @@ def dispatch_forward(self) -> Callable: return self.forward_hip elif current_omni_platform.is_cuda(): return self.forward_cuda + elif current_omni_platform.is_maca(): + return self.forward_maca elif current_omni_platform.is_npu(): return self.forward_npu elif current_omni_platform.is_xpu(): @@ -57,3 +59,7 @@ def forward_hip(self, *args, **kwargs): def forward_musa(self, *args, **kwargs): # By default, we assume that MUSA ops are compatible with CUDA ops. return self.forward_cuda(*args, **kwargs) + + def forward_maca(self, *args, **kwargs): + # MetaX MACA uses the CUDA-compatible runtime (vLLM-metax). + return self.forward_cuda(*args, **kwargs) diff --git a/vllm_omni/platforms/__init__.py b/vllm_omni/platforms/__init__.py index 64a7cdb16f..ba96740ca2 100644 --- a/vllm_omni/platforms/__init__.py +++ b/vllm_omni/platforms/__init__.py @@ -121,12 +121,37 @@ def musa_omni_platform_plugin() -> str | None: return "vllm_omni.platforms.musa.platform.MUSAOmniPlatform" if is_musa else None +def maca_omni_platform_plugin() -> str | None: + """Check if MACA (MetaX) OmniPlatform should be activated. + + Requires vLLM-metax and a CUDA-compatible torch runtime. If multiple Omni + platform plugins match, set ``VLLM_OMNI_TARGET_DEVICE`` (or adjust the + environment) so only one activates. + """ + logger.debug("Checking if MACA OmniPlatform is available.") + try: + import torch + except Exception as e: + logger.debug("MACA OmniPlatform is not available because: %s", str(e)) + return None + try: + import vllm_metax # noqa: F401 + except ImportError: + return None + if not torch.cuda.is_available(): + logger.debug("MACA OmniPlatform is not available: CUDA runtime not available.") + return None + logger.debug("Confirmed MACA OmniPlatform is available.") + return "vllm_omni.platforms.maca.platform.MacaOmniPlatform" + + builtin_omni_platform_plugins = { "cuda": cuda_omni_platform_plugin, "rocm": rocm_omni_platform_plugin, "npu": npu_omni_platform_plugin, "xpu": xpu_omni_platform_plugin, "musa": musa_omni_platform_plugin, + "maca": maca_omni_platform_plugin, } diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py index 8f1e66747d..e579e21720 100644 --- a/vllm_omni/platforms/interface.py +++ b/vllm_omni/platforms/interface.py @@ -20,6 +20,7 @@ class OmniPlatformEnum(Enum): NPU = "npu" XPU = "xpu" MUSA = "musa" + MACA = "maca" UNSPECIFIED = "unspecified" @@ -49,6 +50,9 @@ def is_rocm(self) -> bool: def is_musa(self) -> bool: return self._omni_enum == OmniPlatformEnum.MUSA + def is_maca(self) -> bool: + return self._omni_enum == OmniPlatformEnum.MACA + @classmethod def get_omni_ar_worker_cls(cls) -> str: raise NotImplementedError diff --git a/vllm_omni/platforms/maca/__init__.py b/vllm_omni/platforms/maca/__init__.py new file mode 100644 index 0000000000..d2fbc9f4ef --- /dev/null +++ b/vllm_omni/platforms/maca/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.platforms.maca.platform import MacaOmniPlatform + +__all__ = ["MacaOmniPlatform"] diff --git a/vllm_omni/platforms/maca/platform.py b/vllm_omni/platforms/maca/platform.py new file mode 100644 index 0000000000..6a6f2b71da --- /dev/null +++ b/vllm_omni/platforms/maca/platform.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger +from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm_metax.platform import MacaPlatform as MacaPlatformBase + +from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum +from vllm_omni.diffusion.envs import PACKAGES_CHECKER +from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum + +logger = init_logger(__name__) + + +class MacaOmniPlatform(OmniPlatform, MacaPlatformBase): + """MetaX MACA implementation of OmniPlatform. + + Inherits MACA-specific behavior from vLLM-metax and adds Omni worker / diffusion hooks. + """ + + _omni_enum = OmniPlatformEnum.MACA + + @classmethod + def get_omni_ar_worker_cls(cls) -> str: + return "vllm_omni.platforms.maca.worker.maca_ar_worker.MacaARWorker" + + @classmethod + def get_omni_generation_worker_cls(cls) -> str: + return "vllm_omni.platforms.maca.worker.maca_generation_worker.MacaGenerationWorker" + + @classmethod + def get_default_stage_config_path(cls) -> str: + return "vllm_omni/model_executor/stage_configs" + + @classmethod + def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: + if op_name == "hunyuan_fused_moe": + return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault" + return super().get_diffusion_model_impl_qualname(op_name) + + @classmethod + def get_diffusion_attn_backend_cls( + cls, + selected_backend: str | None, + head_size: int, + ) -> str: + """Diffusion attention backend selection (CUDA-compatible path, like CudaOmniPlatform).""" + + compute_capability = cls.get_device_capability() + compute_supported = False + if compute_capability is not None: + major, minor = compute_capability + capability = major * 10 + minor + compute_supported = 80 <= capability < 100 + + packages_info = PACKAGES_CHECKER.get_packages_info() + packages_available = packages_info.get("has_flash_attn", False) + flash_attn_supported = compute_supported and packages_available + + if selected_backend is not None: + backend_upper = selected_backend.upper() + if backend_upper == "FLASH_ATTN" and not flash_attn_supported: + if not compute_supported: + logger.warning( + "Flash Attention expects compute capability >= 8.0 and < 10.0. " + "Falling back to TORCH_SDPA backend." + ) + elif not packages_available: + logger.warning("Flash Attention packages not available. Falling back to TORCH_SDPA backend.") + logger.info("Defaulting to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + backend = DiffusionAttentionBackendEnum[backend_upper] + logger.info("Using diffusion attention backend '%s'", backend_upper) + return backend.get_path() + + if flash_attn_supported: + logger.info("Defaulting to diffusion attention backend FLASH_ATTN") + return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path() + + logger.info("Defaulting to diffusion attention backend SDPA") + return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path() + + @classmethod + def supports_torch_inductor(cls) -> bool: + # vLLM-metax currently disables generic Triton kernel paths; keep inductor off by default. + return False + + @classmethod + def get_torch_device(cls, local_rank: int | None = None) -> torch.device: + if local_rank is None: + return torch.device("cuda") + return torch.device("cuda", local_rank) + + @classmethod + def get_device_count(cls) -> int: + return cuda_device_count_stateless() + + @classmethod + def get_device_version(cls) -> str | None: + return torch.version.cuda + + @classmethod + def synchronize(cls) -> None: + torch.cuda.synchronize() + + @classmethod + def get_free_memory(cls, device: torch.device | None = None) -> int: + free, _ = torch.cuda.mem_get_info(device) + return free diff --git a/vllm_omni/platforms/maca/worker/__init__.py b/vllm_omni/platforms/maca/worker/__init__.py new file mode 100644 index 0000000000..ba40deefb2 --- /dev/null +++ b/vllm_omni/platforms/maca/worker/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.platforms.maca.worker.maca_ar_worker import MacaARWorker +from vllm_omni.platforms.maca.worker.maca_generation_worker import MacaGenerationWorker + +__all__ = ["MacaARWorker", "MacaGenerationWorker"] diff --git a/vllm_omni/platforms/maca/worker/maca_ar_worker.py b/vllm_omni/platforms/maca/worker/maca_ar_worker.py new file mode 100644 index 0000000000..baa345443e --- /dev/null +++ b/vllm_omni/platforms/maca/worker/maca_ar_worker.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""MACA (MetaX) AR worker for vLLM-Omni. + +MetaX uses the CUDA-compatible runtime (``device_type="cuda"`` in vLLM-metax); this +worker reuses the standard Omni GPU AR worker implementation. +""" + +from vllm_omni.worker.gpu_ar_worker import GPUARWorker + + +class MacaARWorker(GPUARWorker): + """Autoregressive omni stages on MetaX MACA via vLLM-metax.""" diff --git a/vllm_omni/platforms/maca/worker/maca_generation_worker.py b/vllm_omni/platforms/maca/worker/maca_generation_worker.py new file mode 100644 index 0000000000..3680d9abf6 --- /dev/null +++ b/vllm_omni/platforms/maca/worker/maca_generation_worker.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""MACA (MetaX) generation worker for vLLM-Omni. + +Non-AR stages use the same CUDA-compatible device path as ``GPUGenerationWorker``. +""" + +from vllm_omni.worker.gpu_generation_worker import GPUGenerationWorker + + +class MacaGenerationWorker(GPUGenerationWorker): + """Non-autoregressive generation stages on MetaX MACA via vLLM-metax.""" diff --git a/vllm_omni/quantization/int8_config.py b/vllm_omni/quantization/int8_config.py index 37d4300470..bc8846d5b1 100644 --- a/vllm_omni/quantization/int8_config.py +++ b/vllm_omni/quantization/int8_config.py @@ -148,6 +148,8 @@ def get_quant_method( if not self.is_checkpoint_int8_serialized: if current_omni_platform.is_cuda(): online_method = Int8OnlineLinearMethod(self) + elif current_omni_platform.is_maca(): + online_method = Int8OnlineLinearMethod(self) elif current_omni_platform.is_npu(): online_method = NPUInt8OnlineLinearMethod(self) else: @@ -156,6 +158,8 @@ def get_quant_method( else: if current_omni_platform.is_cuda(): offline_method = Int8LinearMethod(self) + elif current_omni_platform.is_maca(): + offline_method = Int8LinearMethod(self) elif current_omni_platform.is_npu(): offline_method = NPUInt8LinearMethod(self) else: