Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .buildkite/scripts/hardware_ci/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ function cpu_tests() {
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"

# Run AWQ test
# docker exec cpu-test-"$NUMA_NODE" bash -c "
# set -e
# VLLM_USE_V1=0 pytest -s -v \
# tests/quantization/test_ipex_quant.py"
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
VLLM_USE_V1=0 pytest -s -v \
tests/quantization/test_ipex_quant.py"

# Run chunked-prefill and prefix-cache test
docker exec cpu-test-"$NUMA_NODE" bash -c "
Expand Down
2 changes: 2 additions & 0 deletions .buildkite/scripts/hardware_ci/run-xpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ docker run \
--name "${container_name}" \
"${image_name}" \
sh -c '
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
'
2 changes: 1 addition & 1 deletion examples/online_serving/chart-helm/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ image:
# -- Image tag
tag: "latest"
# -- Container launch command
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]

# -- Container port
containerPort: 8000
Expand Down
17 changes: 7 additions & 10 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def clear_cache():
DEVICE_MLA_BLOCK_SIZES = {
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
"hip": [16, 1], # HIP requires special handling for block_size=1
# "cpu": [16] # CPU uses fixed block size from test cases
"cpu": [] # FIXME(woosuk): Temporarily disable CPU tests
"cpu": [16] # CPU uses fixed block size from test cases
}


Expand Down Expand Up @@ -82,14 +81,14 @@ def test_env(
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")

if device == "cpu":
if not use_v1:
pytest.skip("CPU backend only supports V1")

with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
block_size, False)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1:
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
else:
assert backend.get_name() == "TORCH_SDPA"

elif device == "hip":
with patch("vllm.attention.selector.current_platform",
Expand Down Expand Up @@ -205,14 +204,12 @@ def test_fp32_fallback(
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")

if device == "cpu":
if not use_v1:
pytest.skip("CPU backend only supports V1")

with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")

elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
Expand Down
307 changes: 307 additions & 0 deletions vllm/attention/backends/cpu_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

import vllm._custom_ops as ops
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder


class CPUMLABackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "CPU_MLA"

@staticmethod
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
return CPUMLAMetadata

@staticmethod
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
return CPUMLAMetadataBuilder

@staticmethod
def get_state_cls() -> Type["MLACommonState"]:
return MLACommonState

@staticmethod
def get_impl_cls() -> Type["CPUMLAImpl"]:
return CPUMLAImpl

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
ops.copy_blocks_mla(kv_caches, src_to_dists)

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]


@dataclass
class CPUMLAMetadata(TorchSDPAMetadata):
# New for MLA
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor = None

# required by MLACommonImpl
is_profile_run: bool = False


class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):

def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
assert not self.chunked_prefill, \
"chunked prefill is currently not supported"

def prepare(self):
self.input_data = self.input_builder.input_data

def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")

# metadata for prefill
if input_data.num_prefills > 0:
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)

# for chunked-prefill
if self.chunked_prefill:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
prefill_block_tables = None

else:
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
prefill_block_tables = None

# metadata for decode
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device="cpu",
)

# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}

return CPUMLAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
prefill_query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
input_positions=torch.tensor([self.input_data.input_positions]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The input_positions tensor is being created as a 2D tensor by wrapping self.input_data.input_positions in an extra list. The rotary embedding layer expects a 1D tensor for positions. This will likely lead to an indexing error or incorrect behavior at runtime.

Suggested change
input_positions=torch.tensor([self.input_data.input_positions]))
input_positions=torch.tensor(self.input_data.input_positions))



class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CPUMLAImpl")

# states is implemented.
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CPUMLAImpl with FP8 KV cache not yet supported")

def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:

prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None

kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

output = torch.empty_like(q)
ipex_ops.varlen_attention(
query=q,
key=k,
value=v_padded,
out=output,
seqlen_q=prefill_metadata.prefill_query_start_loc,
seqlen_k=prefill_metadata.prefill_query_start_loc,
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.max_query_len,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None,
logits_soft_cap=0.0,
window_size_left=-1,
window_size_right=-1,
alibi_slopes=None,
)

# remove padding
output = output.view(-1, self.num_heads,
q.shape[-1])[..., :v.shape[-1]]
return output.reshape(-1, self.num_heads * v.shape[-1])

def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0

decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None

q = torch.cat([q_nope, q_pe], dim=-1)
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)

# Run MQA
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor)
return self._v_up_proj(o)
Loading
Loading