@@ -362,7 +362,7 @@ def test_cannot_schedule_after_recv():
362
362
BLOCK_SIZE = vllm_config .cache_config .block_size
363
363
# Prompt will use 2 blocks + 1 block after we schedule.
364
364
NUM_TOKENS_LOCAL = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
365
- NUM_TOKENS_REMOTE = int (BLOCK_SIZE * ( NUM_PROMPT_BLOCKS + 0.5 ) )
365
+ NUM_TOKENS_REMOTE = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
366
366
367
367
request_normal = create_request (request_id = 1 , num_tokens = NUM_TOKENS_LOCAL )
368
368
request_remote = create_request (request_id = 2 ,
@@ -393,30 +393,124 @@ def test_cannot_schedule_after_recv():
393
393
assert len (scheduler .running ) == 1
394
394
assert len (scheduler .waiting ) == 1
395
395
396
- # Step 4: try to schedule, not enough blocks.
396
+ # Step 4: try to schedule, remote request is put to running list
397
+ # because the transfer is completed.
398
+ scheduler_output = scheduler .schedule ()
399
+ model_runner_output = create_model_runner_output (
400
+ reqs = [request_normal , request_remote ])
401
+ scheduler .update_from_output (scheduler_output , model_runner_output )
402
+ assert len (scheduler .running ) == 2
403
+ assert len (scheduler .waiting ) == 0
404
+
405
+ # Step 5: Remote request will be put back to waiting list
406
+ # because it needs new block to hold generated token.
397
407
scheduler_output = scheduler .schedule ()
398
408
model_runner_output = create_model_runner_output (reqs = [request_normal ])
399
409
scheduler .update_from_output (scheduler_output , model_runner_output )
400
410
assert len (scheduler .running ) == 1
401
411
assert len (scheduler .waiting ) == 1
402
412
403
- # Step 5 : finish the request, free it.
413
+ # Step 6 : finish the request, free it.
404
414
scheduler_output = scheduler .schedule ()
405
415
model_runner_output = create_model_runner_output (reqs = [request_normal ],
406
416
use_eos = True )
407
417
scheduler .update_from_output (scheduler_output , model_runner_output )
408
418
assert len (scheduler .running ) == 0
409
419
assert len (scheduler .waiting ) == 1
410
420
411
- # Step 6: now we can schedule (with 2 blocks computed).
421
+ # Step 7: now we can schedule (with 2 blocks computed),
422
+ # request is retrieved from preempted list.
412
423
scheduler_output = scheduler .schedule ()
413
424
model_runner_output = create_model_runner_output (reqs = [request_remote ])
414
- assert (scheduler_output .scheduled_new_reqs [0 ]. num_computed_tokens ==
425
+ assert (scheduler_output .scheduled_cached_reqs . num_computed_tokens [0 ] ==
415
426
NUM_PROMPT_BLOCKS * BLOCK_SIZE )
416
427
scheduler .update_from_output (scheduler_output , model_runner_output )
417
428
assert len (scheduler .running ) == 1
418
429
assert len (scheduler .waiting ) == 0
419
430
431
+ # Step 8: free everything.
432
+ scheduler_output = scheduler .schedule ()
433
+ model_runner_output = create_model_runner_output (reqs = [request_remote ],
434
+ use_eos = True )
435
+ scheduler .update_from_output (scheduler_output , model_runner_output )
436
+ _ = scheduler .schedule ()
437
+ assert_scheduler_empty (scheduler )
438
+
439
+
440
+ def test_cannot_recv ():
441
+ """
442
+ Test that we can handle no schedule KV block transfer due to not
443
+ enough remaining KV blocks.
444
+ """
445
+
446
+ # NOTE: the KVCacheManager will use 1 null block.
447
+ # So there are 5 total working blocks.
448
+ TOTAL_NUM_BLOCKS = 6
449
+ vllm_config = create_vllm_config ()
450
+ scheduler = create_scheduler (vllm_config , num_blocks = TOTAL_NUM_BLOCKS )
451
+
452
+ # Prime the KVCache.
453
+ NUM_PROMPT_BLOCKS = 2
454
+ BLOCK_SIZE = vllm_config .cache_config .block_size
455
+ # Prompt will use 2 blocks + 1 block after we schedule.
456
+ NUM_TOKENS_LOCAL = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
457
+ NUM_TOKENS_REMOTE = int (BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5 ))
458
+
459
+ request_normal = create_request (request_id = 1 , num_tokens = NUM_TOKENS_LOCAL )
460
+ request_remote = create_request (request_id = 2 ,
461
+ num_tokens = NUM_TOKENS_REMOTE ,
462
+ do_remote_prefill = True )
463
+
464
+ # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
465
+ scheduler .add_request (request_normal )
466
+ scheduler_output = scheduler .schedule ()
467
+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
468
+ scheduler .update_from_output (scheduler_output , model_runner_output )
469
+ assert len (scheduler .running ) == 1
470
+ assert len (scheduler .waiting ) == 0
471
+
472
+ # Step 2: 3 blocks are in use,
473
+ # need 3 new for remote blocks but only 2 are available.
474
+ scheduler .add_request (request_remote )
475
+ scheduler_output = scheduler .schedule ()
476
+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
477
+ scheduler .update_from_output (scheduler_output , model_runner_output )
478
+ assert len (scheduler .running ) == 1
479
+ assert len (scheduler .waiting ) == 1
480
+ # Should not have KV transfer in progress.
481
+ assert (request_remote .status != RequestStatus .WAITING_FOR_REMOTE_KVS )
482
+
483
+ # Step 3: finish the request, free it.
484
+ scheduler_output = scheduler .schedule ()
485
+ model_runner_output = create_model_runner_output (reqs = [request_normal ],
486
+ use_eos = True )
487
+ scheduler .update_from_output (scheduler_output , model_runner_output )
488
+ assert len (scheduler .running ) == 0
489
+ assert len (scheduler .waiting ) == 1
490
+
491
+ # Step 4: now we can initiate KV transfer (with 2 blocks computed).
492
+ scheduler_output = scheduler .schedule ()
493
+ model_runner_output = create_model_runner_output (reqs = [])
494
+ scheduler .update_from_output (scheduler_output , model_runner_output )
495
+ assert len (scheduler .running ) == 0
496
+ assert len (scheduler .waiting ) == 1
497
+ assert (request_remote .status == RequestStatus .WAITING_FOR_REMOTE_KVS )
498
+
499
+ # Step 5: finish recving (5 blocks in use)
500
+ scheduler_output = scheduler .schedule ()
501
+ model_runner_output = create_model_runner_output (
502
+ reqs = [], finished_recving = [request_remote .request_id ])
503
+ scheduler .update_from_output (scheduler_output , model_runner_output )
504
+ assert len (scheduler .running ) == 0
505
+ assert len (scheduler .waiting ) == 1
506
+
507
+ # Step 6: schedule remote request
508
+ scheduler_output = scheduler .schedule ()
509
+ model_runner_output = create_model_runner_output (reqs = [request_remote ])
510
+ scheduler .update_from_output (scheduler_output , model_runner_output )
511
+ assert len (scheduler .running ) == 1
512
+ assert len (scheduler .waiting ) == 0
513
+
420
514
# Step 7: free everything.
421
515
scheduler_output = scheduler .schedule ()
422
516
model_runner_output = create_model_runner_output (reqs = [request_remote ],
0 commit comments