Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
vllm_config: VllmConfig,
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
block_size: int,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
Expand Down Expand Up @@ -101,15 +102,8 @@ def __init__(
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0

self.block_size = self.cache_config.block_size

self.block_size = block_size
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): The scheduler’s block_size must be multiplied
# by dcp_world_size, since block hashes are computed on the
# original full token sequence at a granularity of
# original_block_size × dcp_world_size.
if self.dcp_world_size > 1:
self.block_size *= self.dcp_world_size

# req_id -> Request
self.requests: dict[str, Request] = {}
Expand Down
9 changes: 7 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,18 @@ def __init__(
logger.info("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.chunked_prefill_enabled = False

scheduler_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
)

self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
log_stats=self.log_stats,
block_size=scheduler_block_size,
)
self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore
Expand Down Expand Up @@ -177,14 +183,13 @@ def __init__(
self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None
):
block_size = vllm_config.cache_config.block_size
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo
)
init_none_hash(caching_hash_fn)

self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn
scheduler_block_size, caching_hash_fn
)

self.step_fn = (
Expand Down