Skip to content

Commit d2b9ab0

Browse files
committed
Enable LMCache for cpuoffloading, LMCache docker support, enable lmcache
Signed-off-by: Harish Subramony <[email protected]>
1 parent ee2156a commit d2b9ab0

File tree

3 files changed

+117
-13
lines changed

3 files changed

+117
-13
lines changed

vllm_gaudi/v1/worker/hpu_input_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ def __init__(
236236
# This is updated each time the batch constituents change.
237237
self.sampling_metadata = self._make_sampling_metadata()
238238

239+
self.req_type: dict[str, str] = {}
240+
239241
@property
240242
def req_ids(self) -> list[str]:
241243
# None elements should only be present transiently

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import collections
33
import contextlib
4+
import copy
45
import functools
56
import itertools
67
import math
@@ -26,7 +27,10 @@
2627
from vllm.attention.layer import Attention
2728
from vllm.attention.selector import get_attn_backend
2829
from vllm.config import (VllmConfig, update_config)
29-
from vllm.forward_context import set_forward_context
30+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
31+
has_kv_transfer_group)
32+
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
33+
from vllm.forward_context import get_forward_context, set_forward_context
3034
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
3135
from vllm.model_executor.layers.layernorm import RMSNorm
3236
from vllm.model_executor.layers.sampler import get_sampler
@@ -407,6 +411,7 @@ def forward(self, *args, **kwargs):
407411
# kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata
408412
kwargs = kwargs.copy()
409413
# selected_token_indices = kwargs.pop('selected_token_indices')
414+
is_warmup = kwargs.get('warmup_mode', False)
410415
if 'warmup_mode' in kwargs:
411416
kwargs.pop('warmup_mode')
412417
input_ids = kwargs['input_ids']
@@ -420,7 +425,12 @@ def forward(self, *args, **kwargs):
420425
if 'kv_caches' in kwargs:
421426
kwargs.pop('kv_caches')
422427
with set_forward_context(attn_meta, self.vllm_config):
428+
if not is_warmup:
429+
self.maybe_start_load_kv()
423430
hidden_states = self.model(*args, **kwargs)
431+
if not is_warmup:
432+
self.maybe_wait_for_kv_save()
433+
424434
if self._rotary_prepare_cos_sin is not None:
425435
self._reset_rotary_cos_sin()
426436
return hidden_states
@@ -431,6 +441,22 @@ def compute_logits(self, *args, **kwargs):
431441
# def sample(self, *args, **kwargs):
432442
# return self.sampler(*args, **kwargs)
433443

444+
@staticmethod
445+
def maybe_start_load_kv():
446+
if has_kv_transfer_group():
447+
kv_connector = get_kv_transfer_group()
448+
449+
# Background KV cache transfers happen here.
450+
# These transfers are designed to be async and the requests
451+
# involved may be disjoint from the running requests.
452+
# Do this here to save a collective_rpc.
453+
kv_connector.start_load_kv(get_forward_context())
454+
455+
@staticmethod
456+
def maybe_wait_for_kv_save() -> None:
457+
if has_kv_transfer_group():
458+
get_kv_transfer_group().wait_for_save()
459+
434460
def generate_proposals(self, *args, **kwargs):
435461
return self.model.generate_proposals(*args, **kwargs)
436462

@@ -716,6 +742,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
716742
req_index = self.input_batch.remove_request(req_id)
717743
if req_index is not None:
718744
removed_req_indices.append(req_index)
745+
if req_id in self.input_batch.req_type:
746+
del self.input_batch.req_type[req_id]
719747

720748
# Remove the unscheduled requests from the persistent batch.
721749
# NOTE(woosuk): The unscheduled requests are either preempted requests
@@ -862,6 +890,10 @@ def get_model(self) -> torch.nn.Module:
862890
assert self.model is not None
863891
return self.model
864892

893+
def is_decoder_only(self, req_id) -> bool:
894+
return bool(req_id in self.input_batch.req_type and \
895+
self.input_batch.req_type[req_id] == "decode")
896+
865897
def _get_prompts_and_decodes(
866898
self,
867899
scheduler_output: "SchedulerOutput",
@@ -871,24 +903,38 @@ def _get_prompts_and_decodes(
871903
num_reqs = self.input_batch.num_reqs
872904
assert num_reqs > 0
873905

906+
if scheduler_output.kv_connector_metadata:
907+
requests = scheduler_output.kv_connector_metadata.requests
908+
else:
909+
requests = None
910+
874911
# Traverse decodes first
875912
decode_req_ids = []
876913
num_computed_tokens_decode = []
877914
for i in range(num_reqs):
878915
req_id = self.input_batch.req_ids[i]
879916
assert req_id is not None
880917

918+
if requests is not None and req_id not in self.input_batch.req_type:
919+
for request in requests:
920+
if request.req_id == req_id:
921+
self.input_batch.req_type[req_id] = "prefill" \
922+
if request.load_spec is None else "decode"
923+
break
924+
881925
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
882926
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
883927
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
884928
req_id]
885929

886-
if num_computed_tokens < num_prompt_tokens:
930+
if num_computed_tokens < num_prompt_tokens and \
931+
not self.is_decoder_only(req_id):
887932
# This is prompt
888933
break
889934

890935
# This is decode
891-
assert num_scheduled_tokens == 1
936+
if not self.is_decoder_only(req_id):
937+
assert num_scheduled_tokens == 1
892938
decode_req_ids.append(req_id)
893939
num_computed_tokens_decode.append(int(num_computed_tokens + 1))
894940

@@ -1369,7 +1415,7 @@ def _prepare_inputs(
13691415
num_scheduled_tokens.append(seq_num_scheduled_tokens)
13701416
num_prompt_tokens.append(seq_num_prompt_tokens)
13711417
# NOTE: assert that all the decodes are "decodes".
1372-
if idx < num_decodes:
1418+
if idx < num_decodes and not self.is_decoder_only(req_id):
13731419
assert seq_num_scheduled_tokens == 1
13741420
return (self._prepare_prefill_inputs(num_prefills, num_decodes,
13751421
num_scheduled_tokens),
@@ -1391,8 +1437,8 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata,
13911437
self.seen_configs.add(cfg)
13921438
if not seen and not warmup_mode:
13931439
logger.warning(
1394-
"Configuration: (%s, %s, %s, %s) was not warmed-up!", phase,
1395-
batch_size, seq_len, num_blocks)
1440+
"Configuration: rank (%s, %s, %s, %s, %s) was not warmed-up!",
1441+
os.getenv('RANK', '0'), phase, batch_size, seq_len, num_blocks)
13961442

13971443
def _execute_model_generic(self,
13981444
token_ids,
@@ -1579,8 +1625,11 @@ def execute_model(
15791625

15801626
batch_changed = self._update_states(scheduler_output)
15811627
if not scheduler_output.total_num_scheduled_tokens:
1582-
# Return empty ModelRunnerOuptut if there's no work to do.
1583-
return EMPTY_MODEL_RUNNER_OUTPUT
1628+
if not has_kv_transfer_group():
1629+
# Return empty ModelRunnerOuptut if there's no work to do.
1630+
return EMPTY_MODEL_RUNNER_OUTPUT
1631+
1632+
return self.kv_connector_no_forward(scheduler_output)
15841633
# If necessary, swap decodes/prompts to have all decodes on the start
15851634
ensure_decodes_first(self.input_batch)
15861635
# Prepare prompts/decodes info
@@ -1607,11 +1656,14 @@ def execute_model(
16071656
self.event_start = self.profiler.get_timestamp_us()
16081657
self.profiler.start("internal", "prefill")
16091658
htorch.core.mark_step()
1659+
self.maybe_setup_kv_connector(scheduler_output)
16101660
prefill_hidden_states_ts, logits_device = \
16111661
self._execute_model_generic(
16121662
token_ids, position_ids, attn_metadata, logits_indices,
16131663
self.kv_caches)
16141664
htorch.core.mark_step()
1665+
finished_sending, finished_recving = (
1666+
self.get_finished_kv_transfers(scheduler_output))
16151667
with self.profiler.record_event('internal', "sampler"):
16161668
sampling_metadata = self._prepare_sampling(
16171669
batch_changed, req_id, pad_to=logits_device.shape[0])
@@ -1645,11 +1697,15 @@ def execute_model(
16451697
self.profiler.start("internal", "decode")
16461698
assert decode_data is not None
16471699
htorch.core.mark_step()
1648-
_, logits_device = self._execute_model_generic(
1700+
self.maybe_setup_kv_connector(scheduler_output)
1701+
_, logits_device = \
1702+
self._execute_model_generic(
16491703
decode_data.token_ids, decode_data.position_ids,
16501704
decode_data.attn_metadata, decode_data.logits_indices,
16511705
self.kv_caches)
16521706
htorch.core.mark_step()
1707+
finished_sending, finished_recving = (
1708+
self.get_finished_kv_transfers(scheduler_output))
16531709
with self.profiler.record_event('internal', "sampler"):
16541710
sampling_metadata = self._prepare_sampling(
16551711
batch_changed,
@@ -1760,7 +1816,11 @@ def execute_model(
17601816
spec_token_ids=None,
17611817
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
17621818
pooler_output=[],
1819+
finished_sending=finished_sending,
1820+
finished_recving=finished_recving,
17631821
)
1822+
if has_kv_transfer_group():
1823+
get_kv_transfer_group().clear_connector_metadata()
17641824
return model_runner_output
17651825

17661826
def load_model(self) -> None:
@@ -2450,3 +2510,41 @@ def reload_weights(self) -> None:
24502510
logger.info("Reloading weights inplace...")
24512511
model_loader.load_weights(self.model, model_config=self.model_config)
24522512
torch.hpu.synchronize()
2513+
2514+
@staticmethod
2515+
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
2516+
# Update KVConnector with the KVConnector metadata forward().
2517+
if has_kv_transfer_group():
2518+
kv_connector = get_kv_transfer_group()
2519+
assert isinstance(kv_connector, KVConnectorBase_V1)
2520+
assert scheduler_output.kv_connector_metadata is not None
2521+
kv_connector.bind_connector_metadata(
2522+
scheduler_output.kv_connector_metadata)
2523+
2524+
@staticmethod
2525+
def get_finished_kv_transfers(
2526+
scheduler_output: "SchedulerOutput",
2527+
) -> tuple[Optional[set[str]], Optional[set[str]]]:
2528+
if has_kv_transfer_group():
2529+
return get_kv_transfer_group().get_finished(
2530+
scheduler_output.finished_req_ids)
2531+
return None, None
2532+
2533+
def kv_connector_no_forward(
2534+
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
2535+
# KV send/recv even if no work to do.
2536+
with set_forward_context(None, self.vllm_config):
2537+
self.maybe_setup_kv_connector(scheduler_output)
2538+
if has_kv_transfer_group():
2539+
kv_connector = get_kv_transfer_group()
2540+
kv_connector.start_load_kv(get_forward_context())
2541+
finished_sending, finished_recving = (
2542+
self.get_finished_kv_transfers(scheduler_output))
2543+
2544+
if not finished_sending and not finished_recving:
2545+
return EMPTY_MODEL_RUNNER_OUTPUT
2546+
2547+
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
2548+
output.finished_sending = finished_sending
2549+
output.finished_recving = finished_recving
2550+
return output

vllm_gaudi/v1/worker/hpu_worker.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
from vllm_gaudi.extension.profiler import HabanaMemoryProfiler, format_bytes
1515

1616
import vllm.envs as envs
17-
from vllm.config import ParallelConfig, VllmConfig
17+
from vllm.config import VllmConfig
1818
from vllm.distributed import (ensure_model_parallel_initialized,
1919
init_distributed_environment)
20+
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
2021
from vllm.model_executor import set_random_seed
2122
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
2223
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -58,6 +59,7 @@ def __init__(
5859
self.speculative_config = vllm_config.speculative_config
5960
self.observability_config = vllm_config.observability_config
6061

62+
self.parallel_config.rank = rank
6163
self.local_rank = local_rank
6264
self.rank = rank
6365
self.distributed_init_method = distributed_init_method
@@ -121,7 +123,7 @@ def stop_profile(self):
121123

122124
def init_device(self):
123125
# Initialize the distributed environment.
124-
init_worker_distributed_environment(self.parallel_config, self.rank,
126+
init_worker_distributed_environment(self.vllm_config, self.rank,
125127
self.distributed_init_method,
126128
self.local_rank)
127129
# Set random seed.
@@ -235,6 +237,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
235237
msg = (f"Usable num_blocks: {kv_cache_config.num_blocks}, "
236238
f"actual allocated num_blocks: "
237239
f"{self.model_runner.kv_caches[0][0].shape[0]} "
240+
f"{self.model_runner.kv_caches[0][0].shape} "
238241
f"(_PAD_BLOCK_ID={self.model_runner._PAD_BLOCK_ID}, "
239242
f"_PAD_SLOT_ID={self.model_runner._PAD_SLOT_ID})")
240243
logger.info(msg)
@@ -275,12 +278,13 @@ def profile(self, is_start: bool = True):
275278

276279

277280
def init_worker_distributed_environment(
278-
parallel_config: ParallelConfig,
281+
vllm_config: VllmConfig,
279282
rank: int,
280283
distributed_init_method: Optional[str] = None,
281284
local_rank: int = -1,
282285
) -> None:
283286
"""Initialize the distributed environment."""
287+
parallel_config = vllm_config.parallel_config
284288
init_distributed_environment(parallel_config.world_size,
285289
rank,
286290
distributed_init_method,
@@ -293,7 +297,7 @@ def init_worker_distributed_environment(
293297
assert dummy_tensor_hpu.item() == parallel_config.world_size
294298
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
295299
parallel_config.pipeline_parallel_size)
296-
300+
ensure_kv_transfer_initialized(vllm_config)
297301

298302
@contextmanager
299303
def track_graph_compile(name: str):

0 commit comments

Comments
 (0)