Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,13 @@ def _forward_decode( # type: ignore
# during each graph execution
def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
self.W_UV: torch.Tensor = self.W_UV.contiguous()
self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous()
# W_UV and W_UK_T are plain tensor attributes (not nn.Parameter or
# register_buffer), so model.to('hpu') won't move them. When INC
# CPU-first loading is active the source weights live on CPU, making
# these derived tensors CPU-resident too — which then causes a device
# mismatch at the bmm calls in forward. Explicitly place on HPU.
self.W_UV: torch.Tensor = self.W_UV.contiguous().to("hpu")
self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous().to("hpu")

# NOTE(Chendi): PR25184 using output buffer as default, which can't be used in HPU Graph,
# so we override and always return a new tensor
Expand Down Expand Up @@ -1220,5 +1225,10 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
# Parent MLACommonImpl extracts W_UV and W_UK_T from kv_b_proj weights
# These projection matrices are used for latent ↔ full space conversions
super().process_weights_after_loading(act_dtype)
self.W_UV: torch.Tensor = self.W_UV.contiguous()
self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous()
# W_UV and W_UK_T are plain tensor attributes (not nn.Parameter or
# register_buffer), so model.to('hpu') won't move them. When INC
# CPU-first loading is active the source weights live on CPU, making
# these derived tensors CPU-resident too — which then causes a device
# mismatch at the bmm calls in forward. Explicitly place on HPU.
self.W_UV: torch.Tensor = self.W_UV.contiguous().to("hpu")
self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous().to("hpu")
34 changes: 33 additions & 1 deletion vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.vocab_parallel_embedding import (VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model, get_model_loader
from vllm.platforms import current_platform
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem)
from vllm.multimodal.utils import group_mm_kwargs_by_modality
Expand Down Expand Up @@ -145,6 +146,25 @@
shutdown_inc_called = False


@contextlib.contextmanager
def _override_platform_device_type(device_type: str):
"""Temporarily override current_platform.device_type to match load device.

When load_config.device is set (e.g. "cpu" for INC quantization), the
model loader uses ``torch.set_default_device(load_device)`` so implicit
tensor creation goes to that device. However, upstream vLLM code also
creates tensors with explicit ``device=current_platform.device_type``
(always "hpu"). The mix of CPU-default and HPU-explicit causes
RuntimeError. This context manager aligns both paths.
"""
original = current_platform.device_type
try:
current_platform.device_type = device_type
yield
finally:
current_platform.device_type = original


class BucketingFailedException(Exception):
pass

Expand Down Expand Up @@ -4035,7 +4055,19 @@ def load_model(self) -> None:
htcore.hpu_inference_set_env()
logger.info("Starting to load model %s...", self.model_config.model)
with HabanaMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
# When load_config.device differs from the platform device (e.g.
# "cpu" for INC quantization), upstream code that uses both
# torch.set_default_device (via the model loader context manager)
# and explicit device=current_platform.device_type creates a
# device mismatch. Temporarily aligning device_type with the
# load device makes both paths consistent, avoiding RuntimeError
# in modules like DeepseekScalingRotaryEmbedding.
load_device = self.vllm_config.load_config.device
ctx = _override_platform_device_type(load_device) \
if load_device and load_device != current_platform.device_type \
else contextlib.nullcontext()
with ctx:
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
self.model_memory_usage = m.consumed_device_memory
Expand Down
Loading