diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index a3133c1796..94a83c65a7 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -1,198 +1,214 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any import torch import torch.nn as nn +import torch.nn.functional as F from vllm.config import VllmConfig from vllm.config.vllm import set_current_vllm_config -from vllm.forward_context import set_forward_context from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) -from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer +from vllm.model_executor.models.qwen2 import Qwen2MLP as Qwen3MLP from vllm.model_executor.models.utils import is_pp_missing_parameter -from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor -from vllm.v1.worker.gpu import attn_utils - -from vllm_omni.platforms import current_omni_platform +from vllm.transformers_utils.config import set_default_rope_theta from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig +# Type alias for per-layer KV cache: (k_cache, v_cache) each of shape +# [max_batch_size, num_kv_heads, max_seq_len, head_dim]. +KVCache = tuple[torch.Tensor, torch.Tensor] + -class _LocalPredictorKVCache: - """Minimal local KV cache + attention metadata for running - code_predictor inside one worker (independent of engine KV).""" +class CodePredictorAttention(nn.Module): + """Standalone attention using SDPA + dense KV buffers. + + Reuses QKVParallelLinear, RowParallelLinear, RMSNorm (QK-norm), and RoPE + from vLLM but replaces the paged-attention backend with + ``F.scaled_dot_product_attention``. + """ def __init__( self, - *, - vllm_config: VllmConfig, - max_seq_len: int, - max_batch_size: int, - device: torch.device, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_parameters: dict, + max_position: int = 32768, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + quant_config=None, + prefix: str = "", ) -> None: - self.vllm_config = vllm_config - self.device = device - - # Collect attention layers registered in this vllm_config. - kv_cache_spec_by_layer = attn_utils.get_kv_cache_spec(vllm_config) - if not kv_cache_spec_by_layer: - raise RuntimeError("Local predictor KVCache requires vLLM Attention layers to be registered.") - - # We only need enough blocks for a tiny per-frame sequence (<= max_seq_len). - any_spec = next(iter(kv_cache_spec_by_layer.values())) - block_size = int(any_spec.block_size) - blocks_per_seq = (int(max_seq_len) + block_size - 1) // block_size - num_blocks = max(1, int(max_batch_size) * int(blocks_per_seq)) - - # Allocate per-layer KV caches (small, independent). - kv_cache_tensors: list[KVCacheTensor] = [] - for layer_name, spec in kv_cache_spec_by_layer.items(): - kv_cache_tensors.append(KVCacheTensor(size=int(spec.page_size_bytes) * num_blocks, shared_by=[layer_name])) - - merged_spec: KVCacheSpec = KVCacheSpec.merge(list(kv_cache_spec_by_layer.values())) - self.kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=kv_cache_tensors, - kv_cache_groups=[ - KVCacheGroupSpec(layer_names=list(kv_cache_spec_by_layer.keys()), kv_cache_spec=merged_spec) - ], + super().__init__() + self.hidden_size = hidden_size + self.total_num_heads = num_heads + self.num_heads = num_heads + self.total_num_kv_heads = num_kv_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim or hidden_size // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=True, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + disable_tp=True, ) - # Init backend + bind KV cache tensors to attention modules. - self.attn_backends, self.attn_metadata_builders = attn_utils.init_attn_backend( - self.kv_cache_config, vllm_config, device + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position, + rope_parameters=rope_parameters, ) - self.runner_kv_caches: list[torch.Tensor] = [] - attn_utils.init_kv_cache( - self.runner_kv_caches, - vllm_config.compilation_config.static_forward_context, - self.kv_cache_config, - self.attn_backends, - device, + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + seq_len: int, + ) -> torch.Tensor: + """ + Args: + positions: [B, qlen] position ids. + hidden_states: [B, qlen, hidden_size]. + kv_cache: (k_cache, v_cache) each [B, num_kv_heads, max_seq_len, head_dim]. + seq_len: total sequence length *after* this forward (past + current query). + + Returns: + output: [B, qlen, hidden_size]. + """ + bsz, qlen, _ = hidden_states.shape + k_cache, v_cache = kv_cache + + qkv, _ = self.qkv_proj(hidden_states.reshape(bsz * qlen, -1)) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # QK-norm (per head). + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(q.shape) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(k.shape) + + # RoPE. + q, k = self.rotary_emb(positions.reshape(-1), q, k) + + # Reshape to [B, heads, qlen, head_dim]. + q = q.view(bsz, qlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, qlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, qlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + + # Write new K/V into the dense cache at the correct positions. + start_pos = seq_len - qlen + k_cache[:bsz, :, start_pos:seq_len, :] = k + v_cache[:bsz, :, start_pos:seq_len, :] = v + + # Attend over the full sequence so far. + k_full = k_cache[:bsz, :, :seq_len, :] + v_full = v_cache[:bsz, :, :seq_len, :] + + attn_out = F.scaled_dot_product_attention( + q, + k_full, + v_full, + scale=self.scaling, + is_causal=(qlen == seq_len), # True for prefill, False for decode + enable_gqa=(self.num_heads != self.num_kv_heads), ) + # [B, num_heads, qlen, head_dim] -> [B*qlen, num_heads * head_dim] + attn_out = attn_out.transpose(1, 2).reshape(bsz * qlen, -1) + output, _ = self.o_proj(attn_out) + return output.view(bsz, qlen, -1) + - # Precompute a fixed block table mapping for the maximum batch. - self.block_size = block_size - self.blocks_per_seq = blocks_per_seq - self.max_batch_size = int(max_batch_size) +class CodePredictorDecoderLayer(nn.Module): + """Standalone decoder layer for the code predictor. - bt = torch.full((self.max_batch_size, self.blocks_per_seq), -1, dtype=torch.int32, device=device) - for i in range(self.max_batch_size): - for j in range(self.blocks_per_seq): - bt[i, j] = i * self.blocks_per_seq + j - self._block_table = bt + Same architecture as ``Qwen3DecoderLayer`` (attention + MLP with + pre-norm residuals) but uses ``CodePredictorAttention`` instead of + vLLM's ``Attention`` backend. Weight names are identical so existing + checkpoints load without changes. + """ - def build_attn_metadata( + def __init__( self, - *, - num_reqs: int, - query_lens: torch.Tensor, # (num_reqs,) int32 on cpu - seq_lens: torch.Tensor, # (num_reqs,) int32 on cpu - ) -> tuple[dict[str, Any], torch.Tensor, dict[str, torch.Tensor]]: - """Build attention metadata, positions, and slot_mapping dict. + config: Qwen3TTSTalkerCodePredictorConfig, + quant_config=None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + set_default_rope_theta(config, default_theta=1000000) + + self.self_attn = CodePredictorAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, "head_dim", None), + max_position=config.max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + quant_config=quant_config, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Qwen3MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - Returns: - (attn_metadata, positions, slot_mappings_by_layer) - - attn_metadata: per-layer attention metadata for attn backends. - - positions: (num_tokens,) position IDs on device. - - slot_mappings_by_layer: {layer_name: slot_mapping_tensor} for - set_forward_context so that unified_kv_cache_update can write - the KV cache correctly. - """ - num_reqs = int(num_reqs) - if num_reqs <= 0: - return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {} - if num_reqs > self.max_batch_size: - raise ValueError(f"num_reqs={num_reqs} exceeds local predictor max_batch_size={self.max_batch_size}") - - query_lens_i32 = query_lens.to(dtype=torch.int32, device="cpu") - seq_lens_i32 = seq_lens.to(dtype=torch.int32, device="cpu") - - # query_start_loc: prefix sums of query_lens. - qsl = torch.zeros((num_reqs + 1,), dtype=torch.int32, device="cpu") - qsl[1:] = torch.cumsum(query_lens_i32, dim=0) - num_tokens = int(qsl[-1].item()) - if num_tokens <= 0: - return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {} - - # positions: for each request i, emit positions [seq_len-query_len .. seq_len-1] - pos_list: list[torch.Tensor] = [] - for i in range(num_reqs): - ql = int(query_lens_i32[i].item()) - sl = int(seq_lens_i32[i].item()) - start = sl - ql - pos_list.append(torch.arange(start, sl, dtype=torch.int64)) - positions_cpu = torch.cat(pos_list, dim=0) - - # slot_mapping: map each query token to a physical slot in the paged KV cache. - # We allocate per-request contiguous blocks; slot = base + position. - slot_mapping = torch.empty((num_tokens,), dtype=torch.int64, device="cpu") - cursor = 0 - for i in range(num_reqs): - ql = int(query_lens_i32[i].item()) - sl = int(seq_lens_i32[i].item()) - start = sl - ql - for p in range(start, sl): - block_idx = p // self.block_size - offset = p % self.block_size - block_id = int(self._block_table[i, block_idx].item()) - slot_mapping[cursor] = block_id * self.block_size + offset - cursor += 1 - - max_seq_len = int(seq_lens_i32[:num_reqs].max().item()) - query_start_loc_gpu = qsl.to(device=self.device) - seq_lens_gpu = seq_lens_i32.to(device=self.device) - block_table = self._block_table[:num_reqs].contiguous() - slot_mapping_gpu = slot_mapping.to(device=self.device) - - # FIXME(gcanlin): Refactor build_attn_metadata to avoid special-casing NPU backends here. - if current_omni_platform.is_npu(): - # NPU requires AscendCommonAttentionMetadata with extra attributes - from vllm_ascend.worker.v2 import attn_utils as attn_utils_npu - - max_query_len = int(query_lens_i32[:num_reqs].max().item()) - # NPU version expects slot_mappings as a stacked tensor, not a list - slot_mappings_tensor = slot_mapping_gpu.unsqueeze(0) - attn_metadata = attn_utils_npu.build_attn_metadata( - attn_metadata_builders=self.attn_metadata_builders, - num_reqs=num_reqs, - num_tokens=num_tokens, - query_start_loc_gpu=query_start_loc_gpu, - query_start_loc_cpu=qsl, - max_query_len=max_query_len, - seq_lens=seq_lens_gpu, - max_seq_len=max_seq_len, - block_tables=[block_table], - slot_mappings=slot_mappings_tensor, - kv_cache_config=self.kv_cache_config, - ) + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + kv_cache: KVCache, + seq_len: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) else: - attn_metadata = attn_utils.build_attn_metadata( - self.attn_metadata_builders, - num_reqs=num_reqs, - num_tokens=num_tokens, - query_start_loc_gpu=query_start_loc_gpu, - query_start_loc_cpu=qsl, - seq_lens=seq_lens_gpu, - max_seq_len=max_seq_len, - block_tables=[block_table], - slot_mappings=[slot_mapping_gpu], - kv_cache_config=self.kv_cache_config, - ) + hidden_states, residual = self.input_layernorm(hidden_states, residual) - # Build slot_mappings_by_layer for set_forward_context. - # Fix for vllm 0.15.0 - slot_mappings_by_layer: dict[str, torch.Tensor] = {} - for kv_cache_group in self.kv_cache_config.kv_cache_groups: - for layer_name in kv_cache_group.layer_names: - slot_mappings_by_layer[layer_name] = slot_mapping_gpu + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + seq_len=seq_len, + ) - return attn_metadata, positions_cpu.to(device=self.device), slot_mappings_by_layer + # MLP + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual class Qwen3TTSTalkerCodePredictorModelVLLM(nn.Module): @@ -201,7 +217,6 @@ def __init__( config: Qwen3TTSTalkerCodePredictorConfig, *, talker_hidden_size: int | None = None, - cache_config=None, quant_config=None, prefix: str = "", ) -> None: @@ -211,9 +226,7 @@ def __init__( self.layers = nn.ModuleList( [ - Qwen3DecoderLayer( - config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.layers.{i}" - ) + CodePredictorDecoderLayer(config, quant_config=quant_config, prefix=f"{prefix}.layers.{i}") for i in range(config.num_hidden_layers) ] ) @@ -232,12 +245,24 @@ def __init__( def get_input_embeddings(self) -> nn.ModuleList: return self.codec_embedding - def forward(self, positions: torch.Tensor, inputs_embeds: torch.Tensor) -> torch.Tensor: - # Token-major: [num_tokens, hidden] + def forward( + self, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + kv_caches: list[KVCache], + seq_len: int, + ) -> torch.Tensor: + """ + Args: + positions: [B, qlen] position ids. + inputs_embeds: [B, qlen, hidden_size]. + kv_caches: list of (k_cache, v_cache) per layer. + seq_len: total sequence length after this forward. + """ hidden_states = inputs_embeds residual = None - for layer in self.layers: - hidden_states, residual = layer(positions, hidden_states, residual) + for layer, kv_cache in zip(self.layers, kv_caches): + hidden_states, residual = layer(positions, hidden_states, residual, kv_cache, seq_len) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -323,7 +348,6 @@ def __init__( self.model = Qwen3TTSTalkerCodePredictorModelVLLM( config, talker_hidden_size=int(talker_config.hidden_size), - cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, prefix=f"{prefix}.model", ) @@ -338,14 +362,17 @@ def __init__( else: self.small_to_mtp_projection = nn.Identity() - self._kv_cache: _LocalPredictorKVCache | None = None + # Dense KV cache state (allocated lazily). + self._kv_caches: list[KVCache] | None = None + self._max_seq_len = int(getattr(config, "num_code_groups", 16) or 16) + self._num_layers = int(config.num_hidden_layers) + self._num_kv_heads = int(config.num_key_value_heads) + self._head_dim = int(getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads) def get_input_embeddings(self) -> nn.ModuleList: return self.model.get_input_embeddings() def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # Ensure all vLLM custom layers consult the predictor vllm_config - # (esp. for Attention static_forward_context). with set_current_vllm_config(self._vllm_config): loaded: set[str] = set() model_weights: list[tuple[str, torch.Tensor]] = [] @@ -367,92 +394,77 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded.add(name) return loaded - def _maybe_init_kv_cache(self, device: torch.device) -> None: - if self._kv_cache is not None: - return - max_seq_len = int(getattr(self.config, "num_code_groups", 16) or 16) - # Upper bound on batch size: vLLM scheduler max_num_seqs (fallback 8). - max_batch = int(getattr(self._vllm_config.scheduler_config, "max_num_seqs", 8) or 8) - max_batch = max(1, max_batch) - self._kv_cache = _LocalPredictorKVCache( - vllm_config=self._vllm_config, - max_seq_len=max_seq_len, - max_batch_size=max_batch, - device=device, - ) + def _allocate_kv_caches(self, batch_size: int, device: torch.device) -> list[KVCache]: + """Allocate dense KV cache tensors for all layers.""" + caches: list[KVCache] = [] + for _ in range(self._num_layers): + k = torch.zeros( + batch_size, + self._num_kv_heads, + self._max_seq_len, + self._head_dim, + dtype=torch.bfloat16, + device=device, + ) + v = torch.zeros( + batch_size, + self._num_kv_heads, + self._max_seq_len, + self._head_dim, + dtype=torch.bfloat16, + device=device, + ) + caches.append((k, v)) + return caches @torch.inference_mode() def reset_cache(self) -> None: - # We reuse a fixed kv cache buffer and overwrite starting at slot 0. - # No action required here (seq_lens controls what is read). - return + if self._kv_caches is not None: + for k, v in self._kv_caches: + k.zero_() + v.zero_() @torch.inference_mode() def prefill_logits(self, inputs_embeds: torch.Tensor) -> torch.Tensor: """Prefill with 2 tokens: [past_hidden, layer0_embed]. Returns logits for residual group 0.""" - self._maybe_init_kv_cache(inputs_embeds.device) - assert self._kv_cache is not None - bsz = int(inputs_embeds.shape[0]) qlen = 2 - # Flatten to token-major. - hs = inputs_embeds.to(dtype=torch.bfloat16).reshape(bsz * qlen, -1) - hs = self.small_to_mtp_projection(hs) - - query_lens = torch.full((bsz,), qlen, dtype=torch.int32) - seq_lens = query_lens.clone() - attn_metadata, positions, slot_mappings = self._kv_cache.build_attn_metadata( - num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens - ) + device = inputs_embeds.device + + # Allocate / re-allocate KV caches if needed. + if self._kv_caches is None or self._kv_caches[0][0].shape[0] < bsz: + self._kv_caches = self._allocate_kv_caches(bsz, device) + + hs = inputs_embeds.to(dtype=torch.bfloat16) # [B, 2, H] + hs = self.small_to_mtp_projection(hs.reshape(bsz * qlen, -1)).view(bsz, qlen, -1) - with ( - set_current_vllm_config(self._vllm_config), - set_forward_context( - attn_metadata, - self._vllm_config, - num_tokens=int(hs.shape[0]), - slot_mapping=slot_mappings, - ), - ): - out = self.model(positions=positions, inputs_embeds=hs) + positions = torch.arange(qlen, dtype=torch.long, device=device).unsqueeze(0).expand(bsz, -1) + + out = self.model(positions=positions, inputs_embeds=hs, kv_caches=self._kv_caches, seq_len=qlen) # Gather last token per request. - last_idx = torch.arange(qlen - 1, bsz * qlen, step=qlen, device=out.device, dtype=torch.long) - last_h = out.index_select(0, last_idx) + last_h = out[:, -1, :] # [B, hidden] logits = self.lm_head[0](last_h) return logits @torch.inference_mode() def decode_logits(self, input_ids: torch.Tensor, *, generation_step: int, past_seq_len: int) -> torch.Tensor: """Decode one new token for residual group `generation_step` (1..Q-1).""" - self._maybe_init_kv_cache(input_ids.device) - assert self._kv_cache is not None + assert self._kv_caches is not None bsz = int(input_ids.shape[0]) if generation_step <= 0: raise ValueError("generation_step must be >= 1 for decode_logits") embed_idx = generation_step - 1 hs = self.model.get_input_embeddings()[embed_idx](input_ids.to(dtype=torch.long).reshape(bsz, 1)) - hs = self.small_to_mtp_projection(hs.reshape(bsz, -1)) + hs = self.small_to_mtp_projection(hs.reshape(bsz, -1)).view(bsz, 1, -1) - query_lens = torch.ones((bsz,), dtype=torch.int32) - seq_lens = torch.full((bsz,), int(past_seq_len) + 1, dtype=torch.int32) - attn_metadata, positions, slot_mappings = self._kv_cache.build_attn_metadata( - num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens - ) + seq_len = past_seq_len + 1 + positions = torch.full((bsz, 1), past_seq_len, dtype=torch.long, device=input_ids.device) + + out = self.model(positions=positions, inputs_embeds=hs, kv_caches=self._kv_caches, seq_len=seq_len) - with ( - set_current_vllm_config(self._vllm_config), - set_forward_context( - attn_metadata, - self._vllm_config, - num_tokens=int(hs.shape[0]), - slot_mapping=slot_mappings, - ), - ): - out = self.model(positions=positions, inputs_embeds=hs) - - logits = self.lm_head[generation_step](out) + logits = self.lm_head[generation_step](out[:, 0, :]) return logits @torch.inference_mode()