@@ -226,7 +226,7 @@ def test_get_num_unfinished_requests(self):
226
226
len (requests ) - i - 1 )
227
227
228
228
def test_schedule (self ):
229
- '''Test scheduling.
229
+ '''Test scheduling.
230
230
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
231
231
'''
232
232
scheduler = self .create_scheduler ()
@@ -251,6 +251,27 @@ def test_schedule(self):
251
251
for i , request in enumerate (requests ):
252
252
self .assertEqual (scheduler .running [i ], request )
253
253
254
+ def test_concurrent_partial_prefills_schedule (self ):
255
+ '''Test concurrent partial prefills scheduling.
256
+ total requests = 10, every request has 10 token.
257
+ while set long_prefill_token_threshold = 1, scheduler can
258
+ only schedule max_long_partial_prefills long request.
259
+ '''
260
+ scheduler = self .create_scheduler ()
261
+ scheduler .scheduler_config .chunked_prefill_enabled = False
262
+ scheduler .scheduler_config .max_long_partial_prefills = 2
263
+ scheduler .scheduler_config .long_prefill_token_threshold = 1
264
+ requests = create_requests (num_requests = 10 , num_tokens = 20 )
265
+ for request in requests :
266
+ scheduler .add_request (request )
267
+
268
+ # Test initial scheduling
269
+ output = scheduler .schedule ()
270
+ self .assertEqual (len (output .scheduled_new_reqs ),
271
+ scheduler .scheduler_config .max_long_partial_prefills )
272
+ self .assertEqual (output .scheduled_cached_reqs .num_reqs , 0 )
273
+ self .assertEqual (len (output .finished_req_ids ), 0 )
274
+
254
275
def test_schedule_enable_prefix_caching (self ):
255
276
'''Test scheduling.
256
277
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
0 commit comments