Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ markers = [
"xpu: Tests that run on XPU (auto-added)",
"npu: Tests that run on NPU/Ascend (auto-added)",
"musa: Tests that run on MUSA/Moore Threads (auto-added)",
"maca: Tests that run on MACA/MetaX (auto-added)",
# specified computation resources marks (auto-added)
"H100: Tests that require H100 GPU",
"L4: Tests that require L4 GPU",
Expand All @@ -190,11 +191,13 @@ markers = [
"distributed_xpu: Tests that require multi cards on XPU platform",
"distributed_npu: Tests that require multi cards on NPU platform",
"distributed_musa: Tests that require multi cards on MUSA platform",
"distributed_maca: Tests that require multi cards on MACA/MetaX platform",
"skipif_cuda: Skip if the num of CUDA cards is less than the required",
"skipif_rocm: Skip if the num of ROCm cards is less than the required",
"skipif_xpu: Skip if the num of XPU cards is less than the required",
"skipif_npu: Skip if the num of NPU cards is less than the required",
"skipif_musa: Skip if the num of MUSA cards is less than the required",
"skipif_maca: Skip if the num of MACA cards is less than the required",
# more detailed markers
"slow: Slow tests (may skip in quick CI)",
"benchmark: Benchmark tests",
Expand Down
4 changes: 4 additions & 0 deletions requirements/maca.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-r common.txt
# MetaX MACA: install matching `vllm` + `vllm-metax` wheels or from source per
# https://github.com/MetaX-MACA/vLLM-metax — not published as a single PyPI extra here.
onnxruntime>=1.23.2
24 changes: 19 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

This setup.py implements platform-aware dependency routing so users can run
`pip install vllm-omni` and automatically receive the correct platform-specific
dependencies (CUDA/ROCm/CPU/XPU/NPU/MUSA) without requiring extras like `[cuda]`.
dependencies (CUDA/ROCm/CPU/XPU/NPU/MUSA/MACA) without requiring extras like `[cuda]`.

Optional PyPI vLLM pin (CUDA): ``pip install -e ".[vllm-upstream]"`` (see pyproject.toml).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vllm-upstream extra looks like it was dropped in a later commit — please remove this line from the docstring.

"""

import os
Expand Down Expand Up @@ -46,16 +48,16 @@ def detect_target_device() -> str:

Priority order:
1. VLLM_OMNI_TARGET_DEVICE environment variable (highest priority)
2. Torch backend detection (cuda, rocm, npu, xpu, musa)
2. Torch backend detection (maca, cuda, rocm, npu, xpu, musa)
3. CPU fallback (default)

Returns:
str: Device name ('cuda', 'rocm', 'npu', 'xpu', 'musa', or 'cpu')
str: Device name ('cuda', 'rocm', 'npu', 'xpu', 'musa', 'maca', or 'cpu')
"""
# Priority 1: Explicit override via environment variable
target_device = os.environ.get("VLLM_OMNI_TARGET_DEVICE")
if target_device:
valid_devices = ["cuda", "rocm", "npu", "xpu", "musa", "cpu"]
valid_devices = ["cuda", "rocm", "npu", "xpu", "musa", "maca", "cpu"]
if target_device.lower() in valid_devices:
print(f"Using target device from VLLM_OMNI_TARGET_DEVICE: {target_device.lower()}")
return target_device.lower()
Expand All @@ -68,7 +70,17 @@ def detect_target_device() -> str:
try:
import torch

# Check for CUDA
# MACA (MetaX): mcPyTorch may set torch.version.cuda; detect before generic CUDA.
try:
import vllm_metax # noqa: F401
except ImportError:
pass
else:
if torch.cuda.is_available():
print("Detected MACA (MetaX) backend from vllm-metax")
return "maca"

# Check for CUDA (NVIDIA)
if torch.version.cuda is not None:
print("Detected CUDA backend from torch")
return "cuda"
Expand Down Expand Up @@ -163,6 +175,8 @@ def get_vllm_omni_version() -> str:
version += f"{sep}xpu"
elif device == "musa":
version += f"{sep}musa"
elif device == "maca":
version += f"{sep}maca"
elif device == "cpu":
version += f"{sep}cpu"
else:
Expand Down
12 changes: 12 additions & 0 deletions vllm_omni/diffusion/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def forward(
return self.forward_hip(query, key, value, attn_metadata)
elif current_omni_platform.is_cuda():
return self.forward_cuda(query, key, value, attn_metadata)
elif current_omni_platform.is_maca():
return self.forward_maca(query, key, value, attn_metadata)
elif current_omni_platform.is_npu():
return self.forward_npu(query, key, value, attn_metadata)
elif current_omni_platform.is_xpu():
Expand Down Expand Up @@ -150,3 +152,13 @@ def forward_musa(
) -> torch.Tensor:
# By default, MUSA ops are compatible with CUDA ops.
return self.forward_cuda(query, key, value, attn_metadata)

def forward_maca(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T | None = None,
) -> torch.Tensor:
# MetaX MACA uses the CUDA-compatible runtime (vLLM-metax).
return self.forward_cuda(query, key, value, attn_metadata)
6 changes: 6 additions & 0 deletions vllm_omni/diffusion/layers/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def dispatch_forward(self) -> Callable:
return self.forward_hip
elif current_omni_platform.is_cuda():
return self.forward_cuda
elif current_omni_platform.is_maca():
return self.forward_maca
elif current_omni_platform.is_npu():
return self.forward_npu
elif current_omni_platform.is_xpu():
Expand Down Expand Up @@ -57,3 +59,7 @@ def forward_hip(self, *args, **kwargs):
def forward_musa(self, *args, **kwargs):
# By default, we assume that MUSA ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)

def forward_maca(self, *args, **kwargs):
# MetaX MACA uses the CUDA-compatible runtime (vLLM-metax).
return self.forward_cuda(*args, **kwargs)
25 changes: 25 additions & 0 deletions vllm_omni/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,37 @@ def musa_omni_platform_plugin() -> str | None:
return "vllm_omni.platforms.musa.platform.MUSAOmniPlatform" if is_musa else None


def maca_omni_platform_plugin() -> str | None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a MetaX host vllm_metax is importable and torch.cuda.is_available() is true — same is also trivially true on any NVIDIA box that happens to have vllm_metax installed. What resolves the collision with cuda_omni_platform_plugin at plugin-dispatch time? Worth a VLLM_OMNI_TARGET_DEVICE check here.

"""Check if MACA (MetaX) OmniPlatform should be activated.

Requires vLLM-metax and a CUDA-compatible torch runtime. If multiple Omni
platform plugins match, set ``VLLM_OMNI_TARGET_DEVICE`` (or adjust the
environment) so only one activates.
"""
logger.debug("Checking if MACA OmniPlatform is available.")
try:
import torch
except Exception as e:
logger.debug("MACA OmniPlatform is not available because: %s", str(e))
return None
try:
import vllm_metax # noqa: F401
except ImportError:
return None
if not torch.cuda.is_available():
logger.debug("MACA OmniPlatform is not available: CUDA runtime not available.")
return None
logger.debug("Confirmed MACA OmniPlatform is available.")
return "vllm_omni.platforms.maca.platform.MacaOmniPlatform"


builtin_omni_platform_plugins = {
"cuda": cuda_omni_platform_plugin,
"rocm": rocm_omni_platform_plugin,
"npu": npu_omni_platform_plugin,
"xpu": xpu_omni_platform_plugin,
"musa": musa_omni_platform_plugin,
"maca": maca_omni_platform_plugin,
}


Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class OmniPlatformEnum(Enum):
NPU = "npu"
XPU = "xpu"
MUSA = "musa"
MACA = "maca"
UNSPECIFIED = "unspecified"


Expand Down Expand Up @@ -49,6 +50,9 @@ def is_rocm(self) -> bool:
def is_musa(self) -> bool:
return self._omni_enum == OmniPlatformEnum.MUSA

def is_maca(self) -> bool:
return self._omni_enum == OmniPlatformEnum.MACA

@classmethod
def get_omni_ar_worker_cls(cls) -> str:
raise NotImplementedError
Expand Down
6 changes: 6 additions & 0 deletions vllm_omni/platforms/maca/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm_omni.platforms.maca.platform import MacaOmniPlatform

__all__ = ["MacaOmniPlatform"]
110 changes: 110 additions & 0 deletions vllm_omni/platforms/maca/platform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm_metax.platform import MacaPlatform as MacaPlatformBase

from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum
from vllm_omni.diffusion.envs import PACKAGES_CHECKER
from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum

logger = init_logger(__name__)


class MacaOmniPlatform(OmniPlatform, MacaPlatformBase):
"""MetaX MACA implementation of OmniPlatform.

Inherits MACA-specific behavior from vLLM-metax and adds Omni worker / diffusion hooks.
"""

_omni_enum = OmniPlatformEnum.MACA

@classmethod
def get_omni_ar_worker_cls(cls) -> str:
return "vllm_omni.platforms.maca.worker.maca_ar_worker.MacaARWorker"

@classmethod
def get_omni_generation_worker_cls(cls) -> str:
return "vllm_omni.platforms.maca.worker.maca_generation_worker.MacaGenerationWorker"

@classmethod
def get_default_stage_config_path(cls) -> str:
return "vllm_omni/model_executor/stage_configs"

@classmethod
def get_diffusion_model_impl_qualname(cls, op_name: str) -> str:
if op_name == "hunyuan_fused_moe":
return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
return super().get_diffusion_model_impl_qualname(op_name)

@classmethod
def get_diffusion_attn_backend_cls(
cls,
selected_backend: str | None,
head_size: int,
) -> str:
"""Diffusion attention backend selection (CUDA-compatible path, like CudaOmniPlatform)."""

compute_capability = cls.get_device_capability()
compute_supported = False
if compute_capability is not None:
major, minor = compute_capability
capability = major * 10 + minor
compute_supported = 80 <= capability < 100

packages_info = PACKAGES_CHECKER.get_packages_info()
packages_available = packages_info.get("has_flash_attn", False)
flash_attn_supported = compute_supported and packages_available

if selected_backend is not None:
backend_upper = selected_backend.upper()
if backend_upper == "FLASH_ATTN" and not flash_attn_supported:
if not compute_supported:
logger.warning(
"Flash Attention expects compute capability >= 8.0 and < 10.0. "
"Falling back to TORCH_SDPA backend."
)
elif not packages_available:
logger.warning("Flash Attention packages not available. Falling back to TORCH_SDPA backend.")
logger.info("Defaulting to diffusion attention backend SDPA")
return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path()
backend = DiffusionAttentionBackendEnum[backend_upper]
logger.info("Using diffusion attention backend '%s'", backend_upper)
return backend.get_path()

if flash_attn_supported:
logger.info("Defaulting to diffusion attention backend FLASH_ATTN")
return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path()

logger.info("Defaulting to diffusion attention backend SDPA")
return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path()

@classmethod
def supports_torch_inductor(cls) -> bool:
# vLLM-metax currently disables generic Triton kernel paths; keep inductor off by default.
return False

@classmethod
def get_torch_device(cls, local_rank: int | None = None) -> torch.device:
if local_rank is None:
return torch.device("cuda")
return torch.device("cuda", local_rank)

@classmethod
def get_device_count(cls) -> int:
return cuda_device_count_stateless()

@classmethod
def get_device_version(cls) -> str | None:
return torch.version.cuda

@classmethod
def synchronize(cls) -> None:
torch.cuda.synchronize()

@classmethod
def get_free_memory(cls, device: torch.device | None = None) -> int:
free, _ = torch.cuda.mem_get_info(device)
return free
7 changes: 7 additions & 0 deletions vllm_omni/platforms/maca/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm_omni.platforms.maca.worker.maca_ar_worker import MacaARWorker
from vllm_omni.platforms.maca.worker.maca_generation_worker import MacaGenerationWorker

__all__ = ["MacaARWorker", "MacaGenerationWorker"]
14 changes: 14 additions & 0 deletions vllm_omni/platforms/maca/worker/maca_ar_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""MACA (MetaX) AR worker for vLLM-Omni.

MetaX uses the CUDA-compatible runtime (``device_type="cuda"`` in vLLM-metax); this
worker reuses the standard Omni GPU AR worker implementation.
"""

from vllm_omni.worker.gpu_ar_worker import GPUARWorker


class MacaARWorker(GPUARWorker):
"""Autoregressive omni stages on MetaX MACA via vLLM-metax."""
13 changes: 13 additions & 0 deletions vllm_omni/platforms/maca/worker/maca_generation_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""MACA (MetaX) generation worker for vLLM-Omni.

Non-AR stages use the same CUDA-compatible device path as ``GPUGenerationWorker``.
"""

from vllm_omni.worker.gpu_generation_worker import GPUGenerationWorker


class MacaGenerationWorker(GPUGenerationWorker):
"""Non-autoregressive generation stages on MetaX MACA via vLLM-metax."""
4 changes: 4 additions & 0 deletions vllm_omni/quantization/int8_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def get_quant_method(
if not self.is_checkpoint_int8_serialized:
if current_omni_platform.is_cuda():
online_method = Int8OnlineLinearMethod(self)
elif current_omni_platform.is_maca():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif current_omni_platform.is_maca():
elif current_omni_platform.is_cuda() or current_omni_platform.is_maca():

Can you just or the two branches? The maca branch body is identical.

online_method = Int8OnlineLinearMethod(self)
elif current_omni_platform.is_npu():
online_method = NPUInt8OnlineLinearMethod(self)
else:
Expand All @@ -156,6 +158,8 @@ def get_quant_method(
else:
if current_omni_platform.is_cuda():
offline_method = Int8LinearMethod(self)
elif current_omni_platform.is_maca():
offline_method = Int8LinearMethod(self)
elif current_omni_platform.is_npu():
offline_method = NPUInt8LinearMethod(self)
else:
Expand Down
Loading