Skip to content

Commit 0f8a914

Browse files
authored
[Core] Ignore infeasible swap requests. (#4557)
1 parent 9b5c9f9 commit 0f8a914

12 files changed

+187
-42
lines changed

tests/basic_correctness/test_preemption.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
import pytest
99

10+
from vllm import SamplingParams
1011
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
1112
ENABLE_ARTIFICIAL_PREEMPT)
1213

@@ -136,3 +137,87 @@ def test_swap(
136137
assert hf_output_ids[j] == vllm_output_ids[j], (
137138
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
138139
f"vLLM: {vllm_output_ids}")
140+
141+
142+
@pytest.mark.parametrize("model", MODELS)
143+
@pytest.mark.parametrize("dtype", ["float"])
144+
@pytest.mark.parametrize("max_tokens", [96])
145+
@pytest.mark.parametrize("beam_width", [4])
146+
def test_swap_infeasible(
147+
vllm_runner,
148+
example_prompts,
149+
model: str,
150+
dtype: str,
151+
max_tokens: int,
152+
beam_width: int,
153+
) -> None:
154+
"""Verify infeasible swap request will be ignored."""
155+
BLOCK_SIZE = 16
156+
prefill_blocks = 2
157+
decode_blocks = max_tokens // BLOCK_SIZE
158+
example_prompts = example_prompts[:1]
159+
160+
vllm_model = vllm_runner(
161+
model,
162+
dtype=dtype,
163+
swap_space=10,
164+
block_size=BLOCK_SIZE,
165+
# Since beam search have more than 1 sequence, prefill + decode blocks
166+
# are not enough to finish.
167+
num_gpu_blocks_override=prefill_blocks + decode_blocks,
168+
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
169+
)
170+
sampling_params = SamplingParams(n=beam_width,
171+
use_beam_search=True,
172+
temperature=0.0,
173+
max_tokens=max_tokens,
174+
ignore_eos=True)
175+
req_outputs = vllm_model.model.generate(
176+
example_prompts,
177+
sampling_params=sampling_params,
178+
)
179+
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
180+
ARTIFICIAL_PREEMPTION_MAX_CNT)
181+
del vllm_model
182+
# Verify the request is ignored and not hang.
183+
assert req_outputs[0].outputs[0].finish_reason == "length"
184+
185+
186+
@pytest.mark.parametrize("model", MODELS)
187+
@pytest.mark.parametrize("dtype", ["float"])
188+
@pytest.mark.parametrize("max_tokens", [96])
189+
def test_preemption_infeasible(
190+
vllm_runner,
191+
example_prompts,
192+
model: str,
193+
dtype: str,
194+
max_tokens: int,
195+
) -> None:
196+
"""Verify infeasible preemption request will be ignored."""
197+
BLOCK_SIZE = 16
198+
prefill_blocks = 2
199+
decode_blocks = max_tokens // BLOCK_SIZE
200+
vllm_model = vllm_runner(
201+
model,
202+
dtype=dtype,
203+
block_size=BLOCK_SIZE,
204+
# Not enough gpu blocks to complete a single sequence.
205+
# preemption should happen, and the sequence should be
206+
# ignored instead of hanging forever.
207+
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
208+
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
209+
)
210+
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
211+
req_outputs = vllm_model.model.generate(
212+
example_prompts,
213+
sampling_params=sampling_params,
214+
)
215+
216+
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
217+
ARTIFICIAL_PREEMPTION_MAX_CNT)
218+
del vllm_model
219+
# Verify the request is ignored and not hang.
220+
for req_output in req_outputs:
221+
outputs = req_output.outputs
222+
assert len(outputs) == 1
223+
assert outputs[0].finish_reason == "length"

tests/core/test_block_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def test_swap():
224224

225225
# Swap seq group from CPU -> GPU.
226226
cpu_blocks = block_manager.get_block_table(prompt)
227-
assert block_manager.can_swap_in(seq_group)
227+
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
228228
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
229229
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
230230
mapping = block_manager.swap_in(seq_group)

tests/core/test_chunked_prefill_scheduler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest # noqa
55

66
from vllm.config import CacheConfig, SchedulerConfig
7+
from vllm.core.interfaces import AllocStatus
78
from vllm.core.scheduler import Scheduler
89
from vllm.sequence import Logprob, SequenceGroup
910

@@ -410,7 +411,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):
410411

411412
# Add 1 more task. Swap is not possible, so prefill is running.
412413
scheduler.block_manager.can_swap_in = MagicMock()
413-
scheduler.block_manager.can_swap_in.return_value = False
414+
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
414415

415416
_, seq_group2 = create_dummy_prompt("2", prompt_length=60)
416417
scheduler.add_seq_group(seq_group2)
@@ -423,7 +424,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):
423424
assert out.scheduled_seq_groups[0].seq_group == seq_group2
424425

425426
# Now although swap is possible, running prefill is prioritized.
426-
scheduler.block_manager.can_swap_in.return_value = True
427+
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
427428
_, out = schedule_and_update_computed_tokens(scheduler)
428429
assert len(out.scheduled_seq_groups) == 1
429430
# 3 decodes. It is swapped in.

tests/core/test_scheduler.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ def test_schedule_swapped_cannot_swap_in():
791791

792792
# The last request should be swapped out.
793793
scheduler.block_manager.can_swap_in = MagicMock()
794-
scheduler.block_manager.can_swap_in.return_value = False
794+
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
795795
# Since we cannot swap in, none of the requests are swapped in.
796796
budget = create_token_budget()
797797
remaining_swapped, output = scheduler._schedule_swapped(
@@ -803,6 +803,34 @@ def test_schedule_swapped_cannot_swap_in():
803803
assert len(output.prefill_seq_groups) == 0
804804

805805

806+
def test_infeasible_swap():
807+
scheduler = initialize_scheduler()
808+
swapped = deque()
809+
policy = PolicyFactory.get_policy(policy_name="fcfs")
810+
curr_loras = None
811+
blocks_to_swap_out = {}
812+
for _ in range(2):
813+
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
814+
scheduler._allocate_and_set_running(seq_group)
815+
append_new_token_seq_group(60, seq_group, 1)
816+
scheduler._swap_out(seq_group, blocks_to_swap_out)
817+
swapped.append(seq_group)
818+
819+
# The last request should be swapped out.
820+
scheduler.block_manager.can_swap_in = MagicMock()
821+
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
822+
# Since we cannot swap in, none of the requests are swapped in.
823+
budget = create_token_budget()
824+
remaining_swapped, output = scheduler._schedule_swapped(
825+
swapped, budget, curr_loras, policy)
826+
assert len(remaining_swapped) == 0
827+
assert len(output.infeasible_seq_groups) == 2
828+
assert budget.num_batched_tokens == 0
829+
assert budget.num_curr_seqs == 0
830+
assert len(output.decode_seq_groups) == 0
831+
assert len(output.prefill_seq_groups) == 0
832+
833+
806834
def test_schedule_swapped_blocks_to_copy():
807835
scheduler = initialize_scheduler()
808836
swapped = deque()

vllm/core/block/cpu_gpu_block_allocator.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,8 @@ def __init__(
110110
for block_id in allocator.all_block_ids:
111111
self._block_ids_to_allocator[block_id] = allocator
112112

113-
def allocate_mutable(self,
114-
prev_block: Optional[Block],
115-
device: Optional[Device] = None) -> Block:
113+
def allocate_mutable(self, prev_block: Optional[Block],
114+
device: Device) -> Block:
116115
"""Allocates a new mutable block on the specified device.
117116
118117
Args:
@@ -123,13 +122,10 @@ def allocate_mutable(self,
123122
Returns:
124123
Block: The newly allocated mutable block.
125124
"""
126-
assert device is not None
127125
return self._allocators[device].allocate_mutable(prev_block)
128126

129-
def allocate_immutable(self,
130-
prev_block: Optional[Block],
131-
token_ids: List[int],
132-
device: Optional[Device] = None) -> Block:
127+
def allocate_immutable(self, prev_block: Optional[Block],
128+
token_ids: List[int], device: Device) -> Block:
133129
"""Allocates a new immutable block with the provided token IDs on the
134130
specified device.
135131
@@ -144,7 +140,6 @@ def allocate_immutable(self,
144140
Block: The newly allocated immutable block containing the provided
145141
token IDs.
146142
"""
147-
assert device is not None
148143
return self._allocators[device].allocate_immutable(
149144
prev_block, token_ids)
150145

@@ -175,7 +170,7 @@ def fork(self, last_block: Block) -> List[Block]:
175170
allocator = self._block_ids_to_allocator[block_id]
176171
return allocator.fork(last_block)
177172

178-
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
173+
def get_num_free_blocks(self, device: Device) -> int:
179174
"""Returns the number of free blocks available on the specified device.
180175
181176
Args:
@@ -185,9 +180,11 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
185180
Returns:
186181
int: The number of free blocks available on the specified device.
187182
"""
188-
assert device is not None
189183
return self._allocators[device].get_num_free_blocks()
190184

185+
def get_num_total_blocks(self, device: Device) -> int:
186+
return self._allocators[device].get_num_total_blocks()
187+
191188
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
192189
"""Clears the copy-on-write (CoW) state and returns the mapping of
193190
source to destination block IDs.

vllm/core/block/interfaces.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def free(self, block: Block) -> None:
108108
def fork(self, last_block: Block) -> List[Block]:
109109
pass
110110

111+
@abstractmethod
112+
def get_num_total_blocks(self) -> int:
113+
pass
114+
111115
@abstractmethod
112116
def get_num_free_blocks(self) -> int:
113117
pass
@@ -152,20 +156,21 @@ class NoFreeBlocksError(ValueError):
152156
class DeviceAwareBlockAllocator(ABC):
153157

154158
@abstractmethod
155-
def allocate_mutable(self,
156-
prev_block: Optional[Block],
157-
device: Optional[Device] = None) -> Block:
159+
def allocate_mutable(self, prev_block: Optional[Block],
160+
device: Device) -> Block:
161+
pass
162+
163+
@abstractmethod
164+
def allocate_immutable(self, prev_block: Optional[Block],
165+
token_ids: List[int], device: Device) -> Block:
158166
pass
159167

160168
@abstractmethod
161-
def allocate_immutable(self,
162-
prev_block: Optional[Block],
163-
token_ids: List[int],
164-
device: Optional[Device] = None) -> Block:
169+
def get_num_free_blocks(self, device: Device) -> int:
165170
pass
166171

167172
@abstractmethod
168-
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
173+
def get_num_total_blocks(self, device: Device) -> int:
169174
pass
170175

171176
@abstractmethod

vllm/core/block/naive_block.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,12 @@ def fork(self, last_block: Block) -> List[Block]:
133133

134134
return forked_blocks
135135

136-
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
137-
assert device is None
136+
def get_num_free_blocks(self) -> int:
138137
return len(self._free_block_indices)
139138

139+
def get_num_total_blocks(self) -> int:
140+
return len(self._all_block_indices)
141+
140142
def _allocate_new_block_id(self) -> BlockId:
141143
if not self._free_block_indices:
142144
raise BlockAllocator.NoFreeBlocksError()

vllm/core/block/prefix_caching_block.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
285285
return self._hashless_allocator.get_num_free_blocks(
286286
) + self.evictor.num_blocks
287287

288+
def get_num_total_blocks(self) -> int:
289+
return self._hashless_allocator.get_num_total_blocks()
290+
288291
@property
289292
def all_block_ids(self) -> FrozenSet[int]:
290293
return self._hashless_allocator.all_block_ids

vllm/core/block_manager_v1.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def free(self, block: PhysicalTokenBlock) -> None:
4747
def get_num_free_blocks(self) -> int:
4848
pass
4949

50+
@abstractmethod
51+
def get_num_total_blocks(self) -> int:
52+
pass
53+
5054
@abstractmethod
5155
def contains_block(self, block_hash: int) -> bool:
5256
pass
@@ -131,6 +135,9 @@ def get_num_free_blocks(self) -> int:
131135
return (self.num_blocks - self.current_num_blocks +
132136
self.evictor.num_blocks)
133137

138+
def get_num_total_blocks(self) -> int:
139+
return self.num_blocks
140+
134141
def contains_block(self, block_hash: int) -> bool:
135142
return block_hash in self.cached_blocks or block_hash in self.evictor
136143

@@ -190,6 +197,9 @@ def free(self, block: PhysicalTokenBlock) -> None:
190197
def get_num_free_blocks(self) -> int:
191198
return len(self.free_blocks)
192199

200+
def get_num_total_blocks(self) -> int:
201+
return self.num_blocks
202+
193203
def contains_block(self, block_hash: int) -> bool:
194204
raise NotImplementedError(
195205
"Invalid codepath for uncached block allocator.")
@@ -444,7 +454,7 @@ def _get_physical_blocks(
444454

445455
def can_swap_in(self,
446456
seq_group: SequenceGroup,
447-
num_lookahead_slots: int = 0) -> bool:
457+
num_lookahead_slots: int = 0) -> AllocStatus:
448458
assert (num_lookahead_slots == 0
449459
), "BlockSpaceManagerV1 does not support lookahead allocation"
450460
blocks = self._get_physical_blocks(seq_group)
@@ -454,7 +464,12 @@ def can_swap_in(self,
454464
# at least one free block right after the swap-in.
455465
# NOTE: This should match the logic in can_append_slot().
456466
num_required_blocks = len(blocks) + num_swapped_seqs
457-
return num_free_blocks - num_required_blocks >= self.watermark_blocks
467+
if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
468+
return AllocStatus.NEVER
469+
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
470+
return AllocStatus.OK
471+
else:
472+
return AllocStatus.LATER
458473

459474
def swap_in(self,
460475
seq_group: SequenceGroup,

vllm/core/block_manager_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
238238
self.block_tables[child_seq.seq_id] = src_block_table.fork()
239239

240240
def can_swap_in(self, seq_group: SequenceGroup,
241-
num_lookahead_slots: int) -> bool:
242-
return False
241+
num_lookahead_slots: int) -> AllocStatus:
242+
return AllocStatus.LATER
243243

244244
def swap_in(self, seq_group: SequenceGroup,
245245
num_lookahead_slots: int) -> Dict[int, int]:

0 commit comments

Comments
 (0)