Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_metadata is None:
output = torch.zeros(query.shape,
dtype=query.dtype,
device=query.device)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
if self.attn_type == AttentionType.ENCODER_DECODER:
return self.forward_encoder_decoder(
Expand Down
4 changes: 3 additions & 1 deletion vllm_gaudi/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
import vllm_gaudi.ops.hpu_compressed_tensors # noqa
import vllm_gaudi.ops.hpu_fp8 # noqa
import vllm_gaudi.ops.hpu_gptq # noqa
import vllm_gaudi.ops.hpu_awq # noqa
import vllm_gaudi.ops.hpu_awq # noqa
import vllm_gaudi.ops.hpu_pooling_metadata # noqa
import vllm_gaudi.ops.hpu_pooler # noqa
3 changes: 3 additions & 0 deletions vllm_gaudi/ops/hpu_pooler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.pooler import PoolingMetadata

76 changes: 76 additions & 0 deletions vllm_gaudi/ops/hpu_pooling_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional

import torch

from vllm.utils import is_pin_memory_available
from vllm.pooling_params import PoolingParams
from vllm.model_executor.custom_op import CustomOp
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
from vllm.model_executor.pooling_metadata import PoolingTensors

@CustomOp.register_oot(name='V1PoolingMetadata')
class HPUPoolingMetadata(V1PoolingMetadata):
"""Tensors for pooling."""

prompt_lens: torch.Tensor
prompt_token_ids: Optional[torch.Tensor]
pooling_params: list[PoolingParams]
prompt_offsets: Optional[list[int]] = None

def __getitem__(self, indices: slice):
return V1PoolingMetadata(
prompt_lens=self.prompt_lens[indices],
prompt_token_ids=None if self.prompt_token_ids is None else
self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices],
prompt_offsets=self.prompt_offsets[indices]
)

@CustomOp.register_oot(name='PoolingTensors')
class HPUPoolingTensors(PoolingTensors):
"""Tensors for pooling."""

prompt_lens: torch.Tensor
prompt_offsets: torch.Tensor

@classmethod
def from_pooling_metadata(
cls,
pooling_metadata: "V1PoolingMetadata",
device: torch.device,
) -> "PoolingTensors":
"""
Create PoolingTensors from PoolingMetadata.

Args:
pooling_metadata: PoolingMetadata instance to convert.
device: Device to store the tensors.
"""
# Convert prompt lengths to tensor
pin_memory = is_pin_memory_available()

prompt_lens_t = torch.tensor(
pooling_metadata.prompt_lens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
if pooling_metadata.prompt_offsets is not None:
prompt_offsets_t = torch.tensor(
pooling_metadata.prompt_offsets,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).to(device=device, non_blocking=True)
else:
prompt_offsets_t = None

return cls(prompt_lens=prompt_lens_t.to(device=device,
non_blocking=True), prompt_offsets=prompt_offsets_t)

3 changes: 2 additions & 1 deletion vllm_gaudi/ops/hpu_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def forward_oot(
if hasattr(self, "scaling_factors") or hasattr(
self, "scaling_factor") or self.sin is None:
self.prepare_cos_sin(positions, offsets)
num_tokens = positions.shape[0] * positions.shape[1]
positions = positions.flatten()
num_tokens = positions.shape[0]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
Expand Down
167 changes: 99 additions & 68 deletions vllm_gaudi/v1/worker/hpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitsProcessors)

from vllm_gaudi.utils import async_h2d_copy

_SAMPLING_EPS = 1e-5
Expand All @@ -30,7 +31,8 @@ class CachedRequestState:
prompt_token_ids: list[int]
mm_kwargs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]

block_ids: list[list[int]]
Expand Down Expand Up @@ -233,6 +235,8 @@ def __init__(

# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()

self.pooling_params: dict[str, PoolingParams] = {}

@property
def req_ids(self) -> list[str]:
Expand Down Expand Up @@ -277,73 +281,80 @@ def add_request(
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)

sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)

self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)

# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator

if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs

if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
if sampling_params := request.sampling_params:
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)

self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (
sampling_params.min_tokens,
sampling_params.all_stop_token_ids)

# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator

if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias

if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False

if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
else:
assert request.pooling_params is not None
self.pooling_params[req_id] = request.pooling_params

# Add request lora ID
if request.lora_request:
Expand Down Expand Up @@ -395,6 +406,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index

def swap_states(self, i1: int, i2: int) -> None:
Expand Down Expand Up @@ -675,6 +687,25 @@ def make_selective_sampling_metadata(
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)

@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]

return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
Expand Down
Loading